1use std::collections::BinaryHeap;
2use std::cmp::Ordering;
3
4use crate::item::{Item, Scoreable};
5use crate::factor::Factor;
6use crate::state::ProfileState;
7use crate::error::AriaError;
8
9const HEAP_THRESHOLD: usize = 500;
11
12struct ScoredItem<'a> {
14 item: &'a Item,
15 score: f32,
16}
17
18impl<'a> PartialEq for ScoredItem<'a> {
19 fn eq(&self, other: &Self) -> bool {
20 self.score == other.score
21 }
22}
23
24impl<'a> Eq for ScoredItem<'a> {}
25
26impl<'a> PartialOrd for ScoredItem<'a> {
27 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28 Some(self.cmp(other))
29 }
30}
31
32impl<'a> Ord for ScoredItem<'a> {
33 fn cmp(&self, other: &Self) -> Ordering {
34 self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal)
35 }
36}
37
38pub struct Selector {
40 pub exploration_rate: f32,
42 rng_state: u64,
44}
45
46impl Selector {
47 pub fn new(exploration_rate: f32) -> Self {
48 Self {
49 exploration_rate,
50 rng_state: 42,
51 }
52 }
53
54 pub fn select<'a>(
57 &mut self,
58 eligible: &[&'a Item],
59 factors: &[Box<dyn Factor>],
60 state: &ProfileState,
61 now: u64,
62 ) -> Result<&'a Item, AriaError> {
63 if eligible.is_empty() {
64 return Err(AriaError::NoEligibleItems);
65 }
66 if factors.is_empty() {
67 return Err(AriaError::NoFactors);
68 }
69
70 if eligible.len() <= HEAP_THRESHOLD {
71 self.select_linear(eligible, factors, state, now)
72 } else {
73 self.select_heap(eligible, factors, state, now)
74 }
75 }
76
77 fn compute_score(
78 &mut self,
79 item: &dyn Scoreable,
80 factors: &[Box<dyn Factor>],
81 state: &ProfileState,
82 now: u64,
83 ) -> f32 {
84 let base: f32 = factors.iter().map(|f| f.score(item, state, now)).product();
85 let noise = self.next_noise();
86 base * (1.0 + noise)
87 }
88
89 fn select_linear<'a>(
90 &mut self,
91 eligible: &[&'a Item],
92 factors: &[Box<dyn Factor>],
93 state: &ProfileState,
94 now: u64,
95 ) -> Result<&'a Item, AriaError> {
96 let mut best_score = f32::NEG_INFINITY;
97 let mut best_item = eligible[0];
98
99 for &item in eligible {
100 let score = self.compute_score(item, factors, state, now);
101 if score > best_score {
102 best_score = score;
103 best_item = item;
104 }
105 }
106 Ok(best_item)
107 }
108
109 fn select_heap<'a>(
110 &mut self,
111 eligible: &[&'a Item],
112 factors: &[Box<dyn Factor>],
113 state: &ProfileState,
114 now: u64,
115 ) -> Result<&'a Item, AriaError> {
116 let mut heap = BinaryHeap::with_capacity(eligible.len());
117
118 for &item in eligible {
119 let score = self.compute_score(item, factors, state, now);
120 heap.push(ScoredItem { item, score });
121 }
122
123 Ok(heap.pop().unwrap().item)
124 }
125
126 fn next_noise(&mut self) -> f32 {
128 if self.exploration_rate == 0.0 {
129 return 0.0;
130 }
131 self.rng_state ^= self.rng_state << 13;
132 self.rng_state ^= self.rng_state >> 7;
133 self.rng_state ^= self.rng_state << 17;
134 let norm = (self.rng_state as f32) / (u64::MAX as f32);
135 norm.abs() * self.exploration_rate
136 }
137
138 pub fn seed(&mut self, seed: u64) {
140 self.rng_state = if seed == 0 { 1 } else { seed };
141 }
142}
143
144impl Default for Selector {
145 fn default() -> Self {
146 Self::new(0.05)
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::factor::{ChallengeFactor, SpacingFactor, CoverageFactor};
154 use crate::item::Item;
155 use crate::state::ProfileState;
156
157 fn factors() -> Vec<Box<dyn Factor>> {
158 vec![
159 Box::new(ChallengeFactor::default()),
160 Box::new(SpacingFactor::default()),
161 Box::new(CoverageFactor),
162 ]
163 }
164
165 #[test]
166 fn deterministic_with_zero_exploration() {
167 let mut selector = Selector::new(0.0);
168 let items = vec![
169 Item::new("easy", 0.1, "cat"),
170 Item::new("target", 0.6, "cat"),
171 Item::new("hard", 0.9, "cat"),
172 ];
173 let eligible: Vec<&Item> = items.iter().collect();
174 let mut state = ProfileState::new();
175 state.skill = 0.5;
176 state.optimism_bias = 0.1;
177
178 let first = selector.select(&eligible, &factors(), &state, 0).unwrap().id().to_string();
179 let second = selector.select(&eligible, &factors(), &state, 0).unwrap().id().to_string();
180 assert_eq!(first, second);
181 assert_eq!(first, "target");
182 }
183
184 #[test]
185 fn no_eligible_items_returns_error() {
186 let mut selector = Selector::new(0.0);
187 let empty: Vec<&Item> = vec![];
188 let state = ProfileState::new();
189 let result = selector.select(&empty, &factors(), &state, 0);
190 assert_eq!(result.unwrap_err(), AriaError::NoEligibleItems);
191 }
192
193 #[test]
194 fn no_factors_returns_error() {
195 let mut selector = Selector::new(0.0);
196 let items = vec![Item::new("x", 0.5, "cat")];
197 let eligible: Vec<&Item> = items.iter().collect();
198 let state = ProfileState::new();
199 let result = selector.select(&eligible, &[], &state, 0);
200 assert_eq!(result.unwrap_err(), AriaError::NoFactors);
201 }
202}