use crate::types::{Layer2Result, SessionId};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::sleep;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ErrorCategory {
Transient,
Resource,
Configuration,
Logic,
System,
UserInterrupt,
}
impl ErrorCategory {
pub fn from_error_message(msg: &str) -> Self {
let msg_lower = msg.to_lowercase();
if msg_lower.contains("timeout")
|| msg_lower.contains("network")
|| msg_lower.contains("rate limit")
{
ErrorCategory::Transient
} else if msg_lower.contains("memory")
|| msg_lower.contains("disk")
|| msg_lower.contains("resource")
{
ErrorCategory::Resource
} else if msg_lower.contains("api key")
|| msg_lower.contains("config")
|| msg_lower.contains("auth")
{
ErrorCategory::Configuration
} else if msg_lower.contains("invalid")
|| msg_lower.contains("parameter")
|| msg_lower.contains("argument")
{
ErrorCategory::Logic
} else if msg_lower.contains("interrupt")
|| msg_lower.contains("cancel")
|| msg_lower.contains("abort")
{
ErrorCategory::UserInterrupt
} else {
ErrorCategory::System
}
}
pub fn is_retryable(&self) -> bool {
matches!(self, ErrorCategory::Transient | ErrorCategory::Resource)
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_retries: usize,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub multiplier: f64,
pub jitter: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 30000,
multiplier: 2.0,
jitter: 0.1,
}
}
}
impl RetryPolicy {
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
let base_delay = self.initial_delay_ms as f64 * self.multiplier.powi(attempt as i32);
let capped_delay = base_delay.min(self.max_delay_ms as f64);
let jitter_range = capped_delay * self.jitter;
let jitter_offset = ((attempt as f64 * 0.3).fract() - 0.5) * 2.0 * jitter_range;
let final_delay = (capped_delay + jitter_offset).max(0.0) as u64;
Duration::from_millis(final_delay)
}
}
#[derive(Debug, Clone)]
pub enum FallbackStrategy {
None,
BackupService { endpoint: String },
UseCache { max_age_seconds: u64 },
Simplified { mode: String },
Skip,
}
#[derive(Debug, Clone)]
pub struct RecoveryResult {
pub success: bool,
pub layer_used: RecoveryLayer,
pub attempts: usize,
pub error_message: Option<String>,
pub user_action: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecoveryLayer {
Automatic,
Fallback,
UserIntervention,
}
#[derive(Debug, Clone)]
pub enum RecoveryAction {
Retry,
Skip,
Abort,
ModifyConfig { key: String, value: String },
SwitchBackup { service: String },
}
pub type UserConfirmationCallback =
Arc<dyn Fn(&str, Vec<RecoveryAction>) -> RecoveryAction + Send + Sync>;
pub struct ErrorRecovery {
retry_policy: RetryPolicy,
fallback_strategy: FallbackStrategy,
user_callback: RwLock<Option<UserConfirmationCallback>>,
stats: RwLock<RecoveryStats>,
}
#[derive(Debug, Clone, Default)]
pub struct RecoveryStats {
pub total_errors: usize,
pub auto_recovered: usize,
pub fallback_recovered: usize,
pub user_interventions: usize,
pub unrecovered: usize,
}
impl Default for ErrorRecovery {
fn default() -> Self {
Self::new()
}
}
impl ErrorRecovery {
pub fn new() -> Self {
Self {
retry_policy: RetryPolicy::default(),
fallback_strategy: FallbackStrategy::None,
user_callback: RwLock::new(None),
stats: RwLock::new(RecoveryStats::default()),
}
}
pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = policy;
self
}
pub fn with_fallback(mut self, strategy: FallbackStrategy) -> Self {
self.fallback_strategy = strategy;
self
}
pub async fn set_user_callback(&self, callback: UserConfirmationCallback) {
*self.user_callback.write().await = Some(callback);
}
pub async fn execute_with_recovery<F, Fut, T>(&self, operation: F) -> RecoveryResult
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Layer2Result<T>> + Send,
T: Send,
{
let mut stats = self.stats.write().await;
stats.total_errors += 1;
drop(stats);
let retry_result = self.try_with_retry(&operation).await;
if retry_result.success {
let mut stats = self.stats.write().await;
stats.auto_recovered += 1;
return retry_result;
}
let fallback_result = self.try_with_fallback(&operation).await;
if fallback_result.success {
let mut stats = self.stats.write().await;
stats.fallback_recovered += 1;
return fallback_result;
}
let user_result = self.try_with_user_intervention(&operation).await;
if user_result.success {
let mut stats = self.stats.write().await;
stats.user_interventions += 1;
} else {
let mut stats = self.stats.write().await;
stats.unrecovered += 1;
}
user_result
}
async fn try_with_retry<F, Fut, T>(&self, operation: &F) -> RecoveryResult
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Layer2Result<T>> + Send,
T: Send,
{
let mut last_error: Option<String> = None;
for attempt in 0..=self.retry_policy.max_retries {
match operation().await {
Ok(_) => {
return RecoveryResult {
success: true,
layer_used: RecoveryLayer::Automatic,
attempts: attempt,
error_message: None,
user_action: None,
};
}
Err(e) => {
let error_msg = e.to_string();
let category = ErrorCategory::from_error_message(&error_msg);
if !category.is_retryable() {
return RecoveryResult {
success: false,
layer_used: RecoveryLayer::Automatic,
attempts: attempt,
error_message: Some(error_msg.clone()),
user_action: Some(self.get_user_hint(&category)),
};
}
last_error = Some(error_msg);
if attempt < self.retry_policy.max_retries {
let delay = self.retry_policy.delay_for_attempt(attempt);
sleep(delay).await;
}
}
}
}
RecoveryResult {
success: false,
layer_used: RecoveryLayer::Automatic,
attempts: self.retry_policy.max_retries + 1,
error_message: last_error,
user_action: None,
}
}
async fn try_with_fallback<F, Fut, T>(&self, _operation: &F) -> RecoveryResult
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Layer2Result<T>> + Send,
T: Send,
{
match &self.fallback_strategy {
FallbackStrategy::None => RecoveryResult {
success: false,
layer_used: RecoveryLayer::Fallback,
attempts: 0,
error_message: Some("No fallback strategy configured".to_string()),
user_action: None,
},
FallbackStrategy::Skip => RecoveryResult {
success: true,
layer_used: RecoveryLayer::Fallback,
attempts: 1,
error_message: None,
user_action: Some("Operation skipped due to fallback policy".to_string()),
},
FallbackStrategy::BackupService { endpoint } => {
RecoveryResult {
success: true,
layer_used: RecoveryLayer::Fallback,
attempts: 1,
error_message: None,
user_action: Some(format!("Switched to backup: {}", endpoint)),
}
}
FallbackStrategy::UseCache { max_age_seconds } => RecoveryResult {
success: true,
layer_used: RecoveryLayer::Fallback,
attempts: 1,
error_message: None,
user_action: Some(format!("Using cached data (max {}s old)", max_age_seconds)),
},
FallbackStrategy::Simplified { mode } => RecoveryResult {
success: true,
layer_used: RecoveryLayer::Fallback,
attempts: 1,
error_message: None,
user_action: Some(format!("Using simplified mode: {}", mode)),
},
}
}
async fn try_with_user_intervention<F, Fut, T>(&self, _operation: &F) -> RecoveryResult
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Layer2Result<T>> + Send,
T: Send,
{
let callback = self.user_callback.read().await;
if let Some(cb) = callback.as_ref() {
let actions = vec![
RecoveryAction::Retry,
RecoveryAction::Skip,
RecoveryAction::Abort,
];
let action = cb("Operation failed. Choose action:", actions);
match action {
RecoveryAction::Retry => RecoveryResult {
success: false, layer_used: RecoveryLayer::UserIntervention,
attempts: 1,
error_message: None,
user_action: Some("User requested retry".to_string()),
},
RecoveryAction::Skip => RecoveryResult {
success: true,
layer_used: RecoveryLayer::UserIntervention,
attempts: 1,
error_message: None,
user_action: Some("User chose to skip".to_string()),
},
RecoveryAction::Abort => RecoveryResult {
success: false,
layer_used: RecoveryLayer::UserIntervention,
attempts: 1,
error_message: Some("User aborted operation".to_string()),
user_action: Some("User aborted".to_string()),
},
_ => RecoveryResult {
success: false,
layer_used: RecoveryLayer::UserIntervention,
attempts: 1,
error_message: Some("Unknown action".to_string()),
user_action: None,
},
}
} else {
RecoveryResult {
success: false,
layer_used: RecoveryLayer::UserIntervention,
attempts: 0,
error_message: Some("No user callback set".to_string()),
user_action: Some("Please configure user callback for intervention".to_string()),
}
}
}
fn get_user_hint(&self, category: &ErrorCategory) -> String {
match category {
ErrorCategory::Configuration => "Check your API key and configuration".to_string(),
ErrorCategory::Logic => "Verify your input parameters".to_string(),
ErrorCategory::UserInterrupt => "Operation was cancelled".to_string(),
ErrorCategory::Transient => "Temporary issue, will retry automatically".to_string(),
ErrorCategory::Resource => {
"System resource issue, consider freeing up memory/disk".to_string()
}
ErrorCategory::System => "Unknown error occurred".to_string(),
}
}
pub async fn get_stats(&self) -> RecoveryStats {
self.stats.read().await.clone()
}
}
pub struct SessionRecovery {
storage_path: std::path::PathBuf,
}
impl SessionRecovery {
pub fn new(storage_path: impl AsRef<std::path::Path>) -> Self {
Self {
storage_path: storage_path.as_ref().to_path_buf(),
}
}
pub fn detect_interrupted_sessions(&self) -> Layer2Result<Vec<InterruptedSession>> {
let mut interrupted = Vec::new();
if !self.storage_path.exists() {
return Ok(interrupted);
}
for entry in std::fs::read_dir(&self.storage_path)? {
let entry = entry?;
let session_dir = entry.path();
if !session_dir.is_dir() {
continue;
}
let state_file = session_dir.join("state.json");
if state_file.exists() {
if let Ok(content) = std::fs::read_to_string(&state_file) {
if let Ok(state) = serde_json::from_str::<SessionState>(&content) {
if state.status == SessionStatus::Running && !state.completed {
interrupted.push(InterruptedSession {
session_id: state.session_id,
last_iteration: state.iteration,
last_activity: state.last_updated,
task_description: state.task_description,
});
}
}
}
}
}
interrupted.sort_by_key(|b| std::cmp::Reverse(b.last_activity));
Ok(interrupted)
}
pub fn render_interrupted(&self) -> String {
match self.detect_interrupted_sessions() {
Ok(sessions) => {
if sessions.is_empty() {
"No interrupted sessions found.".to_string()
} else {
let mut output =
format!("Found {} interrupted session(s):\n\n", sessions.len());
for (i, session) in sessions.iter().enumerate() {
output.push_str(&format!(
"{}. Session: {}\n Task: {}\n Iteration: {}\n Last activity: {}\n\n",
i + 1,
session.session_id,
session.task_description.as_deref().unwrap_or("Unknown"),
session.last_iteration,
session.last_activity.format("%Y-%m-%d %H:%M:%S")
));
}
output.push_str("Use 'continuum session resume <id>' to continue.");
output
}
}
Err(e) => format!("Error detecting sessions: {}", e),
}
}
}
#[derive(Debug, Clone)]
pub struct InterruptedSession {
pub session_id: SessionId,
pub last_iteration: i32,
pub last_activity: chrono::DateTime<chrono::Utc>,
pub task_description: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct SessionState {
session_id: SessionId,
status: SessionStatus,
completed: bool,
iteration: i32,
last_updated: chrono::DateTime<chrono::Utc>,
task_description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
enum SessionStatus {
Running,
Paused,
Completed,
Error,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_category_analysis() {
let cat = ErrorCategory::from_error_message("network timeout");
assert_eq!(cat, ErrorCategory::Transient);
let cat = ErrorCategory::from_error_message("invalid parameter");
assert_eq!(cat, ErrorCategory::Logic);
}
#[test]
fn test_retry_policy_delay() {
let policy = RetryPolicy::default();
let delay = policy.delay_for_attempt(0);
assert!(delay.as_millis() >= 900); assert!(delay.as_millis() <= 1100);
}
#[test]
fn test_retry_policy_max_delay() {
let policy = RetryPolicy {
max_delay_ms: 5000,
..Default::default()
};
let delay = policy.delay_for_attempt(10);
assert!(delay.as_millis() <= 5500); }
#[tokio::test]
async fn test_error_recovery_creation() {
let recovery = ErrorRecovery::new();
let stats = recovery.get_stats().await;
assert_eq!(stats.total_errors, 0);
}
#[test]
fn test_fallback_strategy() {
let strategy = FallbackStrategy::Skip;
matches!(strategy, FallbackStrategy::Skip);
}
}