stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Connection state management.

use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;

use crate::error::{ConnectError, DisconnectReason};

/// Connection status
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionStatus {
    /// Not connected
    Disconnected,
    /// Attempting to connect
    Connecting,
    /// Connected and operational
    Connected,
    /// Reconnecting after disconnection
    Reconnecting,
    /// Shutting down
    ShuttingDown,
}

impl std::fmt::Display for ConnectionStatus {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Disconnected => write!(f, "disconnected"),
            Self::Connecting => write!(f, "connecting"),
            Self::Connected => write!(f, "connected"),
            Self::Reconnecting => write!(f, "reconnecting"),
            Self::ShuttingDown => write!(f, "shutting_down"),
        }
    }
}

/// Internal connection state
pub struct ConnectionState {
    /// Unique connection ID (incremented on each connection)
    id: AtomicU64,
    /// Connection ID counter
    id_counter: AtomicU64,
    /// Current status (encoded as u8)
    status: AtomicU8,
    /// Number of reconnection attempts
    reconnect_count: AtomicU64,
    /// Total number of errors encountered
    error_count: AtomicU64,
    /// Whether shutdown has been requested
    shutdown_requested: AtomicBool,

    /// Time when current connection was established
    connected_at: RwLock<Option<Instant>>,
    /// Time of last activity (message sent/received)
    last_activity: RwLock<Option<Instant>>,
    /// Last error encountered
    last_error: RwLock<Option<ConnectError>>,
    /// Last disconnect reason
    last_disconnect: RwLock<Option<DisconnectReason>>,
}

impl ConnectionState {
    // Encoded status values (avoid magic numbers)
    const STATUS_DISCONNECTED: u8 = 0;
    const STATUS_CONNECTING: u8 = 1;
    const STATUS_CONNECTED: u8 = 2;
    const STATUS_RECONNECTING: u8 = 3;
    const STATUS_SHUTTING_DOWN: u8 = 4;

    /// Create a new connection state
    #[must_use]
    pub fn new() -> Self {
        Self {
            id: AtomicU64::new(0),
            id_counter: AtomicU64::new(0),
            status: AtomicU8::new(Self::STATUS_DISCONNECTED),
            reconnect_count: AtomicU64::new(0),
            error_count: AtomicU64::new(0),
            shutdown_requested: AtomicBool::new(false),
            connected_at: RwLock::new(None),
            last_activity: RwLock::new(None),
            last_error: RwLock::new(None),
            last_disconnect: RwLock::new(None),
        }
    }

    /// Get the current connection ID
    #[must_use]
    pub fn id(&self) -> u64 {
        self.id.load(Ordering::Acquire)
    }

    /// Get the current status
    #[must_use]
    pub fn status(&self) -> ConnectionStatus {
        match self.status.load(Ordering::Acquire) {
            Self::STATUS_CONNECTING => ConnectionStatus::Connecting,
            Self::STATUS_CONNECTED => ConnectionStatus::Connected,
            Self::STATUS_RECONNECTING => ConnectionStatus::Reconnecting,
            Self::STATUS_SHUTTING_DOWN => ConnectionStatus::ShuttingDown,
            // STATUS_DISCONNECTED and any invalid value => Disconnected
            _ => ConnectionStatus::Disconnected,
        }
    }

    /// Check if currently connected
    #[must_use]
    pub fn is_connected(&self) -> bool {
        self.status() == ConnectionStatus::Connected
    }

    /// Check if shutdown was requested
    #[must_use]
    pub fn is_shutdown_requested(&self) -> bool {
        self.shutdown_requested.load(Ordering::Acquire)
    }

    /// Get reconnect count
    #[must_use]
    pub fn reconnect_count(&self) -> u64 {
        self.reconnect_count.load(Ordering::Relaxed)
    }

    /// Get error count
    #[must_use]
    pub fn error_count(&self) -> u64 {
        self.error_count.load(Ordering::Relaxed)
    }

    /// Mark as connecting
    pub fn mark_connecting(&self) {
        self.status
            .store(Self::STATUS_CONNECTING, Ordering::Release);
    }

    /// Mark as reconnecting
    pub fn mark_reconnecting(&self) {
        self.status
            .store(Self::STATUS_RECONNECTING, Ordering::Release);
        self.reconnect_count.fetch_add(1, Ordering::Relaxed);
    }

    /// Mark as connected with new connection ID
    pub async fn mark_connected(&self) -> u64 {
        let new_id = self.id_counter.fetch_add(1, Ordering::Relaxed) + 1;
        self.id.store(new_id, Ordering::Release);
        self.status.store(Self::STATUS_CONNECTED, Ordering::Release);

        let now = Instant::now();
        *self.connected_at.write().await = Some(now);
        *self.last_activity.write().await = Some(now);

        new_id
    }

    /// Mark as disconnected
    pub async fn mark_disconnected(&self, reason: DisconnectReason) {
        self.status
            .store(Self::STATUS_DISCONNECTED, Ordering::Release);
        *self.connected_at.write().await = None;
        *self.last_disconnect.write().await = Some(reason);
    }

