use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
#[cfg_attr(not(test), allow(unused_imports))]
use super::detector::{Problem, ProblemType, Severity};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RemediationOutcome {
Success,
Partial,
Failure,
NoOp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemediationResult {
pub outcome: RemediationOutcome,
pub actions_taken: usize,
pub improvement_pct: f32,
pub error_message: Option<String>,
pub metadata: serde_json::Value,
pub duration_ms: u64,
pub rollback_actions: Vec<serde_json::Value>,
}
impl RemediationResult {
pub fn success(actions_taken: usize, improvement_pct: f32) -> Self {
Self {
outcome: RemediationOutcome::Success,
actions_taken,
improvement_pct,
error_message: None,
metadata: serde_json::json!({}),
duration_ms: 0,
rollback_actions: vec![],
}
}
pub fn partial(actions_taken: usize, improvement_pct: f32, message: &str) -> Self {
Self {
outcome: RemediationOutcome::Partial,
actions_taken,
improvement_pct,
error_message: Some(message.to_string()),
metadata: serde_json::json!({}),
duration_ms: 0,
rollback_actions: vec![],
}
}
pub fn failure(message: &str) -> Self {
Self {
outcome: RemediationOutcome::Failure,
actions_taken: 0,
improvement_pct: 0.0,
error_message: Some(message.to_string()),
metadata: serde_json::json!({}),
duration_ms: 0,
rollback_actions: vec![],
}
}
pub fn noop() -> Self {
Self {
outcome: RemediationOutcome::NoOp,
actions_taken: 0,
improvement_pct: 0.0,
error_message: None,
metadata: serde_json::json!({}),
duration_ms: 0,
rollback_actions: vec![],
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = duration_ms;
self
}
pub fn with_rollback(mut self, actions: Vec<serde_json::Value>) -> Self {
self.rollback_actions = actions;
self
}
pub fn is_success(&self) -> bool {
matches!(self.outcome, RemediationOutcome::Success)
}
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"outcome": format!("{:?}", self.outcome).to_lowercase(),
"actions_taken": self.actions_taken,
"improvement_pct": self.improvement_pct,
"error_message": self.error_message,
"metadata": self.metadata,
"duration_ms": self.duration_ms,
})
}
}
#[derive(Debug, Clone)]
pub struct StrategyContext {
pub problem: Problem,
pub collection_id: i64,
pub initial_lambda: f32,
pub target_lambda: f32,
pub max_impact: f32,
pub timeout: Duration,
pub start_time: SystemTime,
pub dry_run: bool,
}
impl StrategyContext {
pub fn new(problem: Problem) -> Self {
Self {
problem,
collection_id: 0,
initial_lambda: 1.0,
target_lambda: 0.8,
max_impact: 0.5,
timeout: Duration::from_secs(300),
start_time: SystemTime::now(),
dry_run: false,
}
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed().unwrap_or(Duration::ZERO)
}
pub fn is_timed_out(&self) -> bool {
self.elapsed() > self.timeout
}
}
pub trait RemediationStrategy: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn handles(&self) -> Vec<ProblemType>;
fn impact(&self) -> f32;
fn estimated_duration(&self) -> Duration;
fn reversible(&self) -> bool;
fn execute(&self, context: &StrategyContext) -> RemediationResult;
fn rollback(&self, context: &StrategyContext, result: &RemediationResult)
-> Result<(), String>;
}
pub struct ReindexPartition {
max_partitions: usize,
concurrent: bool,
}
impl ReindexPartition {
pub fn new() -> Self {
Self {
max_partitions: 3,
concurrent: true,
}
}
pub fn with_settings(max_partitions: usize, concurrent: bool) -> Self {
Self {
max_partitions,
concurrent,
}
}
fn reindex_partition(&self, partition_id: i64, concurrent: bool) -> Result<(), String> {
if concurrent {
pgrx::log!("Reindexing partition {} concurrently", partition_id);
} else {
pgrx::log!("Reindexing partition {}", partition_id);
}
Ok(())
}
}
impl Default for ReindexPartition {
fn default() -> Self {
Self::new()
}
}
impl RemediationStrategy for ReindexPartition {
fn name(&self) -> &str {
"reindex_partition"
}
fn description(&self) -> &str {
"Rebuild degraded index partition to restore search performance"
}
fn handles(&self) -> Vec<ProblemType> {
vec![ProblemType::IndexDegradation]
}
fn impact(&self) -> f32 {
if self.concurrent {
0.3 } else {
0.7 }
}
fn estimated_duration(&self) -> Duration {
Duration::from_secs(60 * self.max_partitions as u64)
}
fn reversible(&self) -> bool {
false }
fn execute(&self, context: &StrategyContext) -> RemediationResult {
let start = std::time::Instant::now();
if context.dry_run {
return RemediationResult::noop().with_metadata(serde_json::json!({
"dry_run": true,
"would_reindex": context.problem.affected_partitions.len(),
}));
}
let mut reindexed = 0;
let mut errors = Vec::new();
for partition_id in context
.problem
.affected_partitions
.iter()
.take(self.max_partitions)
{
if context.is_timed_out() {
break;
}
match self.reindex_partition(*partition_id, self.concurrent) {
Ok(()) => reindexed += 1,
Err(e) => errors.push(format!("Partition {}: {}", partition_id, e)),
}
}
let duration_ms = start.elapsed().as_millis() as u64;
if reindexed == 0 && !errors.is_empty() {
RemediationResult::failure(&errors.join("; ")).with_duration(duration_ms)
} else if !errors.is_empty() {
RemediationResult::partial(reindexed, 0.0, &errors.join("; "))
.with_duration(duration_ms)
} else {
RemediationResult::success(reindexed, 0.0)
.with_duration(duration_ms)
.with_metadata(serde_json::json!({
"reindexed_partitions": reindexed,
"concurrent": self.concurrent,
}))
}
}
fn rollback(
&self,
_context: &StrategyContext,
_result: &RemediationResult,
) -> Result<(), String> {
Ok(())
}
}
pub struct PromoteReplica {
grace_period: Duration,
}
impl PromoteReplica {
pub fn new() -> Self {
Self {
grace_period: Duration::from_secs(30),
}
}
pub fn with_grace_period(grace_period: Duration) -> Self {
Self { grace_period }
}
fn find_best_replica(&self) -> Option<String> {
Some("replica_1".to_string())
}
fn promote_replica(&self, replica_id: &str) -> Result<(), String> {
pgrx::log!("Promoting replica {} to primary", replica_id);
Ok(())
}
}
impl Default for PromoteReplica {
fn default() -> Self {
Self::new()
}
}
impl RemediationStrategy for PromoteReplica {
fn name(&self) -> &str {
"promote_replica"
}
fn description(&self) -> &str {
"Failover to healthy replica when primary is experiencing issues"
}
fn handles(&self) -> Vec<ProblemType> {
vec![ProblemType::ReplicaLag, ProblemType::IntegrityViolation]
}
fn impact(&self) -> f32 {
0.6 }
fn estimated_duration(&self) -> Duration {
self.grace_period + Duration::from_secs(30)
}
fn reversible(&self) -> bool {
true }
fn execute(&self, context: &StrategyContext) -> RemediationResult {
let start = std::time::Instant::now();
if context.dry_run {
return RemediationResult::noop().with_metadata(serde_json::json!({
"dry_run": true,
"candidate_replica": self.find_best_replica(),
}));
}
let replica_id = match self.find_best_replica() {
Some(id) => id,
None => {
return RemediationResult::failure("No healthy replica found");
}
};
std::thread::sleep(self.grace_period);
match self.promote_replica(&replica_id) {
Ok(()) => RemediationResult::success(1, 0.0)
.with_duration(start.elapsed().as_millis() as u64)
.with_metadata(serde_json::json!({
"promoted_replica": replica_id,
}))
.with_rollback(vec![serde_json::json!({
"action": "demote",
"replica_id": replica_id,
})]),
Err(e) => {
RemediationResult::failure(&e).with_duration(start.elapsed().as_millis() as u64)
}
}
}
fn rollback(
&self,
_context: &StrategyContext,
result: &RemediationResult,
) -> Result<(), String> {
for action in &result.rollback_actions {
if action.get("action") == Some(&serde_json::json!("demote")) {
let replica_id = action
.get("replica_id")
.and_then(|v| v.as_str())
.ok_or("Missing replica_id in rollback action")?;
pgrx::log!("Rolling back: demoting {}", replica_id);
}
}
Ok(())
}
}
pub struct TierEviction {
target_free_pct: f32,
batch_size: usize,
}
impl TierEviction {
pub fn new() -> Self {
Self {
target_free_pct: 20.0,
batch_size: 10000,
}
}
pub fn with_settings(target_free_pct: f32, batch_size: usize) -> Self {
Self {
target_free_pct,
batch_size,
}
}
fn find_cold_candidates(&self, _limit: usize) -> Vec<i64> {
vec![]
}
fn evict_to_cold_tier(&self, vector_ids: &[i64]) -> Result<usize, String> {
pgrx::log!("Evicting {} vectors to cold tier", vector_ids.len());
Ok(vector_ids.len())
}
}
impl Default for TierEviction {
fn default() -> Self {
Self::new()
}
}
impl RemediationStrategy for TierEviction {
fn name(&self) -> &str {
"tier_eviction"
}
fn description(&self) -> &str {
"Move cold data to lower storage tier to free up space"
}
fn handles(&self) -> Vec<ProblemType> {
vec![ProblemType::StorageExhaustion, ProblemType::MemoryPressure]
}
fn impact(&self) -> f32 {
0.4 }
fn estimated_duration(&self) -> Duration {
Duration::from_secs(120)
}
fn reversible(&self) -> bool {
true }
fn execute(&self, context: &StrategyContext) -> RemediationResult {
let start = std::time::Instant::now();
if context.dry_run {
let candidates = self.find_cold_candidates(self.batch_size);
return RemediationResult::noop().with_metadata(serde_json::json!({
"dry_run": true,
"candidates_found": candidates.len(),
}));
}
let mut total_evicted = 0;
let mut evicted_ids = Vec::new();
while !context.is_timed_out() {
let candidates = self.find_cold_candidates(self.batch_size);
if candidates.is_empty() {
break;
}
match self.evict_to_cold_tier(&candidates) {
Ok(count) => {
total_evicted += count;
evicted_ids.extend(candidates);
}
Err(e) => {
return RemediationResult::partial(total_evicted, 0.0, &e)
.with_duration(start.elapsed().as_millis() as u64);
}
}
}
if total_evicted > 0 {
RemediationResult::success(total_evicted, self.target_free_pct)
.with_duration(start.elapsed().as_millis() as u64)
.with_metadata(serde_json::json!({
"evicted_count": total_evicted,
}))
.with_rollback(vec![serde_json::json!({
"action": "restore_from_cold",
"vector_ids": evicted_ids,
})])
} else {
RemediationResult::noop().with_metadata(serde_json::json!({
"message": "No cold data candidates found",
}))
}
}
fn rollback(
&self,
_context: &StrategyContext,
result: &RemediationResult,
) -> Result<(), String> {
for action in &result.rollback_actions {
if action.get("action") == Some(&serde_json::json!("restore_from_cold")) {
pgrx::log!("Rolling back tier eviction");
}
}
Ok(())
}
}
pub struct QueryCircuitBreaker {
block_duration: Duration,
blocked_patterns: RwLock<Vec<String>>,
}
impl QueryCircuitBreaker {
pub fn new() -> Self {
Self {
block_duration: Duration::from_secs(300),
blocked_patterns: RwLock::new(Vec::new()),
}
}
pub fn with_duration(block_duration: Duration) -> Self {
Self {
block_duration,
blocked_patterns: RwLock::new(Vec::new()),
}
}
fn find_problematic_queries(&self) -> Vec<String> {
vec![]
}
fn block_pattern(&self, pattern: &str) -> Result<(), String> {
self.blocked_patterns.write().push(pattern.to_string());
pgrx::log!("Blocking query pattern: {}", pattern);
Ok(())
}
fn unblock_pattern(&self, pattern: &str) -> Result<(), String> {
self.blocked_patterns.write().retain(|p| p != pattern);
pgrx::log!("Unblocking query pattern: {}", pattern);
Ok(())
}
}
impl Default for QueryCircuitBreaker {
fn default() -> Self {
Self::new()
}
}
impl RemediationStrategy for QueryCircuitBreaker {
fn name(&self) -> &str {
"query_circuit_breaker"
}
fn description(&self) -> &str {
"Block problematic queries causing excessive timeouts"
}
fn handles(&self) -> Vec<ProblemType> {
vec![ProblemType::QueryTimeout, ProblemType::ConnectionExhaustion]
}
fn impact(&self) -> f32 {
0.5 }
fn estimated_duration(&self) -> Duration {
Duration::from_secs(10)
}
fn reversible(&self) -> bool {
true
}
fn execute(&self, context: &StrategyContext) -> RemediationResult {
let start = std::time::Instant::now();
if context.dry_run {
let problematic = self.find_problematic_queries();
return RemediationResult::noop().with_metadata(serde_json::json!({
"dry_run": true,
"would_block": problematic,
}));
}
let problematic = self.find_problematic_queries();
let mut blocked = Vec::new();
for pattern in &problematic {
if self.block_pattern(pattern).is_ok() {
blocked.push(pattern.clone());
}
}
if blocked.is_empty() {
RemediationResult::noop().with_metadata(serde_json::json!({
"message": "No problematic query patterns identified",
}))
} else {
RemediationResult::success(blocked.len(), 0.0)
.with_duration(start.elapsed().as_millis() as u64)
.with_metadata(serde_json::json!({
"blocked_patterns": blocked,
"block_duration_secs": self.block_duration.as_secs(),
}))
.with_rollback(vec![serde_json::json!({
"action": "unblock",
"patterns": blocked,
})])
}
}
fn rollback(
&self,
_context: &StrategyContext,
result: &RemediationResult,
) -> Result<(), String> {
for action in &result.rollback_actions {
if action.get("action") == Some(&serde_json::json!("unblock")) {
if let Some(patterns) = action.get("patterns").and_then(|v| v.as_array()) {
for pattern in patterns {
if let Some(p) = pattern.as_str() {
self.unblock_pattern(p)?;
}
}
}
}
}
Ok(())
}
}
pub struct IntegrityRecovery {
max_edges: usize,
verify_after: bool,
}
impl IntegrityRecovery {
pub fn new() -> Self {
Self {
max_edges: 1000,
verify_after: true,
}
}
pub fn with_settings(max_edges: usize, verify_after: bool) -> Self {
Self {
max_edges,
verify_after,
}
}
fn get_witness_edges(&self) -> Vec<(i64, i64)> {
vec![]
}
fn repair_edge(&self, from: i64, to: i64) -> Result<(), String> {
pgrx::log!("Repairing edge {} -> {}", from, to);
Ok(())
}
fn verify_integrity(&self) -> Result<f32, String> {
Ok(1.0)
}
}
impl Default for IntegrityRecovery {
fn default() -> Self {
Self::new()
}
}
impl RemediationStrategy for IntegrityRecovery {
fn name(&self) -> &str {
"integrity_recovery"
}
fn description(&self) -> &str {
"Repair contracted graph when integrity violations are detected"
}
fn handles(&self) -> Vec<ProblemType> {
vec![
ProblemType::IntegrityViolation,
ProblemType::IndexDegradation,
]
}
fn impact(&self) -> f32 {
0.4 }
fn estimated_duration(&self) -> Duration {
Duration::from_secs(60)
}
fn reversible(&self) -> bool {
false }
fn execute(&self, context: &StrategyContext) -> RemediationResult {
let start = std::time::Instant::now();
if context.dry_run {
let witness_edges = self.get_witness_edges();
return RemediationResult::noop().with_metadata(serde_json::json!({
"dry_run": true,
"witness_edges_found": witness_edges.len(),
}));
}
let witness_edges = self.get_witness_edges();
let mut repaired = 0;
let mut errors = Vec::new();
for (from, to) in witness_edges.iter().take(self.max_edges) {
if context.is_timed_out() {
break;
}
match self.repair_edge(*from, *to) {
Ok(()) => repaired += 1,
Err(e) => errors.push(e),
}
}
let improvement = if self.verify_after && repaired > 0 {
match self.verify_integrity() {
Ok(new_lambda) => ((new_lambda - context.initial_lambda) / context.initial_lambda
* 100.0)
.max(0.0),
Err(_) => 0.0,
}
} else {
0.0
};
let duration_ms = start.elapsed().as_millis() as u64;
if repaired == 0 && !errors.is_empty() {
RemediationResult::failure(&errors.join("; ")).with_duration(duration_ms)
} else if repaired > 0 {
RemediationResult::success(repaired, improvement)
.with_duration(duration_ms)
.with_metadata(serde_json::json!({
"edges_repaired": repaired,
"new_lambda": context.initial_lambda + (improvement / 100.0),
}))
} else {
RemediationResult::noop().with_metadata(serde_json::json!({
"message": "No witness edges to repair",
}))
}
}
fn rollback(
&self,
_context: &StrategyContext,
_result: &RemediationResult,
) -> Result<(), String> {
Err("Integrity recovery cannot be rolled back".to_string())
}
}
pub struct StrategyRegistry {
strategies: Vec<Arc<dyn RemediationStrategy>>,
weights: RwLock<HashMap<String, f32>>,
}
impl StrategyRegistry {
pub fn new() -> Self {
Self {
strategies: Vec::new(),
weights: RwLock::new(HashMap::new()),
}
}
pub fn new_with_defaults() -> Self {
let mut registry = Self::new();
registry.register(Arc::new(ReindexPartition::new()));
registry.register(Arc::new(PromoteReplica::new()));
registry.register(Arc::new(TierEviction::new()));
registry.register(Arc::new(QueryCircuitBreaker::new()));
registry.register(Arc::new(IntegrityRecovery::new()));
registry
}
pub fn register(&mut self, strategy: Arc<dyn RemediationStrategy>) {
let name = strategy.name().to_string();
self.strategies.push(strategy);
self.weights.write().insert(name, 1.0);
}
pub fn all_strategies(&self) -> &[Arc<dyn RemediationStrategy>] {
&self.strategies
}
pub fn get_by_name(&self, name: &str) -> Option<Arc<dyn RemediationStrategy>> {
self.strategies.iter().find(|s| s.name() == name).cloned()
}
pub fn select(
&self,
problem: &Problem,
max_impact: f32,
) -> Option<Arc<dyn RemediationStrategy>> {
let weights = self.weights.read();
self.strategies
.iter()
.filter(|s| s.handles().contains(&problem.problem_type))
.filter(|s| s.impact() <= max_impact)
.max_by(|a, b| {
let weight_a = weights.get(a.name()).unwrap_or(&1.0);
let weight_b = weights.get(b.name()).unwrap_or(&1.0);
weight_a.partial_cmp(weight_b).unwrap()
})
.cloned()
}
pub fn update_weight(&self, strategy_name: &str, success: bool, improvement: f32) {
let mut weights = self.weights.write();
let current = *weights.get(strategy_name).unwrap_or(&1.0);
let adjustment = if success {
0.1 + (improvement / 100.0).min(0.2)
} else {
-0.1
};
let new_weight = (current + adjustment).max(0.1).min(2.0);
weights.insert(strategy_name.to_string(), new_weight);
}
pub fn get_weight(&self, strategy_name: &str) -> f32 {
*self.weights.read().get(strategy_name).unwrap_or(&1.0)
}
pub fn get_all_weights(&self) -> HashMap<String, f32> {
self.weights.read().clone()
}
}
impl Default for StrategyRegistry {
fn default() -> Self {
Self::new_with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remediation_result_success() {
let result = RemediationResult::success(5, 15.0);
assert!(result.is_success());
assert_eq!(result.actions_taken, 5);
assert_eq!(result.improvement_pct, 15.0);
}
#[test]
fn test_remediation_result_failure() {
let result = RemediationResult::failure("test error");
assert!(!result.is_success());
assert_eq!(result.error_message, Some("test error".to_string()));
}
#[test]
fn test_strategy_registry_defaults() {
let registry = StrategyRegistry::new_with_defaults();
assert_eq!(registry.all_strategies().len(), 5);
}
#[test]
fn test_strategy_selection() {
let registry = StrategyRegistry::new_with_defaults();
let problem = Problem::new(ProblemType::IndexDegradation, Severity::Medium);
let strategy = registry.select(&problem, 1.0);
assert!(strategy.is_some());
assert!(strategy
.unwrap()
.handles()
.contains(&ProblemType::IndexDegradation));
}
#[test]
fn test_strategy_selection_with_impact_filter() {
let registry = StrategyRegistry::new_with_defaults();
let problem = Problem::new(ProblemType::ReplicaLag, Severity::High);
let strategy = registry.select(&problem, 0.5);
}
#[test]
fn test_weight_updates() {
let registry = StrategyRegistry::new_with_defaults();
assert_eq!(registry.get_weight("reindex_partition"), 1.0);
registry.update_weight("reindex_partition", true, 20.0);
assert!(registry.get_weight("reindex_partition") > 1.0);
registry.update_weight("reindex_partition", false, 0.0);
let weight = registry.get_weight("reindex_partition");
assert!(weight < 1.2); }
#[test]
fn test_reindex_partition_handles() {
let strategy = ReindexPartition::new();
assert!(strategy.handles().contains(&ProblemType::IndexDegradation));
assert!(!strategy.handles().contains(&ProblemType::ReplicaLag));
}
#[test]
fn test_promote_replica_handles() {
let strategy = PromoteReplica::new();
assert!(strategy.handles().contains(&ProblemType::ReplicaLag));
assert!(strategy
.handles()
.contains(&ProblemType::IntegrityViolation));
}
#[test]
fn test_tier_eviction_handles() {
let strategy = TierEviction::new();
assert!(strategy.handles().contains(&ProblemType::StorageExhaustion));
assert!(strategy.handles().contains(&ProblemType::MemoryPressure));
}
#[test]
fn test_circuit_breaker_handles() {
let strategy = QueryCircuitBreaker::new();
assert!(strategy.handles().contains(&ProblemType::QueryTimeout));
assert!(strategy
.handles()
.contains(&ProblemType::ConnectionExhaustion));
}
#[test]
fn test_integrity_recovery_handles() {
let strategy = IntegrityRecovery::new();
assert!(strategy
.handles()
.contains(&ProblemType::IntegrityViolation));
assert!(strategy.handles().contains(&ProblemType::IndexDegradation));
}
#[test]
fn test_dry_run() {
let strategy = ReindexPartition::new();
let mut context = StrategyContext::new(Problem::new(
ProblemType::IndexDegradation,
Severity::Medium,
));
context.dry_run = true;
let result = strategy.execute(&context);
assert_eq!(result.outcome, RemediationOutcome::NoOp);
assert!(result.metadata.get("dry_run") == Some(&serde_json::json!(true)));
}
}