Skip to main content

aria_core/
engine.rs

1use std::collections::HashMap;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use crate::error::AriaError;
5use crate::factor::Factor;
6use crate::item::{Item, ItemRegistry, Scoreable};
7use crate::selector::Selector;
8use crate::signal::Signal;
9use crate::state::ProfileState;
10use crate::updater::{DefaultStateUpdater, StateUpdater};
11
12/// Engine configuration.
13#[derive(Debug, Clone)]
14pub struct EngineConfig {
15    /// Exploration noise rate. 0.0 = deterministic. Default 0.05.
16    pub exploration_rate: f32,
17    /// Alpha for default state updater skill learning rate. Default 0.05.
18    pub alpha: f32,
19}
20
21impl Default for EngineConfig {
22    fn default() -> Self {
23        Self {
24            exploration_rate: 0.05,
25            alpha: 0.05,
26        }
27    }
28}
29
30/// The main engine. Owns item registry, user states, factor pipeline, selector.
31///
32/// # Usage
33/// ```rust
34/// use aria_core::{Engine, EngineConfig, Signal, Scoreable};
35/// use aria_core::item::Item;
36/// use aria_core::factor::{ChallengeFactor, SpacingFactor, CoverageFactor};
37///
38/// let mut engine = Engine::new(EngineConfig::default());
39///
40/// engine.add_factor(Box::new(ChallengeFactor::default()));
41/// engine.add_factor(Box::new(SpacingFactor::default()));
42/// engine.add_factor(Box::new(CoverageFactor));
43///
44/// engine.add_items(vec![
45///     Item::new("algebra_basics", 0.2, "algebra"),
46///     Item::new("quadratic_eq",   0.6, "algebra"),
47///     Item::new("integration",    0.9, "calculus"),
48/// ]).unwrap();
49///
50/// let item = engine.suggest("user_42").unwrap();
51/// let item_id = item.id().to_string();
52/// engine.feedback("user_42", &item_id, Signal::new(true, 0.4)).unwrap();
53/// ```
54pub struct Engine {
55    #[allow(dead_code)]
56    config: EngineConfig,
57    registry: ItemRegistry,
58    factors: Vec<Box<dyn Factor>>,
59    states: HashMap<String, ProfileState>,
60    selector: Selector,
61    updater: Box<dyn StateUpdater>,
62}
63
64impl Engine {
65    /// Create engine with config.
66    pub fn new(config: EngineConfig) -> Self {
67        let selector = Selector::new(config.exploration_rate);
68        let updater = Box::new(DefaultStateUpdater::new(config.alpha));
69        Self {
70            config,
71            registry: ItemRegistry::new(),
72            factors: Vec::new(),
73            states: HashMap::new(),
74            selector,
75            updater,
76        }
77    }
78
79    /// Replace the default state updater with a custom implementation.
80    /// Call before first interaction.
81    pub fn set_updater(&mut self, updater: Box<dyn StateUpdater>) {
82        self.updater = updater;
83    }
84
85    /// Register items. Returns Err if prerequisites form a cycle.
86    pub fn add_items(&mut self, items: Vec<Item>) -> Result<(), AriaError> {
87        self.registry.register(items)
88    }
89
90    /// Register a scoring factor. Order matters — factors are applied in
91    /// registration order; all scores are multiplied together.
92    pub fn add_factor(&mut self, factor: Box<dyn Factor>) {
93        self.factors.push(factor);
94    }
95
96    /// Suggest the best next item for a user.
97    ///
98    /// Creates a fresh ProfileState for new users automatically.
99    /// Returns a reference to the winning Item.
100    pub fn suggest(&mut self, user_id: &str) -> Result<&Item, AriaError> {
101        if self.registry.is_empty() {
102            return Err(AriaError::NoItems);
103        }
104        if self.factors.is_empty() {
105            return Err(AriaError::NoFactors);
106        }
107
108        let state = self.states.entry(user_id.to_string()).or_default();
109        let eligible = self.registry.eligible(&state.resolved_set);
110
111        if eligible.is_empty() {
112            return Err(AriaError::NoEligibleItems);
113        }
114
115        let now = current_timestamp();
116        // selector needs &mut self, so we need to borrow carefully
117        let item_id = {
118            let selected = self.selector.select(&eligible, &self.factors, state, now)?;
119            selected.id().to_string()
120        };
121
122        self.registry
123            .get(&item_id)
124            .ok_or_else(|| AriaError::ItemNotFound(item_id))
125    }
126
127    /// Report feedback for a user-item interaction.
128    /// Updates internal ProfileState for the user.
129    pub fn feedback(
130        &mut self,
131        user_id: &str,
132        item_id: &str,
133        signal: Signal,
134    ) -> Result<(), AriaError> {
135        let item = self
136            .registry
137            .get(item_id)
138            .ok_or_else(|| AriaError::ItemNotFound(item_id.to_string()))?
139            .clone();
140
141        let state = self.states.entry(user_id.to_string()).or_default().clone();
142        let now = current_timestamp();
143        let next_state = self.updater.update(&state, &item, &signal, now);
144        self.states.insert(user_id.to_string(), next_state);
145
146        Ok(())
147    }
148
149    /// Get current ProfileState for a user. Returns None if user has no interactions.
150    pub fn get_state(&self, user_id: &str) -> Option<&ProfileState> {
151        self.states.get(user_id)
152    }
153
154    /// Load a previously serialised ProfileState for a user.
155    /// Use with Serialiser::decode to restore sessions across restarts.
156    pub fn load_state(&mut self, user_id: impl Into<String>, state: ProfileState) {
157        self.states.insert(user_id.into(), state);
158    }
159
160    /// Number of registered items.
161    pub fn item_count(&self) -> usize {
162        self.registry.len()
163    }
164
165    /// Number of registered factors.
166    pub fn factor_count(&self) -> usize {
167        self.factors.len()
168    }
169
170    /// Seed the selector RNG — use in tests for determinism.
171    pub fn seed_rng(&mut self, seed: u64) {
172        self.selector.seed(seed);
173    }
174}
175
176fn current_timestamp() -> u64 {
177    SystemTime::now()
178        .duration_since(UNIX_EPOCH)
179        .map(|d| d.as_secs())
180        .unwrap_or(0)
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::factor::{ChallengeFactor, CoverageFactor, SpacingFactor};
187
188    fn make_engine() -> Engine {
189        let mut e = Engine::new(EngineConfig {
190            exploration_rate: 0.0, // deterministic
191            alpha: 0.05,
192        });
193        e.add_factor(Box::new(ChallengeFactor::default()));
194        e.add_factor(Box::new(SpacingFactor::default()));
195        e.add_factor(Box::new(CoverageFactor));
196        e.seed_rng(1);
197        e
198    }
199
200    fn add_items(e: &mut Engine) {
201        e.add_items(vec![
202            Item::new("easy",   0.1, "math"),
203            Item::new("medium", 0.5, "math"),
204            Item::new("hard",   0.9, "math"),
205            Item::new("sci_a",  0.4, "science"),
206            Item::new("sci_b",  0.6, "science"),
207        ])
208        .unwrap();
209    }
210
211    #[test]
212    fn suggest_returns_item() {
213        let mut e = make_engine();
214        add_items(&mut e);
215        let item = e.suggest("user1").unwrap();
216        assert!(!item.id().is_empty());
217    }
218
219    #[test]
220    fn no_items_returns_error() {
221        let mut e = make_engine();
222        let err = e.suggest("user1").unwrap_err();
223        assert_eq!(err, AriaError::NoItems);
224    }
225
226    #[test]
227    fn no_factors_returns_error() {
228        let mut e = Engine::new(EngineConfig::default());
229        e.add_items(vec![Item::new("x", 0.5, "cat")]).unwrap();
230        let err = e.suggest("user1").unwrap_err();
231        assert_eq!(err, AriaError::NoFactors);
232    }
233
234    #[test]
235    fn feedback_updates_state() {
236        let mut e = make_engine();
237        add_items(&mut e);
238        e.suggest("user1").unwrap();
239        e.feedback("user1", "easy", Signal::new(true, 0.2)).unwrap();
240        let state = e.get_state("user1").unwrap();
241        assert!(state.skill > 0.0);
242        assert!(state.interaction_count == 1);
243    }
244
245    #[test]
246    fn skill_monotone_on_all_success() {
247        let mut e = make_engine();
248        add_items(&mut e);
249        let mut prev_skill = 0.0f32;
250        for _ in 0..5 {
251            let item_id = e.suggest("user1").unwrap().id().to_string();
252            e.feedback("user1", &item_id, Signal::new(true, 0.5)).unwrap();
253            let skill = e.get_state("user1").unwrap().skill;
254            assert!(skill >= prev_skill);
255            prev_skill = skill;
256        }
257    }
258
259    #[test]
260    fn state_roundtrip_via_load() {
261        let mut e = make_engine();
262        add_items(&mut e);
263        e.feedback("user1", "easy", Signal::new(true, 0.3)).unwrap();
264        let saved = e.get_state("user1").unwrap().clone();
265
266        let mut e2 = make_engine();
267        add_items(&mut e2);
268        e2.load_state("user1", saved.clone());
269        let loaded = e2.get_state("user1").unwrap();
270
271        assert!((loaded.skill - saved.skill).abs() < 1e-6);
272        assert_eq!(loaded.interaction_count, saved.interaction_count);
273    }
274
275    #[test]
276    fn prereq_gating_works() {
277        let mut e = make_engine();
278        e.add_items(vec![
279            Item::new("base", 0.3, "math"),
280            Item::new("advanced", 0.8, "math").with_prereqs(vec!["base".into()]),
281        ])
282        .unwrap();
283
284        // Before resolving prereq — advanced should never be suggested
285        for _ in 0..10 {
286            let item = e.suggest("user1").unwrap();
287            assert_eq!(item.id(), "base");
288        }
289
290        // Resolve base
291        e.feedback("user1", "base", Signal::new(true, 0.5)).unwrap();
292
293        // Now advanced can appear
294        let mut saw_advanced = false;
295        for _ in 0..10 {
296            let item = e.suggest("user1").unwrap();
297            if item.id() == "advanced" {
298                saw_advanced = true;
299                break;
300            }
301        }
302        assert!(saw_advanced);
303    }
304}