use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use super::detector::{Problem, ProblemType, SystemMetrics};
use super::learning::OutcomeTracker;
use super::strategies::{
RemediationResult, RemediationStrategy, StrategyContext, StrategyRegistry,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealingConfig {
pub min_healing_interval: Duration,
pub max_attempts_per_window: usize,
pub attempt_window: Duration,
pub max_auto_heal_impact: f32,
pub require_approval: Vec<ProblemType>,
pub require_approval_strategies: Vec<String>,
pub learning_enabled: bool,
pub failure_cooldown: Duration,
pub verify_improvement: bool,
pub min_improvement_pct: f32,
pub max_concurrent_remediations: usize,
}
impl Default for HealingConfig {
fn default() -> Self {
Self {
min_healing_interval: Duration::from_secs(300), max_attempts_per_window: 3,
attempt_window: Duration::from_secs(3600), max_auto_heal_impact: 0.5,
require_approval: vec![],
require_approval_strategies: vec!["promote_replica".to_string()],
learning_enabled: true,
failure_cooldown: Duration::from_secs(600), verify_improvement: true,
min_improvement_pct: 5.0,
max_concurrent_remediations: 2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HealingOutcome {
Completed {
problem_type: ProblemType,
strategy: String,
result: RemediationResult,
verified: bool,
},
Deferred {
reason: String,
problem_type: ProblemType,
},
NoStrategy { problem_type: ProblemType },
Disabled,
MaxConcurrent,
}
impl HealingOutcome {
pub fn to_json(&self) -> serde_json::Value {
match self {
HealingOutcome::Completed {
problem_type,
strategy,
result,
verified,
} => {
serde_json::json!({
"status": "completed",
"problem_type": problem_type.to_string(),
"strategy": strategy,
"result": result.to_json(),
"verified": verified,
})
}
HealingOutcome::Deferred {
reason,
problem_type,
} => {
serde_json::json!({
"status": "deferred",
"reason": reason,
"problem_type": problem_type.to_string(),
})
}
HealingOutcome::NoStrategy { problem_type } => {
serde_json::json!({
"status": "no_strategy",
"problem_type": problem_type.to_string(),
})
}
HealingOutcome::Disabled => {
serde_json::json!({
"status": "disabled",
})
}
HealingOutcome::MaxConcurrent => {
serde_json::json!({
"status": "max_concurrent",
})
}
}
}
}
#[derive(Debug, Clone)]
pub struct ActiveRemediation {
pub id: u64,
pub problem: Problem,
pub strategy_name: String,
pub started_at: SystemTime,
pub expected_completion: SystemTime,
}
impl ActiveRemediation {
pub fn to_json(&self) -> serde_json::Value {
let started_ts = self
.started_at
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let expected_ts = self
.expected_completion
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
serde_json::json!({
"id": self.id,
"problem_type": self.problem.problem_type.to_string(),
"strategy": self.strategy_name,
"started_at": started_ts,
"expected_completion": expected_ts,
})
}
}
#[derive(Debug, Clone)]
pub struct RemediationContext {
pub problem: Problem,
pub collection_id: i64,
pub tenant_id: Option<String>,
pub initial_lambda: f32,
pub target_lambda: f32,
pub initial_metrics: SystemMetrics,
pub created_at: SystemTime,
pub max_impact: f32,
pub timeout: Duration,
pub attempts_in_window: usize,
pub last_attempt: Option<SystemTime>,
}
impl RemediationContext {
pub fn new(problem: Problem, metrics: SystemMetrics) -> Self {
Self {
problem,
collection_id: 0,
tenant_id: None,
initial_lambda: metrics.integrity_lambda,
target_lambda: 0.8,
initial_metrics: metrics,
created_at: SystemTime::now(),
max_impact: 0.5,
timeout: Duration::from_secs(300),
attempts_in_window: 0,
last_attempt: None,
}
}
pub fn with_collection(mut self, collection_id: i64) -> Self {
self.collection_id = collection_id;
self
}
pub fn with_tenant(mut self, tenant_id: String) -> Self {
self.tenant_id = Some(tenant_id);
self
}
pub fn to_strategy_context(&self) -> StrategyContext {
StrategyContext {
problem: self.problem.clone(),
collection_id: self.collection_id,
initial_lambda: self.initial_lambda,
target_lambda: self.target_lambda,
max_impact: self.max_impact,
timeout: self.timeout,
start_time: SystemTime::now(),
dry_run: false,
}
}
}
pub struct RemediationEngine {
pub registry: StrategyRegistry,
config: RwLock<HealingConfig>,
tracker: OutcomeTracker,
active: RwLock<Vec<ActiveRemediation>>,
next_id: AtomicU64,
attempt_history: RwLock<HashMap<ProblemType, VecDeque<SystemTime>>>,
enabled: AtomicBool,
total_healings: AtomicU64,
successful_healings: AtomicU64,
}
impl RemediationEngine {
pub fn new(registry: StrategyRegistry, config: HealingConfig, tracker: OutcomeTracker) -> Self {
Self {
registry,
config: RwLock::new(config),
tracker,
active: RwLock::new(Vec::new()),
next_id: AtomicU64::new(1),
attempt_history: RwLock::new(HashMap::new()),
enabled: AtomicBool::new(true),
total_healings: AtomicU64::new(0),
successful_healings: AtomicU64::new(0),
}
}
pub fn set_enabled(&self, enabled: bool) {
self.enabled.store(enabled, Ordering::SeqCst);
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::SeqCst)
}
pub fn update_config(&self, config: HealingConfig) {
*self.config.write() = config;
}
pub fn get_config(&self) -> HealingConfig {
self.config.read().clone()
}
pub fn active_remediations(&self) -> Vec<ActiveRemediation> {
self.active.read().clone()
}
pub fn heal(&self, problem: &Problem) -> HealingOutcome {
if !self.is_enabled() {
return HealingOutcome::Disabled;
}
let config = self.config.read().clone();
if self.active.read().len() >= config.max_concurrent_remediations {
return HealingOutcome::MaxConcurrent;
}
if !self.should_auto_heal(problem, &config) {
return HealingOutcome::Deferred {
reason: self.get_defer_reason(problem, &config),
problem_type: problem.problem_type,
};
}
let strategy = match self.registry.select(problem, config.max_auto_heal_impact) {
Some(s) => s,
None => {
return HealingOutcome::NoStrategy {
problem_type: problem.problem_type,
};
}
};
if config
.require_approval_strategies
.contains(&strategy.name().to_string())
{
return HealingOutcome::Deferred {
reason: format!("Strategy '{}' requires human approval", strategy.name()),
problem_type: problem.problem_type,
};
}
self.record_attempt(problem.problem_type);
self.total_healings.fetch_add(1, Ordering::SeqCst);
let remediation_id = self.next_id.fetch_add(1, Ordering::SeqCst);
let active_rem = ActiveRemediation {
id: remediation_id,
problem: problem.clone(),
strategy_name: strategy.name().to_string(),
started_at: SystemTime::now(),
expected_completion: SystemTime::now() + strategy.estimated_duration(),
};
self.active.write().push(active_rem);
let context = StrategyContext {
problem: problem.clone(),
collection_id: 0,
initial_lambda: 1.0,
target_lambda: 0.8,
max_impact: config.max_auto_heal_impact,
timeout: strategy.estimated_duration() * 2,
start_time: SystemTime::now(),
dry_run: false,
};
let result = self.execute_with_safeguards(&*strategy, &context);
self.active.write().retain(|r| r.id != remediation_id);
let verified = if config.verify_improvement && result.is_success() {
self.verify_improvement(&result, config.min_improvement_pct)
} else {
result.is_success()
};
if !verified && strategy.reversible() {
pgrx::log!(
"Remediation not verified, rolling back: {}",
strategy.name()
);
if let Err(e) = strategy.rollback(&context, &result) {
pgrx::warning!("Rollback failed: {}", e);
}
}
if config.learning_enabled {
self.registry
.update_weight(strategy.name(), verified, result.improvement_pct);
self.tracker
.record(problem, strategy.name(), &result, verified);
}
if verified {
self.successful_healings.fetch_add(1, Ordering::SeqCst);
}
HealingOutcome::Completed {
problem_type: problem.problem_type,
strategy: strategy.name().to_string(),
result,
verified,
}
}
fn execute_with_safeguards(
&self,
strategy: &dyn RemediationStrategy,
context: &StrategyContext,
) -> RemediationResult {
let start = std::time::Instant::now();
let mut result = strategy.execute(context);
result.duration_ms = start.elapsed().as_millis() as u64;
result
}
fn should_auto_heal(&self, problem: &Problem, config: &HealingConfig) -> bool {
if config.require_approval.contains(&problem.problem_type) {
return false;
}
if !self.is_past_cooldown(problem.problem_type, config) {
return false;
}
if self.attempts_in_window(problem.problem_type, &config.attempt_window)
>= config.max_attempts_per_window
{
return false;
}
true
}
fn get_defer_reason(&self, problem: &Problem, config: &HealingConfig) -> String {
if config.require_approval.contains(&problem.problem_type) {
return format!(
"Problem type '{:?}' requires human approval",
problem.problem_type
);
}
if !self.is_past_cooldown(problem.problem_type, config) {
return "In cooldown period after recent healing attempt".to_string();
}
if self.attempts_in_window(problem.problem_type, &config.attempt_window)
>= config.max_attempts_per_window
{
return format!(
"Exceeded maximum {} attempts per {:?}",
config.max_attempts_per_window, config.attempt_window
);
}
"Unknown reason".to_string()
}
fn is_past_cooldown(&self, problem_type: ProblemType, config: &HealingConfig) -> bool {
let history = self.attempt_history.read();
if let Some(attempts) = history.get(&problem_type) {
if let Some(last) = attempts.back() {
if let Ok(elapsed) = last.elapsed() {
return elapsed >= config.min_healing_interval;
}
}
}
true
}
fn attempts_in_window(&self, problem_type: ProblemType, window: &Duration) -> usize {
let history = self.attempt_history.read();
if let Some(attempts) = history.get(&problem_type) {
let cutoff = SystemTime::now() - *window;
attempts.iter().filter(|t| **t > cutoff).count()
} else {
0
}
}
fn record_attempt(&self, problem_type: ProblemType) {
let mut history = self.attempt_history.write();
let attempts = history.entry(problem_type).or_insert_with(VecDeque::new);
attempts.push_back(SystemTime::now());
let cutoff = SystemTime::now() - Duration::from_secs(86400); while let Some(front) = attempts.front() {
if *front < cutoff {
attempts.pop_front();
} else {
break;
}
}
}
fn verify_improvement(&self, result: &RemediationResult, min_pct: f32) -> bool {
result.improvement_pct >= min_pct
}
pub fn get_stats(&self) -> EngineStats {
let total = self.total_healings.load(Ordering::SeqCst);
let successful = self.successful_healings.load(Ordering::SeqCst);
EngineStats {
enabled: self.is_enabled(),
total_healings: total,
successful_healings: successful,
success_rate: if total > 0 {
successful as f32 / total as f32
} else {
0.0
},
active_remediations: self.active.read().len(),
strategy_weights: self.registry.get_all_weights(),
}
}
pub fn execute_strategy(
&self,
strategy_name: &str,
problem: &Problem,
dry_run: bool,
) -> Option<HealingOutcome> {
let strategy = self.registry.get_by_name(strategy_name)?;
let _config = self.config.read().clone();
let context = StrategyContext {
problem: problem.clone(),
collection_id: 0,
initial_lambda: 1.0,
target_lambda: 0.8,
max_impact: 1.0, timeout: strategy.estimated_duration() * 2,
start_time: SystemTime::now(),
dry_run,
};
let result = strategy.execute(&context);
Some(HealingOutcome::Completed {
problem_type: problem.problem_type,
strategy: strategy_name.to_string(),
result,
verified: !dry_run,
})
}
}
#[derive(Debug, Clone)]
pub struct EngineStats {
pub enabled: bool,
pub total_healings: u64,
pub successful_healings: u64,
pub success_rate: f32,
pub active_remediations: usize,
pub strategy_weights: HashMap<String, f32>,
}
impl EngineStats {
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"enabled": self.enabled,
"total_healings": self.total_healings,
"successful_healings": self.successful_healings,
"success_rate": self.success_rate,
"active_remediations": self.active_remediations,
"strategy_weights": self.strategy_weights,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::healing::detector::Severity;
fn create_engine() -> RemediationEngine {
let registry = StrategyRegistry::new_with_defaults();
let config = HealingConfig::default();
let tracker = OutcomeTracker::new();
RemediationEngine::new(registry, config, tracker)
}
#[test]
fn test_engine_creation() {
let engine = create_engine();
assert!(engine.is_enabled());
assert!(engine.active_remediations().is_empty());
}
#[test]
fn test_engine_enable_disable() {
let engine = create_engine();
engine.set_enabled(false);
assert!(!engine.is_enabled());
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
let outcome = engine.heal(&problem);
assert!(matches!(outcome, HealingOutcome::Disabled));
engine.set_enabled(true);
assert!(engine.is_enabled());
}
#[test]
fn test_heal_index_degradation() {
let engine = create_engine();
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
let outcome = engine.heal(&problem);
match outcome {
HealingOutcome::Completed { strategy, .. } => {
assert!(strategy.contains("reindex") || strategy.contains("integrity"));
}
_ => panic!("Expected Completed outcome"),
}
}
#[test]
fn test_cooldown_enforcement() {
let mut config = HealingConfig::default();
config.min_healing_interval = Duration::from_secs(60);
let registry = StrategyRegistry::new_with_defaults();
let tracker = OutcomeTracker::new();
let engine = RemediationEngine::new(registry, config, tracker);
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
let outcome1 = engine.heal(&problem);
assert!(matches!(outcome1, HealingOutcome::Completed { .. }));
let outcome2 = engine.heal(&problem);
assert!(matches!(outcome2, HealingOutcome::Deferred { .. }));
}
#[test]
fn test_max_attempts_enforcement() {
let mut config = HealingConfig::default();
config.max_attempts_per_window = 2;
config.min_healing_interval = Duration::from_millis(1);
let registry = StrategyRegistry::new_with_defaults();
let tracker = OutcomeTracker::new();
let engine = RemediationEngine::new(registry, config, tracker);
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
engine.heal(&problem);
std::thread::sleep(Duration::from_millis(2));
engine.heal(&problem);
std::thread::sleep(Duration::from_millis(2));
let outcome = engine.heal(&problem);
assert!(matches!(outcome, HealingOutcome::Deferred { .. }));
}
#[test]
fn test_approval_requirement() {
let mut config = HealingConfig::default();
config.require_approval.push(ProblemType::ReplicaLag);
let registry = StrategyRegistry::new_with_defaults();
let tracker = OutcomeTracker::new();
let engine = RemediationEngine::new(registry, config, tracker);
let problem = Problem::new(ProblemType::ReplicaLag, Severity::High);
let outcome = engine.heal(&problem);
assert!(matches!(outcome, HealingOutcome::Deferred { .. }));
}
#[test]
fn test_strategy_approval_requirement() {
let mut config = HealingConfig::default();
config
.require_approval_strategies
.push("promote_replica".to_string());
config.max_auto_heal_impact = 1.0;
let registry = StrategyRegistry::new_with_defaults();
let tracker = OutcomeTracker::new();
let engine = RemediationEngine::new(registry, config, tracker);
let problem = Problem::new(ProblemType::ReplicaLag, Severity::High);
let outcome = engine.heal(&problem);
assert!(matches!(outcome, HealingOutcome::Deferred { .. }));
}
#[test]
fn test_no_strategy() {
let registry = StrategyRegistry::new(); let config = HealingConfig::default();
let tracker = OutcomeTracker::new();
let engine = RemediationEngine::new(registry, config, tracker);
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
let outcome = engine.heal(&problem);
assert!(matches!(outcome, HealingOutcome::NoStrategy { .. }));
}
#[test]
fn test_manual_execution() {
let engine = create_engine();
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
let outcome = engine.execute_strategy("reindex_partition", &problem, true);
assert!(outcome.is_some());
if let Some(HealingOutcome::Completed { result, .. }) = outcome {
assert!(result.metadata.get("dry_run") == Some(&serde_json::json!(true)));
}
}
#[test]
fn test_engine_stats() {
let engine = create_engine();
let stats = engine.get_stats();
assert!(stats.enabled);
assert_eq!(stats.total_healings, 0);
assert_eq!(stats.active_remediations, 0);
}
}