use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FailureCategoryKind {
Critical,
High,
Medium,
Low,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "failure_mode", content = "details")]
pub enum AIFailureMode {
ArtifactValidationFailed { validator_class: String },
ContractViolation { expected: String, observed: String },
CitationMissing {
required_count: u32,
found_count: u32,
},
WebSourcesMissing,
PathTraversalDetected { attempted_path: String },
InvalidOutputStructure { expected_format: String },
ProviderTimeout { timeout_ms: u64, provider: String },
ProviderRateLimit {
provider: String,
retry_after_secs: Option<u64>,
},
ProviderModelNotFound { model: String, provider: String },
ProviderError {
provider: String,
error_code: String,
},
ProviderStreamFailure { provider: String, error: String },
FallbackProvidersExhausted,
RepairBudgetExhausted { max_iterations: u32, attempts: u32 },
RepairLoopInfinite { cycle_count: u32 },
RepairStateInvalid { state: String },
RepairUnavailable { reason: String },
SourceAccessDenied { source: String, reason: String },
SourceNotFound { source: String },
DataCorruption { source: String, error: String },
WorkspaceNotFound { workspace_id: String },
NodeTimeout { timeout_ms: u64, actual_ms: u64 },
SessionTimeout { timeout_ms: u64, actual_ms: u64 },
ProviderCallTimeout { timeout_ms: u64 },
AuthorizationFailed { user_id: String, resource: String },
InvalidApiKey { provider: String },
PermissionDenied { resource: String },
TokenBudgetExhausted { limit: u64, used: u64 },
CostBudgetExhausted { limit_usd: f64, used_usd: f64 },
MemoryExhausted { limit_mb: u64 },
DiskSpaceExhausted { required_mb: u64, available_mb: u64 },
ConfigurationError { field: String, reason: String },
DependencyMissing { dependency: String },
FeatureDisabled { feature: String },
Unknown { error_message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailureContext {
pub automation_id: String,
pub node_id: String,
pub run_id: String,
pub failure_mode: AIFailureMode,
pub category: FailureCategoryKind,
pub timestamp_ms: u64,
pub error_text: String,
pub recovery_attempted: bool,
pub repair_attempts: u32,
pub provider: Option<String>,
pub tags: Vec<String>,
pub description: String,
}
impl fmt::Display for AIFailureMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl fmt::Display for FailureCategoryKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FailureCategoryKind::Critical => write!(f, "Critical"),
FailureCategoryKind::High => write!(f, "High"),
FailureCategoryKind::Medium => write!(f, "Medium"),
FailureCategoryKind::Low => write!(f, "Low"),
}
}
}
pub fn classify_error_text(error_text: &str, node_context: Option<&str>) -> AIFailureMode {
let lower = error_text.to_lowercase();
if lower.contains("timeout") || lower.contains("timed out") {
return AIFailureMode::ProviderTimeout {
timeout_ms: 30000,
provider: "unknown".to_string(),
};
}
if lower.contains("rate limit") || lower.contains("rate_limit") {
return AIFailureMode::ProviderRateLimit {
provider: "unknown".to_string(),
retry_after_secs: None,
};
}
if lower.contains("model not found") || lower.contains("model_not_found") {
return AIFailureMode::ProviderModelNotFound {
model: "unknown".to_string(),
provider: "unknown".to_string(),
};
}
if lower.contains("validation failed") || lower.contains("validator") {
return AIFailureMode::ArtifactValidationFailed {
validator_class: "unknown".to_string(),
};
}
if lower.contains("citation") && lower.contains("missing") {
return AIFailureMode::CitationMissing {
required_count: 1,
found_count: 0,
};
}
if lower.contains("web source") && lower.contains("missing") {
return AIFailureMode::WebSourcesMissing;
}
if lower.contains("repair") && lower.contains("budget") {
return AIFailureMode::RepairBudgetExhausted {
max_iterations: 3,
attempts: 3,
};
}
if lower.contains("repair") && lower.contains("loop") {
return AIFailureMode::RepairLoopInfinite { cycle_count: 5 };
}
if lower.contains("permission") && lower.contains("denied") {
return AIFailureMode::SourceAccessDenied {
source: "unknown".to_string(),
reason: error_text.to_string(),
};
}
if lower.contains("not found") && lower.contains("source") {
return AIFailureMode::SourceNotFound {
source: "unknown".to_string(),
};
}
if lower.contains("node") && lower.contains("timeout") {
return AIFailureMode::NodeTimeout {
timeout_ms: 600000,
actual_ms: 600001,
};
}
if lower.contains("configuration") || lower.contains("config") {
return AIFailureMode::ConfigurationError {
field: "unknown".to_string(),
reason: error_text.to_string(),
};
}
AIFailureMode::Unknown {
error_message: error_text.to_string(),
}
}
pub fn should_retry(failure: &AIFailureMode) -> bool {
matches!(
failure,
AIFailureMode::ProviderTimeout { .. }
| AIFailureMode::ProviderRateLimit { .. }
| AIFailureMode::ProviderStreamFailure { .. }
| AIFailureMode::ProviderCallTimeout { .. }
| AIFailureMode::SessionTimeout { .. }
)
}
pub fn categorize_failure(failure: &AIFailureMode) -> FailureCategoryKind {
match failure {
AIFailureMode::AuthorizationFailed { .. }
| AIFailureMode::PermissionDenied { .. }
| AIFailureMode::ContractViolation { .. }
| AIFailureMode::RepairLoopInfinite { .. }
| AIFailureMode::TokenBudgetExhausted { .. }
| AIFailureMode::SourceNotFound { .. } => FailureCategoryKind::Critical,
AIFailureMode::ArtifactValidationFailed { .. }
| AIFailureMode::RepairBudgetExhausted { .. }
| AIFailureMode::PathTraversalDetected { .. }
| AIFailureMode::NodeTimeout { .. }
| AIFailureMode::CostBudgetExhausted { .. }
| AIFailureMode::ProviderError { .. }
| AIFailureMode::FallbackProvidersExhausted => FailureCategoryKind::High,
AIFailureMode::ProviderTimeout { .. }
| AIFailureMode::ProviderRateLimit { .. }
| AIFailureMode::ProviderStreamFailure { .. }
| AIFailureMode::CitationMissing { .. }
| AIFailureMode::WebSourcesMissing
| AIFailureMode::DataCorruption { .. }
| AIFailureMode::SourceAccessDenied { .. }
| AIFailureMode::SessionTimeout { .. }
| AIFailureMode::ConfigurationError { .. } => FailureCategoryKind::Medium,
AIFailureMode::ProviderModelNotFound { .. }
| AIFailureMode::InvalidApiKey { .. }
| AIFailureMode::MemoryExhausted { .. }
| AIFailureMode::DiskSpaceExhausted { .. }
| AIFailureMode::DependencyMissing { .. }
| AIFailureMode::FeatureDisabled { .. }
| AIFailureMode::InvalidOutputStructure { .. }
| AIFailureMode::RepairStateInvalid { .. }
| AIFailureMode::RepairUnavailable { .. }
| AIFailureMode::WorkspaceNotFound { .. }
| AIFailureMode::ProviderCallTimeout { .. }
| AIFailureMode::Unknown { .. } => FailureCategoryKind::Low,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_provider_timeout() {
let failure = classify_error_text("Request timed out after 30s", None);
assert!(matches!(failure, AIFailureMode::ProviderTimeout { .. }));
}
#[test]
fn test_classify_rate_limit() {
let failure = classify_error_text("Rate limit exceeded", None);
assert!(matches!(failure, AIFailureMode::ProviderRateLimit { .. }));
}
#[test]
fn test_classify_validation_failure() {
let failure = classify_error_text("Artifact validation failed: missing_section", None);
assert!(matches!(
failure,
AIFailureMode::ArtifactValidationFailed { .. }
));
}
#[test]
fn test_classify_citation_missing() {
let failure = classify_error_text("Citations missing from research output", None);
assert!(matches!(failure, AIFailureMode::CitationMissing { .. }));
}
#[test]
fn test_classify_unknown() {
let failure = classify_error_text("Some random error that we don't understand", None);
assert!(matches!(failure, AIFailureMode::Unknown { .. }));
}
#[test]
fn test_should_retry_transient() {
let timeout = AIFailureMode::ProviderTimeout {
timeout_ms: 30000,
provider: "openai".to_string(),
};
assert!(should_retry(&timeout));
}
#[test]
fn test_should_not_retry_deterministic() {
let validation = AIFailureMode::ArtifactValidationFailed {
validator_class: "citation".to_string(),
};
assert!(!should_retry(&validation));
}
#[test]
fn test_categorize_critical() {
let auth_failure = AIFailureMode::AuthorizationFailed {
user_id: "user1".to_string(),
resource: "automation1".to_string(),
};
assert_eq!(
categorize_failure(&auth_failure),
FailureCategoryKind::Critical
);
}
#[test]
fn test_categorize_high() {
let validation = AIFailureMode::ArtifactValidationFailed {
validator_class: "contract".to_string(),
};
assert_eq!(categorize_failure(&validation), FailureCategoryKind::High);
}
#[test]
fn test_categorize_medium() {
let timeout = AIFailureMode::ProviderTimeout {
timeout_ms: 30000,
provider: "openai".to_string(),
};
assert_eq!(categorize_failure(&timeout), FailureCategoryKind::Medium);
}
#[test]
fn test_failure_context_serialization() {
let context = FailureContext {
automation_id: "auto1".to_string(),
node_id: "node1".to_string(),
run_id: "run1".to_string(),
failure_mode: AIFailureMode::ProviderTimeout {
timeout_ms: 30000,
provider: "openai".to_string(),
},
category: FailureCategoryKind::Medium,
timestamp_ms: 1234567890,
error_text: "Request timed out".to_string(),
recovery_attempted: true,
repair_attempts: 2,
provider: Some("openai".to_string()),
tags: vec!["transient".to_string(), "retriable".to_string()],
description: "Provider timeout during execution".to_string(),
};
let json = serde_json::to_string(&context).unwrap();
let parsed: FailureContext = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.automation_id, "auto1");
assert_eq!(parsed.repair_attempts, 2);
}
#[test]
fn test_failure_mode_display() {
let failure = AIFailureMode::CitationMissing {
required_count: 3,
found_count: 1,
};
let display_str = failure.to_string();
assert!(display_str.contains("CitationMissing"));
}
}