use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::raft::OxirsNodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum NodeState {
Initializing,
Joining,
Active,
Degraded,
Suspect,
Failed,
Leaving,
Left,
Maintenance,
}
impl NodeState {
pub fn is_operational(&self) -> bool {
matches!(self, NodeState::Active | NodeState::Degraded)
}
pub fn is_problematic(&self) -> bool {
matches!(self, NodeState::Suspect | NodeState::Failed)
}
pub fn is_transitional(&self) -> bool {
matches!(
self,
NodeState::Initializing | NodeState::Joining | NodeState::Leaving
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateTransition {
pub node_id: OxirsNodeId,
pub from_state: NodeState,
pub to_state: NodeState,
pub reason: String,
pub timestamp: SystemTime,
pub duration: Option<Duration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeStatus {
pub node_id: OxirsNodeId,
pub current_state: NodeState,
pub previous_state: Option<NodeState>,
pub time_in_state: Duration,
pub last_state_change: SystemTime,
pub transition_count: u64,
pub uptime: Duration,
pub start_time: SystemTime,
pub metadata: BTreeMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateMachineConfig {
pub max_history_size: usize,
pub enable_auto_transitions: bool,
pub suspect_timeout_secs: u64,
pub failed_timeout_secs: u64,
pub enable_validation: bool,
}
impl Default for StateMachineConfig {
fn default() -> Self {
Self {
max_history_size: 1000,
enable_auto_transitions: true,
suspect_timeout_secs: 30,
failed_timeout_secs: 60,
enable_validation: true,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StateMachineStats {
pub total_nodes: usize,
pub nodes_by_state: BTreeMap<String, usize>,
pub total_transitions: u64,
pub invalid_transitions: u64,
pub avg_time_in_state: BTreeMap<String, f64>,
}
pub struct NodeStatusTracker {
config: StateMachineConfig,
node_statuses: Arc<RwLock<BTreeMap<OxirsNodeId, NodeStatus>>>,
transition_history: Arc<RwLock<VecDeque<StateTransition>>>,
stats: Arc<RwLock<StateMachineStats>>,
valid_transitions: BTreeMap<NodeState, Vec<NodeState>>,
}
impl NodeStatusTracker {
pub fn new(config: StateMachineConfig) -> Self {
let mut valid_transitions = BTreeMap::new();
valid_transitions.insert(
NodeState::Initializing,
vec![NodeState::Joining, NodeState::Failed],
);
valid_transitions.insert(
NodeState::Joining,
vec![NodeState::Active, NodeState::Failed],
);
valid_transitions.insert(
NodeState::Active,
vec![
NodeState::Degraded,
NodeState::Suspect,
NodeState::Leaving,
NodeState::Maintenance,
],
);
valid_transitions.insert(
NodeState::Degraded,
vec![
NodeState::Active,
NodeState::Suspect,
NodeState::Failed,
NodeState::Leaving,
],
);
valid_transitions.insert(
NodeState::Suspect,
vec![NodeState::Active, NodeState::Degraded, NodeState::Failed],
);
valid_transitions.insert(
NodeState::Failed,
vec![NodeState::Initializing, NodeState::Left],
);
valid_transitions.insert(NodeState::Leaving, vec![NodeState::Left]);
valid_transitions.insert(NodeState::Left, vec![NodeState::Initializing]);
valid_transitions.insert(
NodeState::Maintenance,
vec![NodeState::Active, NodeState::Degraded, NodeState::Failed],
);
Self {
config,
node_statuses: Arc::new(RwLock::new(BTreeMap::new())),
transition_history: Arc::new(RwLock::new(VecDeque::new())),
stats: Arc::new(RwLock::new(StateMachineStats::default())),
valid_transitions,
}
}
pub async fn register_node(&self, node_id: OxirsNodeId) {
let now = SystemTime::now();
let status = NodeStatus {
node_id,
current_state: NodeState::Initializing,
previous_state: None,
time_in_state: Duration::from_secs(0),
last_state_change: now,
transition_count: 0,
uptime: Duration::from_secs(0),
start_time: now,
metadata: BTreeMap::new(),
};
{
let mut statuses = self.node_statuses.write().await;
statuses.insert(node_id, status);
}
info!("Registered node {} in state Initializing", node_id);
self.update_stats().await;
}
pub async fn unregister_node(&self, node_id: &OxirsNodeId) {
{
let mut statuses = self.node_statuses.write().await;
statuses.remove(node_id);
}
info!("Unregistered node {}", node_id);
self.update_stats().await;
}
pub async fn transition_state(
&self,
node_id: OxirsNodeId,
new_state: NodeState,
reason: String,
) -> Result<(), String> {
let old_state = {
let statuses = self.node_statuses.read().await;
statuses
.get(&node_id)
.map(|s| s.current_state)
.ok_or_else(|| format!("Node {} not found", node_id))?
};
if self.config.enable_validation && !self.is_valid_transition(old_state, new_state) {
let mut stats = self.stats.write().await;
stats.invalid_transitions += 1;
return Err(format!(
"Invalid transition from {:?} to {:?}",
old_state, new_state
));
}
let mut statuses = self.node_statuses.write().await;
let status = statuses
.get_mut(&node_id)
.ok_or_else(|| format!("Node {} not found", node_id))?;
let now = SystemTime::now();
let duration = now
.duration_since(status.last_state_change)
.unwrap_or(Duration::from_secs(0));
status.previous_state = Some(old_state);
status.current_state = new_state;
status.time_in_state = Duration::from_secs(0);
status.last_state_change = now;
status.transition_count += 1;
status.uptime = now
.duration_since(status.start_time)
.unwrap_or(Duration::from_secs(0));
let transition = StateTransition {
node_id,
from_state: old_state,
to_state: new_state,
reason: reason.clone(),
timestamp: now,
duration: Some(duration),
};
drop(statuses);
{
let mut history = self.transition_history.write().await;
history.push_back(transition.clone());
if history.len() > self.config.max_history_size {
history.pop_front();
}
}
info!(
"Node {} transitioned from {:?} to {:?}: {}",
node_id, old_state, new_state, reason
);
self.update_stats().await;
Ok(())
}
fn is_valid_transition(&self, from: NodeState, to: NodeState) -> bool {
if from == to {
return true;
}
self.valid_transitions
.get(&from)
.map(|valid| valid.contains(&to))
.unwrap_or(false)
}
pub async fn get_node_status(&self, node_id: &OxirsNodeId) -> Option<NodeStatus> {
let statuses = self.node_statuses.read().await;
statuses.get(node_id).cloned()
}
pub async fn get_all_statuses(&self) -> BTreeMap<OxirsNodeId, NodeStatus> {
self.node_statuses.read().await.clone()
}
pub async fn get_nodes_in_state(&self, state: NodeState) -> Vec<OxirsNodeId> {
let statuses = self.node_statuses.read().await;
statuses
.iter()
.filter(|(_, status)| status.current_state == state)
.map(|(id, _)| *id)
.collect()
}
pub async fn get_transition_history(&self) -> Vec<StateTransition> {
self.transition_history
.read()
.await
.iter()
.cloned()
.collect()
}
pub async fn get_node_transition_history(&self, node_id: &OxirsNodeId) -> Vec<StateTransition> {
let history = self.transition_history.read().await;
history
.iter()
.filter(|t| &t.node_id == node_id)
.cloned()
.collect()
}
pub async fn update_metadata(
&self,
node_id: &OxirsNodeId,
key: String,
value: String,
) -> Result<(), String> {
let mut statuses = self.node_statuses.write().await;
let status = statuses
.get_mut(node_id)
.ok_or_else(|| format!("Node {} not found", node_id))?;
status.metadata.insert(key, value);
Ok(())
}
pub async fn get_stats(&self) -> StateMachineStats {
self.stats.read().await.clone()
}
async fn update_stats(&self) {
let mut stats = StateMachineStats {
total_nodes: 0,
nodes_by_state: BTreeMap::new(),
total_transitions: 0,
invalid_transitions: 0,
avg_time_in_state: BTreeMap::new(),
};
{
let statuses = self.node_statuses.read().await;
stats.total_nodes = statuses.len();
for status in statuses.values() {
let state_name = format!("{:?}", status.current_state);
*stats.nodes_by_state.entry(state_name).or_insert(0) += 1;
stats.total_transitions += status.transition_count;
}
}
{
let history = self.transition_history.read().await;
let mut state_durations: BTreeMap<String, Vec<u64>> = BTreeMap::new();
for transition in history.iter() {
if let Some(duration) = transition.duration {
let state_name = format!("{:?}", transition.from_state);
state_durations
.entry(state_name)
.or_default()
.push(duration.as_secs());
}
}
for (state, durations) in state_durations.iter() {
if !durations.is_empty() {
let avg = durations.iter().sum::<u64>() as f64 / durations.len() as f64;
stats.avg_time_in_state.insert(state.clone(), avg);
}
}
}
{
let old_stats = self.stats.read().await;
stats.invalid_transitions = old_stats.invalid_transitions;
}
*self.stats.write().await = stats;
}
pub async fn perform_auto_checks(&self) {
if !self.config.enable_auto_transitions {
return;
}
let statuses = self.node_statuses.read().await;
let now = SystemTime::now();
let mut transitions = Vec::new();
for (node_id, status) in statuses.iter() {
let time_in_state = now
.duration_since(status.last_state_change)
.unwrap_or(Duration::from_secs(0));
if status.current_state == NodeState::Suspect
&& time_in_state.as_secs() > self.config.suspect_timeout_secs
{
transitions.push((
*node_id,
NodeState::Failed,
"Suspect timeout exceeded".to_string(),
));
}
if status.current_state == NodeState::Failed
&& time_in_state.as_secs() > self.config.failed_timeout_secs
{
transitions.push((
*node_id,
NodeState::Left,
"Failed timeout exceeded".to_string(),
));
}
}
drop(statuses);
for (node_id, new_state, reason) in transitions {
if let Err(e) = self.transition_state(node_id, new_state, reason).await {
warn!("Auto-transition failed for node {}: {}", node_id, e);
}
}
}
pub async fn clear(&self) {
self.node_statuses.write().await.clear();
self.transition_history.write().await.clear();
*self.stats.write().await = StateMachineStats::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_node_status_tracker_creation() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
let stats = tracker.get_stats().await;
assert_eq!(stats.total_nodes, 0);
}
#[tokio::test]
async fn test_register_node() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
let status = tracker.get_node_status(&1).await;
assert!(status.is_some());
let status = status.unwrap();
assert_eq!(status.current_state, NodeState::Initializing);
assert_eq!(status.transition_count, 0);
}
#[tokio::test]
async fn test_valid_transition() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
let result = tracker
.transition_state(1, NodeState::Joining, "Starting join process".to_string())
.await;
assert!(result.is_ok());
let status = tracker.get_node_status(&1).await.unwrap();
assert_eq!(status.current_state, NodeState::Joining);
assert_eq!(status.previous_state, Some(NodeState::Initializing));
assert_eq!(status.transition_count, 1);
}
#[tokio::test]
async fn test_invalid_transition() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
let result = tracker
.transition_state(1, NodeState::Active, "Invalid transition".to_string())
.await;
assert!(result.is_err());
let stats = tracker.get_stats().await;
assert_eq!(stats.invalid_transitions, 1);
}
#[tokio::test]
async fn test_state_machine_flow() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
tracker
.transition_state(1, NodeState::Joining, "Joining cluster".to_string())
.await
.unwrap();
tracker
.transition_state(1, NodeState::Active, "Joined successfully".to_string())
.await
.unwrap();
tracker
.transition_state(1, NodeState::Degraded, "Performance degraded".to_string())
.await
.unwrap();
let status = tracker.get_node_status(&1).await.unwrap();
assert_eq!(status.current_state, NodeState::Degraded);
assert_eq!(status.transition_count, 3);
}
#[tokio::test]
async fn test_get_nodes_in_state() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
tracker.register_node(2).await;
tracker.register_node(3).await;
tracker
.transition_state(1, NodeState::Joining, "test".to_string())
.await
.unwrap();
tracker
.transition_state(2, NodeState::Joining, "test".to_string())
.await
.unwrap();
let joining_nodes = tracker.get_nodes_in_state(NodeState::Joining).await;
assert_eq!(joining_nodes.len(), 2);
assert!(joining_nodes.contains(&1));
assert!(joining_nodes.contains(&2));
}
#[tokio::test]
async fn test_transition_history() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
tracker
.transition_state(1, NodeState::Joining, "test1".to_string())
.await
.unwrap();
tracker
.transition_state(1, NodeState::Active, "test2".to_string())
.await
.unwrap();
let history = tracker.get_node_transition_history(&1).await;
assert_eq!(history.len(), 2);
assert_eq!(history[0].from_state, NodeState::Initializing);
assert_eq!(history[0].to_state, NodeState::Joining);
assert_eq!(history[1].from_state, NodeState::Joining);
assert_eq!(history[1].to_state, NodeState::Active);
}
#[tokio::test]
async fn test_metadata() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
tracker
.update_metadata(&1, "version".to_string(), "1.0.0".to_string())
.await
.unwrap();
tracker
.update_metadata(&1, "region".to_string(), "us-west".to_string())
.await
.unwrap();
let status = tracker.get_node_status(&1).await.unwrap();
assert_eq!(status.metadata.get("version"), Some(&"1.0.0".to_string()));
assert_eq!(status.metadata.get("region"), Some(&"us-west".to_string()));
}
#[tokio::test]
async fn test_stats() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
tracker.register_node(2).await;
tracker.register_node(3).await;
tracker
.transition_state(1, NodeState::Joining, "test".to_string())
.await
.unwrap();
tracker
.transition_state(2, NodeState::Joining, "test".to_string())
.await
.unwrap();
tracker
.transition_state(1, NodeState::Active, "test".to_string())
.await
.unwrap();
let stats = tracker.get_stats().await;
assert_eq!(stats.total_nodes, 3);
assert_eq!(stats.total_transitions, 3);
}
#[tokio::test]
async fn test_unregister_node() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
assert!(tracker.get_node_status(&1).await.is_some());
tracker.unregister_node(&1).await;
assert!(tracker.get_node_status(&1).await.is_none());
}
#[tokio::test]
async fn test_node_state_helpers() {
assert!(NodeState::Active.is_operational());
assert!(NodeState::Degraded.is_operational());
assert!(!NodeState::Failed.is_operational());
assert!(NodeState::Suspect.is_problematic());
assert!(NodeState::Failed.is_problematic());
assert!(!NodeState::Active.is_problematic());
assert!(NodeState::Initializing.is_transitional());
assert!(NodeState::Joining.is_transitional());
assert!(!NodeState::Active.is_transitional());
}
#[tokio::test]
async fn test_clear() {
let config = StateMachineConfig::default();
let tracker = NodeStatusTracker::new(config);
tracker.register_node(1).await;
tracker.register_node(2).await;
tracker.clear().await;
let stats = tracker.get_stats().await;
assert_eq!(stats.total_nodes, 0);
}
}