1use crate::item::Scoreable;
2use crate::signal::Signal;
3use crate::state::{ProfileState, OPTIMISM_CEIL, OPTIMISM_FLOOR};
4
5pub trait StateUpdater: Send + Sync {
12 fn update(
17 &self,
18 state: &ProfileState,
19 item: &dyn Scoreable,
20 signal: &Signal,
21 now: u64,
22 ) -> ProfileState;
23}
24
25pub struct DefaultStateUpdater {
39 pub alpha: f32,
41}
42
43impl DefaultStateUpdater {
44 pub fn new(alpha: f32) -> Self {
45 Self { alpha }
46 }
47}
48
49impl Default for DefaultStateUpdater {
50 fn default() -> Self {
51 Self { alpha: 0.05 }
52 }
53}
54
55impl StateUpdater for DefaultStateUpdater {
56 fn update(
57 &self,
58 state: &ProfileState,
59 item: &dyn Scoreable,
60 signal: &Signal,
61 now: u64,
62 ) -> ProfileState {
63 let mut next = state.clone();
64
65 let performance = signal.performance();
67 next.skill = (state.skill + self.alpha * (performance - state.skill)).clamp(0.0, 1.0);
68
69 next.optimism_bias = if signal.success && signal.effort < 0.4 {
71 (state.optimism_bias + 0.02).clamp(OPTIMISM_FLOOR, OPTIMISM_CEIL)
72 } else if !signal.success {
73 (state.optimism_bias - 0.01).clamp(OPTIMISM_FLOOR, OPTIMISM_CEIL)
74 } else {
75 state.optimism_bias
76 };
77
78 next.last_seen.insert(item.id().to_string(), now);
80
81 *next.category_count.entry(item.category().to_string()).or_insert(0) += 1;
83
84 if signal.success {
86 next.resolved_set.insert(item.id().to_string());
87 }
88
89 next.interaction_count += 1;
90
91 next
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::item::Item;
99 use crate::state::{DEFAULT_OPTIMISM, DEFAULT_SKILL};
100
101 fn base_state() -> ProfileState {
102 ProfileState::new()
103 }
104
105 #[test]
106 fn skill_increases_on_success() {
107 let updater = DefaultStateUpdater::default();
108 let state = base_state();
109 let item = Item::new("x", 0.3, "cat");
110 let signal = Signal::new(true, 0.5);
111 let next = updater.update(&state, &item, &signal, 100);
112 assert!(next.skill > DEFAULT_SKILL);
113 }
114
115 #[test]
116 fn skill_stays_at_zero_on_failure_from_zero() {
117 let updater = DefaultStateUpdater::default();
118 let state = base_state();
119 let item = Item::new("x", 0.5, "cat");
120 let signal = Signal::new(false, 0.8);
121 let next = updater.update(&state, &item, &signal, 100);
122 assert!((next.skill - 0.0).abs() < 1e-6);
124 }
125
126 #[test]
127 fn optimism_increases_on_easy_success() {
128 let updater = DefaultStateUpdater::default();
129 let state = base_state();
130 let item = Item::new("x", 0.1, "cat");
131 let signal = Signal::new(true, 0.1); let next = updater.update(&state, &item, &signal, 100);
133 assert!(next.optimism_bias > DEFAULT_OPTIMISM);
134 }
135
136 #[test]
137 fn optimism_decreases_on_failure() {
138 let updater = DefaultStateUpdater::default();
139 let state = base_state();
140 let item = Item::new("x", 0.5, "cat");
141 let signal = Signal::new(false, 0.9);
142 let next = updater.update(&state, &item, &signal, 100);
143 assert!(next.optimism_bias < DEFAULT_OPTIMISM);
144 }
145
146 #[test]
147 fn optimism_never_below_floor() {
148 let updater = DefaultStateUpdater::default();
149 let mut state = base_state();
150 state.optimism_bias = OPTIMISM_FLOOR; let item = Item::new("x", 0.5, "cat");
152 let signal = Signal::new(false, 1.0);
153 let next = updater.update(&state, &item, &signal, 100);
154 assert!(next.optimism_bias >= OPTIMISM_FLOOR);
155 }
156
157 #[test]
158 fn resolved_set_updated_on_success() {
159 let updater = DefaultStateUpdater::default();
160 let state = base_state();
161 let item = Item::new("x", 0.5, "cat");
162 let signal = Signal::new(true, 0.5);
163 let next = updater.update(&state, &item, &signal, 100);
164 assert!(next.resolved_set.contains("x"));
165 }
166
167 #[test]
168 fn resolved_set_not_updated_on_failure() {
169 let updater = DefaultStateUpdater::default();
170 let state = base_state();
171 let item = Item::new("x", 0.5, "cat");
172 let signal = Signal::new(false, 0.5);
173 let next = updater.update(&state, &item, &signal, 100);
174 assert!(!next.resolved_set.contains("x"));
175 }
176
177 #[test]
178 fn last_seen_updated() {
179 let updater = DefaultStateUpdater::default();
180 let state = base_state();
181 let item = Item::new("x", 0.5, "cat");
182 let signal = Signal::new(true, 0.5);
183 let next = updater.update(&state, &item, &signal, 9999);
184 assert_eq!(next.last_seen["x"], 9999);
185 }
186}