Skip to main content

converge_knowledge/agentic/
reflexion.rs

1//! Reflexion Episodes - Self-Critique Memory
2//!
3//! Implements the Reflexion pattern from "Reflexion: Language Agents with Verbal Reinforcement Learning"
4//! (Shinn et al., 2023). This allows agents to:
5//!
6//! 1. Attempt a task
7//! 2. If failed, generate a critique of what went wrong
8//! 3. Store the critique for future reference
9//! 4. When attempting similar tasks, retrieve past critiques to avoid mistakes
10//!
11//! # Example
12//!
13//! ```rust,no_run
14//! use converge_knowledge::agentic::{ReflexionEpisode, Critique, CritiqueType};
15//!
16//! // Agent tried to implement sorting but made an error
17//! let episode = ReflexionEpisode::new(
18//!     "algorithm_implementation",
19//!     "Implement quicksort",
20//!     "fn quicksort(arr: &mut [i32]) { /* buggy code */ }",
21//!     false, // failed
22//! )
23//! .with_critique(Critique::new(
24//!     CritiqueType::LogicError,
25//!     "Partition function doesn't handle equal elements",
26//!     "Use <= instead of < in comparison",
27//! ));
28//! ```
29
30use chrono::{DateTime, Utc};
31use serde::{Deserialize, Serialize};
32use uuid::Uuid;
33
34/// A reflexion episode capturing a task attempt and self-critique.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ReflexionEpisode {
37    /// Unique identifier.
38    pub id: Uuid,
39
40    /// Type of task attempted.
41    pub task_type: String,
42
43    /// Original goal or instruction.
44    pub goal: String,
45
46    /// What the agent actually did/produced.
47    pub attempt: String,
48
49    /// Whether the attempt succeeded.
50    pub succeeded: bool,
51
52    /// Self-critiques identifying what went wrong.
53    pub critiques: Vec<Critique>,
54
55    /// Retry count (how many times this was attempted).
56    pub retry_count: u32,
57
58    /// When this episode occurred.
59    pub timestamp: DateTime<Utc>,
60
61    /// Embedding of the goal for similarity search.
62    #[serde(skip)]
63    pub goal_embedding: Option<Vec<f32>>,
64}
65
66impl ReflexionEpisode {
67    /// Create a new reflexion episode.
68    pub fn new(
69        task_type: impl Into<String>,
70        goal: impl Into<String>,
71        attempt: impl Into<String>,
72        succeeded: bool,
73    ) -> Self {
74        Self {
75            id: Uuid::new_v4(),
76            task_type: task_type.into(),
77            goal: goal.into(),
78            attempt: attempt.into(),
79            succeeded,
80            critiques: Vec::new(),
81            retry_count: 0,
82            timestamp: Utc::now(),
83            goal_embedding: None,
84        }
85    }
86
87    /// Add a self-critique.
88    pub fn with_critique(mut self, critique: Critique) -> Self {
89        self.critiques.push(critique);
90        self
91    }
92
93    /// Set retry count.
94    pub fn with_retry_count(mut self, count: u32) -> Self {
95        self.retry_count = count;
96        self
97    }
98
99    /// Set the goal embedding for similarity search.
100    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
101        self.goal_embedding = Some(embedding);
102        self
103    }
104
105    /// Get a summary of all critiques.
106    pub fn critique_summary(&self) -> String {
107        self.critiques
108            .iter()
109            .map(|c| format!("[{}] {}: {}", c.critique_type, c.issue, c.suggestion))
110            .collect::<Vec<_>>()
111            .join("\n")
112    }
113}
114
115/// A self-critique identifying a specific issue.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct Critique {
118    /// Type of issue identified.
119    pub critique_type: CritiqueType,
120
121    /// Description of what went wrong.
122    pub issue: String,
123
124    /// Suggested fix or improvement.
125    pub suggestion: String,
126
127    /// Confidence in this critique (0.0 to 1.0).
128    pub confidence: f32,
129}
130
131impl Critique {
132    /// Create a new critique.
133    pub fn new(
134        critique_type: CritiqueType,
135        issue: impl Into<String>,
136        suggestion: impl Into<String>,
137    ) -> Self {
138        Self {
139            critique_type,
140            issue: issue.into(),
141            suggestion: suggestion.into(),
142            confidence: 1.0,
143        }
144    }
145
146    /// Set confidence level.
147    pub fn with_confidence(mut self, confidence: f32) -> Self {
148        self.confidence = confidence.clamp(0.0, 1.0);
149        self
150    }
151}
152
153/// Types of critiques that can be identified.
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
155pub enum CritiqueType {
156    /// Logical error in reasoning or code.
157    LogicError,
158
159    /// Missing step in the process.
160    MissingStep,
161
162    /// Syntax or formatting error.
163    SyntaxError,
164
165    /// Design or architectural flaw.
166    DesignFlaw,
167
168    /// Edge case not handled.
169    EdgeCase,
170
171    /// Performance issue.
172    Performance,
173
174    /// Security vulnerability.
175    Security,
176
177    /// Misunderstood requirements.
178    Misunderstanding,
179
180    /// Wrong tool or approach used.
181    WrongApproach,
182
183    /// Other issue type.
184    Other,
185}
186
187impl std::fmt::Display for CritiqueType {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            CritiqueType::LogicError => write!(f, "LogicError"),
191            CritiqueType::MissingStep => write!(f, "MissingStep"),
192            CritiqueType::SyntaxError => write!(f, "SyntaxError"),
193            CritiqueType::DesignFlaw => write!(f, "DesignFlaw"),
194            CritiqueType::EdgeCase => write!(f, "EdgeCase"),
195            CritiqueType::Performance => write!(f, "Performance"),
196            CritiqueType::Security => write!(f, "Security"),
197            CritiqueType::Misunderstanding => write!(f, "Misunderstanding"),
198            CritiqueType::WrongApproach => write!(f, "WrongApproach"),
199            CritiqueType::Other => write!(f, "Other"),
200        }
201    }
202}
203
204/// Memory store for reflexion episodes.
205pub struct ReflexionMemory {
206    episodes: Vec<ReflexionEpisode>,
207}
208
209impl ReflexionMemory {
210    /// Create a new reflexion memory.
211    pub fn new() -> Self {
212        Self {
213            episodes: Vec::new(),
214        }
215    }
216
217    /// Add an episode to memory.
218    pub fn add_episode(&mut self, episode: ReflexionEpisode) {
219        self.episodes.push(episode);
220    }
221
222    /// Find similar past failures for a given task.
223    ///
224    /// This uses simple keyword matching. In production, use embedding similarity.
225    pub fn find_similar_failures(&self, task: &str, limit: usize) -> Vec<ReflexionEpisode> {
226        let task_lower = task.to_lowercase();
227        let keywords: Vec<&str> = task_lower.split_whitespace().collect();
228
229        let mut scored: Vec<(f32, &ReflexionEpisode)> = self
230            .episodes
231            .iter()
232            .filter(|e| !e.succeeded) // Only failures
233            .map(|e| {
234                let goal_lower = e.goal.to_lowercase();
235                let type_lower = e.task_type.to_lowercase();
236
237                // Simple keyword matching score
238                let score: f32 = keywords
239                    .iter()
240                    .map(|k| {
241                        if goal_lower.contains(k) || type_lower.contains(k) {
242                            1.0
243                        } else {
244                            0.0
245                        }
246                    })
247                    .sum();
248
249                (score, e)
250            })
251            .filter(|(score, _)| *score > 0.0)
252            .collect();
253
254        // Sort by score descending
255        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
256
257        scored
258            .into_iter()
259            .take(limit)
260            .map(|(_, e)| e.clone())
261            .collect()
262    }
263
264    /// Get all episodes of a specific type.
265    pub fn get_by_type(&self, task_type: &str) -> Vec<&ReflexionEpisode> {
266        self.episodes
267            .iter()
268            .filter(|e| e.task_type == task_type)
269            .collect()
270    }
271
272    /// Get recent failures.
273    pub fn recent_failures(&self, limit: usize) -> Vec<&ReflexionEpisode> {
274        let mut failures: Vec<_> = self.episodes.iter().filter(|e| !e.succeeded).collect();
275        failures.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
276        failures.into_iter().take(limit).collect()
277    }
278
279    /// Total episode count.
280    pub fn len(&self) -> usize {
281        self.episodes.len()
282    }
283
284    /// Check if empty.
285    pub fn is_empty(&self) -> bool {
286        self.episodes.is_empty()
287    }
288
289    /// Count of failed episodes.
290    pub fn failure_count(&self) -> usize {
291        self.episodes.iter().filter(|e| !e.succeeded).count()
292    }
293
294    /// Success rate.
295    pub fn success_rate(&self) -> f32 {
296        if self.episodes.is_empty() {
297            return 0.0;
298        }
299        let successes = self.episodes.iter().filter(|e| e.succeeded).count();
300        successes as f32 / self.episodes.len() as f32
301    }
302}
303
304impl Default for ReflexionMemory {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    /// Test: Basic reflexion episode creation.
315    ///
316    /// What happens:
317    /// 1. Create an episode describing a failed attempt
318    /// 2. Add critiques explaining what went wrong
319    /// 3. The episode captures the full context for future learning
320    #[test]
321    fn test_episode_creation() {
322        let episode = ReflexionEpisode::new(
323            "code_review",
324            "Review pull request #123",
325            "Approved without noticing the SQL injection",
326            false,
327        )
328        .with_critique(Critique::new(
329            CritiqueType::Security,
330            "Missed SQL injection vulnerability in user input",
331            "Always check for unsanitized inputs in database queries",
332        ))
333        .with_retry_count(2);
334
335        assert_eq!(episode.task_type, "code_review");
336        assert!(!episode.succeeded);
337        assert_eq!(episode.critiques.len(), 1);
338        assert_eq!(episode.retry_count, 2);
339    }
340
341    /// Test: Finding similar failures.
342    ///
343    /// What happens:
344    /// 1. Store multiple failed episodes
345    /// 2. Search for episodes similar to a new task
346    /// 3. Return episodes that match keywords
347    /// 4. Agent can learn from past mistakes
348    #[test]
349    fn test_find_similar_failures() {
350        let mut memory = ReflexionMemory::new();
351
352        // Add some failures
353        memory.add_episode(
354            ReflexionEpisode::new(
355                "sql_query",
356                "Write SQL query for user search",
357                "SELECT * FROM users WHERE name = '{input}'",
358                false,
359            )
360            .with_critique(Critique::new(
361                CritiqueType::Security,
362                "SQL injection possible",
363                "Use parameterized queries",
364            )),
365        );
366
367        memory.add_episode(
368            ReflexionEpisode::new(
369                "api_design",
370                "Design REST API for payments",
371                "POST /pay without authentication",
372                false,
373            )
374            .with_critique(Critique::new(
375                CritiqueType::Security,
376                "No auth on sensitive endpoint",
377                "Add authentication middleware",
378            )),
379        );
380
381        // Search for SQL-related failures
382        let similar = memory.find_similar_failures("SQL query for orders", 5);
383        assert!(!similar.is_empty());
384        // SQL query should be the top result
385        assert!(similar.iter().any(|e| e.task_type == "sql_query"));
386
387        // Search for API-related failures
388        let similar = memory.find_similar_failures("REST API endpoint", 5);
389        assert!(!similar.is_empty());
390        // API design should be in the results
391        assert!(similar.iter().any(|e| e.task_type == "api_design"));
392    }
393
394    /// Test: Success rate tracking.
395    ///
396    /// What happens:
397    /// 1. Track both successes and failures
398    /// 2. Calculate success rate
399    /// 3. Agent can see improvement over time
400    #[test]
401    fn test_success_rate() {
402        let mut memory = ReflexionMemory::new();
403
404        // Add 3 failures
405        for _ in 0..3 {
406            memory.add_episode(ReflexionEpisode::new("test", "goal", "attempt", false));
407        }
408
409        // Add 1 success
410        memory.add_episode(ReflexionEpisode::new("test", "goal", "attempt", true));
411
412        assert_eq!(memory.len(), 4);
413        assert_eq!(memory.failure_count(), 3);
414        assert!((memory.success_rate() - 0.25).abs() < 0.01);
415    }
416}