    /// Mark as shutting down
    pub fn mark_shutting_down(&self) {
        self.shutdown_requested.store(true, Ordering::Release);
        self.status
            .store(Self::STATUS_SHUTTING_DOWN, Ordering::Release);
    }

    /// Update last activity time
    pub async fn update_activity(&self) {
        *self.last_activity.write().await = Some(Instant::now());
    }

    /// Record an error
    pub async fn record_error(&self, error: ConnectError) {
        self.error_count.fetch_add(1, Ordering::Relaxed);
        *self.last_error.write().await = Some(error);
    }

    /// Check if connection is healthy (received activity within timeout)
    pub async fn is_healthy(&self, timeout: Duration) -> bool {
        if !self.is_connected() {
            return false;
        }

        let last_activity = self.last_activity.read().await;
        last_activity.is_some_and(|time| time.elapsed() < timeout)
    }

    /// Get current connection duration
    pub async fn connection_duration(&self) -> Option<Duration> {
        let connected_at = self.connected_at.read().await;
        connected_at.map(|t| t.elapsed())
    }

    /// Get a snapshot of the current state
    pub async fn snapshot(&self) -> ConnectionSnapshot {
        let connected_at = *self.connected_at.read().await;
        let last_activity = *self.last_activity.read().await;
        let last_error = self.last_error.read().await.clone();
        let last_disconnect = self.last_disconnect.read().await.clone();

        ConnectionSnapshot {
            id: self.id(),
            status: self.status(),
            connected_at,
            last_activity,
            reconnect_count: self.reconnect_count(),
            error_count: self.error_count(),
            last_error,
            last_disconnect,
            connection_duration: connected_at.map(|t| t.elapsed()),
        }
    }

    /// Reset state (for testing)
    pub async fn reset(&self) {
        self.id.store(0, Ordering::Release);
        self.status
            .store(Self::STATUS_DISCONNECTED, Ordering::Release);
        self.reconnect_count.store(0, Ordering::Relaxed);
        self.error_count.store(0, Ordering::Relaxed);
        self.shutdown_requested.store(false, Ordering::Release);
        *self.connected_at.write().await = None;
        *self.last_activity.write().await = None;
        *self.last_error.write().await = None;
        *self.last_disconnect.write().await = None;
    }
}

impl Default for ConnectionState {
    fn default() -> Self {
        Self::new()
    }
}

/// Snapshot of connection state for external consumption
#[derive(Debug, Clone)]
pub struct ConnectionSnapshot {
    /// Current connection ID
    pub id: u64,
    /// Current status
    pub status: ConnectionStatus,
    /// Time when connection was established
    pub connected_at: Option<Instant>,
    /// Time of last activity
    pub last_activity: Option<Instant>,
    /// Number of reconnections
    pub reconnect_count: u64,
    /// Number of errors
    pub error_count: u64,
    /// Last error (if any)
    pub last_error: Option<ConnectError>,
    /// Last disconnect reason (if any)
    pub last_disconnect: Option<DisconnectReason>,
    /// Current connection duration
    pub connection_duration: Option<Duration>,
}

impl ConnectionSnapshot {
    /// Check if currently connected
    #[must_use]
    pub const fn is_connected(&self) -> bool {
        matches!(self.status, ConnectionStatus::Connected)
    }

    /// Get uptime percentage (connected time / total time)
    #[must_use]
    pub fn uptime_ratio(&self, since: Instant) -> f64 {
        let total_duration = since.elapsed();
        if total_duration.is_zero() {
            return 0.0;
        }

        let connected_duration = self.connection_duration.unwrap_or(Duration::ZERO);
        connected_duration.as_secs_f64() / total_duration.as_secs_f64()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_connection_state_lifecycle() {
        let state = ConnectionState::new();

        // Initially disconnected
        assert_eq!(state.status(), ConnectionStatus::Disconnected);
        assert!(!state.is_connected());

        // Mark as connecting
        state.mark_connecting();
        assert_eq!(state.status(), ConnectionStatus::Connecting);

        // Mark as connected
        let id = state.mark_connected().await;
        assert_eq!(id, 1);
        assert_eq!(state.status(), ConnectionStatus::Connected);
        assert!(state.is_connected());

        // Mark as disconnected
        state.mark_disconnected(DisconnectReason::Normal).await;
        assert_eq!(state.status(), ConnectionStatus::Disconnected);
        assert!(!state.is_connected());
    }

    #[tokio::test]
    async fn test_connection_state_snapshot() {
        let state = ConnectionState::new();
        state.mark_connected().await;

        let snapshot = state.snapshot().await;
        assert!(snapshot.is_connected());
        assert_eq!(snapshot.id, 1);
        assert!(snapshot.connected_at.is_some());
    }

    #[tokio::test]
    async fn test_reconnect_counting() {
        let state = ConnectionState::new();

        state.mark_reconnecting();
        assert_eq!(state.reconnect_count(), 1);

        state.mark_reconnecting();
        assert_eq!(state.reconnect_count(), 2);
    }
}