use serde::{Deserialize, Serialize};
use tracing::{info, instrument, warn};
use crate::core::error::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Diagnosis {
pub original_error: String,
pub root_cause: String,
pub confidence: f64,
pub severity: DiagnosisSeverity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DiagnosisSeverity {
Recoverable,
Degraded,
Fatal,
}
impl std::fmt::Display for DiagnosisSeverity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DiagnosisSeverity::Recoverable => write!(f, "RECOVERABLE"),
DiagnosisSeverity::Degraded => write!(f, "DEGRADED"),
DiagnosisSeverity::Fatal => write!(f, "FATAL"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RepairAction {
SwapModel {
replacement: String,
reason: String,
},
AdjustSchema {
change: String,
},
RetryWithParams {
adjustments: String,
},
SkipAndContinue {
skip_count: usize,
reason: String,
},
Escalate {
explanation: String,
},
}
impl std::fmt::Display for RepairAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RepairAction::SwapModel {
replacement,
reason,
} => write!(f, "SwapModel → {replacement}: {reason}"),
RepairAction::AdjustSchema { change } => write!(f, "AdjustSchema: {change}"),
RepairAction::RetryWithParams { adjustments } => {
write!(f, "RetryWithParams: {adjustments}")
}
RepairAction::SkipAndContinue { skip_count, reason } => {
write!(f, "SkipAndContinue ({skip_count} rows): {reason}")
}
RepairAction::Escalate { explanation } => write!(f, "Escalate → User: {explanation}"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RepairReport {
pub diagnosis: Diagnosis,
pub action: RepairAction,
pub applied: bool,
}
impl RepairReport {
pub fn summary(&self) -> String {
let status = if self.applied {
"✓ Applied"
} else {
"⚠ Pending"
};
format!(
"═══ Self-Repair Report ═══\n\
Severity: {}\n\
Root Cause: {}\n\
Confidence: {:.0}%\n\
Action: {}\n\
Status: {}",
self.diagnosis.severity,
self.diagnosis.root_cause,
self.diagnosis.confidence * 100.0,
self.action,
status
)
}
}
#[derive(Debug)]
pub struct SelfRepairAgent {
available_models: Vec<String>,
}
impl SelfRepairAgent {
pub fn new() -> Self {
Self {
available_models: Vec::new(),
}
}
pub fn register_available_models(&mut self, models: Vec<String>) {
self.available_models = models;
}
#[instrument(skip(self))]
pub fn diagnose_and_repair(
&self,
error_msg: &str,
operator_name: &str,
context: &str,
) -> Result<RepairReport> {
info!(
error = error_msg,
operator = operator_name,
"self-repair agent triggered"
);
let diagnosis = self.review(error_msg, operator_name, context)?;
let action = self.rewrite(&diagnosis, operator_name)?;
let report = RepairReport {
diagnosis,
action,
applied: false,
};
info!(report = %report.summary(), "self-repair report generated");
Ok(report)
}
fn review(&self, error_msg: &str, operator_name: &str, context: &str) -> Result<Diagnosis> {
let error_lower = error_msg.to_lowercase();
let (root_cause, severity, confidence) = if error_lower.contains("dimension")
|| error_lower.contains("shape")
{
(
format!(
"Input tensor shape mismatch in operator '{operator_name}'. \
The model expects a different number of features than provided."
),
DiagnosisSeverity::Recoverable,
0.9,
)
} else if error_lower.contains("unsupported")
|| error_lower.contains("format")
|| error_lower.contains("codec")
{
(
format!(
"Unsupported data format encountered by operator '{operator_name}'. \
The input data contains a type or encoding this operator cannot process."
),
DiagnosisSeverity::Degraded,
0.85,
)
} else if error_lower.contains("null")
|| error_lower.contains("missing")
|| error_lower.contains("none")
{
(
format!(
"Null or missing values detected in input to '{operator_name}'. \
{context}"
),
DiagnosisSeverity::Recoverable,
0.8,
)
} else if error_lower.contains("timeout")
|| error_lower.contains("deadline")
|| error_lower.contains("exceeded")
{
(
format!(
"Operator '{operator_name}' exceeded its execution time budget. \
Consider swapping to a faster model variant."
),
DiagnosisSeverity::Recoverable,
0.95,
)
} else if error_lower.contains("memory")
|| error_lower.contains("oom")
|| error_lower.contains("allocation")
{
(
format!(
"Out-of-memory condition in operator '{operator_name}'. \
The input batch may be too large for the current device."
),
DiagnosisSeverity::Degraded,
0.9,
)
} else {
(
format!("Unrecognized structural error in operator '{operator_name}': {error_msg}"),
DiagnosisSeverity::Fatal,
0.5,
)
};
Ok(Diagnosis {
original_error: error_msg.to_string(),
root_cause,
confidence,
severity,
})
}
fn rewrite(&self, diagnosis: &Diagnosis, operator_name: &str) -> Result<RepairAction> {
match diagnosis.severity {
DiagnosisSeverity::Recoverable => {
if let Some(alt) = self
.available_models
.iter()
.find(|m| m.as_str() != operator_name)
{
Ok(RepairAction::SwapModel {
replacement: alt.clone(),
reason: format!(
"Swapping from '{}' to '{}' to bypass: {}",
operator_name, alt, diagnosis.root_cause
),
})
} else {
Ok(RepairAction::RetryWithParams {
adjustments: "Reduce batch size and retry.".into(),
})
}
}
DiagnosisSeverity::Degraded => Ok(RepairAction::SkipAndContinue {
skip_count: 0,
reason: format!(
"Continuing in degraded mode. Unsupported rows will be skipped. \
Root cause: {}",
diagnosis.root_cause
),
}),
DiagnosisSeverity::Fatal => {
warn!(
error = %diagnosis.original_error,
"self-repair escalating to user"
);
Ok(RepairAction::Escalate {
explanation: format!(
"Cannot auto-repair: {}. Please review the operator configuration \
and input data manually.",
diagnosis.root_cause
),
})
}
}
}
}
impl Default for SelfRepairAgent {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn diagnose_dimension_mismatch() {
let agent = SelfRepairAgent::new();
let report = agent
.diagnose_and_repair(
"dimension mismatch: expected 3, got 5",
"fraud_detector",
"input batch has 5 columns",
)
.unwrap();
assert_eq!(report.diagnosis.severity, DiagnosisSeverity::Recoverable);
assert!(report.diagnosis.root_cause.contains("shape mismatch"));
}
#[test]
fn diagnose_timeout_with_swap() {
let mut agent = SelfRepairAgent::new();
agent.register_available_models(vec!["fraud_detector".into(), "fraud_fast".into()]);
let report = agent
.diagnose_and_repair(
"operator exceeded deadline of 50ms",
"fraud_detector",
"latency constraint violated",
)
.unwrap();
assert_eq!(report.diagnosis.severity, DiagnosisSeverity::Recoverable);
match &report.action {
RepairAction::SwapModel { replacement, .. } => {
assert_eq!(replacement, "fraud_fast");
}
other => panic!("expected SwapModel, got {other:?}"),
}
}
#[test]
fn diagnose_fatal_escalates() {
let agent = SelfRepairAgent::new();
let report = agent
.diagnose_and_repair("some unknown error xyz", "op1", "")
.unwrap();
assert_eq!(report.diagnosis.severity, DiagnosisSeverity::Fatal);
assert!(matches!(report.action, RepairAction::Escalate { .. }));
}
}