Skip to main content

agent_reflection/
lib.rs

1/*!
2agent-reflection: self-evaluation loop for LLM agent outputs.
3
4```rust
5use agent_reflection::{ReflectionResult, ReflectionStore};
6
7let mut store = ReflectionStore::new();
8let id = store.record("original output", "my critique", 0.7);
9let r = store.get(id).unwrap();
10assert!((r.score - 0.7).abs() < 1e-9);
11```
12*/
13
14/// A reflection entry.
15#[derive(Debug, Clone)]
16pub struct ReflectionResult {
17    pub id: usize,
18    /// The original output being evaluated.
19    pub output: String,
20    /// The critique / self-evaluation text.
21    pub critique: String,
22    /// Quality score [0.0, 1.0].
23    pub score: f64,
24    /// Suggested improvement.
25    pub suggestion: Option<String>,
26}
27
28impl ReflectionResult {
29    pub fn passes(&self, threshold: f64) -> bool {
30        self.score >= threshold
31    }
32}
33
34/// Stores reflection results across turns.
35#[derive(Default)]
36pub struct ReflectionStore {
37    entries: Vec<ReflectionResult>,
38}
39
40impl ReflectionStore {
41    pub fn new() -> Self { Self::default() }
42
43    /// Record a reflection and return its id.
44    pub fn record(&mut self, output: &str, critique: &str, score: f64) -> usize {
45        let id = self.entries.len();
46        self.entries.push(ReflectionResult { id, output: output.to_string(), critique: critique.to_string(), score, suggestion: None });
47        id
48    }
49
50    /// Record with a suggestion.
51    pub fn record_with_suggestion(&mut self, output: &str, critique: &str, score: f64, suggestion: &str) -> usize {
52        let id = self.record(output, critique, score);
53        self.entries[id].suggestion = Some(suggestion.to_string());
54        id
55    }
56
57    pub fn get(&self, id: usize) -> Option<&ReflectionResult> {
58        self.entries.get(id)
59    }
60
61    pub fn len(&self) -> usize { self.entries.len() }
62    pub fn is_empty(&self) -> bool { self.entries.is_empty() }
63
64    /// Average score across all entries.
65    pub fn avg_score(&self) -> Option<f64> {
66        if self.entries.is_empty() { return None; }
67        Some(self.entries.iter().map(|e| e.score).sum::<f64>() / self.entries.len() as f64)
68    }
69
70    /// Entries that pass a given threshold.
71    pub fn passing(&self, threshold: f64) -> Vec<&ReflectionResult> {
72        self.entries.iter().filter(|e| e.passes(threshold)).collect()
73    }
74
75    /// Entries below threshold.
76    pub fn failing(&self, threshold: f64) -> Vec<&ReflectionResult> {
77        self.entries.iter().filter(|e| !e.passes(threshold)).collect()
78    }
79
80    /// Best-scoring entry.
81    pub fn best(&self) -> Option<&ReflectionResult> {
82        self.entries.iter().max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
83    }
84
85    pub fn clear(&mut self) { self.entries.clear(); }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn record_and_get() {
94        let mut s = ReflectionStore::new();
95        let id = s.record("output", "good", 0.8);
96        let r = s.get(id).unwrap();
97        assert_eq!(r.output, "output");
98        assert!((r.score - 0.8).abs() < 1e-9);
99    }
100
101    #[test]
102    fn passes_threshold() {
103        let r = ReflectionResult { id: 0, output: "x".into(), critique: "ok".into(), score: 0.75, suggestion: None };
104        assert!(r.passes(0.7));
105        assert!(!r.passes(0.8));
106    }
107
108    #[test]
109    fn avg_score() {
110        let mut s = ReflectionStore::new();
111        s.record("a", "ok", 0.6);
112        s.record("b", "ok", 0.8);
113        let avg = s.avg_score().unwrap();
114        assert!((avg - 0.7).abs() < 1e-9);
115    }
116
117    #[test]
118    fn avg_score_empty() {
119        let s = ReflectionStore::new();
120        assert!(s.avg_score().is_none());
121    }
122
123    #[test]
124    fn passing_and_failing() {
125        let mut s = ReflectionStore::new();
126        s.record("a", "bad", 0.4);
127        s.record("b", "ok", 0.9);
128        s.record("c", "meh", 0.7);
129        assert_eq!(s.passing(0.7).len(), 2);
130        assert_eq!(s.failing(0.7).len(), 1);
131    }
132
133    #[test]
134    fn best() {
135        let mut s = ReflectionStore::new();
136        s.record("a", "c", 0.5);
137        s.record("b", "c", 0.9);
138        s.record("c", "c", 0.7);
139        assert!((s.best().unwrap().score - 0.9).abs() < 1e-9);
140    }
141
142    #[test]
143    fn record_with_suggestion() {
144        let mut s = ReflectionStore::new();
145        let id = s.record_with_suggestion("out", "bad", 0.3, "try being more concise");
146        assert_eq!(s.get(id).unwrap().suggestion.as_deref(), Some("try being more concise"));
147    }
148
149    #[test]
150    fn len_and_empty() {
151        let mut s = ReflectionStore::new();
152        assert!(s.is_empty());
153        s.record("x", "y", 0.5);
154        assert_eq!(s.len(), 1);
155    }
156
157    #[test]
158    fn clear() {
159        let mut s = ReflectionStore::new();
160        s.record("x", "y", 0.5);
161        s.clear();
162        assert!(s.is_empty());
163    }
164
165    #[test]
166    fn ids_sequential() {
167        let mut s = ReflectionStore::new();
168        let a = s.record("x", "y", 0.5);
169        let b = s.record("p", "q", 0.6);
170        assert_eq!(a, 0);
171        assert_eq!(b, 1);
172    }
173
174    #[test]
175    fn missing_id_returns_none() {
176        let s = ReflectionStore::new();
177        assert!(s.get(99).is_none());
178    }
179
180    #[test]
181    fn no_suggestion_is_none() {
182        let mut s = ReflectionStore::new();
183        let id = s.record("x", "y", 0.5);
184        assert!(s.get(id).unwrap().suggestion.is_none());
185    }
186
187    #[test]
188    fn passing_all() {
189        let mut s = ReflectionStore::new();
190        s.record("a", "c", 1.0);
191        s.record("b", "c", 1.0);
192        assert_eq!(s.passing(0.9).len(), 2);
193    }
194}