Skip to main content

a3s_code_core/skills/
feedback.rs

1//! Skill Feedback Loop
2//!
3//! Provides a scoring mechanism for skills based on usage feedback.
4//! Skills that consistently perform poorly are automatically disabled
5//! from the system prompt (soft-disable, not deleted).
6//!
7//! ## Extension Point
8//!
9//! `SkillScorer` is a trait — consumers can replace `DefaultSkillScorer`
10//! with a persistent implementation (e.g., database-backed, LLM-evaluated).
11
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, VecDeque};
14use std::sync::RwLock;
15
16/// Outcome of a skill usage
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum SkillOutcome {
20    /// Skill contributed to a successful result
21    Success,
22    /// Skill did not help or caused failure
23    Failure,
24    /// Skill partially helped
25    Partial,
26}
27
28/// Feedback record for a single skill usage
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SkillFeedback {
31    /// Skill name
32    pub skill_name: String,
33    /// Outcome of the usage
34    pub outcome: SkillOutcome,
35    /// Score adjustment (-1.0 to 1.0)
36    pub score_delta: f32,
37    /// Human-readable reason
38    pub reason: String,
39    /// Timestamp (Unix milliseconds)
40    pub timestamp: i64,
41}
42
43/// Skill score summary
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SkillScore {
46    /// Skill name
47    pub skill_name: String,
48    /// Current weighted score (0.0 to 1.0)
49    pub score: f32,
50    /// Total feedback count
51    pub feedback_count: usize,
52    /// Whether the skill is currently disabled
53    pub disabled: bool,
54}
55
56/// Skill scorer trait (extension point)
57///
58/// Records feedback and computes scores for skills. Implementations
59/// can use any scoring algorithm and persistence strategy.
60pub trait SkillScorer: Send + Sync {
61    /// Record a feedback entry for a skill
62    fn record(&self, feedback: SkillFeedback);
63
64    /// Get the current score for a skill (0.0 to 1.0, default 1.0 for unknown)
65    fn score(&self, skill_name: &str) -> f32;
66
67    /// Check if a skill should be disabled based on its score
68    fn should_disable(&self, skill_name: &str) -> bool;
69
70    /// Get score summaries for all tracked skills
71    fn all_scores(&self) -> Vec<SkillScore>;
72}
73
74/// Default skill scorer using sliding window weighted average
75pub struct DefaultSkillScorer {
76    /// Feedback history per skill
77    history: RwLock<HashMap<String, VecDeque<SkillFeedback>>>,
78    /// Maximum feedback entries to keep per skill
79    pub window_size: usize,
80    /// Score threshold below which a skill is disabled
81    pub disable_threshold: f32,
82    /// Minimum feedback count before disabling is considered
83    pub min_feedback_count: usize,
84}
85
86impl Default for DefaultSkillScorer {
87    fn default() -> Self {
88        Self {
89            history: RwLock::new(HashMap::new()),
90            window_size: 20,
91            disable_threshold: 0.3,
92            min_feedback_count: 3,
93        }
94    }
95}
96
97impl DefaultSkillScorer {
98    /// Create a new scorer with custom parameters
99    pub fn new(window_size: usize, disable_threshold: f32, min_feedback_count: usize) -> Self {
100        Self {
101            history: RwLock::new(HashMap::new()),
102            window_size,
103            disable_threshold,
104            min_feedback_count,
105        }
106    }
107
108    /// Compute weighted average score for a feedback window.
109    /// More recent entries have higher weight (linear decay).
110    fn compute_score(entries: &VecDeque<SkillFeedback>) -> f32 {
111        if entries.is_empty() {
112            return 1.0; // Default: trust unknown skills
113        }
114
115        let n = entries.len() as f32;
116        let mut weighted_sum = 0.0f32;
117        let mut weight_total = 0.0f32;
118
119        for (i, entry) in entries.iter().enumerate() {
120            // Linear weight: older entries get lower weight
121            let weight = (i as f32 + 1.0) / n;
122            // Convert score_delta from [-1, 1] to [0, 1]
123            let normalized = (entry.score_delta + 1.0) / 2.0;
124            weighted_sum += normalized * weight;
125            weight_total += weight;
126        }
127
128        if weight_total == 0.0 {
129            return 1.0;
130        }
131
132        (weighted_sum / weight_total).clamp(0.0, 1.0)
133    }
134}
135
136impl SkillScorer for DefaultSkillScorer {
137    fn record(&self, feedback: SkillFeedback) {
138        let mut history = self.history.write().unwrap();
139        let entries = history.entry(feedback.skill_name.clone()).or_default();
140
141        entries.push_back(feedback);
142
143        // Trim to window size
144        while entries.len() > self.window_size {
145            entries.pop_front();
146        }
147    }
148
149    fn score(&self, skill_name: &str) -> f32 {
150        let history = self.history.read().unwrap();
151        match history.get(skill_name) {
152            Some(entries) => Self::compute_score(entries),
153            None => 1.0, // Unknown skill = full trust
154        }
155    }
156
157    fn should_disable(&self, skill_name: &str) -> bool {
158        let history = self.history.read().unwrap();
159        match history.get(skill_name) {
160            Some(entries) => {
161                if entries.len() < self.min_feedback_count {
162                    return false; // Not enough data
163                }
164                Self::compute_score(entries) < self.disable_threshold
165            }
166            None => false,
167        }
168    }
169
170    fn all_scores(&self) -> Vec<SkillScore> {
171        let history = self.history.read().unwrap();
172        history
173            .iter()
174            .map(|(name, entries)| {
175                let score = Self::compute_score(entries);
176                SkillScore {
177                    skill_name: name.clone(),
178                    score,
179                    feedback_count: entries.len(),
180                    disabled: entries.len() >= self.min_feedback_count
181                        && score < self.disable_threshold,
182                }
183            })
184            .collect()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    fn now_ms() -> i64 {
193        std::time::SystemTime::now()
194            .duration_since(std::time::UNIX_EPOCH)
195            .unwrap()
196            .as_millis() as i64
197    }
198
199    fn make_feedback(skill: &str, outcome: SkillOutcome, delta: f32) -> SkillFeedback {
200        SkillFeedback {
201            skill_name: skill.to_string(),
202            outcome,
203            score_delta: delta,
204            reason: "test".to_string(),
205            timestamp: now_ms(),
206        }
207    }
208
209    // --- Basic scoring ---
210
211    #[test]
212    fn test_unknown_skill_score_is_1() {
213        let scorer = DefaultSkillScorer::default();
214        assert_eq!(scorer.score("nonexistent"), 1.0);
215    }
216
217    #[test]
218    fn test_all_success_high_score() {
219        let scorer = DefaultSkillScorer::default();
220        for _ in 0..5 {
221            scorer.record(make_feedback("good-skill", SkillOutcome::Success, 1.0));
222        }
223        let score = scorer.score("good-skill");
224        assert!(score > 0.9, "Expected high score, got {}", score);
225    }
226
227    #[test]
228    fn test_all_failure_low_score() {
229        let scorer = DefaultSkillScorer::default();
230        for _ in 0..5 {
231            scorer.record(make_feedback("bad-skill", SkillOutcome::Failure, -1.0));
232        }
233        let score = scorer.score("bad-skill");
234        assert!(score < 0.1, "Expected low score, got {}", score);
235    }
236
237    #[test]
238    fn test_mixed_feedback_moderate_score() {
239        let scorer = DefaultSkillScorer::default();
240        scorer.record(make_feedback("mixed", SkillOutcome::Success, 1.0));
241        scorer.record(make_feedback("mixed", SkillOutcome::Failure, -1.0));
242        scorer.record(make_feedback("mixed", SkillOutcome::Success, 1.0));
243        let score = scorer.score("mixed");
244        // Weighted toward recent (last entry is success), should be > 0.5
245        assert!(
246            score > 0.4 && score < 0.8,
247            "Expected moderate score, got {}",
248            score
249        );
250    }
251
252    // --- Disable logic ---
253
254    #[test]
255    fn test_should_not_disable_unknown() {
256        let scorer = DefaultSkillScorer::default();
257        assert!(!scorer.should_disable("unknown"));
258    }
259
260    #[test]
261    fn test_should_not_disable_insufficient_data() {
262        let scorer = DefaultSkillScorer::default();
263        // Only 2 entries, min is 3
264        scorer.record(make_feedback("new-skill", SkillOutcome::Failure, -1.0));
265        scorer.record(make_feedback("new-skill", SkillOutcome::Failure, -1.0));
266        assert!(!scorer.should_disable("new-skill"));
267    }
268
269    #[test]
270    fn test_should_disable_consistently_bad() {
271        let scorer = DefaultSkillScorer::default();
272        for _ in 0..5 {
273            scorer.record(make_feedback("terrible", SkillOutcome::Failure, -1.0));
274        }
275        assert!(scorer.should_disable("terrible"));
276    }
277
278    #[test]
279    fn test_should_not_disable_good_skill() {
280        let scorer = DefaultSkillScorer::default();
281        for _ in 0..5 {
282            scorer.record(make_feedback("great", SkillOutcome::Success, 1.0));
283        }
284        assert!(!scorer.should_disable("great"));
285    }
286
287    // --- Window trimming ---
288
289    #[test]
290    fn test_window_trimming() {
291        let scorer = DefaultSkillScorer::new(5, 0.3, 3);
292        // Fill with failures
293        for _ in 0..5 {
294            scorer.record(make_feedback("recover", SkillOutcome::Failure, -1.0));
295        }
296        assert!(scorer.should_disable("recover"));
297
298        // Now add successes — old failures should be trimmed
299        for _ in 0..5 {
300            scorer.record(make_feedback("recover", SkillOutcome::Success, 1.0));
301        }
302        assert!(!scorer.should_disable("recover"));
303        assert!(scorer.score("recover") > 0.9);
304    }
305
306    // --- all_scores ---
307
308    #[test]
309    fn test_all_scores_empty() {
310        let scorer = DefaultSkillScorer::default();
311        assert!(scorer.all_scores().is_empty());
312    }
313
314    #[test]
315    fn test_all_scores_multiple_skills() {
316        let scorer = DefaultSkillScorer::default();
317        for _ in 0..3 {
318            scorer.record(make_feedback("skill-a", SkillOutcome::Success, 1.0));
319            scorer.record(make_feedback("skill-b", SkillOutcome::Failure, -1.0));
320        }
321
322        let scores = scorer.all_scores();
323        assert_eq!(scores.len(), 2);
324
325        let a = scores.iter().find(|s| s.skill_name == "skill-a").unwrap();
326        let b = scores.iter().find(|s| s.skill_name == "skill-b").unwrap();
327
328        assert!(a.score > 0.9);
329        assert!(!a.disabled);
330        assert_eq!(a.feedback_count, 3);
331
332        assert!(b.score < 0.1);
333        assert!(b.disabled);
334        assert_eq!(b.feedback_count, 3);
335    }
336
337    // --- Custom parameters ---
338
339    #[test]
340    fn test_custom_threshold() {
341        let scorer = DefaultSkillScorer::new(20, 0.8, 3);
342        // Partial feedback (delta=0.0 → normalized 0.5)
343        for _ in 0..5 {
344            scorer.record(make_feedback("mediocre", SkillOutcome::Partial, 0.0));
345        }
346        // Score ~0.5, threshold 0.8 → should disable
347        assert!(scorer.should_disable("mediocre"));
348    }
349
350    // --- Outcome serialization ---
351
352    #[test]
353    fn test_outcome_serialization() {
354        let json = serde_json::to_string(&SkillOutcome::Success).unwrap();
355        assert_eq!(json, "\"success\"");
356
357        let parsed: SkillOutcome = serde_json::from_str("\"failure\"").unwrap();
358        assert_eq!(parsed, SkillOutcome::Failure);
359    }
360
361    // --- SkillFeedback serialization ---
362
363    #[test]
364    fn test_feedback_serialization() {
365        let fb = make_feedback("test", SkillOutcome::Success, 0.8);
366        let json = serde_json::to_string(&fb).unwrap();
367        assert!(json.contains("\"skill_name\":\"test\""));
368        assert!(json.contains("\"outcome\":\"success\""));
369
370        let parsed: SkillFeedback = serde_json::from_str(&json).unwrap();
371        assert_eq!(parsed.skill_name, "test");
372        assert_eq!(parsed.outcome, SkillOutcome::Success);
373    }
374
375    // --- compute_score edge cases ---
376
377    #[test]
378    fn test_compute_score_empty() {
379        let empty = VecDeque::new();
380        assert_eq!(DefaultSkillScorer::compute_score(&empty), 1.0);
381    }
382
383    #[test]
384    fn test_compute_score_single_entry() {
385        let mut entries = VecDeque::new();
386        entries.push_back(make_feedback("s", SkillOutcome::Success, 1.0));
387        let score = DefaultSkillScorer::compute_score(&entries);
388        assert!((score - 1.0).abs() < f32::EPSILON);
389    }
390}