use super::confidence::{ConfidenceChecker, ConfidenceConfig, ConfidenceLevel, WeakPoint};
use super::error_recovery::{ErrorPatternLearner, ErrorPatternLearnerConfig, RecoverySuggestion};
use super::perspectives::{CritiqueResult, Perspective, UnifiedCritique};
use crate::claude_flow::{AgentType, Verdict};
use crate::error::{Result, RuvLLMError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflectionConfig {
pub max_reflection_attempts: u32,
pub reflection_timeout_ms: u64,
pub learn_from_recovery: bool,
pub min_quality_threshold: f32,
pub record_trajectories: bool,
pub confidence_config: ConfidenceConfig,
pub error_learner_config: ErrorPatternLearnerConfig,
}
impl Default for ReflectionConfig {
fn default() -> Self {
Self {
max_reflection_attempts: 3,
reflection_timeout_ms: 30000, learn_from_recovery: true,
min_quality_threshold: 0.7,
record_trajectories: true,
confidence_config: ConfidenceConfig::default(),
error_learner_config: ErrorPatternLearnerConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub backoff_multiplier: f32,
pub initial_delay_ms: u64,
pub include_error_context: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
backoff_multiplier: 2.0,
initial_delay_ms: 100,
include_error_context: true,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub enum ReflectionStrategy {
Retry(RetryConfig),
IfOrElse {
#[serde(skip)]
checker: Option<Arc<ConfidenceChecker>>,
threshold: f32,
revision_budget: u32,
},
MultiPerspective {
#[serde(skip)]
perspectives: Vec<Arc<dyn Perspective + Send + Sync>>,
min_agreement: f32,
},
TrajectoryReflection {
window_size: usize,
use_sona: bool,
},
}
impl std::fmt::Debug for ReflectionStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Retry(config) => f.debug_tuple("Retry").field(config).finish(),
Self::IfOrElse {
threshold,
revision_budget,
..
} => f
.debug_struct("IfOrElse")
.field("threshold", threshold)
.field("revision_budget", revision_budget)
.field("checker", &"<ConfidenceChecker>")
.finish(),
Self::MultiPerspective {
min_agreement,
perspectives,
} => f
.debug_struct("MultiPerspective")
.field("min_agreement", min_agreement)
.field("perspectives_count", &perspectives.len())
.finish(),
Self::TrajectoryReflection {
window_size,
use_sona,
} => f
.debug_struct("TrajectoryReflection")
.field("window_size", window_size)
.field("use_sona", use_sona)
.finish(),
}
}
}
impl Default for ReflectionStrategy {
fn default() -> Self {
Self::Retry(RetryConfig::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionContext {
pub task: String,
pub agent_type: AgentType,
pub input: String,
pub previous_attempts: Vec<PreviousAttempt>,
pub metadata: HashMap<String, String>,
pub session_id: Option<String>,
pub parent_task: Option<String>,
}
impl ExecutionContext {
pub fn new(task: impl Into<String>, agent_type: AgentType, input: impl Into<String>) -> Self {
Self {
task: task.into(),
agent_type,
input: input.into(),
previous_attempts: Vec::new(),
metadata: HashMap::new(),
session_id: None,
parent_task: None,
}
}
pub fn with_previous_attempt(mut self, attempt: PreviousAttempt) -> Self {
self.previous_attempts.push(attempt);
self
}
pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreviousAttempt {
pub attempt_number: u32,
pub output: String,
pub error: Option<String>,
pub quality_score: Option<f32>,
pub duration_ms: u64,
pub reflection: Option<Reflection>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Reflection {
pub strategy: String,
pub context: String,
pub insights: Vec<String>,
pub suggestions: Vec<String>,
pub confidence: f32,
pub weak_points: Vec<WeakPoint>,
pub critique_results: Vec<CritiqueResult>,
pub reflection_time_ms: u64,
}
impl Reflection {
pub fn new(strategy: impl Into<String>, context: impl Into<String>) -> Self {
Self {
strategy: strategy.into(),
context: context.into(),
insights: Vec::new(),
suggestions: Vec::new(),
confidence: 0.5,
weak_points: Vec::new(),
critique_results: Vec::new(),
reflection_time_ms: 0,
}
}
pub fn with_insight(mut self, insight: impl Into<String>) -> Self {
self.insights.push(insight.into());
self
}
pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
self.suggestions.push(suggestion.into());
self
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn with_weak_points(mut self, weak_points: Vec<WeakPoint>) -> Self {
self.weak_points = weak_points;
self
}
pub fn with_critiques(mut self, critiques: Vec<CritiqueResult>) -> Self {
self.critique_results = critiques;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionResult {
pub output: String,
pub recovered_via_reflection: bool,
pub attempts: u32,
pub total_duration_ms: u64,
pub quality_score: f32,
pub verdict: Verdict,
pub reflection: Option<Reflection>,
pub attempt_history: Vec<PreviousAttempt>,
pub applied_suggestions: Vec<RecoverySuggestion>,
}
impl ExecutionResult {
pub fn success(output: impl Into<String>, attempts: u32, duration_ms: u64) -> Self {
Self {
output: output.into(),
recovered_via_reflection: false,
attempts,
total_duration_ms: duration_ms,
quality_score: 1.0,
verdict: Verdict::Success {
reason: "Task completed successfully".to_string(),
},
reflection: None,
attempt_history: Vec::new(),
applied_suggestions: Vec::new(),
}
}
pub fn recovered(
output: impl Into<String>,
original_error: impl Into<String>,
recovery_strategy: impl Into<String>,
attempts: u32,
duration_ms: u64,
reflection: Reflection,
) -> Self {
Self {
output: output.into(),
recovered_via_reflection: true,
attempts,
total_duration_ms: duration_ms,
quality_score: reflection.confidence,
verdict: Verdict::RecoveredViaReflection {
original_error: original_error.into(),
recovery_strategy: recovery_strategy.into(),
attempts,
},
reflection: Some(reflection),
attempt_history: Vec::new(),
applied_suggestions: Vec::new(),
}
}
pub fn failure(error: impl Into<String>, attempts: u32, duration_ms: u64) -> Self {
Self {
output: String::new(),
recovered_via_reflection: false,
attempts,
total_duration_ms: duration_ms,
quality_score: 0.0,
verdict: Verdict::Failure {
reason: error.into(),
error_code: None,
},
reflection: None,
attempt_history: Vec::new(),
applied_suggestions: Vec::new(),
}
}
pub fn with_history(mut self, history: Vec<PreviousAttempt>) -> Self {
self.attempt_history = history;
self
}
}
pub trait BaseAgent: Send + Sync {
fn execute(&self, context: &ExecutionContext) -> Result<String>;
fn agent_type(&self) -> AgentType;
fn estimate_confidence(&self, output: &str, _context: &ExecutionContext) -> f32 {
let has_content = !output.is_empty();
let has_structure = output.contains('\n') || output.len() > 100;
let output_lower = output.to_lowercase();
let not_error = !output_lower.contains("error") && !output_lower.contains("failed");
let score = (has_content as u8 as f32 * 0.3)
+ (has_structure as u8 as f32 * 0.3)
+ (not_error as u8 as f32 * 0.4);
score
}
}
pub struct ReflectiveAgent<A: BaseAgent> {
base_agent: A,
strategy: ReflectionStrategy,
config: ReflectionConfig,
error_learner: ErrorPatternLearner,
confidence_checker: ConfidenceChecker,
stats: ReflectiveAgentStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ReflectiveAgentStats {
pub total_executions: u64,
pub first_try_successes: u64,
pub recovered_count: u64,
pub failed_count: u64,
pub total_reflection_time_ms: u64,
pub avg_attempts: f32,
pub recovery_rate: f32,
}
impl<A: BaseAgent> ReflectiveAgent<A> {
pub fn new(base_agent: A, strategy: ReflectionStrategy) -> Self {
let config = ReflectionConfig::default();
let error_learner = ErrorPatternLearner::new(config.error_learner_config.clone());
let confidence_checker = ConfidenceChecker::new(config.confidence_config.clone());
Self {
base_agent,
strategy,
config,
error_learner,
confidence_checker,
stats: ReflectiveAgentStats::default(),
}
}
pub fn with_config(
base_agent: A,
strategy: ReflectionStrategy,
config: ReflectionConfig,
) -> Self {
let error_learner = ErrorPatternLearner::new(config.error_learner_config.clone());
let confidence_checker = ConfidenceChecker::new(config.confidence_config.clone());
Self {
base_agent,
strategy,
config,
error_learner,
confidence_checker,
stats: ReflectiveAgentStats::default(),
}
}
pub fn execute_with_reflection(
&mut self,
context: &ExecutionContext,
) -> Result<ExecutionResult> {
let start = Instant::now();
let mut attempts = 0u32;
let mut attempt_history = Vec::new();
let mut last_error: Option<String> = None;
let mut last_reflection: Option<Reflection> = None;
let mut applied_suggestions = Vec::new();
let mut current_context = context.clone();
loop {
attempts += 1;
let attempt_start = Instant::now();
if attempts > self.config.max_reflection_attempts {
self.stats.failed_count += 1;
self.stats.total_executions += 1;
return Ok(ExecutionResult::failure(
last_error.unwrap_or_else(|| "Max reflection attempts exceeded".to_string()),
attempts - 1,
start.elapsed().as_millis() as u64,
)
.with_history(attempt_history));
}
let result = self.base_agent.execute(¤t_context);
match result {
Ok(output) => {
let duration_ms = attempt_start.elapsed().as_millis() as u64;
let should_reflect = self.should_reflect(&output, ¤t_context);
if !should_reflect {
self.stats.total_executions += 1;
if attempts == 1 {
self.stats.first_try_successes += 1;
} else {
self.stats.recovered_count += 1;
}
self.update_avg_attempts(attempts);
if self.config.learn_from_recovery && last_error.is_some() {
if let Some(ref error) = last_error {
self.error_learner.learn_from_recovery(
error,
&output,
last_reflection.as_ref(),
);
}
}
let mut exec_result = if attempts > 1 && last_error.is_some() {
ExecutionResult::recovered(
output,
last_error.unwrap(),
self.strategy_name(),
attempts,
start.elapsed().as_millis() as u64,
last_reflection.unwrap_or_else(|| {
Reflection::new("retry", "Recovered on retry")
}),
)
} else {
ExecutionResult::success(
output,
attempts,
start.elapsed().as_millis() as u64,
)
};
exec_result.attempt_history = attempt_history;
exec_result.applied_suggestions = applied_suggestions;
return Ok(exec_result);
}
let reflection_start = Instant::now();
let reflection = self.generate_reflection(&output, ¤t_context, None)?;
self.stats.total_reflection_time_ms +=
reflection_start.elapsed().as_millis() as u64;
attempt_history.push(PreviousAttempt {
attempt_number: attempts,
output: output.clone(),
error: None,
quality_score: Some(reflection.confidence),
duration_ms,
reflection: Some(reflection.clone()),
});
current_context =
self.retry_with_context(¤t_context, Some(&output), None, &reflection);
last_reflection = Some(reflection);
}
Err(e) => {
let duration_ms = attempt_start.elapsed().as_millis() as u64;
let error_msg = e.to_string();
let suggestions = self.error_learner.suggest_recovery(&error_msg);
let reflection_start = Instant::now();
let reflection =
self.generate_reflection("", ¤t_context, Some(&error_msg))?;
self.stats.total_reflection_time_ms +=
reflection_start.elapsed().as_millis() as u64;
attempt_history.push(PreviousAttempt {
attempt_number: attempts,
output: String::new(),
error: Some(error_msg.clone()),
quality_score: Some(0.0),
duration_ms,
reflection: Some(reflection.clone()),
});
for suggestion in &suggestions {
if suggestion.confidence > 0.5 {
applied_suggestions.push(suggestion.clone());
}
}
current_context = self.retry_with_context(
¤t_context,
None,
Some(&error_msg),
&reflection,
);
last_error = Some(error_msg);
last_reflection = Some(reflection);
}
}
}
}
fn should_reflect(&self, output: &str, context: &ExecutionContext) -> bool {
match &self.strategy {
ReflectionStrategy::Retry(_) => {
false
}
ReflectionStrategy::IfOrElse {
threshold,
revision_budget,
..
} => {
let confidence = self.base_agent.estimate_confidence(output, context);
let attempts = context.previous_attempts.len() as u32;
confidence < *threshold && attempts < *revision_budget
}
ReflectionStrategy::MultiPerspective {
min_agreement,
perspectives,
} => {
if perspectives.is_empty() {
return false;
}
let mut agreements = 0;
for perspective in perspectives {
let critique = perspective.critique(output, context);
if critique.passed {
agreements += 1;
}
}
let agreement_ratio = agreements as f32 / perspectives.len() as f32;
agreement_ratio < *min_agreement
}
ReflectionStrategy::TrajectoryReflection { window_size, .. } => {
let recent_quality: f32 = context
.previous_attempts
.iter()
.rev()
.take(*window_size)
.filter_map(|a| a.quality_score)
.sum::<f32>()
/ context.previous_attempts.len().min(*window_size).max(1) as f32;
recent_quality < self.config.min_quality_threshold
}
}
}
pub fn generate_reflection(
&self,
output: &str,
context: &ExecutionContext,
error: Option<&str>,
) -> Result<Reflection> {
let start = Instant::now();
let mut reflection = match &self.strategy {
ReflectionStrategy::Retry(config) => {
let mut r = Reflection::new("retry", "Retry with accumulated context");
if let Some(e) = error {
r.insights.push(format!("Error encountered: {}", e));
r.suggestions
.push("Review error and adjust approach".to_string());
}
if config.include_error_context && !context.previous_attempts.is_empty() {
r.insights.push(format!(
"Previous {} attempts failed",
context.previous_attempts.len()
));
}
r
}
ReflectionStrategy::IfOrElse { threshold, .. } => {
let confidence = self.base_agent.estimate_confidence(output, context);
let weak_points = self
.confidence_checker
.identify_weak_points(output, context);
let mut r = Reflection::new(
"if_or_else",
format!(
"Confidence {} ({:.2}) threshold {:.2}",
if confidence < *threshold {
"below"
} else {
"meets"
},
confidence,
threshold
),
);
r.confidence = confidence;
r.weak_points = weak_points.clone();
for wp in &weak_points {
r.insights.push(format!(
"{}: {} (severity: {:.2})",
wp.location, wp.description, wp.severity
));
r.suggestions.push(wp.suggestion.clone());
}
r
}
ReflectionStrategy::MultiPerspective { perspectives, .. } => {
let mut r = Reflection::new("multi_perspective", "Multi-angle critique");
let mut critiques = Vec::new();
for perspective in perspectives {
let critique = perspective.critique(output, context);
r.insights.push(format!(
"[{}] {}: {}",
critique.perspective_name,
if critique.passed { "PASS" } else { "FAIL" },
critique.summary
));
for issue in &critique.issues {
r.suggestions.push(format!(
"[{}] {}",
critique.perspective_name, issue.suggestion
));
}
critiques.push(critique);
}
let avg_score: f32 =
critiques.iter().map(|c| c.score).sum::<f32>() / critiques.len().max(1) as f32;
r.confidence = avg_score;
r.critique_results = critiques;
r
}
ReflectionStrategy::TrajectoryReflection { window_size, .. } => {
let mut r =
Reflection::new("trajectory", "Trajectory analysis over execution history");
let recent: Vec<_> = context
.previous_attempts
.iter()
.rev()
.take(*window_size)
.collect();
if !recent.is_empty() {
let error_count = recent.iter().filter(|a| a.error.is_some()).count();
if error_count > 0 {
r.insights.push(format!(
"{} errors in last {} attempts",
error_count,
recent.len()
));
}
let qualities: Vec<f32> =
recent.iter().filter_map(|a| a.quality_score).collect();
if qualities.len() >= 2 {
let trend = qualities[0] - qualities[qualities.len() - 1];
if trend > 0.1 {
r.insights.push("Quality improving".to_string());
} else if trend < -0.1 {
r.insights
.push("Quality declining - consider strategy change".to_string());
r.suggestions
.push("Try different approach or break task down".to_string());
}
}
let avg_quality = qualities.iter().sum::<f32>() / qualities.len().max(1) as f32;
r.confidence = avg_quality;
}
r
}
};
reflection.reflection_time_ms = start.elapsed().as_millis() as u64;
Ok(reflection)
}
pub fn retry_with_context(
&self,
original: &ExecutionContext,
previous_output: Option<&str>,
error: Option<&str>,
reflection: &Reflection,
) -> ExecutionContext {
let mut context = original.clone();
let attempt_number = context.previous_attempts.len() as u32 + 1;
context.previous_attempts.push(PreviousAttempt {
attempt_number,
output: previous_output.unwrap_or("").to_string(),
error: error.map(String::from),
quality_score: Some(reflection.confidence),
duration_ms: 0,
reflection: Some(reflection.clone()),
});
let mut augmented_input = context.input.clone();
augmented_input.push_str("\n\n--- Reflection Context ---\n");
if let Some(e) = error {
augmented_input.push_str(&format!("Previous error: {}\n", e));
}
if !reflection.insights.is_empty() {
augmented_input.push_str("Insights:\n");
for insight in &reflection.insights {
augmented_input.push_str(&format!("- {}\n", insight));
}
}
if !reflection.suggestions.is_empty() {
augmented_input.push_str("Suggestions:\n");
for suggestion in &reflection.suggestions {
augmented_input.push_str(&format!("- {}\n", suggestion));
}
}
context.input = augmented_input;
context
}
fn strategy_name(&self) -> String {
match &self.strategy {
ReflectionStrategy::Retry(_) => "retry".to_string(),
ReflectionStrategy::IfOrElse { .. } => "if_or_else".to_string(),
ReflectionStrategy::MultiPerspective { .. } => "multi_perspective".to_string(),
ReflectionStrategy::TrajectoryReflection { .. } => "trajectory".to_string(),
}
}
fn update_avg_attempts(&mut self, attempts: u32) {
let n = self.stats.total_executions as f32;
self.stats.avg_attempts =
(self.stats.avg_attempts * (n - 1.0) + attempts as f32) / n.max(1.0);
let total =
self.stats.first_try_successes + self.stats.recovered_count + self.stats.failed_count;
if total > 0 {
self.stats.recovery_rate = self.stats.recovered_count as f32
/ (self.stats.recovered_count + self.stats.failed_count).max(1) as f32;
}
}
pub fn stats(&self) -> &ReflectiveAgentStats {
&self.stats
}
pub fn error_learner(&self) -> &ErrorPatternLearner {
&self.error_learner
}
pub fn error_learner_mut(&mut self) -> &mut ErrorPatternLearner {
&mut self.error_learner
}
pub fn confidence_checker(&self) -> &ConfidenceChecker {
&self.confidence_checker
}
pub fn base_agent(&self) -> &A {
&self.base_agent
}
pub fn base_agent_mut(&mut self) -> &mut A {
&mut self.base_agent
}
pub fn set_strategy(&mut self, strategy: ReflectionStrategy) {
self.strategy = strategy;
}
pub fn strategy(&self) -> &ReflectionStrategy {
&self.strategy
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
struct TestAgent {
agent_type: AgentType,
fail_count: AtomicU32,
max_fails: u32,
}
impl TestAgent {
fn new(max_fails: u32) -> Self {
Self {
agent_type: AgentType::Coder,
fail_count: AtomicU32::new(0),
max_fails,
}
}
}
impl BaseAgent for TestAgent {
fn execute(&self, context: &ExecutionContext) -> Result<String> {
let count = self.fail_count.fetch_add(1, Ordering::SeqCst);
if count < self.max_fails {
Err(RuvLLMError::InvalidOperation(format!(
"Simulated failure {}",
count + 1
)))
} else {
Ok(format!(
"Success after {} failures for: {}",
count, context.task
))
}
}
fn agent_type(&self) -> AgentType {
self.agent_type
}
}
#[test]
fn test_reflective_agent_retry_success() {
let base = TestAgent::new(2); let mut agent =
ReflectiveAgent::new(base, ReflectionStrategy::Retry(RetryConfig::default()));
let context = ExecutionContext::new("test task", AgentType::Coder, "test input");
let result = agent.execute_with_reflection(&context).unwrap();
assert!(result.recovered_via_reflection);
assert_eq!(result.attempts, 3);
assert!(result.output.contains("Success"));
}
#[test]
fn test_reflective_agent_max_attempts() {
let base = TestAgent::new(10); let config = ReflectionConfig {
max_reflection_attempts: 3,
..Default::default()
};
let mut agent = ReflectiveAgent::with_config(
base,
ReflectionStrategy::Retry(RetryConfig::default()),
config,
);
let context = ExecutionContext::new("test task", AgentType::Coder, "test input");
let result = agent.execute_with_reflection(&context).unwrap();
assert!(!result.recovered_via_reflection);
assert!(matches!(result.verdict, Verdict::Failure { .. }));
}
#[test]
fn test_reflection_generation() {
let base = TestAgent::new(0);
let agent = ReflectiveAgent::new(base, ReflectionStrategy::Retry(RetryConfig::default()));
let context = ExecutionContext::new("test", AgentType::Coder, "input");
let reflection = agent
.generate_reflection("output", &context, Some("test error"))
.unwrap();
assert_eq!(reflection.strategy, "retry");
assert!(!reflection.insights.is_empty());
}
#[test]
fn test_execution_context_builder() {
let context = ExecutionContext::new("task", AgentType::Researcher, "input")
.with_session("session-123")
.with_metadata("key", "value");
assert_eq!(context.session_id, Some("session-123".to_string()));
assert_eq!(context.metadata.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_execution_result_variants() {
let success = ExecutionResult::success("output", 1, 100);
assert!(matches!(success.verdict, Verdict::Success { .. }));
let recovered = ExecutionResult::recovered(
"output",
"error",
"retry",
2,
200,
Reflection::new("retry", "context"),
);
assert!(matches!(
recovered.verdict,
Verdict::RecoveredViaReflection { .. }
));
assert!(recovered.recovered_via_reflection);
let failure = ExecutionResult::failure("error", 3, 300);
assert!(matches!(failure.verdict, Verdict::Failure { .. }));
}
#[test]
fn test_stats_tracking() {
let base = TestAgent::new(1);
let mut agent =
ReflectiveAgent::new(base, ReflectionStrategy::Retry(RetryConfig::default()));
let context = ExecutionContext::new("test", AgentType::Coder, "input");
let _ = agent.execute_with_reflection(&context);
let stats = agent.stats();
assert_eq!(stats.total_executions, 1);
assert_eq!(stats.recovered_count, 1);
}
}