use crate::checkpoint_system::{ErrorCategory, ErrorRecovery, RecoveryLayer};
use crate::planner::{ExecutionPlan, RiskLevel};
use crate::types::Layer2Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExecutionStatus {
#[default]
Pending,
Running,
Paused,
StepCompleted,
StepFailed,
Completed,
Failed,
AwaitingUserInput,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct StepResult {
pub subtask_id: String,
pub status: ExecutionStatus,
pub output: Option<String>,
pub error: Option<String>,
pub duration: Duration,
pub retry_count: u32,
pub recovery_layer: Option<RecoveryLayer>,
}
#[allow(clippy::type_complexity)]
pub struct ExecutionMonitor {
plan: Arc<RwLock<ExecutionPlan>>,
status: Arc<RwLock<ExecutionStatus>>,
step_results: Arc<RwLock<HashMap<String, StepResult>>>,
#[allow(dead_code)]
error_recovery: Arc<ErrorRecovery>,
start_time: Arc<RwLock<Option<Instant>>>,
progress_callback: Arc<RwLock<Option<Box<dyn Fn(&str, ExecutionStatus) + Send + Sync>>>>,
correction_history: Arc<RwLock<Vec<CorrectionRecord>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrectionRecord {
pub id: String,
pub failed_subtask: String,
pub error_category: String,
pub original_error: String,
pub strategy: CorrectionStrategy,
pub success: bool,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CorrectionStrategy {
Retry { max_attempts: u32 },
Skip,
Alternative { replacement_subtask: String },
Decompose { new_subtasks: Vec<String> },
UserIntervention { action: String },
AdjustParameters { new_params: serde_json::Value },
}
impl CorrectionStrategy {
pub fn debug_name(&self) -> &'static str {
match self {
CorrectionStrategy::Retry { .. } => "Retry",
CorrectionStrategy::Skip => "Skip",
CorrectionStrategy::Alternative { .. } => "Alternative",
CorrectionStrategy::Decompose { .. } => "Decompose",
CorrectionStrategy::UserIntervention { .. } => "UserIntervention",
CorrectionStrategy::AdjustParameters { .. } => "AdjustParameters",
}
}
}
impl ExecutionMonitor {
pub fn new(plan: ExecutionPlan) -> Self {
Self {
plan: Arc::new(RwLock::new(plan)),
status: Arc::new(RwLock::new(ExecutionStatus::Pending)),
step_results: Arc::new(RwLock::new(HashMap::new())),
error_recovery: Arc::new(ErrorRecovery::new()),
start_time: Arc::new(RwLock::new(None)),
progress_callback: Arc::new(RwLock::new(None)),
correction_history: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn set_progress_callback<F>(&self, callback: F)
where
F: Fn(&str, ExecutionStatus) + Send + Sync + 'static,
{
*self.progress_callback.write().await = Some(Box::new(callback));
}
pub async fn get_status(&self) -> ExecutionStatus {
*self.status.read().await
}
pub async fn get_progress(&self) -> u32 {
let plan = self.plan.read().await;
let results = self.step_results.read().await;
if plan.subtasks.is_empty() {
return 0;
}
let completed = results
.values()
.filter(|r| matches!(r.status, ExecutionStatus::StepCompleted))
.count();
(completed as u32 * 100) / plan.subtasks.len() as u32
}
pub async fn start(&self) -> Layer2Result<()> {
let mut status = self.status.write().await;
*status = ExecutionStatus::Running;
drop(status);
*self.start_time.write().await = Some(Instant::now());
self.notify_progress("execution_started", ExecutionStatus::Running)
.await;
Ok(())
}
pub async fn report_step_completed(
&self,
subtask_id: &str,
output: String,
) -> Layer2Result<()> {
let result = StepResult {
subtask_id: subtask_id.to_string(),
status: ExecutionStatus::StepCompleted,
output: Some(output),
error: None,
duration: Duration::from_secs(0),
retry_count: 0,
recovery_layer: None,
};
self.step_results
.write()
.await
.insert(subtask_id.to_string(), result);
self.notify_progress(subtask_id, ExecutionStatus::StepCompleted)
.await;
Ok(())
}
pub async fn report_step_failed(
&self,
subtask_id: &str,
error: String,
) -> Layer2Result<CorrectionDecision> {
let category = ErrorCategory::from_error_message(&error);
let result = StepResult {
subtask_id: subtask_id.to_string(),
status: ExecutionStatus::StepFailed,
output: None,
error: Some(error.clone()),
duration: Duration::from_secs(0),
retry_count: 0,
recovery_layer: None,
};
self.step_results
.write()
.await
.insert(subtask_id.to_string(), result);
let decision = self.decide_correction(subtask_id, &category, &error).await;
self.record_correction(subtask_id, &category, &error, &decision)
.await;
self.notify_progress(subtask_id, ExecutionStatus::StepFailed)
.await;
Ok(decision)
}
async fn decide_correction(
&self,
subtask_id: &str,
category: &ErrorCategory,
error: &str,
) -> CorrectionDecision {
let plan = self.plan.read().await;
let subtask = plan.subtasks.iter().find(|s| s.id == subtask_id);
match category {
ErrorCategory::Transient => {
CorrectionDecision {
strategy: CorrectionStrategy::Retry { max_attempts: 3 },
should_continue: true,
user_message: Some("Temporary error, will retry automatically".to_string()),
}
}
ErrorCategory::Resource => {
CorrectionDecision {
strategy: CorrectionStrategy::Retry { max_attempts: 2 },
should_continue: true,
user_message: Some("Resource issue, waiting before retry".to_string()),
}
}
ErrorCategory::Logic => {
if let Some(subtask) = subtask {
if let Some(fallback) = &subtask.fallback {
CorrectionDecision {
strategy: CorrectionStrategy::Alternative {
replacement_subtask: fallback.name.clone(),
},
should_continue: true,
user_message: Some("Using fallback strategy".to_string()),
}
} else {
CorrectionDecision {
strategy: CorrectionStrategy::Decompose {
new_subtasks: vec!["simplified_step".to_string()],
},
should_continue: true,
user_message: Some("Breaking down the task".to_string()),
}
}
} else {
CorrectionDecision {
strategy: CorrectionStrategy::Skip,
should_continue: true,
user_message: Some("Skipping failed step".to_string()),
}
}
}
ErrorCategory::Configuration => {
CorrectionDecision {
strategy: CorrectionStrategy::UserIntervention {
action: "Please check your configuration".to_string(),
},
should_continue: false,
user_message: Some(format!("Configuration error: {}", error)),
}
}
ErrorCategory::UserInterrupt => {
CorrectionDecision {
strategy: CorrectionStrategy::Skip,
should_continue: false,
user_message: Some("Execution cancelled by user".to_string()),
}
}
ErrorCategory::System => {
if plan.risk_level == RiskLevel::Critical {
CorrectionDecision {
strategy: CorrectionStrategy::UserIntervention {
action: "Critical error requires manual intervention".to_string(),
},
should_continue: false,
user_message: Some(format!("Critical system error: {}", error)),
}
} else {
CorrectionDecision {
strategy: CorrectionStrategy::Retry { max_attempts: 1 },
should_continue: true,
user_message: Some("System error, attempting recovery".to_string()),
}
}
}
}
}
async fn record_correction(
&self,
subtask_id: &str,
category: &ErrorCategory,
error: &str,
decision: &CorrectionDecision,
) {
let category_str = match category {
ErrorCategory::Transient => "Transient",
ErrorCategory::Resource => "Resource",
ErrorCategory::Configuration => "Configuration",
ErrorCategory::Logic => "Logic",
ErrorCategory::System => "System",
ErrorCategory::UserInterrupt => "UserInterrupt",
};
let record = CorrectionRecord {
id: format!("correction_{}", chrono::Utc::now().timestamp()),
failed_subtask: subtask_id.to_string(),
error_category: category_str.to_string(),
original_error: error.to_string(),
strategy: decision.strategy.clone(),
success: false, timestamp: chrono::Utc::now(),
};
self.correction_history.write().await.push(record);
}
pub async fn apply_correction(
&self,
subtask_id: &str,
decision: &CorrectionDecision,
) -> Layer2Result<bool> {
match &decision.strategy {
CorrectionStrategy::Retry { max_attempts: _ } => {
Ok(true)
}
CorrectionStrategy::Skip => {
self.report_step_completed(
subtask_id,
"[SKIPPED] Step skipped due to unrecoverable error".to_string(),
)
.await?;
Ok(true)
}
CorrectionStrategy::Alternative {
replacement_subtask,
} => {
self.report_step_completed(
subtask_id,
format!("[ALTERNATIVE] Used: {}", replacement_subtask),
)
.await?;
Ok(true)
}
CorrectionStrategy::UserIntervention { action: _ } => {
let mut status = self.status.write().await;
*status = ExecutionStatus::AwaitingUserInput;
Ok(false)
}
CorrectionStrategy::Decompose { new_subtasks } => {
self.report_step_completed(
subtask_id,
format!("[DECOMPOSED] Into: {}", new_subtasks.join(", ")),
)
.await?;
Ok(true)
}
CorrectionStrategy::AdjustParameters { new_params: _ } => {
self.report_step_completed(
subtask_id,
"[ADJUSTED] Parameters modified".to_string(),
)
.await?;
Ok(true)
}
}
}
pub async fn complete(&self) -> Layer2Result<ExecutionSummary> {
let mut status = self.status.write().await;
*status = ExecutionStatus::Completed;
drop(status);
self.notify_progress("execution_completed", ExecutionStatus::Completed)
.await;
let plan = self.plan.read().await;
let results = self.step_results.read().await;
let corrections = self.correction_history.read().await;
let start_time = self.start_time.read().await;
let completed = results
.values()
.filter(|r| matches!(r.status, ExecutionStatus::StepCompleted))
.count();
let failed = results
.values()
.filter(|r| matches!(r.status, ExecutionStatus::StepFailed))
.count();
let skipped = results
.values()
.filter(|r| {
r.output
.as_ref()
.map(|o| o.starts_with("[SKIPPED]"))
.unwrap_or(false)
})
.count();
Ok(ExecutionSummary {
plan_id: plan.id.clone(),
total_steps: plan.subtasks.len(),
completed_steps: completed,
failed_steps: failed,
skipped_steps: skipped,
correction_count: corrections.len(),
duration: start_time.map(|t| t.elapsed()).unwrap_or(Duration::ZERO),
status: ExecutionStatus::Completed,
})
}
pub async fn get_correction_history(&self) -> Vec<CorrectionRecord> {
self.correction_history.read().await.clone()
}
async fn notify_progress(&self, subtask_id: &str, status: ExecutionStatus) {
if let Some(callback) = self.progress_callback.read().await.as_ref() {
callback(subtask_id, status);
}
}
}
#[derive(Debug, Clone)]
pub struct CorrectionDecision {
pub strategy: CorrectionStrategy,
pub should_continue: bool,
pub user_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionSummary {
pub plan_id: String,
pub total_steps: usize,
pub completed_steps: usize,
pub failed_steps: usize,
pub skipped_steps: usize,
pub correction_count: usize,
pub duration: Duration,
pub status: ExecutionStatus,
}
pub struct SelfCorrector {
history: RwLock<Vec<CorrectionRecord>>,
patterns: RwLock<HashMap<String, CorrectionStrategy>>,
}
impl Default for SelfCorrector {
fn default() -> Self {
Self::new()
}
}
impl SelfCorrector {
pub fn new() -> Self {
Self {
history: RwLock::new(Vec::new()),
patterns: RwLock::new(HashMap::new()),
}
}
pub async fn learn_pattern(&self, error_signature: &str, strategy: CorrectionStrategy) {
self.patterns
.write()
.await
.insert(error_signature.to_string(), strategy);
}
pub async fn get_recommended_strategy(&self, error: &str) -> Option<CorrectionStrategy> {
let patterns = self.patterns.read().await;
for (signature, strategy) in patterns.iter() {
if error.contains(signature) {
return Some(strategy.clone());
}
}
None
}
pub async fn record_result(&self, record: CorrectionRecord) {
if record.success {
let signature = Self::extract_signature(&record.original_error);
self.learn_pattern(&signature, record.strategy.clone())
.await;
}
self.history.write().await.push(record);
}
fn extract_signature(error: &str) -> String {
let error_lower = error.to_lowercase();
if error_lower.len() > 50 {
error_lower[..50].to_string()
} else {
error_lower
}
}
pub async fn get_success_rate(&self) -> f32 {
let history = self.history.read().await;
if history.is_empty() {
return 0.0;
}
let success_count = history.iter().filter(|r| r.success).count();
success_count as f32 / history.len() as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::planner::{ExecutionPlan, SubTask};
#[tokio::test]
async fn test_execution_monitor_creation() {
let plan = ExecutionPlan::new("Test task");
let monitor = ExecutionMonitor::new(plan);
let status = monitor.get_status().await;
assert_eq!(status, ExecutionStatus::Pending);
}
#[tokio::test]
async fn test_progress_calculation() {
let mut plan = ExecutionPlan::new("Test task");
plan.add_subtask(SubTask::new("s1", "Step 1", "First"));
plan.add_subtask(SubTask::new("s2", "Step 2", "Second"));
plan.compute_execution_order().unwrap();
let monitor = ExecutionMonitor::new(plan);
monitor.start().await.unwrap();
assert_eq!(monitor.get_progress().await, 0);
monitor
.report_step_completed("s1", "Done".to_string())
.await
.unwrap();
assert_eq!(monitor.get_progress().await, 50);
monitor
.report_step_completed("s2", "Done".to_string())
.await
.unwrap();
assert_eq!(monitor.get_progress().await, 100);
}
#[tokio::test]
async fn test_correction_decision() {
let plan = ExecutionPlan::new("Test task");
let monitor = ExecutionMonitor::new(plan);
let decision = monitor
.decide_correction("test_subtask", &ErrorCategory::Transient, "Network timeout")
.await;
assert!(decision.should_continue);
matches!(decision.strategy, CorrectionStrategy::Retry { .. });
}
#[tokio::test]
async fn test_self_corrector_learning() {
let corrector = SelfCorrector::new();
corrector
.learn_pattern("timeout", CorrectionStrategy::Retry { max_attempts: 3 })
.await;
let strategy = corrector
.get_recommended_strategy("Connection timeout occurred")
.await;
assert!(strategy.is_some());
}
#[tokio::test]
async fn test_correction_history() {
let plan = ExecutionPlan::new("Test task");
let monitor = ExecutionMonitor::new(plan);
monitor
.report_step_failed("s1", "Error occurred".to_string())
.await
.unwrap();
let history = monitor.get_correction_history().await;
assert!(!history.is_empty());
}
}