rantz_random/
weighted_table.rs1use crate::random_traits::RandomWeightedContainer;
2use std::{marker::PhantomData, slice::IterMut};
3
4#[derive(Clone, Debug)]
50pub struct WeightedTable<T>
51where
52 T: PartialEq + Clone,
53{
54 pub(crate) weights: Vec<u32>,
55 pub(crate) total_weight: u32,
56 pub(crate) values: Vec<T>,
57}
58
59pub type WeightedItem<T> = (T, u32);
60pub type WeightedItemRef<'a, T> = (&'a T, &'a u32);
61pub type WeightedItemRefMut<'a, T> = (&'a mut T, &'a mut u32);
62
63impl<T> Default for WeightedTable<T>
64where
65 T: PartialEq + Clone,
66{
67 fn default() -> Self {
68 Self {
69 weights: Vec::<u32>::new(),
70 total_weight: 0,
71 values: Vec::<T>::new(),
72 }
73 }
74}
75
76impl<T> WeightedTable<T>
77where
78 T: PartialEq + Clone,
79{
80 pub fn new() -> Self {
82 Default::default()
83 }
84
85 pub fn from_vec(vec: Vec<(T, u32)>) -> Self {
88 let mut table = Self::new();
89 for (value, weight) in vec {
90 table.insert(value, weight);
91 }
92 table
93 }
94
95 pub fn insert(&mut self, value: T, weight: u32) {
97 if let Some(index) = self.values.iter().position(|v| v == &value) {
98 let old_weight = self.weights[index];
99 self.weights[index] = weight;
100 self.total_weight += weight;
101 self.total_weight -= old_weight;
102 return;
103 }
104
105 self.weights.push(weight);
106 self.total_weight += weight;
107 self.values.push(value);
108 }
109
110 pub fn remove(&mut self, value: &T) {
112 if let Some(index) = self.values.iter().position(|v| v == value) {
113 self.total_weight -= self.weights[index];
114 self.weights.remove(index);
115 self.values.remove(index);
116 }
117 }
118
119 pub fn clear(&mut self) {
121 self.weights.clear();
122 self.total_weight = 0;
123 self.values.clear();
124 }
125
126 pub fn get_entry(&self, index: usize) -> Option<WeightedItem<T>> {
128 if index < self.values.len() {
129 Some((self.values[index].clone(), self.weights[index]))
130 } else {
131 None
132 }
133 }
134
135 pub fn get_entry_ref(&self, index: usize) -> Option<WeightedItemRef<T>> {
137 if index < self.values.len() {
138 Some((&self.values[index], &self.weights[index]))
139 } else {
140 None
141 }
142 }
143
144 pub fn get_entry_mut(&mut self, index: usize) -> Option<WeightedItemRefMut<T>> {
146 if index < self.values.len() {
147 Some((&mut self.values[index], &mut self.weights[index]))
148 } else {
149 None
150 }
151 }
152
153 pub fn get_weight(&self, value: &T) -> Option<u32> {
155 self.values
156 .iter()
157 .position(|v| v == value)
158 .map(|i| self.weights[i])
159 }
160
161 pub fn get_weight_mut(&mut self, value: &T) -> Option<&mut u32> {
163 if let Some(index) = self.values.iter().position(|v| v == value) {
164 Some(&mut self.weights[index])
165 } else {
166 None
167 }
168 }
169
170 pub fn random_with(&self, n: u32) -> WeightedItem<T> {
172 let mut n = n;
173 for (i, weight) in self.weights.iter().enumerate() {
174 if &n <= weight {
175 return self.get_entry(i).unwrap();
176 }
177 n -= weight;
178 }
179 unreachable!();
180 }
181
182 pub fn iter(&self) -> impl Iterator<Item = WeightedItemRef<T>> {
184 WeightedTableIter {
185 table: self,
186 index: 0,
187 size: self.values.len(),
188 }
189 }
190
191 pub fn iter_mut(&mut self) -> impl Iterator<Item = WeightedItemRefMut<T>> {
193 WeightedTableIterMut {
194 value_iter: self.values.iter_mut(),
195 weight_iter: self.weights.iter_mut(),
196 marker: PhantomData,
197 }
198 }
199
200 pub fn combine(&mut self, other: Self) {
202 self.total_weight += other.total_weight;
203 for (v, w) in other.iter() {
204 if let Some(index) = self.values.iter().position(|x| x == v) {
205 self.weights[index] += w;
206 } else {
207 self.weights.push(*w);
208 self.values.push(v.clone());
209 }
210 }
211 }
212}
213
214impl<T> IntoIterator for WeightedTable<T>
215where
216 T: PartialEq + Clone,
217{
218 type Item = T;
219 type IntoIter = WeightedTableTupleIntoIter<T>;
220
221 fn into_iter(self) -> Self::IntoIter {
222 let size = self.values.len();
223 WeightedTableTupleIntoIter {
224 table: self,
225 index: 0,
226 size,
227 }
228 }
229}
230
231pub struct WeightedTableTupleIntoIter<T>
232where
233 T: PartialEq + Clone,
234{
235 table: WeightedTable<T>,
236 index: usize,
237 size: usize,
238}
239
240impl<T> Iterator for WeightedTableTupleIntoIter<T>
241where
242 T: PartialEq + Clone,
243{
244 type Item = T;
245
246 fn next(&mut self) -> Option<Self::Item> {
247 if self.index < self.size {
248 let value = self.table.values[self.index].clone();
249 self.index += 1;
250 Some(value)
251 } else {
252 None
253 }
254 }
255}
256
257pub struct WeightedTableIter<'a, T>
258where
259 T: PartialEq + Clone,
260{
261 table: &'a WeightedTable<T>,
262 index: usize,
263 size: usize,
264}
265
266impl<'a, T> Iterator for WeightedTableIter<'a, T>
267where
268 T: PartialEq + Clone,
269{
270 type Item = WeightedItemRef<'a, T>;
271
272 fn next(&mut self) -> Option<Self::Item> {
273 if self.index < self.size {
274 let value = &self.table.values[self.index];
275 let weight = &self.table.weights[self.index];
276 self.index += 1;
277 Some((value, weight))
278 } else {
279 None
280 }
281 }
282}
283
284pub struct WeightedTableIterMut<'a, T>
285where
286 T: PartialEq + Clone,
287{
288 value_iter: IterMut<'a, T>,
289 weight_iter: IterMut<'a, u32>,
290 marker: PhantomData<&'a mut T>,
291}
292
293impl<'a, T> Iterator for WeightedTableIterMut<'a, T>
294where
295 T: PartialEq + Clone,
296{
297 type Item = WeightedItemRefMut<'a, T>;
298
299 fn next(&mut self) -> Option<Self::Item> {
300 if let (Some(value), Some(weight)) = (self.value_iter.next(), self.weight_iter.next()) {
301 Some((value, weight))
302 } else {
303 None
304 }
305 }
306}
307
308impl<T> FromIterator<(T, u32)> for WeightedTable<T>
309where
310 T: PartialEq + Clone,
311{
312 fn from_iter<I: IntoIterator<Item = (T, u32)>>(iter: I) -> Self {
313 let mut table = WeightedTable::new();
314 for (value, weight) in iter {
315 table.insert(value, weight);
316 }
317 table
318 }
319}
320
321impl<'a, T> FromIterator<(T, &'a u32)> for WeightedTable<T>
322where
323 T: PartialEq + Clone,
324{
325 fn from_iter<I: IntoIterator<Item = (T, &'a u32)>>(iter: I) -> Self {
326 let mut table = WeightedTable::new();
327 for (value, weight) in iter {
328 table.insert(value, *weight);
329 }
330 table
331 }
332}
333
334impl<T> RandomWeightedContainer<T> for WeightedTable<T>
335where
336 T: Clone + PartialEq,
337{
338 fn max_weight(&self) -> u32 {
339 self.total_weight
340 }
341
342 fn weights(&self) -> &Vec<u32> {
343 &self.weights
344 }
345
346 fn values(&self) -> &Vec<T> {
347 &self.values
348 }
349}