Skip to main content

loadwise_core/strategy/
random.rs

1use rand::Rng;
2
3use super::{SelectionContext, Strategy};
4use crate::Weighted;
5
6/// Uniform random selection.
7#[derive(Debug)]
8pub struct Random;
9
10impl Random {
11    pub fn new() -> Self {
12        Self
13    }
14}
15
16impl Default for Random {
17    fn default() -> Self {
18        Self
19    }
20}
21
22impl<N> Strategy<N> for Random {
23    fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
24        let eligible: Vec<usize> = (0..candidates.len())
25            .filter(|i| !ctx.is_excluded(*i))
26            .collect();
27        if eligible.is_empty() {
28            return None;
29        }
30        Some(eligible[rand::rng().random_range(0..eligible.len())])
31    }
32}
33
34/// Weighted random selection. Nodes with higher weight are proportionally more likely.
35#[derive(Debug)]
36pub struct WeightedRandom;
37
38impl WeightedRandom {
39    pub fn new() -> Self {
40        Self
41    }
42}
43
44impl Default for WeightedRandom {
45    fn default() -> Self {
46        Self
47    }
48}
49
50impl<N: Weighted> Strategy<N> for WeightedRandom {
51    fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
52        if candidates.is_empty() {
53            return None;
54        }
55
56        let eligible: Vec<usize> = (0..candidates.len())
57            .filter(|i| !ctx.is_excluded(*i))
58            .collect();
59        if eligible.is_empty() {
60            return None;
61        }
62
63        let total_weight: u64 = eligible.iter().map(|&i| candidates[i].weight() as u64).sum();
64        if total_weight == 0 {
65            return None;
66        }
67
68        let mut point = rand::rng().random_range(0..total_weight);
69
70        for &i in &eligible {
71            let w = candidates[i].weight() as u64;
72            if point < w {
73                return Some(i);
74            }
75            point -= w;
76        }
77
78        eligible.last().copied()
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn random_empty_returns_none() {
88        let s = Random::new();
89        let nodes: [i32; 0] = [];
90        assert_eq!(s.select(&nodes, &SelectionContext::default()), None);
91    }
92
93    #[test]
94    fn random_all_excluded_returns_none() {
95        let s = Random::new();
96        let nodes = [1, 2];
97        let ctx = SelectionContext::builder().exclude(vec![0, 1]).build();
98        assert_eq!(s.select(&nodes, &ctx), None);
99    }
100
101    #[test]
102    fn random_respects_exclude() {
103        let s = Random::new();
104        let nodes = [1, 2];
105        let ctx = SelectionContext::builder().exclude(vec![0]).build();
106        // Can only pick index 1
107        for _ in 0..20 {
108            assert_eq!(s.select(&nodes, &ctx), Some(1));
109        }
110    }
111
112    struct W(u32);
113    impl Weighted for W {
114        fn weight(&self) -> u32 {
115            self.0
116        }
117    }
118
119    #[test]
120    fn weighted_random_empty_returns_none() {
121        let s = WeightedRandom::new();
122        let nodes: [W; 0] = [];
123        assert_eq!(s.select(&nodes, &SelectionContext::default()), None);
124    }
125
126    #[test]
127    fn weighted_random_zero_weight_returns_none() {
128        let s = WeightedRandom::new();
129        let nodes = [W(0), W(0)];
130        assert_eq!(s.select(&nodes, &SelectionContext::default()), None);
131    }
132
133    #[test]
134    fn weighted_random_all_excluded_returns_none() {
135        let s = WeightedRandom::new();
136        let nodes = [W(1)];
137        let ctx = SelectionContext::builder().exclude(vec![0]).build();
138        assert_eq!(s.select(&nodes, &ctx), None);
139    }
140}