1use std::collections::HashMap;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use crate::error::AriaError;
5use crate::factor::Factor;
6use crate::item::{Item, ItemRegistry, Scoreable};
7use crate::selector::Selector;
8use crate::signal::Signal;
9use crate::state::ProfileState;
10use crate::updater::{DefaultStateUpdater, StateUpdater};
11
12#[derive(Debug, Clone)]
14pub struct EngineConfig {
15 pub exploration_rate: f32,
17 pub alpha: f32,
19}
20
21impl Default for EngineConfig {
22 fn default() -> Self {
23 Self {
24 exploration_rate: 0.05,
25 alpha: 0.05,
26 }
27 }
28}
29
30pub struct Engine {
55 #[allow(dead_code)]
56 config: EngineConfig,
57 registry: ItemRegistry,
58 factors: Vec<Box<dyn Factor>>,
59 states: HashMap<String, ProfileState>,
60 selector: Selector,
61 updater: Box<dyn StateUpdater>,
62}
63
64impl Engine {
65 pub fn new(config: EngineConfig) -> Self {
67 let selector = Selector::new(config.exploration_rate);
68 let updater = Box::new(DefaultStateUpdater::new(config.alpha));
69 Self {
70 config,
71 registry: ItemRegistry::new(),
72 factors: Vec::new(),
73 states: HashMap::new(),
74 selector,
75 updater,
76 }
77 }
78
79 pub fn set_updater(&mut self, updater: Box<dyn StateUpdater>) {
82 self.updater = updater;
83 }
84
85 pub fn add_items(&mut self, items: Vec<Item>) -> Result<(), AriaError> {
87 self.registry.register(items)
88 }
89
90 pub fn add_factor(&mut self, factor: Box<dyn Factor>) {
93 self.factors.push(factor);
94 }
95
96 pub fn suggest(&mut self, user_id: &str) -> Result<&Item, AriaError> {
101 if self.registry.is_empty() {
102 return Err(AriaError::NoItems);
103 }
104 if self.factors.is_empty() {
105 return Err(AriaError::NoFactors);
106 }
107
108 let state = self.states.entry(user_id.to_string()).or_default();
109 let eligible = self.registry.eligible(&state.resolved_set);
110
111 if eligible.is_empty() {
112 return Err(AriaError::NoEligibleItems);
113 }
114
115 let now = current_timestamp();
116 let item_id = {
118 let selected = self.selector.select(&eligible, &self.factors, state, now)?;
119 selected.id().to_string()
120 };
121
122 self.registry
123 .get(&item_id)
124 .ok_or_else(|| AriaError::ItemNotFound(item_id))
125 }
126
127 pub fn feedback(
130 &mut self,
131 user_id: &str,
132 item_id: &str,
133 signal: Signal,
134 ) -> Result<(), AriaError> {
135 let item = self
136 .registry
137 .get(item_id)
138 .ok_or_else(|| AriaError::ItemNotFound(item_id.to_string()))?
139 .clone();
140
141 let state = self.states.entry(user_id.to_string()).or_default().clone();
142 let now = current_timestamp();
143 let next_state = self.updater.update(&state, &item, &signal, now);
144 self.states.insert(user_id.to_string(), next_state);
145
146 Ok(())
147 }
148
149 pub fn get_state(&self, user_id: &str) -> Option<&ProfileState> {
151 self.states.get(user_id)
152 }
153
154 pub fn load_state(&mut self, user_id: impl Into<String>, state: ProfileState) {
157 self.states.insert(user_id.into(), state);
158 }
159
160 pub fn item_count(&self) -> usize {
162 self.registry.len()
163 }
164
165 pub fn factor_count(&self) -> usize {
167 self.factors.len()
168 }
169
170 pub fn seed_rng(&mut self, seed: u64) {
172 self.selector.seed(seed);
173 }
174}
175
176fn current_timestamp() -> u64 {
177 SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .map(|d| d.as_secs())
180 .unwrap_or(0)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::factor::{ChallengeFactor, CoverageFactor, SpacingFactor};
187
188 fn make_engine() -> Engine {
189 let mut e = Engine::new(EngineConfig {
190 exploration_rate: 0.0, alpha: 0.05,
192 });
193 e.add_factor(Box::new(ChallengeFactor::default()));
194 e.add_factor(Box::new(SpacingFactor::default()));
195 e.add_factor(Box::new(CoverageFactor));
196 e.seed_rng(1);
197 e
198 }
199
200 fn add_items(e: &mut Engine) {
201 e.add_items(vec![
202 Item::new("easy", 0.1, "math"),
203 Item::new("medium", 0.5, "math"),
204 Item::new("hard", 0.9, "math"),
205 Item::new("sci_a", 0.4, "science"),
206 Item::new("sci_b", 0.6, "science"),
207 ])
208 .unwrap();
209 }
210
211 #[test]
212 fn suggest_returns_item() {
213 let mut e = make_engine();
214 add_items(&mut e);
215 let item = e.suggest("user1").unwrap();
216 assert!(!item.id().is_empty());
217 }
218
219 #[test]
220 fn no_items_returns_error() {
221 let mut e = make_engine();
222 let err = e.suggest("user1").unwrap_err();
223 assert_eq!(err, AriaError::NoItems);
224 }
225
226 #[test]
227 fn no_factors_returns_error() {
228 let mut e = Engine::new(EngineConfig::default());
229 e.add_items(vec![Item::new("x", 0.5, "cat")]).unwrap();
230 let err = e.suggest("user1").unwrap_err();
231 assert_eq!(err, AriaError::NoFactors);
232 }
233
234 #[test]
235 fn feedback_updates_state() {
236 let mut e = make_engine();
237 add_items(&mut e);
238 e.suggest("user1").unwrap();
239 e.feedback("user1", "easy", Signal::new(true, 0.2)).unwrap();
240 let state = e.get_state("user1").unwrap();
241 assert!(state.skill > 0.0);
242 assert!(state.interaction_count == 1);
243 }
244
245 #[test]
246 fn skill_monotone_on_all_success() {
247 let mut e = make_engine();
248 add_items(&mut e);
249 let mut prev_skill = 0.0f32;
250 for _ in 0..5 {
251 let item_id = e.suggest("user1").unwrap().id().to_string();
252 e.feedback("user1", &item_id, Signal::new(true, 0.5)).unwrap();
253 let skill = e.get_state("user1").unwrap().skill;
254 assert!(skill >= prev_skill);
255 prev_skill = skill;
256 }
257 }
258
259 #[test]
260 fn state_roundtrip_via_load() {
261 let mut e = make_engine();
262 add_items(&mut e);
263 e.feedback("user1", "easy", Signal::new(true, 0.3)).unwrap();
264 let saved = e.get_state("user1").unwrap().clone();
265
266 let mut e2 = make_engine();
267 add_items(&mut e2);
268 e2.load_state("user1", saved.clone());
269 let loaded = e2.get_state("user1").unwrap();
270
271 assert!((loaded.skill - saved.skill).abs() < 1e-6);
272 assert_eq!(loaded.interaction_count, saved.interaction_count);
273 }
274
275 #[test]
276 fn prereq_gating_works() {
277 let mut e = make_engine();
278 e.add_items(vec![
279 Item::new("base", 0.3, "math"),
280 Item::new("advanced", 0.8, "math").with_prereqs(vec!["base".into()]),
281 ])
282 .unwrap();
283
284 for _ in 0..10 {
286 let item = e.suggest("user1").unwrap();
287 assert_eq!(item.id(), "base");
288 }
289
290 e.feedback("user1", "base", Signal::new(true, 0.5)).unwrap();
292
293 let mut saw_advanced = false;
295 for _ in 0..10 {
296 let item = e.suggest("user1").unwrap();
297 if item.id() == "advanced" {
298 saw_advanced = true;
299 break;
300 }
301 }
302 assert!(saw_advanced);
303 }
304}