stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Extension host - manages extension lifecycle and execution.

use std::sync::Arc;
use tokio::sync::RwLock;
use tungstenite::Message;

use super::traits::{BoxExtension, Extension};
use crate::context::ConnectionContext;
use crate::error::ExtensionError;

/// Extension host - manages registered extensions
pub struct ExtensionHost {
    /// Registered extensions
    extensions: Arc<RwLock<Vec<Arc<BoxExtension>>>>,
    /// Extension context
    context: Arc<RwLock<ConnectionContext>>,
}

impl ExtensionHost {
    /// Create a new extension host
    #[must_use]
    pub fn new() -> Self {
        Self {
            extensions: Arc::new(RwLock::new(Vec::new())),
            context: Arc::new(RwLock::new(ConnectionContext::default())),
        }
    }

    /// Create a new extension host with initial connection ID
    #[must_use]
    pub fn with_connection_id(connection_id: u64) -> Self {
        Self {
            extensions: Arc::new(RwLock::new(Vec::new())),
            context: Arc::new(RwLock::new(ConnectionContext::new(connection_id))),
        }
    }

    /// Register an extension
    pub async fn register<E: Extension + 'static>(
        &self,
        extension: E,
    ) -> Result<(), ExtensionError> {
        let ext_name = extension.name().to_string();
        let boxed: BoxExtension = Box::new(extension);

        // Initialize the extension
        {
            let ctx = self.context.read().await;
            boxed
                .on_init(&ctx)
                .await
                .map_err(|e| ExtensionError::InitFailed {
                    name: ext_name.clone(),
                    message: e.to_string(),
                })?;
        }

        // Add to registry
        self.extensions.write().await.push(Arc::new(boxed));

        tracing::info!(extension = %ext_name, "Extension registered");
        Ok(())
    }

    /// Get the number of registered extensions
    pub async fn extension_count(&self) -> usize {
        self.extensions.read().await.len()
    }

    /// Get extension names
    pub async fn extension_names(&self) -> Vec<String> {
        let extensions = self.extensions.read().await;
        extensions
            .iter()
            .map(|e| e.as_ref().name().to_string())
            .collect()
    }

    /// Update the context
    pub async fn update_context(&self, connection_id: u64, reconnect_count: u64) {
        let mut ctx = self.context.write().await;
        ctx.update(connection_id, reconnect_count);
    }

    /// Add metadata to context
    pub async fn add_metadata(&self, key: impl Into<String>, value: impl Into<String>) {
        let mut ctx = self.context.write().await;
        ctx.put_metadata(key, value);
    }

    /// Notify all lifecycle-aware extensions of connection
    ///
    /// Calls `on_connect` on all registered extensions that opt into lifecycle callbacks. Continues
    /// even if individual extensions fail. Extension failures are logged but not propagated.
    pub async fn notify_connect(&self) -> Result<(), ExtensionError> {
        // Snapshot then release locks to avoid holding multiple locks across awaits
        let extensions: Vec<_> = { self.extensions.read().await.clone() };
        let ctx = { self.context.read().await.clone() };

        for ext in extensions {
            if ext.as_ref().handles_lifecycle() {
                if let Err(e) = ext.as_ref().on_connect(&ctx).await {
                    tracing::warn!(
                        extension = ext.as_ref().name(),
                        error = ?e,
                        "Extension on_connect failed"
                    );
                    // Continue with other extensions
                }
            }
        }

        Ok(())
    }

    /// Notify all lifecycle-aware extensions of disconnection
    ///
    /// Calls `on_disconnect` on all registered extensions that opt into lifecycle callbacks. Continues
    /// even if individual extensions fail. Extension failures are logged but not propagated.
    pub async fn notify_disconnect(&self) -> Result<(), ExtensionError> {
        // Snapshot then release locks to avoid holding multiple locks across awaits
        let extensions: Vec<_> = { self.extensions.read().await.clone() };
        let ctx = { self.context.read().await.clone() };

        for ext in extensions {
            if ext.as_ref().handles_lifecycle() {
                if let Err(e) = ext.as_ref().on_disconnect(&ctx).await {
                    tracing::warn!(
                        extension = ext.as_ref().name(),
                        error = ?e,
                        "Extension on_disconnect failed"
                    );
                }
            }
        }

        Ok(())
    }

    /// Process a message through all message extensions
    ///
    /// Accepts a reference to the message for zero-copy efficiency. Extensions can inspect
    /// the message without taking ownership, and may return a transformed copy if needed.
    ///
    /// Returns `Some(message)` if the message should be delivered (original or transformed),
    /// `None` if it was filtered out by an extension.
    pub async fn process_message(
        &self,
        message: &Message,
    ) -> Result<Option<Message>, ExtensionError> {
        // Snapshot then release locks to avoid holding multiple locks across awaits
        let extensions: Vec<_> = { self.extensions.read().await.clone() };
        let ctx = { self.context.read().await.clone() };

        let mut current_message: Option<Message> = None;

        for ext in extensions {
            if ext.as_ref().handles_messages() {
                // Pass reference to the message (or last transformation)
                let msg_ref = current_message.as_ref().unwrap_or(message);

                match ext.as_ref().on_message(msg_ref, &ctx).await {
                    Ok(Some(transformed)) => {
                        current_message = Some(transformed);
                    }
                    Ok(None) => {
                        // Message filtered out
                        tracing::trace!(extension = ext.as_ref().name(), "Message filtered");
                        return Ok(None);
                    }
                    Err(e) => {
                        tracing::warn!(
                            extension = ext.as_ref().name(),
                            error = ?e,
                            "Extension on_message failed"
                        );
                        return Err(e);
                    }
                }
            }
        }

        // Return transformed message or clone of original
        Ok(Some(current_message.unwrap_or_else(|| message.clone())))
    }

    /// Shutdown all extensions
    ///
    /// Calls `on_shutdown` on all registered extensions. Continues even if
    /// individual extensions fail. Extension failures are logged but not propagated.
    pub async fn shutdown(&self) -> Result<(), ExtensionError> {
        // Snapshot then release locks to avoid holding multiple locks across awaits
        let extensions: Vec<_> = { self.extensions.read().await.clone() };
        let ctx = { self.context.read().await.clone() };

        for ext in extensions {
            if let Err(e) = ext.as_ref().on_shutdown(&ctx).await {
                tracing::warn!(
                    extension = ext.as_ref().name(),
                    error = ?e,
                    "Extension on_shutdown failed"
                );
            }
        }

        Ok(())
    }
}

impl Default for ExtensionHost {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use async_trait::async_trait;

    struct TestExtension {
        name: &'static str,
    }

    #[async_trait]
    impl Extension for TestExtension {
        fn name(&self) -> &'static str {
            self.name
        }

        fn handles_lifecycle(&self) -> bool {
            true
        }

        fn handles_messages(&self) -> bool {
            true
        }
    }

    #[tokio::test]
    async fn test_extension_registration() {
        let host = ExtensionHost::new();

        host.register(TestExtension { name: "test" }).await.unwrap();

        assert_eq!(host.extension_count().await, 1);
        assert_eq!(host.extension_names().await, vec!["test"]);
    }

    #[tokio::test]
    async fn test_context_update() {
        let host = ExtensionHost::new();
        host.update_context(42, 5).await;

        let ctx = host.context.read().await;
        assert_eq!(ctx.connection_id, 42);
        assert_eq!(ctx.reconnect_count, 5);
        assert!(ctx.is_reconnection);
        drop(ctx);
    }
}