use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflexionEpisode {
pub id: Uuid,
pub task_type: String,
pub goal: String,
pub attempt: String,
pub succeeded: bool,
pub critiques: Vec<Critique>,
pub retry_count: u32,
pub timestamp: DateTime<Utc>,
#[serde(skip)]
pub goal_embedding: Option<Vec<f32>>,
}
impl ReflexionEpisode {
pub fn new(
task_type: impl Into<String>,
goal: impl Into<String>,
attempt: impl Into<String>,
succeeded: bool,
) -> Self {
Self {
id: Uuid::new_v4(),
task_type: task_type.into(),
goal: goal.into(),
attempt: attempt.into(),
succeeded,
critiques: Vec::new(),
retry_count: 0,
timestamp: Utc::now(),
goal_embedding: None,
}
}
pub fn with_critique(mut self, critique: Critique) -> Self {
self.critiques.push(critique);
self
}
pub fn with_retry_count(mut self, count: u32) -> Self {
self.retry_count = count;
self
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.goal_embedding = Some(embedding);
self
}
pub fn critique_summary(&self) -> String {
self.critiques
.iter()
.map(|c| format!("[{}] {}: {}", c.critique_type, c.issue, c.suggestion))
.collect::<Vec<_>>()
.join("\n")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Critique {
pub critique_type: CritiqueType,
pub issue: String,
pub suggestion: String,
pub confidence: f32,
}
impl Critique {
pub fn new(
critique_type: CritiqueType,
issue: impl Into<String>,
suggestion: impl Into<String>,
) -> Self {
Self {
critique_type,
issue: issue.into(),
suggestion: suggestion.into(),
confidence: 1.0,
}
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CritiqueType {
LogicError,
MissingStep,
SyntaxError,
DesignFlaw,
EdgeCase,
Performance,
Security,
Misunderstanding,
WrongApproach,
Other,
}
impl std::fmt::Display for CritiqueType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CritiqueType::LogicError => write!(f, "LogicError"),
CritiqueType::MissingStep => write!(f, "MissingStep"),
CritiqueType::SyntaxError => write!(f, "SyntaxError"),
CritiqueType::DesignFlaw => write!(f, "DesignFlaw"),
CritiqueType::EdgeCase => write!(f, "EdgeCase"),
CritiqueType::Performance => write!(f, "Performance"),
CritiqueType::Security => write!(f, "Security"),
CritiqueType::Misunderstanding => write!(f, "Misunderstanding"),
CritiqueType::WrongApproach => write!(f, "WrongApproach"),
CritiqueType::Other => write!(f, "Other"),
}
}
}
pub struct ReflexionMemory {
episodes: Vec<ReflexionEpisode>,
}
impl ReflexionMemory {
pub fn new() -> Self {
Self {
episodes: Vec::new(),
}
}
pub fn add_episode(&mut self, episode: ReflexionEpisode) {
self.episodes.push(episode);
}
pub fn find_similar_failures(&self, task: &str, limit: usize) -> Vec<ReflexionEpisode> {
let task_lower = task.to_lowercase();
let keywords: Vec<&str> = task_lower.split_whitespace().collect();
let mut scored: Vec<(f32, &ReflexionEpisode)> = self
.episodes
.iter()
.filter(|e| !e.succeeded) .map(|e| {
let goal_lower = e.goal.to_lowercase();
let type_lower = e.task_type.to_lowercase();
let score: f32 = keywords
.iter()
.map(|k| {
if goal_lower.contains(k) || type_lower.contains(k) {
1.0
} else {
0.0
}
})
.sum();
(score, e)
})
.filter(|(score, _)| *score > 0.0)
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(limit)
.map(|(_, e)| e.clone())
.collect()
}
pub fn get_by_type(&self, task_type: &str) -> Vec<&ReflexionEpisode> {
self.episodes
.iter()
.filter(|e| e.task_type == task_type)
.collect()
}
pub fn recent_failures(&self, limit: usize) -> Vec<&ReflexionEpisode> {
let mut failures: Vec<_> = self.episodes.iter().filter(|e| !e.succeeded).collect();
failures.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
failures.into_iter().take(limit).collect()
}
pub fn len(&self) -> usize {
self.episodes.len()
}
pub fn is_empty(&self) -> bool {
self.episodes.is_empty()
}
pub fn failure_count(&self) -> usize {
self.episodes.iter().filter(|e| !e.succeeded).count()
}
pub fn success_rate(&self) -> f32 {
if self.episodes.is_empty() {
return 0.0;
}
let successes = self.episodes.iter().filter(|e| e.succeeded).count();
successes as f32 / self.episodes.len() as f32
}
}
impl Default for ReflexionMemory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_episode_creation() {
let episode = ReflexionEpisode::new(
"code_review",
"Review pull request #123",
"Approved without noticing the SQL injection",
false,
)
.with_critique(Critique::new(
CritiqueType::Security,
"Missed SQL injection vulnerability in user input",
"Always check for unsanitized inputs in database queries",
))
.with_retry_count(2);
assert_eq!(episode.task_type, "code_review");
assert!(!episode.succeeded);
assert_eq!(episode.critiques.len(), 1);
assert_eq!(episode.retry_count, 2);
}
#[test]
fn test_find_similar_failures() {
let mut memory = ReflexionMemory::new();
memory.add_episode(
ReflexionEpisode::new(
"sql_query",
"Write SQL query for user search",
"SELECT * FROM users WHERE name = '{input}'",
false,
)
.with_critique(Critique::new(
CritiqueType::Security,
"SQL injection possible",
"Use parameterized queries",
)),
);
memory.add_episode(
ReflexionEpisode::new(
"api_design",
"Design REST API for payments",
"POST /pay without authentication",
false,
)
.with_critique(Critique::new(
CritiqueType::Security,
"No auth on sensitive endpoint",
"Add authentication middleware",
)),
);
let similar = memory.find_similar_failures("SQL query for orders", 5);
assert!(!similar.is_empty());
assert!(similar.iter().any(|e| e.task_type == "sql_query"));
let similar = memory.find_similar_failures("REST API endpoint", 5);
assert!(!similar.is_empty());
assert!(similar.iter().any(|e| e.task_type == "api_design"));
}
#[test]
fn test_success_rate() {
let mut memory = ReflexionMemory::new();
for _ in 0..3 {
memory.add_episode(ReflexionEpisode::new("test", "goal", "attempt", false));
}
memory.add_episode(ReflexionEpisode::new("test", "goal", "attempt", true));
assert_eq!(memory.len(), 4);
assert_eq!(memory.failure_count(), 3);
assert!((memory.success_rate() - 0.25).abs() < 0.01);
}
}