Skip to main content

aria_core/
selector.rs

1use std::collections::BinaryHeap;
2use std::cmp::Ordering;
3
4use crate::item::{Item, Scoreable};
5use crate::factor::Factor;
6use crate::state::ProfileState;
7use crate::error::AriaError;
8
9/// Threshold at which the selector switches from linear scan to heap.
10const HEAP_THRESHOLD: usize = 500;
11
12/// Wraps an item + score for heap ordering.
13struct ScoredItem<'a> {
14    item: &'a Item,
15    score: f32,
16}
17
18impl<'a> PartialEq for ScoredItem<'a> {
19    fn eq(&self, other: &Self) -> bool {
20        self.score == other.score
21    }
22}
23
24impl<'a> Eq for ScoredItem<'a> {}
25
26impl<'a> PartialOrd for ScoredItem<'a> {
27    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28        Some(self.cmp(other))
29    }
30}
31
32impl<'a> Ord for ScoredItem<'a> {
33    fn cmp(&self, other: &Self) -> Ordering {
34        self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal)
35    }
36}
37
38/// Selector — scores eligible items, applies exploration noise, returns best.
39pub struct Selector {
40    /// Fraction of noise added to score. 0.0 = deterministic. Default 0.05.
41    pub exploration_rate: f32,
42    /// Seed for deterministic noise (useful in tests).
43    rng_state: u64,
44}
45
46impl Selector {
47    pub fn new(exploration_rate: f32) -> Self {
48        Self {
49            exploration_rate,
50            rng_state: 42,
51        }
52    }
53
54    /// Select best item from eligible list using factor pipeline scores.
55    /// Returns reference to winning item.
56    pub fn select<'a>(
57        &mut self,
58        eligible: &[&'a Item],
59        factors: &[Box<dyn Factor>],
60        state: &ProfileState,
61        now: u64,
62    ) -> Result<&'a Item, AriaError> {
63        if eligible.is_empty() {
64            return Err(AriaError::NoEligibleItems);
65        }
66        if factors.is_empty() {
67            return Err(AriaError::NoFactors);
68        }
69
70        if eligible.len() <= HEAP_THRESHOLD {
71            self.select_linear(eligible, factors, state, now)
72        } else {
73            self.select_heap(eligible, factors, state, now)
74        }
75    }
76
77    fn compute_score(
78        &mut self,
79        item: &dyn Scoreable,
80        factors: &[Box<dyn Factor>],
81        state: &ProfileState,
82        now: u64,
83    ) -> f32 {
84        let base: f32 = factors.iter().map(|f| f.score(item, state, now)).product();
85        let noise = self.next_noise();
86        base * (1.0 + noise)
87    }
88
89    fn select_linear<'a>(
90        &mut self,
91        eligible: &[&'a Item],
92        factors: &[Box<dyn Factor>],
93        state: &ProfileState,
94        now: u64,
95    ) -> Result<&'a Item, AriaError> {
96        let mut best_score = f32::NEG_INFINITY;
97        let mut best_item = eligible[0];
98
99        for &item in eligible {
100            let score = self.compute_score(item, factors, state, now);
101            if score > best_score {
102                best_score = score;
103                best_item = item;
104            }
105        }
106        Ok(best_item)
107    }
108
109    fn select_heap<'a>(
110        &mut self,
111        eligible: &[&'a Item],
112        factors: &[Box<dyn Factor>],
113        state: &ProfileState,
114        now: u64,
115    ) -> Result<&'a Item, AriaError> {
116        let mut heap = BinaryHeap::with_capacity(eligible.len());
117
118        for &item in eligible {
119            let score = self.compute_score(item, factors, state, now);
120            heap.push(ScoredItem { item, score });
121        }
122
123        Ok(heap.pop().unwrap().item)
124    }
125
126    /// Minimal xorshift64 PRNG — no external deps, deterministic given seed.
127    fn next_noise(&mut self) -> f32 {
128        if self.exploration_rate == 0.0 {
129            return 0.0;
130        }
131        self.rng_state ^= self.rng_state << 13;
132        self.rng_state ^= self.rng_state >> 7;
133        self.rng_state ^= self.rng_state << 17;
134        let norm = (self.rng_state as f32) / (u64::MAX as f32);
135        norm.abs() * self.exploration_rate
136    }
137
138    /// Re-seed the RNG — useful in tests for determinism.
139    pub fn seed(&mut self, seed: u64) {
140        self.rng_state = if seed == 0 { 1 } else { seed };
141    }
142}
143
144impl Default for Selector {
145    fn default() -> Self {
146        Self::new(0.05)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::factor::{ChallengeFactor, SpacingFactor, CoverageFactor};
154    use crate::item::Item;
155    use crate::state::ProfileState;
156
157    fn factors() -> Vec<Box<dyn Factor>> {
158        vec![
159            Box::new(ChallengeFactor::default()),
160            Box::new(SpacingFactor::default()),
161            Box::new(CoverageFactor),
162        ]
163    }
164
165    #[test]
166    fn deterministic_with_zero_exploration() {
167        let mut selector = Selector::new(0.0);
168        let items = vec![
169            Item::new("easy", 0.1, "cat"),
170            Item::new("target", 0.6, "cat"),
171            Item::new("hard", 0.9, "cat"),
172        ];
173        let eligible: Vec<&Item> = items.iter().collect();
174        let mut state = ProfileState::new();
175        state.skill = 0.5;
176        state.optimism_bias = 0.1;
177
178        let first = selector.select(&eligible, &factors(), &state, 0).unwrap().id().to_string();
179        let second = selector.select(&eligible, &factors(), &state, 0).unwrap().id().to_string();
180        assert_eq!(first, second);
181        assert_eq!(first, "target");
182    }
183
184    #[test]
185    fn no_eligible_items_returns_error() {
186        let mut selector = Selector::new(0.0);
187        let empty: Vec<&Item> = vec![];
188        let state = ProfileState::new();
189        let result = selector.select(&empty, &factors(), &state, 0);
190        assert_eq!(result.unwrap_err(), AriaError::NoEligibleItems);
191    }
192
193    #[test]
194    fn no_factors_returns_error() {
195        let mut selector = Selector::new(0.0);
196        let items = vec![Item::new("x", 0.5, "cat")];
197        let eligible: Vec<&Item> = items.iter().collect();
198        let state = ProfileState::new();
199        let result = selector.select(&eligible, &[], &state, 0);
200        assert_eq!(result.unwrap_err(), AriaError::NoFactors);
201    }
202}