rantz_random/
weighted_table.rs

1use crate::random_traits::RandomWeightedContainer;
2use std::{marker::PhantomData, slice::IterMut};
3
4/// `WeightedTable`
5///
6/// This is a simple implementation of a weighted table. It is designed to be used with a random
7/// number generator. It can be used as follows:
8///
9/// ```rust
10/// use rantz_random::prelude::{WeightedTable, RandomContainer};
11///
12/// let mut table = WeightedTable::new();
13///  
14/// table.insert("Bob", 10); // Bob has a weight of 10
15/// table.insert("Alice", 20); // Alice has a weight of 20
16/// table.remove(&"Bob"); // Bob is removed
17///
18/// table.random(); // Returns a random element from the table based on the weights
19/// ```
20/// The table is iterable and returns owned tuples of (value, weight)
21/// For references, use the [iter](WeightedTable::iter) method
22/// For mutable references, use the `[iter_mut]`(`WeightedTable::iter_mut`) method
23///
24/// ```rust
25/// use rantz_random::prelude::WeightedTable;
26///
27/// let mut table = WeightedTable::new();
28///
29/// table.insert("Bob".to_string(), 10); // Bob has a weight of 10
30/// table.insert("Alice".to_string(), 20); // Alice has a weight of 20
31///
32///
33/// for (value, weight) in table.iter() {
34///   println!("Reference {} has a weight of {}", value, weight);
35/// }
36///
37/// for (value, weight) in table.iter_mut() {
38///   *value += " Test";
39///   *weight += 10;
40///   println!("Reference {} has a weight of {}", value, weight);
41/// }
42///
43/// for value in table { // Must be done last as consumes the table
44///   println!("Owned {}", value);
45/// }
46/// ```
47///
48/// Adding an element that already exists will update the weight of the existing element.
49#[derive(Clone, Debug)]
50pub struct WeightedTable<T>
51where
52    T: PartialEq + Clone,
53{
54    pub(crate) weights: Vec<u32>,
55    pub(crate) total_weight: u32,
56    pub(crate) values: Vec<T>,
57}
58
59pub type WeightedItem<T> = (T, u32);
60pub type WeightedItemRef<'a, T> = (&'a T, &'a u32);
61pub type WeightedItemRefMut<'a, T> = (&'a mut T, &'a mut u32);
62
63impl<T> Default for WeightedTable<T>
64where
65    T: PartialEq + Clone,
66{
67    fn default() -> Self {
68        Self {
69            weights: Vec::<u32>::new(),
70            total_weight: 0,
71            values: Vec::<T>::new(),
72        }
73    }
74}
75
76impl<T> WeightedTable<T>
77where
78    T: PartialEq + Clone,
79{
80    /// Creates a new empty `WeightedTable`
81    pub fn new() -> Self {
82        Default::default()
83    }
84
85    /// Creates a new `WeightedTable` from a vector of tuples
86    /// of (value, weight)
87    pub fn from_vec(vec: Vec<(T, u32)>) -> Self {
88        let mut table = Self::new();
89        for (value, weight) in vec {
90            table.insert(value, weight);
91        }
92        table
93    }
94
95    /// Inserts a new element into the table
96    pub fn insert(&mut self, value: T, weight: u32) {
97        if let Some(index) = self.values.iter().position(|v| v == &value) {
98            let old_weight = self.weights[index];
99            self.weights[index] = weight;
100            self.total_weight += weight;
101            self.total_weight -= old_weight;
102            return;
103        }
104
105        self.weights.push(weight);
106        self.total_weight += weight;
107        self.values.push(value);
108    }
109
110    /// Removes an element from the table
111    pub fn remove(&mut self, value: &T) {
112        if let Some(index) = self.values.iter().position(|v| v == value) {
113            self.total_weight -= self.weights[index];
114            self.weights.remove(index);
115            self.values.remove(index);
116        }
117    }
118
119    /// Clears the table
120    pub fn clear(&mut self) {
121        self.weights.clear();
122        self.total_weight = 0;
123        self.values.clear();
124    }
125
126    /// Returns the entry at the specified index
127    pub fn get_entry(&self, index: usize) -> Option<WeightedItem<T>> {
128        if index < self.values.len() {
129            Some((self.values[index].clone(), self.weights[index]))
130        } else {
131            None
132        }
133    }
134
135    /// Returns the entry at the specified index as a reference
136    pub fn get_entry_ref(&self, index: usize) -> Option<WeightedItemRef<T>> {
137        if index < self.values.len() {
138            Some((&self.values[index], &self.weights[index]))
139        } else {
140            None
141        }
142    }
143
144    /// Returns the entry at the specified index as a mutable reference
145    pub fn get_entry_mut(&mut self, index: usize) -> Option<WeightedItemRefMut<T>> {
146        if index < self.values.len() {
147            Some((&mut self.values[index], &mut self.weights[index]))
148        } else {
149            None
150        }
151    }
152
153    /// Returns the weight of the specified value
154    pub fn get_weight(&self, value: &T) -> Option<u32> {
155        self.values
156            .iter()
157            .position(|v| v == value)
158            .map(|i| self.weights[i])
159    }
160
161    /// Returns the weight of the specified value as a mutable reference
162    pub fn get_weight_mut(&mut self, value: &T) -> Option<&mut u32> {
163        if let Some(index) = self.values.iter().position(|v| v == value) {
164            Some(&mut self.weights[index])
165        } else {
166            None
167        }
168    }
169
170    /// Returns a random element with the specified weight
171    pub fn random_with(&self, n: u32) -> WeightedItem<T> {
172        let mut n = n;
173        for (i, weight) in self.weights.iter().enumerate() {
174            if &n <= weight {
175                return self.get_entry(i).unwrap();
176            }
177            n -= weight;
178        }
179        unreachable!();
180    }
181
182    /// Returns an iterator over the table
183    pub fn iter(&self) -> impl Iterator<Item = WeightedItemRef<T>> {
184        WeightedTableIter {
185            table: self,
186            index: 0,
187            size: self.values.len(),
188        }
189    }
190
191    /// Returns an iterator over the table mutably
192    pub fn iter_mut(&mut self) -> impl Iterator<Item = WeightedItemRefMut<T>> {
193        WeightedTableIterMut {
194            value_iter: self.values.iter_mut(),
195            weight_iter: self.weights.iter_mut(),
196            marker: PhantomData,
197        }
198    }
199
200    /// Combines two tables (in place)
201    pub fn combine(&mut self, other: Self) {
202        self.total_weight += other.total_weight;
203        for (v, w) in other.iter() {
204            if let Some(index) = self.values.iter().position(|x| x == v) {
205                self.weights[index] += w;
206            } else {
207                self.weights.push(*w);
208                self.values.push(v.clone());
209            }
210        }
211    }
212}
213
214impl<T> IntoIterator for WeightedTable<T>
215where
216    T: PartialEq + Clone,
217{
218    type Item = T;
219    type IntoIter = WeightedTableTupleIntoIter<T>;
220
221    fn into_iter(self) -> Self::IntoIter {
222        let size = self.values.len();
223        WeightedTableTupleIntoIter {
224            table: self,
225            index: 0,
226            size,
227        }
228    }
229}
230
231pub struct WeightedTableTupleIntoIter<T>
232where
233    T: PartialEq + Clone,
234{
235    table: WeightedTable<T>,
236    index: usize,
237    size: usize,
238}
239
240impl<T> Iterator for WeightedTableTupleIntoIter<T>
241where
242    T: PartialEq + Clone,
243{
244    type Item = T;
245
246    fn next(&mut self) -> Option<Self::Item> {
247        if self.index < self.size {
248            let value = self.table.values[self.index].clone();
249            self.index += 1;
250            Some(value)
251        } else {
252            None
253        }
254    }
255}
256
257pub struct WeightedTableIter<'a, T>
258where
259    T: PartialEq + Clone,
260{
261    table: &'a WeightedTable<T>,
262    index: usize,
263    size: usize,
264}
265
266impl<'a, T> Iterator for WeightedTableIter<'a, T>
267where
268    T: PartialEq + Clone,
269{
270    type Item = WeightedItemRef<'a, T>;
271
272    fn next(&mut self) -> Option<Self::Item> {
273        if self.index < self.size {
274            let value = &self.table.values[self.index];
275            let weight = &self.table.weights[self.index];
276            self.index += 1;
277            Some((value, weight))
278        } else {
279            None
280        }
281    }
282}
283
284pub struct WeightedTableIterMut<'a, T>
285where
286    T: PartialEq + Clone,
287{
288    value_iter: IterMut<'a, T>,
289    weight_iter: IterMut<'a, u32>,
290    marker: PhantomData<&'a mut T>,
291}
292
293impl<'a, T> Iterator for WeightedTableIterMut<'a, T>
294where
295    T: PartialEq + Clone,
296{
297    type Item = WeightedItemRefMut<'a, T>;
298
299    fn next(&mut self) -> Option<Self::Item> {
300        if let (Some(value), Some(weight)) = (self.value_iter.next(), self.weight_iter.next()) {
301            Some((value, weight))
302        } else {
303            None
304        }
305    }
306}
307
308impl<T> FromIterator<(T, u32)> for WeightedTable<T>
309where
310    T: PartialEq + Clone,
311{
312    fn from_iter<I: IntoIterator<Item = (T, u32)>>(iter: I) -> Self {
313        let mut table = WeightedTable::new();
314        for (value, weight) in iter {
315            table.insert(value, weight);
316        }
317        table
318    }
319}
320
321impl<'a, T> FromIterator<(T, &'a u32)> for WeightedTable<T>
322where
323    T: PartialEq + Clone,
324{
325    fn from_iter<I: IntoIterator<Item = (T, &'a u32)>>(iter: I) -> Self {
326        let mut table = WeightedTable::new();
327        for (value, weight) in iter {
328            table.insert(value, *weight);
329        }
330        table
331    }
332}
333
334impl<T> RandomWeightedContainer<T> for WeightedTable<T>
335where
336    T: Clone + PartialEq,
337{
338    fn max_weight(&self) -> u32 {
339        self.total_weight
340    }
341
342    fn weights(&self) -> &Vec<u32> {
343        &self.weights
344    }
345
346    fn values(&self) -> &Vec<T> {
347        &self.values
348    }
349}