use std::collections::BinaryHeap;
use std::cmp::Ordering;
use crate::item::{Item, Scoreable};
use crate::factor::Factor;
use crate::state::ProfileState;
use crate::error::AriaError;
const HEAP_THRESHOLD: usize = 500;
struct ScoredItem<'a> {
item: &'a Item,
score: f32,
}
impl<'a> PartialEq for ScoredItem<'a> {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl<'a> Eq for ScoredItem<'a> {}
impl<'a> PartialOrd for ScoredItem<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Ord for ScoredItem<'a> {
fn cmp(&self, other: &Self) -> Ordering {
self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal)
}
}
pub struct Selector {
pub exploration_rate: f32,
rng_state: u64,
}
impl Selector {
pub fn new(exploration_rate: f32) -> Self {
Self {
exploration_rate,
rng_state: 42,
}
}
pub fn select<'a>(
&mut self,
eligible: &[&'a Item],
factors: &[Box<dyn Factor>],
state: &ProfileState,
now: u64,
) -> Result<&'a Item, AriaError> {
if eligible.is_empty() {
return Err(AriaError::NoEligibleItems);
}
if factors.is_empty() {
return Err(AriaError::NoFactors);
}
if eligible.len() <= HEAP_THRESHOLD {
self.select_linear(eligible, factors, state, now)
} else {
self.select_heap(eligible, factors, state, now)
}
}
fn compute_score(
&mut self,
item: &dyn Scoreable,
factors: &[Box<dyn Factor>],
state: &ProfileState,
now: u64,
) -> f32 {
let base: f32 = factors.iter().map(|f| f.score(item, state, now)).product();
let noise = self.next_noise();
base * (1.0 + noise)
}
fn select_linear<'a>(
&mut self,
eligible: &[&'a Item],
factors: &[Box<dyn Factor>],
state: &ProfileState,
now: u64,
) -> Result<&'a Item, AriaError> {
let mut best_score = f32::NEG_INFINITY;
let mut best_item = eligible[0];
for &item in eligible {
let score = self.compute_score(item, factors, state, now);
if score > best_score {
best_score = score;
best_item = item;
}
}
Ok(best_item)
}
fn select_heap<'a>(
&mut self,
eligible: &[&'a Item],
factors: &[Box<dyn Factor>],
state: &ProfileState,
now: u64,
) -> Result<&'a Item, AriaError> {
let mut heap = BinaryHeap::with_capacity(eligible.len());
for &item in eligible {
let score = self.compute_score(item, factors, state, now);
heap.push(ScoredItem { item, score });
}
Ok(heap.pop().unwrap().item)
}
fn next_noise(&mut self) -> f32 {
if self.exploration_rate == 0.0 {
return 0.0;
}
self.rng_state ^= self.rng_state << 13;
self.rng_state ^= self.rng_state >> 7;
self.rng_state ^= self.rng_state << 17;
let norm = (self.rng_state as f32) / (u64::MAX as f32);
norm.abs() * self.exploration_rate
}
pub fn seed(&mut self, seed: u64) {
self.rng_state = if seed == 0 { 1 } else { seed };
}
}
impl Default for Selector {
fn default() -> Self {
Self::new(0.05)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::factor::{ChallengeFactor, SpacingFactor, CoverageFactor};
use crate::item::Item;
use crate::state::ProfileState;
fn factors() -> Vec<Box<dyn Factor>> {
vec![
Box::new(ChallengeFactor::default()),
Box::new(SpacingFactor::default()),
Box::new(CoverageFactor),
]
}
#[test]
fn deterministic_with_zero_exploration() {
let mut selector = Selector::new(0.0);
let items = vec![
Item::new("easy", 0.1, "cat"),
Item::new("target", 0.6, "cat"),
Item::new("hard", 0.9, "cat"),
];
let eligible: Vec<&Item> = items.iter().collect();
let mut state = ProfileState::new();
state.skill = 0.5;
state.optimism_bias = 0.1;
let first = selector.select(&eligible, &factors(), &state, 0).unwrap().id().to_string();
let second = selector.select(&eligible, &factors(), &state, 0).unwrap().id().to_string();
assert_eq!(first, second);
assert_eq!(first, "target");
}
#[test]
fn no_eligible_items_returns_error() {
let mut selector = Selector::new(0.0);
let empty: Vec<&Item> = vec![];
let state = ProfileState::new();
let result = selector.select(&empty, &factors(), &state, 0);
assert_eq!(result.unwrap_err(), AriaError::NoEligibleItems);
}
#[test]
fn no_factors_returns_error() {
let mut selector = Selector::new(0.0);
let items = vec![Item::new("x", 0.5, "cat")];
let eligible: Vec<&Item> = items.iter().collect();
let state = ProfileState::new();
let result = selector.select(&eligible, &[], &state, 0);
assert_eq!(result.unwrap_err(), AriaError::NoFactors);
}
}