use std::{collections::{BTreeMap, HashMap}, fmt::Debug};
use ordered_float::OrderedFloat;
use std::hash::Hash;
pub struct ProbabilityContainer<T> {
probability_total: f32,
items_total: u32,
probability_per_item: HashMap<T, f32>,
items: Vec<T>,
item_index_per_cumulative_probability: BTreeMap<OrderedFloat<f32>, usize>,
last_item_index_to_apply_to_item_index_per_cumulative_probability: usize,
last_cumulative_probability: f32
}
impl<T: Ord + Eq + Hash + Clone + Debug> ProbabilityContainer<T> {
pub fn default() -> Self {
let probability_per_item: HashMap<T, f32> = HashMap::new();
ProbabilityContainer {
probability_total: 0.0,
items_total: 0,
probability_per_item,
items: Vec::new(),
item_index_per_cumulative_probability: BTreeMap::new(),
last_item_index_to_apply_to_item_index_per_cumulative_probability: 0,
last_cumulative_probability: 0.0
}
}
#[allow(dead_code)]
pub fn new(probability_per_item: HashMap<T, f32>) -> Self {
let mut probability_total = 0.0;
let mut items_total: u32 = 0;
let mut items: Vec<T> = probability_per_item.keys().cloned().collect::<Vec<T>>();
items.sort();
for item in items.iter() {
let probability = &probability_per_item[item];
if probability != &0.0 {
probability_total += probability;
items_total += 1;
}
}
ProbabilityContainer {
probability_total,
items_total,
probability_per_item,
items,
item_index_per_cumulative_probability: BTreeMap::new(),
last_item_index_to_apply_to_item_index_per_cumulative_probability: 0,
last_cumulative_probability: 0.0
}
}
pub fn push(&mut self, item: T, probability: f32) {
self.probability_total += probability;
self.items_total += 1;
self.probability_per_item.insert(item.clone(), probability);
self.items.push(item);
}
#[allow(dead_code)]
pub fn peek_random(&mut self, random_instance: &mut fastrand::Rng) -> Option<T> {
let item_option: Option<T>;
if self.items_total == 0 {
item_option = None;
}
else if self.items_total == 1 {
item_option = Some(self.items.first().unwrap().clone());
}
else {
let random_value = random_instance.f32() * self.probability_total;
if random_value > self.last_cumulative_probability {
let mut current_item: Option<&T> = None;
while random_value > self.last_cumulative_probability {
current_item = Some(self.items.get(self.last_item_index_to_apply_to_item_index_per_cumulative_probability).unwrap());
let item_probability = self.probability_per_item.get(current_item.unwrap()).unwrap();
if item_probability != &0.0 {
self.last_cumulative_probability += item_probability;
self.item_index_per_cumulative_probability.insert(OrderedFloat(self.last_cumulative_probability), self.last_item_index_to_apply_to_item_index_per_cumulative_probability);
}
self.last_item_index_to_apply_to_item_index_per_cumulative_probability += 1;
}
let current_item = current_item.unwrap().clone();
item_option = Some(current_item.clone());
}
else {
let (_temp_key, temp_value) = self.item_index_per_cumulative_probability.range(OrderedFloat(random_value)..).next().unwrap();
item_option = Some(self.items.get(*temp_value).unwrap().clone());
}
}
item_option
}
pub fn pop_random(&mut self, random_instance: &mut fastrand::Rng) -> Option<T> {
if self.items_total == 0 {
None
}
else {
let item_option: Option<T>;
if self.items_total == 1 {
item_option = self.items.first().cloned();
self.items.clear();
self.items_total = 0;
self.probability_total = 0.0;
self.item_index_per_cumulative_probability.clear();
self.last_item_index_to_apply_to_item_index_per_cumulative_probability = 0;
self.last_cumulative_probability = 0.0;
self.probability_per_item.clear();
}
else {
let random_value = random_instance.f32() * self.probability_total;
let mut is_item_outside_random_value: bool = if self.last_item_index_to_apply_to_item_index_per_cumulative_probability as u32 == self.items_total {
false
}
else if random_value == 0.0 && self.last_item_index_to_apply_to_item_index_per_cumulative_probability == 0 {
true
}
else {
random_value > self.last_cumulative_probability
};
if is_item_outside_random_value {
let mut current_item: &T;
while is_item_outside_random_value {
current_item = self.items.get(self.last_item_index_to_apply_to_item_index_per_cumulative_probability).unwrap();
let item_probability = self.probability_per_item.get(current_item).unwrap();
if item_probability != &0.0 {
if self.last_cumulative_probability + item_probability >= random_value {
if (self.last_item_index_to_apply_to_item_index_per_cumulative_probability as u32) + 1 == self.items_total {
self.probability_total = self.last_cumulative_probability + item_probability;
}
break;
}
else {
self.last_cumulative_probability += item_probability;
self.item_index_per_cumulative_probability.insert(OrderedFloat(self.last_cumulative_probability), self.last_item_index_to_apply_to_item_index_per_cumulative_probability);
}
}
self.last_item_index_to_apply_to_item_index_per_cumulative_probability += 1;
if (self.last_item_index_to_apply_to_item_index_per_cumulative_probability as u32) == self.items_total {
self.probability_total = self.last_cumulative_probability;
self.last_item_index_to_apply_to_item_index_per_cumulative_probability -= 1;
break;
}
is_item_outside_random_value = random_value > self.last_cumulative_probability;
}
let item = self.items.remove(self.last_item_index_to_apply_to_item_index_per_cumulative_probability);
self.probability_total -= self.probability_per_item.remove(&item).unwrap();
item_option = Some(item);
self.items_total -= 1;
}
else {
let found_key: f32;
let found_index: usize;
let found_item: T;
{
let (temp_key, temp_value) = self.item_index_per_cumulative_probability.range(OrderedFloat(random_value)..).next().unwrap();
found_item = self.items.remove(*temp_value);
self.items_total -= 1;
item_option = Some(found_item.clone());
found_key = temp_key.0;
found_index = *temp_value;
}
let found_key_ordered_float = &OrderedFloat(found_key);
self.item_index_per_cumulative_probability.retain(|probability, _| probability < found_key_ordered_float);
self.last_item_index_to_apply_to_item_index_per_cumulative_probability = found_index;
let found_item_probability = self.probability_per_item.remove(&found_item).unwrap();
self.last_cumulative_probability = found_key - found_item_probability;
if (self.last_item_index_to_apply_to_item_index_per_cumulative_probability as u32) == self.items_total {
self.probability_total = self.last_cumulative_probability;
}
else {
self.probability_total -= found_item_probability;
}
}
if item_option.is_none() {
panic!("Failed to find item even though some exists.");
}
}
item_option
}
}
}