Skip to main content

aria_core/
updater.rs

1use crate::item::Scoreable;
2use crate::signal::Signal;
3use crate::state::{ProfileState, OPTIMISM_CEIL, OPTIMISM_FLOOR};
4
5/// Trait for state update logic after feedback.
6///
7/// Callers can provide their own implementation to completely control
8/// how skill and other state fields evolve after each interaction.
9///
10/// Default implementation follows the ARIA update rules from the PRD.
11pub trait StateUpdater: Send + Sync {
12    /// Produce a new ProfileState given current state, the item interacted with,
13    /// the signal reported, and the current timestamp.
14    ///
15    /// Returns an owned new state — immutable update pattern.
16    fn update(
17        &self,
18        state: &ProfileState,
19        item: &dyn Scoreable,
20        signal: &Signal,
21        now: u64,
22    ) -> ProfileState;
23}
24
25/// Default ARIA state updater.
26///
27/// Skill update:
28///   performance = success × (0.5 + 0.5 × (1 - effort))
29///   skill       = skill + alpha × (performance - skill)
30///
31/// Optimism update:
32///   success + effort < 0.4  → optimism += 0.02   (easy win → push harder)
33///   !success                → optimism -= 0.01   (failure → ease back)
34///   otherwise               → unchanged
35///   always clamped to [OPTIMISM_FLOOR, OPTIMISM_CEIL]
36///
37/// resolved_set: item added if success == true
38pub struct DefaultStateUpdater {
39    /// Learning rate for skill update. Default 0.05.
40    pub alpha: f32,
41}
42
43impl DefaultStateUpdater {
44    pub fn new(alpha: f32) -> Self {
45        Self { alpha }
46    }
47}
48
49impl Default for DefaultStateUpdater {
50    fn default() -> Self {
51        Self { alpha: 0.05 }
52    }
53}
54
55impl StateUpdater for DefaultStateUpdater {
56    fn update(
57        &self,
58        state: &ProfileState,
59        item: &dyn Scoreable,
60        signal: &Signal,
61        now: u64,
62    ) -> ProfileState {
63        let mut next = state.clone();
64
65        // Skill update
66        let performance = signal.performance();
67        next.skill = (state.skill + self.alpha * (performance - state.skill)).clamp(0.0, 1.0);
68
69        // Optimism update
70        next.optimism_bias = if signal.success && signal.effort < 0.4 {
71            (state.optimism_bias + 0.02).clamp(OPTIMISM_FLOOR, OPTIMISM_CEIL)
72        } else if !signal.success {
73            (state.optimism_bias - 0.01).clamp(OPTIMISM_FLOOR, OPTIMISM_CEIL)
74        } else {
75            state.optimism_bias
76        };
77
78        // Mark item as seen
79        next.last_seen.insert(item.id().to_string(), now);
80
81        // Update category count
82        *next.category_count.entry(item.category().to_string()).or_insert(0) += 1;
83
84        // Mark resolved if success
85        if signal.success {
86            next.resolved_set.insert(item.id().to_string());
87        }
88
89        next.interaction_count += 1;
90
91        next
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::item::Item;
99    use crate::state::{DEFAULT_OPTIMISM, DEFAULT_SKILL};
100
101    fn base_state() -> ProfileState {
102        ProfileState::new()
103    }
104
105    #[test]
106    fn skill_increases_on_success() {
107        let updater = DefaultStateUpdater::default();
108        let state = base_state();
109        let item = Item::new("x", 0.3, "cat");
110        let signal = Signal::new(true, 0.5);
111        let next = updater.update(&state, &item, &signal, 100);
112        assert!(next.skill > DEFAULT_SKILL);
113    }
114
115    #[test]
116    fn skill_stays_at_zero_on_failure_from_zero() {
117        let updater = DefaultStateUpdater::default();
118        let state = base_state();
119        let item = Item::new("x", 0.5, "cat");
120        let signal = Signal::new(false, 0.8);
121        let next = updater.update(&state, &item, &signal, 100);
122        // performance = 0.0, skill + alpha*(0 - 0) = 0
123        assert!((next.skill - 0.0).abs() < 1e-6);
124    }
125
126    #[test]
127    fn optimism_increases_on_easy_success() {
128        let updater = DefaultStateUpdater::default();
129        let state = base_state();
130        let item = Item::new("x", 0.1, "cat");
131        let signal = Signal::new(true, 0.1); // effort < 0.4
132        let next = updater.update(&state, &item, &signal, 100);
133        assert!(next.optimism_bias > DEFAULT_OPTIMISM);
134    }
135
136    #[test]
137    fn optimism_decreases_on_failure() {
138        let updater = DefaultStateUpdater::default();
139        let state = base_state();
140        let item = Item::new("x", 0.5, "cat");
141        let signal = Signal::new(false, 0.9);
142        let next = updater.update(&state, &item, &signal, 100);
143        assert!(next.optimism_bias < DEFAULT_OPTIMISM);
144    }
145
146    #[test]
147    fn optimism_never_below_floor() {
148        let updater = DefaultStateUpdater::default();
149        let mut state = base_state();
150        state.optimism_bias = OPTIMISM_FLOOR; // already at floor
151        let item = Item::new("x", 0.5, "cat");
152        let signal = Signal::new(false, 1.0);
153        let next = updater.update(&state, &item, &signal, 100);
154        assert!(next.optimism_bias >= OPTIMISM_FLOOR);
155    }
156
157    #[test]
158    fn resolved_set_updated_on_success() {
159        let updater = DefaultStateUpdater::default();
160        let state = base_state();
161        let item = Item::new("x", 0.5, "cat");
162        let signal = Signal::new(true, 0.5);
163        let next = updater.update(&state, &item, &signal, 100);
164        assert!(next.resolved_set.contains("x"));
165    }
166
167    #[test]
168    fn resolved_set_not_updated_on_failure() {
169        let updater = DefaultStateUpdater::default();
170        let state = base_state();
171        let item = Item::new("x", 0.5, "cat");
172        let signal = Signal::new(false, 0.5);
173        let next = updater.update(&state, &item, &signal, 100);
174        assert!(!next.resolved_set.contains("x"));
175    }
176
177    #[test]
178    fn last_seen_updated() {
179        let updater = DefaultStateUpdater::default();
180        let state = base_state();
181        let item = Item::new("x", 0.5, "cat");
182        let signal = Signal::new(true, 0.5);
183        let next = updater.update(&state, &item, &signal, 9999);
184        assert_eq!(next.last_seen["x"], 9999);
185    }
186}