use crate::agent::AgentId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum FailoverError {
#[error("No backup available")]
NoBackupAvailable,
#[error("Failover already in progress")]
FailoverInProgress,
#[error("Agent not found: {0}")]
AgentNotFound(String),
#[error("Invalid state transition: {from:?} -> {to:?}")]
InvalidStateTransition {
from: FailoverState,
to: FailoverState,
},
#[error("Failover timeout")]
Timeout,
}
pub type FailoverResult<T> = Result<T, FailoverError>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum FailoverState {
Active,
Degraded,
Failing,
FailedOver,
Recovering,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FailoverStrategy {
Immediate,
ConsecutiveFailures { count: usize },
FailureRate { count: usize, window: Duration },
Threshold { threshold: f64 },
}
#[derive(Debug, Clone)]
pub struct FailoverPolicy {
pub strategy: FailoverStrategy,
pub max_failover_time: Duration,
pub min_failover_interval: Duration,
pub auto_recover: bool,
}
impl Default for FailoverPolicy {
fn default() -> Self {
Self {
strategy: FailoverStrategy::ConsecutiveFailures { count: 3 },
max_failover_time: Duration::from_secs(30),
min_failover_interval: Duration::from_secs(60),
auto_recover: true,
}
}
}
#[derive(Debug, Clone)]
pub struct FailoverConfig {
pub policy: FailoverPolicy,
pub max_backups: usize,
pub health_check_interval: Duration,
}
impl Default for FailoverConfig {
fn default() -> Self {
Self {
policy: FailoverPolicy::default(),
max_backups: 3,
health_check_interval: Duration::from_secs(5),
}
}
}
#[derive(Debug, Clone)]
pub struct FailoverEvent {
pub agent_id: AgentId,
pub backup_id: Option<AgentId>,
pub timestamp: Instant,
pub reason: String,
pub success: bool,
}
#[derive(Debug, Clone)]
pub struct FailoverDecision {
pub should_failover: bool,
pub backup_id: Option<AgentId>,
pub reason: String,
pub confidence: f64,
}
#[derive(Debug, Clone)]
struct AgentFailoverStatus {
state: FailoverState,
failure_count: usize,
last_failure: Option<Instant>,
last_failover: Option<Instant>,
backup_agents: Vec<AgentId>,
}
pub struct FailoverCoordinator {
config: FailoverConfig,
agents: Arc<RwLock<HashMap<AgentId, AgentFailoverStatus>>>,
events: Arc<RwLock<Vec<FailoverEvent>>>,
}
impl FailoverCoordinator {
pub fn new(config: FailoverConfig) -> Self {
Self {
config,
agents: Arc::new(RwLock::new(HashMap::new())),
events: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn register_agent(&self, agent_id: AgentId, backups: Vec<AgentId>) -> FailoverResult<()> {
let mut agents = self
.agents
.write()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
agents.insert(
agent_id,
AgentFailoverStatus {
state: FailoverState::Active,
failure_count: 0,
last_failure: None,
last_failover: None,
backup_agents: backups,
},
);
Ok(())
}
pub fn unregister_agent(&self, agent_id: &AgentId) -> FailoverResult<()> {
let mut agents = self
.agents
.write()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
agents.remove(agent_id);
Ok(())
}
pub fn report_failure(
&self,
agent_id: &AgentId,
reason: &str,
) -> FailoverResult<FailoverDecision> {
let mut agents = self
.agents
.write()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
let status = agents
.get_mut(agent_id)
.ok_or_else(|| FailoverError::AgentNotFound(agent_id.to_string()))?;
status.failure_count += 1;
status.last_failure = Some(Instant::now());
let decision = self.evaluate_failover(agent_id, status, reason)?;
if decision.should_failover {
status.state = FailoverState::Failing;
}
Ok(decision)
}
fn evaluate_failover(
&self,
_agent_id: &AgentId,
status: &AgentFailoverStatus,
reason: &str,
) -> FailoverResult<FailoverDecision> {
if status.state == FailoverState::Failing {
return Err(FailoverError::FailoverInProgress);
}
if let Some(last_failover) = status.last_failover {
if last_failover.elapsed() < self.config.policy.min_failover_interval {
return Ok(FailoverDecision {
should_failover: false,
backup_id: None,
reason: "Too soon since last failover".to_string(),
confidence: 0.0,
});
}
}
let should_failover = match &self.config.policy.strategy {
FailoverStrategy::Immediate => true,
FailoverStrategy::ConsecutiveFailures { count } => status.failure_count >= *count,
FailoverStrategy::FailureRate { count, window } => {
if let Some(last_failure) = status.last_failure {
last_failure.elapsed() <= *window && status.failure_count >= *count
} else {
false
}
}
FailoverStrategy::Threshold { threshold } => {
let failure_rate = status.failure_count as f64
/ self.config.health_check_interval.as_secs() as f64;
failure_rate >= *threshold
}
};
let backup_id = if should_failover {
status.backup_agents.first().cloned()
} else {
None
};
Ok(FailoverDecision {
should_failover,
backup_id,
reason: reason.to_string(),
confidence: if should_failover { 0.9 } else { 0.1 },
})
}
pub fn execute_failover(
&self,
agent_id: &AgentId,
backup_id: &AgentId,
) -> FailoverResult<FailoverEvent> {
let mut agents = self
.agents
.write()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
let status = agents
.get_mut(agent_id)
.ok_or_else(|| FailoverError::AgentNotFound(agent_id.to_string()))?;
if status.state != FailoverState::Failing {
return Err(FailoverError::InvalidStateTransition {
from: status.state.clone(),
to: FailoverState::FailedOver,
});
}
status.state = FailoverState::FailedOver;
status.last_failover = Some(Instant::now());
status.failure_count = 0;
let event = FailoverEvent {
agent_id: *agent_id,
backup_id: Some(*backup_id),
timestamp: Instant::now(),
reason: "Automatic failover".to_string(),
success: true,
};
self.events
.write()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?
.push(event.clone());
Ok(event)
}
pub fn report_recovery(&self, agent_id: &AgentId) -> FailoverResult<()> {
let mut agents = self
.agents
.write()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
let status = agents
.get_mut(agent_id)
.ok_or_else(|| FailoverError::AgentNotFound(agent_id.to_string()))?;
status.state = FailoverState::Active;
status.failure_count = 0;
status.last_failure = None;
Ok(())
}
pub fn get_state(&self, agent_id: &AgentId) -> FailoverResult<FailoverState> {
let agents = self
.agents
.read()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
Ok(agents
.get(agent_id)
.ok_or_else(|| FailoverError::AgentNotFound(agent_id.to_string()))?
.state
.clone())
}
pub fn get_events(&self) -> FailoverResult<Vec<FailoverEvent>> {
Ok(self
.events
.read()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?
.clone())
}
pub fn get_failure_count(&self, agent_id: &AgentId) -> FailoverResult<usize> {
let agents = self
.agents
.read()
.map_err(|_| FailoverError::AgentNotFound("Failed to acquire lock".to_string()))?;
Ok(agents
.get(agent_id)
.ok_or_else(|| FailoverError::AgentNotFound(agent_id.to_string()))?
.failure_count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::Agent;
#[test]
fn test_failover_coordinator_creation() {
let config = FailoverConfig::default();
let _coordinator = FailoverCoordinator::new(config);
}
#[test]
fn test_register_agent() {
let config = FailoverConfig::default();
let coordinator = FailoverCoordinator::new(config);
let agent = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
let backup = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
coordinator
.register_agent(agent.id(), vec![backup.id()])
.expect("register agent");
assert_eq!(
coordinator.get_state(&agent.id()).expect("get state"),
FailoverState::Active
);
}
#[test]
fn test_immediate_failover() {
let config = FailoverConfig {
policy: FailoverPolicy {
strategy: FailoverStrategy::Immediate,
..Default::default()
},
..Default::default()
};
let coordinator = FailoverCoordinator::new(config);
let agent = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
let backup = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
coordinator
.register_agent(agent.id(), vec![backup.id()])
.expect("register agent");
let decision = coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
assert!(decision.should_failover);
assert_eq!(decision.backup_id, Some(backup.id()));
}
#[test]
fn test_consecutive_failures() {
let config = FailoverConfig {
policy: FailoverPolicy {
strategy: FailoverStrategy::ConsecutiveFailures { count: 3 },
..Default::default()
},
..Default::default()
};
let coordinator = FailoverCoordinator::new(config);
let agent = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
let backup = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
coordinator
.register_agent(agent.id(), vec![backup.id()])
.expect("register agent");
let decision = coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
assert!(!decision.should_failover);
let decision = coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
assert!(!decision.should_failover);
let decision = coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
assert!(decision.should_failover);
}
#[test]
fn test_execute_failover() {
let config = FailoverConfig::default();
let coordinator = FailoverCoordinator::new(config);
let agent = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
let backup = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
coordinator
.register_agent(agent.id(), vec![backup.id()])
.expect("register agent");
coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
let decision = coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
assert!(decision.should_failover);
let event = coordinator
.execute_failover(&agent.id(), &backup.id())
.expect("execute failover");
assert!(event.success);
assert_eq!(event.backup_id, Some(backup.id()));
assert_eq!(
coordinator.get_state(&agent.id()).expect("get state"),
FailoverState::FailedOver
);
}
#[test]
fn test_recovery() {
let config = FailoverConfig::default();
let coordinator = FailoverCoordinator::new(config);
let agent = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
coordinator
.register_agent(agent.id(), vec![])
.expect("register agent");
coordinator
.report_failure(&agent.id(), "test failure")
.expect("report failure");
coordinator
.report_recovery(&agent.id())
.expect("report recovery");
assert_eq!(
coordinator.get_state(&agent.id()).expect("get state"),
FailoverState::Active
);
assert_eq!(
coordinator
.get_failure_count(&agent.id())
.expect("failure count"),
0
);
}
#[test]
fn test_unregister_agent() {
let config = FailoverConfig::default();
let coordinator = FailoverCoordinator::new(config);
let agent = Agent::new(vec![0x00, 0x61, 0x73, 0x6d]);
coordinator
.register_agent(agent.id(), vec![])
.expect("register agent");
coordinator
.unregister_agent(&agent.id())
.expect("unregister agent");
assert!(coordinator.get_state(&agent.id()).is_err());
}
}