stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Status viewer extension - tracks connection status.

use async_trait::async_trait;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;

use crate::context::ConnectionContext;
use crate::error::ExtensionError;
use crate::extension::Extension;

/// Represents the connection status of a WebSocket stream
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ConnectionStatus {
    /// The WebSocket stream is connected
    Connected,
    /// The WebSocket stream is disconnected
    Disconnected,
}

/// Status viewer extension that tracks current WebSocket connection status
pub struct StatusViewer {
    current: Arc<Mutex<ConnectionStatus>>,
}

impl StatusViewer {
    /// Create a new status viewer
    #[must_use]
    pub fn new() -> Self {
        Self {
            current: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
        }
    }

    /// Get the current connection status
    pub async fn current_status(&self) -> ConnectionStatus {
        let status = self.current.lock().await;
        status.clone()
    }

    /// Get a shared reference to the current status
    #[must_use]
    pub fn status_handle(&self) -> Arc<Mutex<ConnectionStatus>> {
        self.current.clone()
    }

    /// Check if currently connected
    pub async fn is_connected(&self) -> bool {
        matches!(self.current_status().await, ConnectionStatus::Connected)
    }
}

impl Default for StatusViewer {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl Extension for StatusViewer {
    fn name(&self) -> &'static str {
        "status_viewer"
    }

    fn version(&self) -> &'static str {
        "2.0.0"
    }

    fn description(&self) -> &'static str {
        "Tracks current WebSocket connection status"
    }

    fn handles_lifecycle(&self) -> bool {
        true
    }

    fn handles_messages(&self) -> bool {
        false
    }

    async fn on_connect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::debug!(
            connection_id = ctx.connection_id,
            reconnect_count = ctx.reconnect_count,
            "StatusViewer: connected"
        );
        *self.current.lock().await = ConnectionStatus::Connected;
        Ok(())
    }

    async fn on_disconnect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::debug!(
            connection_id = ctx.connection_id,
            "StatusViewer: disconnected"
        );
        *self.current.lock().await = ConnectionStatus::Disconnected;
        Ok(())
    }
}

/// Enhanced status viewer with connection history
pub struct AdvancedStatusViewer {
    current: Arc<Mutex<ConnectionStatus>>,
    history: Arc<Mutex<Vec<(Instant, ConnectionStatus)>>>,
    max_history: usize,
}

impl AdvancedStatusViewer {
    /// Create a new advanced status viewer with default history limit
    #[must_use]
    pub fn new() -> Self {
        Self::with_history_limit(100)
    }

    /// Create with custom history limit
    #[must_use]
    pub fn with_history_limit(max_history: usize) -> Self {
        Self {
            current: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
            history: Arc::new(Mutex::new(Vec::new())),
            max_history,
        }
    }

    /// Get current connection status
    pub async fn current_status(&self) -> ConnectionStatus {
        self.current.lock().await.clone()
    }

    /// Get connection history
    pub async fn get_history(&self) -> Vec<(Instant, ConnectionStatus)> {
        self.history.lock().await.clone()
    }

    /// Get uptime since last connection
    pub async fn get_uptime(&self) -> Option<std::time::Duration> {
        self.history
            .lock()
            .await
            .iter()
            .rev()
            .find(|(_, status)| matches!(status, ConnectionStatus::Connected))
            .map(|(time, _)| time.elapsed())
    }

    /// Get total connection count
    pub async fn get_connection_count(&self) -> usize {
        self.history
            .lock()
            .await
            .iter()
            .filter(|(_, status)| matches!(status, ConnectionStatus::Connected))
            .count()
    }

    /// Check if currently connected
    pub async fn is_connected(&self) -> bool {
        matches!(self.current_status().await, ConnectionStatus::Connected)
    }

    async fn add_to_history(&self, status: ConnectionStatus) {
        let mut history = self.history.lock().await;
        history.push((Instant::now(), status));

        if history.len() > self.max_history {
            history.remove(0);
        }
    }

    async fn set_status(&self, status: ConnectionStatus) {
        let mut current = self.current.lock().await;
        *current = status.clone();
        drop(current);
        self.add_to_history(status).await;
    }
}

impl Default for AdvancedStatusViewer {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl Extension for AdvancedStatusViewer {
    fn name(&self) -> &'static str {
        "advanced_status_viewer"
    }

    fn version(&self) -> &'static str {
        "2.0.0"
    }

    fn description(&self) -> &'static str {
        "Advanced status viewer with connection history and metrics"
    }

    fn handles_lifecycle(&self) -> bool {
        true
    }

    fn handles_messages(&self) -> bool {
        false
    }

    async fn on_connect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::info!(
            connection_id = ctx.connection_id,
            reconnect_count = ctx.reconnect_count,
            "AdvancedStatusViewer: connected"
        );

        self.set_status(ConnectionStatus::Connected).await;
        Ok(())
    }

    async fn on_disconnect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::info!(
            connection_id = ctx.connection_id,
            "AdvancedStatusViewer: disconnected"
        );

        self.set_status(ConnectionStatus::Disconnected).await;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_status_viewer() {
        let viewer = StatusViewer::new();
        assert_eq!(
            viewer.current_status().await,
            ConnectionStatus::Disconnected
        );

        let ctx = ConnectionContext::new(1);
        viewer.on_connect(&ctx).await.unwrap();
        assert_eq!(viewer.current_status().await, ConnectionStatus::Connected);

        viewer.on_disconnect(&ctx).await.unwrap();
        assert_eq!(
            viewer.current_status().await,
            ConnectionStatus::Disconnected
        );
    }

    #[tokio::test]
    async fn test_advanced_status_viewer() {
        let viewer = AdvancedStatusViewer::new();
        let ctx = ConnectionContext::new(1);

        viewer.on_connect(&ctx).await.unwrap();
        assert!(viewer.is_connected().await);
        assert_eq!(viewer.get_connection_count().await, 1);

        viewer.on_disconnect(&ctx).await.unwrap();
        assert!(!viewer.is_connected().await);

        viewer.on_connect(&ctx).await.unwrap();
        assert_eq!(viewer.get_connection_count().await, 2);
        assert_eq!(viewer.get_history().await.len(), 3);
    }
}