stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Logging extension - logs connection and message events.

use async_trait::async_trait;
use tungstenite::Message;

use crate::context::ConnectionContext;
use crate::error::ExtensionError;
use crate::extension::Extension;

/// Logging level for the extension
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
    /// Log only errors
    Error,
    /// Log warnings and errors
    Warn,
    /// Log info, warnings, and errors
    Info,
    /// Log debug and above
    Debug,
    /// Log everything
    Trace,
}

/// Logging extension configuration
#[derive(Debug, Clone)]
pub struct LoggingConfig {
    /// Log level
    pub level: LogLevel,
    /// Whether to log messages
    pub log_messages: bool,
    /// Maximum message length to log (truncate if longer)
    pub max_message_len: usize,
    /// Custom prefix for log messages
    pub prefix: String,
}

impl Default for LoggingConfig {
    fn default() -> Self {
        Self {
            level: LogLevel::Info,
            log_messages: false,
            max_message_len: 200,
            prefix: "ws".to_string(),
        }
    }
}

impl LoggingConfig {
    /// Create a new logging config
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Set log level
    #[must_use]
    pub const fn with_level(mut self, level: LogLevel) -> Self {
        self.level = level;
        self
    }

    /// Enable message logging
    #[must_use]
    pub const fn with_messages(mut self) -> Self {
        self.log_messages = true;
        self
    }

    /// Set max message length
    #[must_use]
    pub const fn with_max_len(mut self, len: usize) -> Self {
        self.max_message_len = len;
        self
    }

    /// Set prefix
    #[must_use]
    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
        self.prefix = prefix.into();
        self
    }
}

/// Logging extension - logs connection lifecycle and optionally messages
pub struct LoggingExtension {
    config: LoggingConfig,
}

impl LoggingExtension {
    /// Create a new logging extension with default config
    #[must_use]
    pub fn new() -> Self {
        Self {
            config: LoggingConfig::default(),
        }
    }

    /// Create a logging extension with custom config
    #[must_use]
    pub const fn with_config(config: LoggingConfig) -> Self {
        Self { config }
    }

    /// Create a verbose logging extension
    #[must_use]
    pub fn verbose() -> Self {
        Self {
            config: LoggingConfig::new()
                .with_level(LogLevel::Debug)
                .with_messages(),
        }
    }

    fn truncate_message(&self, msg: &str) -> String {
        if msg.len() > self.config.max_message_len {
            format!("{}...", &msg[..self.config.max_message_len])
        } else {
            msg.to_string()
        }
    }
}

impl Default for LoggingExtension {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl Extension for LoggingExtension {
    fn name(&self) -> &'static str {
        "logging"
    }

    fn version(&self) -> &'static str {
        "1.0.0"
    }

    fn description(&self) -> &'static str {
        "Logs connection lifecycle and messages"
    }

    fn handles_lifecycle(&self) -> bool {
        true
    }

    fn handles_messages(&self) -> bool {
        self.config.log_messages
    }

    async fn on_init(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::debug!(
            prefix = %self.config.prefix,
            connection_id = ctx.connection_id,
            "Logging extension initialized"
        );
        Ok(())
    }

    async fn on_connect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        if ctx.is_reconnection {
            tracing::info!(
                prefix = %self.config.prefix,
                connection_id = ctx.connection_id,
                reconnect_count = ctx.reconnect_count,
                "Reconnected"
            );
        } else {
            tracing::info!(
                prefix = %self.config.prefix,
                connection_id = ctx.connection_id,
                "Connected"
            );
        }
        Ok(())
    }

    async fn on_disconnect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::info!(
            prefix = %self.config.prefix,
            connection_id = ctx.connection_id,
            "Disconnected"
        );
        Ok(())
    }

    async fn on_shutdown(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
        tracing::debug!(
            prefix = %self.config.prefix,
            connection_id = ctx.connection_id,
            "Logging extension shutdown"
        );
        Ok(())
    }

    async fn on_message(
        &self,
        message: &Message,
        ctx: &ConnectionContext,
    ) -> Result<Option<Message>, ExtensionError> {
        let msg_type = match message {
            Message::Text(_) => "text",
            Message::Binary(_) => "binary",
            Message::Ping(_) => "ping",
            Message::Pong(_) => "pong",
            Message::Close(_) => "close",
            Message::Frame(_) => "frame",
        };

        let content = match message {
            Message::Text(t) => self.truncate_message(t.as_ref()),
            Message::Binary(b) => format!("<{} bytes>", b.len()),
            Message::Ping(d) | Message::Pong(d) => format!("<{} bytes>", d.len()),
            Message::Close(cf) => cf.as_ref().map_or_else(
                || "no reason".to_string(),
                |c| format!("{}: {}", c.code, c.reason),
            ),
            Message::Frame(_) => "<frame>".to_string(),
        };

        tracing::debug!(
            prefix = %self.config.prefix,
            connection_id = ctx.connection_id,
            msg_type = msg_type,
            content = %content,
            "Message received"
        );

        // Pass through unchanged (no need to clone unless modifying)
        Ok(Some(message.clone()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_logging_config() {
        let config = LoggingConfig::new()
            .with_level(LogLevel::Debug)
            .with_messages()
            .with_prefix("test");

        assert_eq!(config.level, LogLevel::Debug);
        assert!(config.log_messages);
        assert_eq!(config.prefix, "test");
    }

    #[test]
    fn test_message_truncation() {
        let ext = LoggingExtension::with_config(LoggingConfig::new().with_max_len(10));

        assert_eq!(ext.truncate_message("short"), "short");
        assert_eq!(
            ext.truncate_message("this is a long message"),
            "this is a ..."
        );
    }
}