use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::error::AriaError;
use crate::factor::Factor;
use crate::item::{Item, ItemRegistry, Scoreable};
use crate::selector::Selector;
use crate::signal::Signal;
use crate::state::ProfileState;
use crate::updater::{DefaultStateUpdater, StateUpdater};
#[derive(Debug, Clone)]
pub struct EngineConfig {
pub exploration_rate: f32,
pub alpha: f32,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
exploration_rate: 0.05,
alpha: 0.05,
}
}
}
pub struct Engine {
#[allow(dead_code)]
config: EngineConfig,
registry: ItemRegistry,
factors: Vec<Box<dyn Factor>>,
states: HashMap<String, ProfileState>,
selector: Selector,
updater: Box<dyn StateUpdater>,
}
impl Engine {
pub fn new(config: EngineConfig) -> Self {
let selector = Selector::new(config.exploration_rate);
let updater = Box::new(DefaultStateUpdater::new(config.alpha));
Self {
config,
registry: ItemRegistry::new(),
factors: Vec::new(),
states: HashMap::new(),
selector,
updater,
}
}
pub fn set_updater(&mut self, updater: Box<dyn StateUpdater>) {
self.updater = updater;
}
pub fn add_items(&mut self, items: Vec<Item>) -> Result<(), AriaError> {
self.registry.register(items)
}
pub fn add_factor(&mut self, factor: Box<dyn Factor>) {
self.factors.push(factor);
}
pub fn suggest(&mut self, user_id: &str) -> Result<&Item, AriaError> {
if self.registry.is_empty() {
return Err(AriaError::NoItems);
}
if self.factors.is_empty() {
return Err(AriaError::NoFactors);
}
let state = self.states.entry(user_id.to_string()).or_default();
let eligible = self.registry.eligible(&state.resolved_set);
if eligible.is_empty() {
return Err(AriaError::NoEligibleItems);
}
let now = current_timestamp();
let item_id = {
let selected = self.selector.select(&eligible, &self.factors, state, now)?;
selected.id().to_string()
};
self.registry
.get(&item_id)
.ok_or_else(|| AriaError::ItemNotFound(item_id))
}
pub fn feedback(
&mut self,
user_id: &str,
item_id: &str,
signal: Signal,
) -> Result<(), AriaError> {
let item = self
.registry
.get(item_id)
.ok_or_else(|| AriaError::ItemNotFound(item_id.to_string()))?
.clone();
let state = self.states.entry(user_id.to_string()).or_default().clone();
let now = current_timestamp();
let next_state = self.updater.update(&state, &item, &signal, now);
self.states.insert(user_id.to_string(), next_state);
Ok(())
}
pub fn get_state(&self, user_id: &str) -> Option<&ProfileState> {
self.states.get(user_id)
}
pub fn load_state(&mut self, user_id: impl Into<String>, state: ProfileState) {
self.states.insert(user_id.into(), state);
}
pub fn item_count(&self) -> usize {
self.registry.len()
}
pub fn factor_count(&self) -> usize {
self.factors.len()
}
pub fn seed_rng(&mut self, seed: u64) {
self.selector.seed(seed);
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::factor::{ChallengeFactor, CoverageFactor, SpacingFactor};
fn make_engine() -> Engine {
let mut e = Engine::new(EngineConfig {
exploration_rate: 0.0, alpha: 0.05,
});
e.add_factor(Box::new(ChallengeFactor::default()));
e.add_factor(Box::new(SpacingFactor::default()));
e.add_factor(Box::new(CoverageFactor));
e.seed_rng(1);
e
}
fn add_items(e: &mut Engine) {
e.add_items(vec![
Item::new("easy", 0.1, "math"),
Item::new("medium", 0.5, "math"),
Item::new("hard", 0.9, "math"),
Item::new("sci_a", 0.4, "science"),
Item::new("sci_b", 0.6, "science"),
])
.unwrap();
}
#[test]
fn suggest_returns_item() {
let mut e = make_engine();
add_items(&mut e);
let item = e.suggest("user1").unwrap();
assert!(!item.id().is_empty());
}
#[test]
fn no_items_returns_error() {
let mut e = make_engine();
let err = e.suggest("user1").unwrap_err();
assert_eq!(err, AriaError::NoItems);
}
#[test]
fn no_factors_returns_error() {
let mut e = Engine::new(EngineConfig::default());
e.add_items(vec![Item::new("x", 0.5, "cat")]).unwrap();
let err = e.suggest("user1").unwrap_err();
assert_eq!(err, AriaError::NoFactors);
}
#[test]
fn feedback_updates_state() {
let mut e = make_engine();
add_items(&mut e);
e.suggest("user1").unwrap();
e.feedback("user1", "easy", Signal::new(true, 0.2)).unwrap();
let state = e.get_state("user1").unwrap();
assert!(state.skill > 0.0);
assert!(state.interaction_count == 1);
}
#[test]
fn skill_monotone_on_all_success() {
let mut e = make_engine();
add_items(&mut e);
let mut prev_skill = 0.0f32;
for _ in 0..5 {
let item_id = e.suggest("user1").unwrap().id().to_string();
e.feedback("user1", &item_id, Signal::new(true, 0.5)).unwrap();
let skill = e.get_state("user1").unwrap().skill;
assert!(skill >= prev_skill);
prev_skill = skill;
}
}
#[test]
fn state_roundtrip_via_load() {
let mut e = make_engine();
add_items(&mut e);
e.feedback("user1", "easy", Signal::new(true, 0.3)).unwrap();
let saved = e.get_state("user1").unwrap().clone();
let mut e2 = make_engine();
add_items(&mut e2);
e2.load_state("user1", saved.clone());
let loaded = e2.get_state("user1").unwrap();
assert!((loaded.skill - saved.skill).abs() < 1e-6);
assert_eq!(loaded.interaction_count, saved.interaction_count);
}
#[test]
fn prereq_gating_works() {
let mut e = make_engine();
e.add_items(vec![
Item::new("base", 0.3, "math"),
Item::new("advanced", 0.8, "math").with_prereqs(vec!["base".into()]),
])
.unwrap();
for _ in 0..10 {
let item = e.suggest("user1").unwrap();
assert_eq!(item.id(), "base");
}
e.feedback("user1", "base", Signal::new(true, 0.5)).unwrap();
let mut saw_advanced = false;
for _ in 0..10 {
let item = e.suggest("user1").unwrap();
if item.id() == "advanced" {
saw_advanced = true;
break;
}
}
assert!(saw_advanced);
}
}