use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant, SystemTime};
use thiserror::Error;
use tracing::{debug, error, warn};
use uuid::Uuid;
use crate::message::{HandshakeType, Message, MessageType, ProtocolVersion};
#[derive(Debug, Error)]
pub enum StateError {
#[error("Invalid state transition from {from:?} to {to:?}")]
InvalidTransition {
from: ProtocolState,
to: ProtocolState,
},
#[error("State synchronization failed: {reason}")]
SyncFailed { reason: String },
#[error("Invalid state data: {reason}")]
InvalidData { reason: String },
#[error("State operation timed out after {timeout:?}")]
Timeout { timeout: Duration },
#[error("Session not found: {session_id}")]
SessionNotFound { session_id: Uuid },
#[error("Protocol version mismatch: expected {expected:?}, got {actual:?}")]
VersionMismatch {
expected: ProtocolVersion,
actual: ProtocolVersion,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ProtocolState {
#[default]
Initial,
Handshake(HandshakeState),
Active(ActiveState),
Synchronizing(SyncState),
Error(ErrorState),
Shutdown,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum HandshakeState {
Waiting,
InProgress,
Processing,
Completed,
Failed,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActiveState {
Normal,
HighLoad,
Degraded,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SyncState {
Requesting,
Receiving,
Applying,
Verifying,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ErrorState {
NetworkError,
ConsensusError,
CryptoError,
ResourceError,
InternalError,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub id: Uuid,
pub peer_id: Vec<u8>,
pub protocol_version: ProtocolVersion,
pub state: ProtocolState,
pub started_at: SystemTime,
pub last_activity: SystemTime,
pub capabilities: Vec<String>,
pub metrics: SessionMetrics,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionMetrics {
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub avg_response_time: Duration,
pub error_count: u64,
}
#[derive(Debug)]
pub struct ProtocolStateMachine {
current_state: ProtocolState,
previous_state: Option<ProtocolState>,
state_history: Vec<StateTransition>,
sessions: HashMap<Uuid, SessionInfo>,
started_at: Instant,
protocol_version: ProtocolVersion,
config: StateMachineConfig,
}
#[derive(Debug, Clone)]
pub struct StateTransition {
pub timestamp: Instant,
pub from: ProtocolState,
pub to: ProtocolState,
pub reason: String,
pub duration: Duration,
}
#[derive(Debug, Clone)]
pub struct StateMachineConfig {
pub max_sessions: usize,
pub session_timeout: Duration,
pub handshake_timeout: Duration,
pub sync_timeout: Duration,
pub max_history_size: usize,
}
impl Default for StateMachineConfig {
fn default() -> Self {
Self {
max_sessions: 1000,
session_timeout: Duration::from_secs(300), handshake_timeout: Duration::from_secs(30), sync_timeout: Duration::from_secs(60), max_history_size: 1000,
}
}
}
impl ProtocolStateMachine {
pub fn new(protocol_version: ProtocolVersion) -> Self {
Self {
current_state: ProtocolState::Initial,
previous_state: None,
state_history: Vec::new(),
sessions: HashMap::new(),
started_at: Instant::now(),
protocol_version,
config: StateMachineConfig::default(),
}
}
pub fn with_config(protocol_version: ProtocolVersion, config: StateMachineConfig) -> Self {
Self {
current_state: ProtocolState::Initial,
previous_state: None,
state_history: Vec::new(),
sessions: HashMap::new(),
started_at: Instant::now(),
protocol_version,
config,
}
}
pub fn current_state(&self) -> &ProtocolState {
&self.current_state
}
pub fn active_sessions(&self) -> usize {
self.sessions.len()
}
pub fn protocol_version(&self) -> &ProtocolVersion {
&self.protocol_version
}
pub fn uptime(&self) -> Duration {
self.started_at.elapsed()
}
pub fn transition_to(
&mut self,
new_state: ProtocolState,
reason: String,
) -> Result<(), StateError> {
if !self.is_valid_transition(&self.current_state, &new_state) {
return Err(StateError::InvalidTransition {
from: self.current_state.clone(),
to: new_state,
});
}
let now = Instant::now();
let duration = if let Some(last_transition) = self.state_history.last() {
now.duration_since(last_transition.timestamp)
} else {
now.duration_since(self.started_at)
};
let transition = StateTransition {
timestamp: now,
from: self.current_state.clone(),
to: new_state.clone(),
reason: reason.clone(),
duration,
};
debug!(
"State transition: {:?} -> {:?} ({})",
self.current_state, new_state, reason
);
self.previous_state = Some(self.current_state.clone());
self.current_state = new_state;
self.state_history.push(transition);
if self.state_history.len() > self.config.max_history_size {
self.state_history.remove(0);
}
self.on_state_entry(&reason)?;
Ok(())
}
fn is_valid_transition(&self, from: &ProtocolState, to: &ProtocolState) -> bool {
use ActiveState::*;
use ErrorState::*;
use HandshakeState::*;
use ProtocolState::*;
use SyncState::*;
match (from, to) {
(Initial, Handshake(Waiting)) => true,
(Initial, Error(_)) => true,
(Initial, Shutdown) => true,
(Handshake(Waiting), Handshake(InProgress)) => true,
(Handshake(InProgress), Handshake(Processing)) => true,
(Handshake(InProgress), Handshake(Failed)) => true,
(Handshake(Processing), Handshake(Completed)) => true,
(Handshake(Processing), Handshake(Failed)) => true,
(Handshake(Completed), Active(Normal)) => true,
(Handshake(Failed), Error(NetworkError)) => true,
(Handshake(_), Shutdown) => true,
(Active(Normal), Active(HighLoad)) => true,
(Active(Normal), Active(Degraded)) => true,
(Active(Normal), Synchronizing(Requesting)) => true,
(Active(HighLoad), Active(Normal)) => true,
(Active(HighLoad), Active(Degraded)) => true,
(Active(Degraded), Active(Normal)) => true,
(Active(Degraded), Synchronizing(Requesting)) => true,
(Active(_), Error(_)) => true,
(Active(_), Shutdown) => true,
(Synchronizing(Requesting), Synchronizing(Receiving)) => true,
(Synchronizing(Requesting), Error(NetworkError)) => true,
(Synchronizing(Receiving), Synchronizing(Applying)) => true,
(Synchronizing(Receiving), Error(NetworkError)) => true,
(Synchronizing(Applying), Synchronizing(Verifying)) => true,
(Synchronizing(Applying), Error(InternalError)) => true,
(Synchronizing(Verifying), Active(Normal)) => true,
(Synchronizing(Verifying), Error(InternalError)) => true,
(Synchronizing(_), Shutdown) => true,
(Error(_), Initial) => true, (Error(_), Shutdown) => true,
(Shutdown, _) => false,
(a, b) if a == b => true,
_ => false,
}
}
fn on_state_entry(&mut self, reason: &str) -> Result<(), StateError> {
let current_state = self.current_state.clone();
match ¤t_state {
ProtocolState::Initial => {
debug!("Entered Initial state: {}", reason);
self.sessions.clear();
}
ProtocolState::Handshake(handshake_state) => {
debug!("Entered Handshake state {:?}: {}", handshake_state, reason);
match handshake_state {
HandshakeState::InProgress => {
}
HandshakeState::Failed => {
warn!("Handshake failed: {}", reason);
self.cleanup_failed_sessions();
}
_ => {}
}
}
ProtocolState::Active(active_state) => {
debug!("Entered Active state {:?}: {}", active_state, reason);
match active_state {
ActiveState::HighLoad => {
warn!("Entering high load state, implementing load shedding");
}
ActiveState::Degraded => {
warn!("Entering degraded state: {}", reason);
}
_ => {}
}
}
ProtocolState::Synchronizing(sync_state) => {
debug!("Entered Synchronizing state {:?}: {}", sync_state, reason);
}
ProtocolState::Error(error_state) => {
error!("Entered Error state {:?}: {}", error_state, reason);
self.handle_error_state(error_state, reason)?;
}
ProtocolState::Shutdown => {
debug!("Entered Shutdown state: {}", reason);
self.begin_shutdown();
}
}
Ok(())
}
fn handle_error_state(
&mut self,
error_state: &ErrorState,
reason: &str,
) -> Result<(), StateError> {
match error_state {
ErrorState::NetworkError => {
self.cleanup_failed_sessions();
}
ErrorState::ConsensusError => {
}
ErrorState::CryptoError => {
error!("Critical cryptographic error: {}", reason);
}
ErrorState::ResourceError => {
self.cleanup_resources();
}
ErrorState::InternalError => {
error!("Internal protocol error: {}", reason);
}
}
Ok(())
}
fn begin_shutdown(&mut self) {
debug!("Beginning graceful shutdown");
for (session_id, session) in &mut self.sessions {
debug!("Closing session: {}", session_id);
session.state = ProtocolState::Shutdown;
}
}
fn cleanup_failed_sessions(&mut self) {
let failed_sessions: Vec<Uuid> = self
.sessions
.iter()
.filter(|(_, session)| {
matches!(
session.state,
ProtocolState::Error(_) | ProtocolState::Handshake(HandshakeState::Failed)
)
})
.map(|(id, _)| *id)
.collect();
for session_id in failed_sessions {
debug!("Cleaning up failed session: {}", session_id);
self.sessions.remove(&session_id);
}
}
fn cleanup_resources(&mut self) {
debug!("Cleaning up resources");
if self.state_history.len() > self.config.max_history_size / 2 {
let keep_from = self.state_history.len() - self.config.max_history_size / 2;
self.state_history.drain(0..keep_from);
}
self.cleanup_timed_out_sessions();
}
fn cleanup_timed_out_sessions(&mut self) {
let now = SystemTime::now();
let timeout = self.config.session_timeout;
let timed_out_sessions: Vec<Uuid> = self
.sessions
.iter()
.filter(|(_, session)| {
now.duration_since(session.last_activity)
.unwrap_or(Duration::ZERO)
> timeout
})
.map(|(id, _)| *id)
.collect();
for session_id in timed_out_sessions {
debug!("Removing timed out session: {}", session_id);
self.sessions.remove(&session_id);
}
}
pub fn create_session(
&mut self,
peer_id: Vec<u8>,
protocol_version: ProtocolVersion,
capabilities: Vec<String>,
) -> Result<Uuid, StateError> {
if self.sessions.len() >= self.config.max_sessions {
return Err(StateError::InvalidData {
reason: "Maximum number of sessions reached".to_string(),
});
}
if !self.protocol_version.is_compatible(&protocol_version) {
return Err(StateError::VersionMismatch {
expected: self.protocol_version.clone(),
actual: protocol_version,
});
}
let session_id = Uuid::new_v4();
let now = SystemTime::now();
let session = SessionInfo {
id: session_id,
peer_id,
protocol_version,
state: ProtocolState::Handshake(HandshakeState::Waiting),
started_at: now,
last_activity: now,
capabilities,
metrics: SessionMetrics::default(),
};
self.sessions.insert(session_id, session);
debug!("Created new session: {}", session_id);
Ok(session_id)
}
pub fn update_session_state(
&mut self,
session_id: Uuid,
new_state: ProtocolState,
) -> Result<(), StateError> {
let current_session_state = self
.sessions
.get(&session_id)
.ok_or(StateError::SessionNotFound { session_id })?
.state
.clone();
if !self.is_valid_transition(¤t_session_state, &new_state) {
return Err(StateError::InvalidTransition {
from: current_session_state,
to: new_state,
});
}
let session = self
.sessions
.get_mut(&session_id)
.ok_or(StateError::SessionNotFound { session_id })?;
session.state = new_state;
session.last_activity = SystemTime::now();
Ok(())
}
pub fn get_session(&self, session_id: &Uuid) -> Option<&SessionInfo> {
self.sessions.get(session_id)
}
pub fn remove_session(&mut self, session_id: &Uuid) -> Option<SessionInfo> {
self.sessions.remove(session_id)
}
pub fn process_message(
&mut self,
message: &Message,
session_id: Option<Uuid>,
) -> Result<(), StateError> {
if let Some(session_id) = session_id {
if let Some(session) = self.sessions.get_mut(&session_id) {
session.last_activity = SystemTime::now();
session.metrics.messages_received += 1;
session.metrics.bytes_received += message.payload.len() as u64;
}
}
match &message.msg_type {
MessageType::Handshake(handshake_type) => {
self.process_handshake_message(handshake_type, message, session_id)?;
}
MessageType::Control(_) => {
if !matches!(self.current_state, ProtocolState::Shutdown) {
debug!(
"Processing control message in state {:?}",
self.current_state
);
}
}
_ => {
match &self.current_state {
ProtocolState::Active(_) => {
debug!("Processing message in active state");
}
_ => {
warn!(
"Received message in non-active state: {:?}",
self.current_state
);
}
}
}
}
Ok(())
}
fn process_handshake_message(
&mut self,
handshake_type: &HandshakeType,
_message: &Message,
session_id: Option<Uuid>,
) -> Result<(), StateError> {
match handshake_type {
HandshakeType::Init => {
if matches!(
self.current_state,
ProtocolState::Initial | ProtocolState::Handshake(_)
) {
self.transition_to(
ProtocolState::Handshake(HandshakeState::InProgress),
"Received handshake init".to_string(),
)?;
}
}
HandshakeType::Response => {
if matches!(
self.current_state,
ProtocolState::Handshake(HandshakeState::InProgress)
) {
self.transition_to(
ProtocolState::Handshake(HandshakeState::Processing),
"Received handshake response".to_string(),
)?;
}
}
HandshakeType::Complete => {
if matches!(
self.current_state,
ProtocolState::Handshake(HandshakeState::Processing)
) {
self.transition_to(
ProtocolState::Handshake(HandshakeState::Completed),
"Handshake completed".to_string(),
)?;
self.transition_to(
ProtocolState::Active(ActiveState::Normal),
"Handshake successful, entering active state".to_string(),
)?;
}
}
HandshakeType::VersionNegotiation => {
debug!("Processing version negotiation");
}
}
if let Some(session_id) = session_id {
if let Some(session) = self.sessions.get_mut(&session_id) {
session.state = self.current_state.clone();
}
}
Ok(())
}
pub fn get_state_history(&self) -> &[StateTransition] {
&self.state_history
}
pub fn get_sessions(&self) -> &HashMap<Uuid, SessionInfo> {
&self.sessions
}
pub fn is_healthy(&self) -> bool {
!matches!(
self.current_state,
ProtocolState::Error(_) | ProtocolState::Shutdown
)
}
pub fn get_metrics(&self) -> StateMachineMetrics {
let mut total_messages_sent = 0;
let mut total_messages_received = 0;
let mut total_bytes_sent = 0;
let mut total_bytes_received = 0;
let mut total_errors = 0;
for session in self.sessions.values() {
total_messages_sent += session.metrics.messages_sent;
total_messages_received += session.metrics.messages_received;
total_bytes_sent += session.metrics.bytes_sent;
total_bytes_received += session.metrics.bytes_received;
total_errors += session.metrics.error_count;
}
StateMachineMetrics {
current_state: self.current_state.clone(),
uptime: self.uptime(),
active_sessions: self.sessions.len(),
total_state_transitions: self.state_history.len(),
total_messages_sent,
total_messages_received,
total_bytes_sent,
total_bytes_received,
total_errors,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StateMachineMetrics {
pub current_state: ProtocolState,
pub uptime: Duration,
pub active_sessions: usize,
pub total_state_transitions: usize,
pub total_messages_sent: u64,
pub total_messages_received: u64,
pub total_bytes_sent: u64,
pub total_bytes_received: u64,
pub total_errors: u64,
}
pub trait StateManager {
fn init() -> Result<ProtocolStateMachine, StateError>;
fn transition(&mut self, new_state: ProtocolState) -> Result<(), StateError>;
fn get_state(&self) -> &ProtocolState;
fn validate_transition(&self, new_state: &ProtocolState) -> bool;
}
impl StateManager for ProtocolStateMachine {
fn init() -> Result<ProtocolStateMachine, StateError> {
Ok(ProtocolStateMachine::new(ProtocolVersion::CURRENT))
}
fn transition(&mut self, new_state: ProtocolState) -> Result<(), StateError> {
self.transition_to(new_state, "Manual transition".to_string())
}
fn get_state(&self) -> &ProtocolState {
&self.current_state
}
fn validate_transition(&self, new_state: &ProtocolState) -> bool {
self.is_valid_transition(&self.current_state, new_state)
}
}