capybara_util/
weighted.rs

1pub struct WeightedResource<T>(Vec<(u64, T)>);
2
3impl<T> WeightedResource<T> {
4    pub fn builder() -> WeightedResourceBuilder<T> {
5        WeightedResourceBuilder { inner: vec![] }
6    }
7
8    pub fn next(&self) -> Option<&T> {
9        match self.0.last() {
10            None => None,
11            Some((max, _)) => {
12                let seed = {
13                    use rand::prelude::*;
14
15                    let mut rng = thread_rng();
16                    rng.gen_range(0..*max)
17                };
18
19                let idx = match self.0.binary_search_by_key(&seed, |&(weight, _)| weight) {
20                    Ok(i) => i + 1,
21                    Err(i) => i,
22                };
23
24                self.0.get(idx).map(|(_, u)| u)
25            }
26        }
27    }
28}
29
30pub struct WeightedResourceBuilder<T> {
31    inner: Vec<(u64, T)>,
32}
33
34impl<T> WeightedResourceBuilder<T> {
35    pub fn push(self, weight: u32, resource: T) -> Self {
36        let Self { mut inner } = self;
37
38        if weight != 0 {
39            let w = weight as u64 + inner.last().map(|(w, _)| *w).unwrap_or_default();
40            inner.push((w, resource));
41        }
42
43        Self { inner }
44    }
45
46    pub fn build(self) -> WeightedResource<T> {
47        WeightedResource(self.inner)
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use std::collections::HashMap;
54
55    use super::*;
56
57    fn init() {
58        pretty_env_logger::try_init_timed().ok();
59    }
60
61    #[test]
62    fn test_weighted_resource() {
63        init();
64
65        let wr = WeightedResource::<&str>::builder()
66            .push(2, "foo")
67            .push(3, "bar")
68            .push(5, "qux")
69            .build();
70
71        let mut cnts: HashMap<&str, usize> = Default::default();
72        const N: usize = 1000000;
73
74        (0..N).for_each(|_| {
75            let k = wr.next().unwrap();
76            let val = cnts.entry(*k).or_insert(0usize);
77            *val = *val + 1;
78        });
79
80        for (k, v) in cnts {
81            println!("{}: {:.2}%", k, 100f64 * (v as f64) / (N as f64));
82        }
83    }
84}