camel-component-ws 0.6.3

WebSocket component for rust-camel
Documentation
use std::time::Duration;

use camel_component_api::CamelError;

#[derive(Debug, Clone, Default, serde::Deserialize)]
pub struct WsConfig {
    pub max_connections: Option<u32>,
    pub max_message_size: Option<u32>,
    pub heartbeat_interval_ms: Option<u64>,
    pub idle_timeout_ms: Option<u64>,
    pub connect_timeout_ms: Option<u64>,
    pub response_timeout_ms: Option<u64>,
}

#[derive(Debug, Clone)]
pub struct WsEndpointConfig {
    pub scheme: String,
    pub host: String,
    pub port: u16,
    pub path: String,
    pub max_connections: u32,
    pub max_message_size: u32,
    pub send_to_all: bool,
    pub heartbeat_interval: Duration,
    pub idle_timeout: Duration,
    pub connect_timeout: Duration,
    pub response_timeout: Duration,
    pub allow_origin: String,
    pub tls_cert: Option<String>,
    pub tls_key: Option<String>,
}

impl Default for WsEndpointConfig {
    fn default() -> Self {
        Self {
            scheme: "ws".into(),
            host: "0.0.0.0".into(),
            port: 8080,
            path: "/".into(),
            max_connections: 100,
            max_message_size: 65536,
            send_to_all: false,
            heartbeat_interval: Duration::ZERO,
            idle_timeout: Duration::ZERO,
            connect_timeout: Duration::from_secs(10),
            response_timeout: Duration::from_secs(30),
            allow_origin: "*".into(),
            tls_cert: None,
            tls_key: None,
        }
    }
}

#[derive(Debug, Clone)]
pub struct WsServerConfig {
    pub inner: WsEndpointConfig,
}

#[derive(Debug, Clone)]
pub struct WsClientConfig {
    pub inner: WsEndpointConfig,
}

impl WsEndpointConfig {
    pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
        let parsed = camel_component_api::parse_uri(uri)
            .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;

        let scheme = parsed.scheme;
        if scheme != "ws" && scheme != "wss" {
            return Err(CamelError::EndpointCreationFailed(format!(
                "Invalid WebSocket scheme: {scheme}"
            )));
        }

        let host_port_path = parsed.path;
        let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
        let (host_port, path) = match host_port_path.split_once('/') {
            Some((hp, p)) => (hp, format!("/{p}")),
            None => (host_port_path, "/".to_string()),
        };

        let (host, port) = match host_port.rsplit_once(':') {
            Some((h, p)) if p.parse::<u16>().is_ok() => {
                let parsed_port = p.parse::<u16>().unwrap();
                (h.to_string(), parsed_port)
            }
            _ => (
                host_port.to_string(),
                if scheme == "wss" { 443 } else { 80 },
            ),
        };

        let mut cfg = Self {
            scheme,
            host: if host.is_empty() {
                "0.0.0.0".to_string()
            } else {
                host
            },
            port,
            path,
            ..Self::default()
        };

        let params = parsed.params;
        if let Some(v) = params
            .get("maxConnections")
            .and_then(|v| v.parse::<u32>().ok())
        {
            cfg.max_connections = v;
        }
        if let Some(v) = params
            .get("maxMessageSize")
            .and_then(|v| v.parse::<u32>().ok())
        {
            cfg.max_message_size = v;
        }
        if let Some(v) = params.get("sendToAll").and_then(|v| v.parse::<bool>().ok()) {
            cfg.send_to_all = v;
        }
        if let Some(v) = params
            .get("heartbeatIntervalMs")
            .and_then(|v| v.parse::<u64>().ok())
        {
            cfg.heartbeat_interval = Duration::from_millis(v);
        }
        if let Some(v) = params
            .get("idleTimeoutMs")
            .and_then(|v| v.parse::<u64>().ok())
        {
            cfg.idle_timeout = Duration::from_millis(v);
        }
        if let Some(v) = params
            .get("connectTimeoutMs")
            .and_then(|v| v.parse::<u64>().ok())
        {
            cfg.connect_timeout = Duration::from_millis(v);
        }
        if let Some(v) = params
            .get("responseTimeoutMs")
            .and_then(|v| v.parse::<u64>().ok())
        {
            cfg.response_timeout = Duration::from_millis(v);
        }
        if let Some(v) = params.get("allowOrigin") {
            cfg.allow_origin = v.to_string();
        }
        if let Some(v) = params.get("tlsCert") {
            cfg.tls_cert = Some(v.to_string());
        }
        if let Some(v) = params.get("tlsKey") {
            cfg.tls_key = Some(v.to_string());
        }

        Ok(cfg)
    }

    pub fn server_config(&self) -> WsServerConfig {
        WsServerConfig {
            inner: self.clone(),
        }
    }

    pub fn client_config(&self) -> WsClientConfig {
        WsClientConfig {
            inner: self.clone(),
        }
    }

    pub fn canonical_host(&self) -> String {
        match self.host.as_str() {
            "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
            h => h.to_string(),
        }
    }
}