#![deny(unused_must_use)]
#![deny(missing_docs)]
use crate::errors::NoValuesError;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Item {
text: String,
weight: f64,
}
impl Item {
pub fn get_text(&self) -> &String {
&self.text
}
pub fn get_weight(&self) -> f64 {
self.weight
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LookUpTable {
items: Vec<Item>,
total: f64,
equal_weights: bool,
}
impl LookUpTable {
pub fn new() -> Self {
LookUpTable { items: Vec::new(), total: 0., equal_weights: true }
}
pub fn draw_random(&self, rng: &mut impl RngExt) -> Result<Item, NoValuesError> {
if self.items.len() == 0 {
return Err(NoValuesError {});
}
if self.equal_weights {
let i = rng.random_range(0..self.items.len());
Ok(self.items[i].clone())
} else {
let mut draw = self.total * rng.random_range(0f64..1f64);
for item in &self.items {
if draw <= item.weight {
return Ok(item.clone());
}
draw -= item.weight;
}
assert!(
false,
"Logic violation. Output of random number generator exceeded range of 0-1"
);
return Ok(self.items.last().unwrap().clone());
}
}
pub fn draw_n_random(
&self,
rng: &mut impl RngExt,
count: usize,
) -> Result<Vec<Item>, NoValuesError> {
let mut result: Vec<Item> = Vec::with_capacity(count);
for _ in 0..count {
result.push(self.draw_random(rng)?);
}
Ok(result)
}
pub fn shuffle(&self, rng: &mut impl RngExt) -> Result<Vec<Item>, NoValuesError> {
if self.items.len() == 0 {
return Err(NoValuesError {});
}
let mut copy = self.items.clone();
for i in copy.len() - 1..1 {
let j = rng.random_range(0..=i);
copy.swap(j, i);
}
Ok(copy)
}
pub fn shuffle_draw(&self, rng: &mut impl RngExt, count: usize) -> Result<Vec<Item>, NoValuesError> {
if self.items.len() == 0 {
return Err(NoValuesError {});
}
let s = self.items.len();
let mut buffer: Vec<Item> = Vec::with_capacity(s * (1 + (count % s)));
while buffer.len() < count {
buffer.extend(self.shuffle(rng)?);
}
buffer.truncate(count);
Ok(buffer)
}
pub fn add(&mut self, item: Item) {
if item.weight >= 0. {
let w = item.weight;
if self.items.len() > 0 {
self.equal_weights = self.equal_weights && self.items.last().unwrap().weight == w;
}
self.total += w;
self.items.push(item);
} else {
panic!("Invalid state: item weight must be a positive real number");
}
}
pub fn add_item<T>(&mut self, text: T, weight: f64)
where
T: Into<String>,
{
self.add(Item { text: text.into(), weight })
}
pub fn remove_item<T>(&mut self, text: T) -> bool
where
T: Into<String>,
{
let text = text.into();
let mut removed = false;
let mut i = self.items.len();
while i > 0 {
i -= 1;
if &self.items[i].text == &text {
removed = true;
self.items.remove(i);
}
}
self.recount();
removed
}
fn recount(&mut self) {
let mut sum = 0f64;
for item in &self.items {
sum += item.weight;
}
self.total = sum;
}
}
#[cfg(test)]
mod unit_tests {
use crate::data::{Item, LookUpTable};
#[test]
fn weight_check() {
let w = 0.5f64;
let text = "test";
let i = Item { text: String::from(text), weight: w };
assert_eq!(i.get_weight(), w);
let mut lut = LookUpTable::new();
assert_eq!(lut.total, 0f64);
lut.add(i);
assert_eq!(lut.total, w);
lut.add_item("test2", w);
assert_eq!(lut.total, (w + w));
assert!(lut.remove_item(text));
assert!(!lut.remove_item(text));
assert_eq!(lut.total, w);
}
}