use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::error::{ConnectError, DisconnectReason};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionStatus {
Disconnected,
Connecting,
Connected,
Reconnecting,
ShuttingDown,
}
impl std::fmt::Display for ConnectionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "disconnected"),
Self::Connecting => write!(f, "connecting"),
Self::Connected => write!(f, "connected"),
Self::Reconnecting => write!(f, "reconnecting"),
Self::ShuttingDown => write!(f, "shutting_down"),
}
}
}
pub struct ConnectionState {
id: AtomicU64,
id_counter: AtomicU64,
status: AtomicU8,
reconnect_count: AtomicU64,
error_count: AtomicU64,
shutdown_requested: AtomicBool,
connected_at: RwLock<Option<Instant>>,
last_activity: RwLock<Option<Instant>>,
last_error: RwLock<Option<ConnectError>>,
last_disconnect: RwLock<Option<DisconnectReason>>,
}
impl ConnectionState {
const STATUS_DISCONNECTED: u8 = 0;
const STATUS_CONNECTING: u8 = 1;
const STATUS_CONNECTED: u8 = 2;
const STATUS_RECONNECTING: u8 = 3;
const STATUS_SHUTTING_DOWN: u8 = 4;
#[must_use]
pub fn new() -> Self {
Self {
id: AtomicU64::new(0),
id_counter: AtomicU64::new(0),
status: AtomicU8::new(Self::STATUS_DISCONNECTED),
reconnect_count: AtomicU64::new(0),
error_count: AtomicU64::new(0),
shutdown_requested: AtomicBool::new(false),
connected_at: RwLock::new(None),
last_activity: RwLock::new(None),
last_error: RwLock::new(None),
last_disconnect: RwLock::new(None),
}
}
#[must_use]
pub fn id(&self) -> u64 {
self.id.load(Ordering::Acquire)
}
#[must_use]
pub fn status(&self) -> ConnectionStatus {
match self.status.load(Ordering::Acquire) {
Self::STATUS_CONNECTING => ConnectionStatus::Connecting,
Self::STATUS_CONNECTED => ConnectionStatus::Connected,
Self::STATUS_RECONNECTING => ConnectionStatus::Reconnecting,
Self::STATUS_SHUTTING_DOWN => ConnectionStatus::ShuttingDown,
_ => ConnectionStatus::Disconnected,
}
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.status() == ConnectionStatus::Connected
}
#[must_use]
pub fn is_shutdown_requested(&self) -> bool {
self.shutdown_requested.load(Ordering::Acquire)
}
#[must_use]
pub fn reconnect_count(&self) -> u64 {
self.reconnect_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn error_count(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}
pub fn mark_connecting(&self) {
self.status
.store(Self::STATUS_CONNECTING, Ordering::Release);
}
pub fn mark_reconnecting(&self) {
self.status
.store(Self::STATUS_RECONNECTING, Ordering::Release);
self.reconnect_count.fetch_add(1, Ordering::Relaxed);
}
pub async fn mark_connected(&self) -> u64 {
let new_id = self.id_counter.fetch_add(1, Ordering::Relaxed) + 1;
self.id.store(new_id, Ordering::Release);
self.status.store(Self::STATUS_CONNECTED, Ordering::Release);
let now = Instant::now();
*self.connected_at.write().await = Some(now);
*self.last_activity.write().await = Some(now);
new_id
}
pub async fn mark_disconnected(&self, reason: DisconnectReason) {
self.status
.store(Self::STATUS_DISCONNECTED, Ordering::Release);
*self.connected_at.write().await = None;
*self.last_disconnect.write().await = Some(reason);
}
pub fn mark_shutting_down(&self) {
self.shutdown_requested.store(true, Ordering::Release);
self.status
.store(Self::STATUS_SHUTTING_DOWN, Ordering::Release);
}
pub async fn update_activity(&self) {
*self.last_activity.write().await = Some(Instant::now());
}
pub async fn record_error(&self, error: ConnectError) {
self.error_count.fetch_add(1, Ordering::Relaxed);
*self.last_error.write().await = Some(error);
}
pub async fn is_healthy(&self, timeout: Duration) -> bool {
if !self.is_connected() {
return false;
}
let last_activity = self.last_activity.read().await;
last_activity.is_some_and(|time| time.elapsed() < timeout)
}
pub async fn connection_duration(&self) -> Option<Duration> {
let connected_at = self.connected_at.read().await;
connected_at.map(|t| t.elapsed())
}
pub async fn snapshot(&self) -> ConnectionSnapshot {
let connected_at = *self.connected_at.read().await;
let last_activity = *self.last_activity.read().await;
let last_error = self.last_error.read().await.clone();
let last_disconnect = self.last_disconnect.read().await.clone();
ConnectionSnapshot {
id: self.id(),
status: self.status(),
connected_at,
last_activity,
reconnect_count: self.reconnect_count(),
error_count: self.error_count(),
last_error,
last_disconnect,
connection_duration: connected_at.map(|t| t.elapsed()),
}
}
pub async fn reset(&self) {
self.id.store(0, Ordering::Release);
self.status
.store(Self::STATUS_DISCONNECTED, Ordering::Release);
self.reconnect_count.store(0, Ordering::Relaxed);
self.error_count.store(0, Ordering::Relaxed);
self.shutdown_requested.store(false, Ordering::Release);
*self.connected_at.write().await = None;
*self.last_activity.write().await = None;
*self.last_error.write().await = None;
*self.last_disconnect.write().await = None;
}
}
impl Default for ConnectionState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ConnectionSnapshot {
pub id: u64,
pub status: ConnectionStatus,
pub connected_at: Option<Instant>,
pub last_activity: Option<Instant>,
pub reconnect_count: u64,
pub error_count: u64,
pub last_error: Option<ConnectError>,
pub last_disconnect: Option<DisconnectReason>,
pub connection_duration: Option<Duration>,
}
impl ConnectionSnapshot {
#[must_use]
pub const fn is_connected(&self) -> bool {
matches!(self.status, ConnectionStatus::Connected)
}
#[must_use]
pub fn uptime_ratio(&self, since: Instant) -> f64 {
let total_duration = since.elapsed();
if total_duration.is_zero() {
return 0.0;
}
let connected_duration = self.connection_duration.unwrap_or(Duration::ZERO);
connected_duration.as_secs_f64() / total_duration.as_secs_f64()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_connection_state_lifecycle() {
let state = ConnectionState::new();
assert_eq!(state.status(), ConnectionStatus::Disconnected);
assert!(!state.is_connected());
state.mark_connecting();
assert_eq!(state.status(), ConnectionStatus::Connecting);
let id = state.mark_connected().await;
assert_eq!(id, 1);
assert_eq!(state.status(), ConnectionStatus::Connected);
assert!(state.is_connected());
state.mark_disconnected(DisconnectReason::Normal).await;
assert_eq!(state.status(), ConnectionStatus::Disconnected);
assert!(!state.is_connected());
}
#[tokio::test]
async fn test_connection_state_snapshot() {
let state = ConnectionState::new();
state.mark_connected().await;
let snapshot = state.snapshot().await;
assert!(snapshot.is_connected());
assert_eq!(snapshot.id, 1);
assert!(snapshot.connected_at.is_some());
}
#[tokio::test]
async fn test_reconnect_counting() {
let state = ConnectionState::new();
state.mark_reconnecting();
assert_eq!(state.reconnect_count(), 1);
state.mark_reconnecting();
assert_eq!(state.reconnect_count(), 2);
}
}