1use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, VecDeque};
14use std::sync::RwLock;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum SkillOutcome {
20 Success,
22 Failure,
24 Partial,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SkillFeedback {
31 pub skill_name: String,
33 pub outcome: SkillOutcome,
35 pub score_delta: f32,
37 pub reason: String,
39 pub timestamp: i64,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SkillScore {
46 pub skill_name: String,
48 pub score: f32,
50 pub feedback_count: usize,
52 pub disabled: bool,
54}
55
56pub trait SkillScorer: Send + Sync {
61 fn record(&self, feedback: SkillFeedback);
63
64 fn score(&self, skill_name: &str) -> f32;
66
67 fn should_disable(&self, skill_name: &str) -> bool;
69
70 fn all_scores(&self) -> Vec<SkillScore>;
72}
73
74pub struct DefaultSkillScorer {
76 history: RwLock<HashMap<String, VecDeque<SkillFeedback>>>,
78 pub window_size: usize,
80 pub disable_threshold: f32,
82 pub min_feedback_count: usize,
84}
85
86impl Default for DefaultSkillScorer {
87 fn default() -> Self {
88 Self {
89 history: RwLock::new(HashMap::new()),
90 window_size: 20,
91 disable_threshold: 0.3,
92 min_feedback_count: 3,
93 }
94 }
95}
96
97impl DefaultSkillScorer {
98 pub fn new(window_size: usize, disable_threshold: f32, min_feedback_count: usize) -> Self {
100 Self {
101 history: RwLock::new(HashMap::new()),
102 window_size,
103 disable_threshold,
104 min_feedback_count,
105 }
106 }
107
108 fn compute_score(entries: &VecDeque<SkillFeedback>) -> f32 {
111 if entries.is_empty() {
112 return 1.0; }
114
115 let n = entries.len() as f32;
116 let mut weighted_sum = 0.0f32;
117 let mut weight_total = 0.0f32;
118
119 for (i, entry) in entries.iter().enumerate() {
120 let weight = (i as f32 + 1.0) / n;
122 let normalized = (entry.score_delta + 1.0) / 2.0;
124 weighted_sum += normalized * weight;
125 weight_total += weight;
126 }
127
128 if weight_total == 0.0 {
129 return 1.0;
130 }
131
132 (weighted_sum / weight_total).clamp(0.0, 1.0)
133 }
134}
135
136impl SkillScorer for DefaultSkillScorer {
137 fn record(&self, feedback: SkillFeedback) {
138 let mut history = self.history.write().unwrap();
139 let entries = history.entry(feedback.skill_name.clone()).or_default();
140
141 entries.push_back(feedback);
142
143 while entries.len() > self.window_size {
145 entries.pop_front();
146 }
147 }
148
149 fn score(&self, skill_name: &str) -> f32 {
150 let history = self.history.read().unwrap();
151 match history.get(skill_name) {
152 Some(entries) => Self::compute_score(entries),
153 None => 1.0, }
155 }
156
157 fn should_disable(&self, skill_name: &str) -> bool {
158 let history = self.history.read().unwrap();
159 match history.get(skill_name) {
160 Some(entries) => {
161 if entries.len() < self.min_feedback_count {
162 return false; }
164 Self::compute_score(entries) < self.disable_threshold
165 }
166 None => false,
167 }
168 }
169
170 fn all_scores(&self) -> Vec<SkillScore> {
171 let history = self.history.read().unwrap();
172 history
173 .iter()
174 .map(|(name, entries)| {
175 let score = Self::compute_score(entries);
176 SkillScore {
177 skill_name: name.clone(),
178 score,
179 feedback_count: entries.len(),
180 disabled: entries.len() >= self.min_feedback_count
181 && score < self.disable_threshold,
182 }
183 })
184 .collect()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 fn now_ms() -> i64 {
193 std::time::SystemTime::now()
194 .duration_since(std::time::UNIX_EPOCH)
195 .unwrap()
196 .as_millis() as i64
197 }
198
199 fn make_feedback(skill: &str, outcome: SkillOutcome, delta: f32) -> SkillFeedback {
200 SkillFeedback {
201 skill_name: skill.to_string(),
202 outcome,
203 score_delta: delta,
204 reason: "test".to_string(),
205 timestamp: now_ms(),
206 }
207 }
208
209 #[test]
212 fn test_unknown_skill_score_is_1() {
213 let scorer = DefaultSkillScorer::default();
214 assert_eq!(scorer.score("nonexistent"), 1.0);
215 }
216
217 #[test]
218 fn test_all_success_high_score() {
219 let scorer = DefaultSkillScorer::default();
220 for _ in 0..5 {
221 scorer.record(make_feedback("good-skill", SkillOutcome::Success, 1.0));
222 }
223 let score = scorer.score("good-skill");
224 assert!(score > 0.9, "Expected high score, got {}", score);
225 }
226
227 #[test]
228 fn test_all_failure_low_score() {
229 let scorer = DefaultSkillScorer::default();
230 for _ in 0..5 {
231 scorer.record(make_feedback("bad-skill", SkillOutcome::Failure, -1.0));
232 }
233 let score = scorer.score("bad-skill");
234 assert!(score < 0.1, "Expected low score, got {}", score);
235 }
236
237 #[test]
238 fn test_mixed_feedback_moderate_score() {
239 let scorer = DefaultSkillScorer::default();
240 scorer.record(make_feedback("mixed", SkillOutcome::Success, 1.0));
241 scorer.record(make_feedback("mixed", SkillOutcome::Failure, -1.0));
242 scorer.record(make_feedback("mixed", SkillOutcome::Success, 1.0));
243 let score = scorer.score("mixed");
244 assert!(
246 score > 0.4 && score < 0.8,
247 "Expected moderate score, got {}",
248 score
249 );
250 }
251
252 #[test]
255 fn test_should_not_disable_unknown() {
256 let scorer = DefaultSkillScorer::default();
257 assert!(!scorer.should_disable("unknown"));
258 }
259
260 #[test]
261 fn test_should_not_disable_insufficient_data() {
262 let scorer = DefaultSkillScorer::default();
263 scorer.record(make_feedback("new-skill", SkillOutcome::Failure, -1.0));
265 scorer.record(make_feedback("new-skill", SkillOutcome::Failure, -1.0));
266 assert!(!scorer.should_disable("new-skill"));
267 }
268
269 #[test]
270 fn test_should_disable_consistently_bad() {
271 let scorer = DefaultSkillScorer::default();
272 for _ in 0..5 {
273 scorer.record(make_feedback("terrible", SkillOutcome::Failure, -1.0));
274 }
275 assert!(scorer.should_disable("terrible"));
276 }
277
278 #[test]
279 fn test_should_not_disable_good_skill() {
280 let scorer = DefaultSkillScorer::default();
281 for _ in 0..5 {
282 scorer.record(make_feedback("great", SkillOutcome::Success, 1.0));
283 }
284 assert!(!scorer.should_disable("great"));
285 }
286
287 #[test]
290 fn test_window_trimming() {
291 let scorer = DefaultSkillScorer::new(5, 0.3, 3);
292 for _ in 0..5 {
294 scorer.record(make_feedback("recover", SkillOutcome::Failure, -1.0));
295 }
296 assert!(scorer.should_disable("recover"));
297
298 for _ in 0..5 {
300 scorer.record(make_feedback("recover", SkillOutcome::Success, 1.0));
301 }
302 assert!(!scorer.should_disable("recover"));
303 assert!(scorer.score("recover") > 0.9);
304 }
305
306 #[test]
309 fn test_all_scores_empty() {
310 let scorer = DefaultSkillScorer::default();
311 assert!(scorer.all_scores().is_empty());
312 }
313
314 #[test]
315 fn test_all_scores_multiple_skills() {
316 let scorer = DefaultSkillScorer::default();
317 for _ in 0..3 {
318 scorer.record(make_feedback("skill-a", SkillOutcome::Success, 1.0));
319 scorer.record(make_feedback("skill-b", SkillOutcome::Failure, -1.0));
320 }
321
322 let scores = scorer.all_scores();
323 assert_eq!(scores.len(), 2);
324
325 let a = scores.iter().find(|s| s.skill_name == "skill-a").unwrap();
326 let b = scores.iter().find(|s| s.skill_name == "skill-b").unwrap();
327
328 assert!(a.score > 0.9);
329 assert!(!a.disabled);
330 assert_eq!(a.feedback_count, 3);
331
332 assert!(b.score < 0.1);
333 assert!(b.disabled);
334 assert_eq!(b.feedback_count, 3);
335 }
336
337 #[test]
340 fn test_custom_threshold() {
341 let scorer = DefaultSkillScorer::new(20, 0.8, 3);
342 for _ in 0..5 {
344 scorer.record(make_feedback("mediocre", SkillOutcome::Partial, 0.0));
345 }
346 assert!(scorer.should_disable("mediocre"));
348 }
349
350 #[test]
353 fn test_outcome_serialization() {
354 let json = serde_json::to_string(&SkillOutcome::Success).unwrap();
355 assert_eq!(json, "\"success\"");
356
357 let parsed: SkillOutcome = serde_json::from_str("\"failure\"").unwrap();
358 assert_eq!(parsed, SkillOutcome::Failure);
359 }
360
361 #[test]
364 fn test_feedback_serialization() {
365 let fb = make_feedback("test", SkillOutcome::Success, 0.8);
366 let json = serde_json::to_string(&fb).unwrap();
367 assert!(json.contains("\"skill_name\":\"test\""));
368 assert!(json.contains("\"outcome\":\"success\""));
369
370 let parsed: SkillFeedback = serde_json::from_str(&json).unwrap();
371 assert_eq!(parsed.skill_name, "test");
372 assert_eq!(parsed.outcome, SkillOutcome::Success);
373 }
374
375 #[test]
378 fn test_compute_score_empty() {
379 let empty = VecDeque::new();
380 assert_eq!(DefaultSkillScorer::compute_score(&empty), 1.0);
381 }
382
383 #[test]
384 fn test_compute_score_single_entry() {
385 let mut entries = VecDeque::new();
386 entries.push_back(make_feedback("s", SkillOutcome::Success, 1.0));
387 let score = DefaultSkillScorer::compute_score(&entries);
388 assert!((score - 1.0).abs() < f32::EPSILON);
389 }
390}