Skip to main content

camel_component_ws/
lib.rs

1//! WebSocket component for rust-camel — Axum-based WebSocket server and Tokio-tungstenite client for bidirectional messaging.
2//!
3//! Main types: `WsComponent`, `WsBundle`, `WsConfig`, `WsServerConfig`, `WsClientConfig`, `WsEndpointConfig`.
4//! Main modules: `bundle`, `config`, `health`.
5
6pub mod bundle;
7pub mod config;
8pub mod health;
9
10pub use bundle::WsBundle;
11pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
12pub use health::WsHealthCheck;
13
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex, OnceLock};
16
17use async_trait::async_trait;
18use axum::body::Body;
19use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
20use axum::extract::{FromRequest, Request, State};
21use axum::http::{StatusCode, header};
22use axum::response::IntoResponse;
23use axum::{Router, serve};
24use camel_api::security_policy::AuthorizationDecision;
25use camel_api::{BackoffConfig, BackoffState};
26use camel_component_api::{
27    Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
28};
29use camel_component_api::{
30    Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
31    ProducerContext,
32};
33use dashmap::DashMap;
34use futures::{SinkExt, StreamExt};
35use std::future::Future;
36use std::pin::Pin;
37use std::task::{Context, Poll};
38use tokio::sync::{OnceCell, RwLock, mpsc};
39use tokio::task::JoinHandle;
40use tokio_tungstenite::tungstenite;
41use tokio_tungstenite::tungstenite::client::IntoClientRequest;
42use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
43use tower::Service;
44
45#[derive(Clone)]
46pub struct WsPathConfig {
47    pub max_connections: u32,
48    pub max_message_size: u32,
49    pub heartbeat_interval: std::time::Duration,
50    pub idle_timeout: std::time::Duration,
51    pub allow_origin: String,
52}
53
54impl Default for WsPathConfig {
55    fn default() -> Self {
56        let cfg = WsEndpointConfig::default();
57        Self {
58            max_connections: cfg.max_connections,
59            max_message_size: cfg.max_message_size,
60            heartbeat_interval: cfg.heartbeat_interval,
61            idle_timeout: cfg.idle_timeout,
62            allow_origin: cfg.allow_origin,
63        }
64    }
65}
66
67#[derive(Clone)]
68pub struct WsTlsConfig {
69    pub cert_path: String,
70    pub key_path: String,
71}
72
73type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
74
75struct ServerHandle {
76    state: WsAppState,
77    is_tls: bool,
78    _task: JoinHandle<()>,
79}
80
81struct ServerRegistryInner {
82    cell: Arc<OnceCell<ServerHandle>>,
83    ref_count: usize,
84}
85
86pub struct ServerRegistry {
87    inner: Mutex<HashMap<u16, ServerRegistryInner>>,
88}
89
90impl ServerRegistry {
91    pub fn global() -> &'static Self {
92        static REG: OnceLock<ServerRegistry> = OnceLock::new();
93        REG.get_or_init(|| Self {
94            inner: Mutex::new(HashMap::new()),
95        })
96    }
97
98    pub async fn get_or_spawn(
99        &'static self,
100        host: &str,
101        port: u16,
102        tls_config: Option<WsTlsConfig>,
103    ) -> Result<WsAppState, CamelError> {
104        let wants_tls = tls_config.is_some();
105        let host_owned = host.to_string();
106
107        let (cell, _is_new) = {
108            let mut guard = self.inner.lock().map_err(|_| {
109                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
110            })?;
111            let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
112                cell: Arc::new(OnceCell::new()),
113                ref_count: 0,
114            });
115            entry.ref_count += 1;
116            (entry.cell.clone(), entry.ref_count == 1)
117        };
118
119        let handle = cell
120            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
121            .await?;
122
123        if wants_tls != handle.is_tls {
124            // Decrement ref count since we're rejecting this caller
125            let mut guard = self.inner.lock().map_err(|_| {
126                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
127            })?;
128            if let Some(entry) = guard.get_mut(&port) {
129                entry.ref_count -= 1;
130                if entry.ref_count == 0 {
131                    guard.remove(&port);
132                }
133            }
134            return Err(CamelError::EndpointCreationFailed(format!(
135                "Server on port {port} already running with different TLS mode"
136            )));
137        }
138
139        Ok(handle.state.clone())
140    }
141
142    /// Release a reference to the server on the given port.
143    /// When the last reference is released, the server task is aborted and the entry removed. (WS-005)
144    pub(crate) fn release(&self, port: u16) {
145        let mut guard = match self.inner.lock() {
146            Ok(g) => g,
147            Err(_) => return,
148        };
149        if let Some(entry) = guard.get_mut(&port) {
150            entry.ref_count = entry.ref_count.saturating_sub(1);
151            if entry.ref_count == 0 {
152                // Abort the server task if it exists (WS-001, WS-005)
153                if let Some(handle) = entry.cell.get() {
154                    handle._task.abort();
155                }
156                guard.remove(&port);
157                tracing::info!(port, "WebSocket server registry entry removed");
158            }
159        }
160    }
161}
162
163async fn spawn_server(
164    host: &str,
165    port: u16,
166    tls_config: Option<WsTlsConfig>,
167) -> Result<ServerHandle, CamelError> {
168    let host_owned = host.to_string();
169    let addr = format!("{host}:{port}");
170    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
171    let path_configs = Arc::new(DashMap::new());
172    let path_policies = Arc::new(DashMap::new());
173    let server_error = new_atomic_false();
174    let state = WsAppState {
175        dispatch: Arc::clone(&dispatch),
176        path_configs: Arc::clone(&path_configs),
177        path_policies: Arc::clone(&path_policies),
178        server_error: Arc::clone(&server_error),
179    };
180    let app = Router::new()
181        .fallback(dispatch_handler)
182        .with_state(state.clone());
183
184    let (task, is_tls) = if let Some(ref tls) = tls_config {
185        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
186        let parsed_addr = addr.parse().map_err(|e| {
187            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
188        })?;
189        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
190        let error_flag = Arc::clone(&server_error);
191        let task = tokio::spawn(async move {
192            if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
193                .serve(app.into_make_service())
194                .await
195            {
196                tracing::error!(
197                    host = host_owned,
198                    port = port,
199                    error = %e,
200                    "WebSocket server terminated with error"
201                );
202                error_flag.store(true, Ordering::Relaxed);
203            }
204        });
205        (task, true)
206    } else {
207        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
208            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
209        })?;
210        let error_flag = Arc::clone(&server_error);
211        let task = tokio::spawn(async move {
212            if let Err(e) = serve(listener, app).await {
213                tracing::error!(
214                    host = host_owned,
215                    port = port,
216                    error = %e,
217                    "WebSocket server terminated with error"
218                );
219                error_flag.store(true, Ordering::Relaxed);
220            }
221        });
222        (task, false)
223    };
224
225    tracing::info!(host, port, is_tls, "WebSocket server started");
226
227    Ok(ServerHandle {
228        state,
229        is_tls,
230        _task: task,
231    })
232}
233
234#[derive(Clone)]
235pub struct WsAppState {
236    pub dispatch: DispatchTable,
237    pub path_configs: Arc<DashMap<String, WsPathConfig>>,
238    pub path_policies: Arc<DashMap<String, camel_component_api::SecurityContext>>,
239    pub server_error: Arc<AtomicBool>,
240}
241
242pub struct WsConnectionRegistry {
243    connections: DashMap<String, mpsc::Sender<WsMessage>>,
244}
245
246static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
247    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
248> = OnceLock::new();
249
250fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
251    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
252}
253
254impl Default for WsConnectionRegistry {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl WsConnectionRegistry {
261    pub fn new() -> Self {
262        Self {
263            connections: DashMap::new(),
264        }
265    }
266
267    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
268        self.connections.insert(key, tx);
269    }
270
271    pub fn remove(&self, key: &str) {
272        self.connections.remove(key);
273    }
274
275    pub fn len(&self) -> usize {
276        self.connections.len()
277    }
278
279    pub fn is_empty(&self) -> bool {
280        self.connections.is_empty()
281    }
282
283    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
284        self.connections.iter().map(|e| e.value().clone()).collect()
285    }
286
287    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
288        keys.iter()
289            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
290            .collect()
291    }
292}
293
294pub async fn dispatch_handler(
295    State(state): State<WsAppState>,
296    req: Request<Body>,
297) -> impl IntoResponse {
298    let path = req.uri().path().to_string();
299    let origin = req
300        .headers()
301        .get(header::ORIGIN)
302        .and_then(|value| value.to_str().ok())
303        .map(str::to_string);
304    let remote_addr = req
305        .extensions()
306        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
307        .map(|ci| ci.0.to_string())
308        .unwrap_or_default();
309    let table = state.dispatch.read().await;
310    if !table.contains_key(&path) {
311        return (
312            StatusCode::NOT_FOUND,
313            "no ws endpoint registered for this path",
314        )
315            .into_response();
316    }
317    drop(table);
318
319    let path_config = state
320        .path_configs
321        .get(&path)
322        .map(|entry| entry.value().clone())
323        .unwrap_or_default();
324    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
325        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
326    }
327
328    let mut principal_opt: Option<camel_api::security_policy::Principal> = None;
329    if let Some(sec_ctx) = state.path_policies.get(&path) {
330        let extracted =
331            camel_auth::extract_token_multi(req.headers(), req.uri(), &sec_ctx.credential_sources);
332
333        match extracted {
334            Some(extracted) => {
335                if matches!(
336                    extracted.source,
337                    camel_auth::CredentialSource::QueryParam { .. }
338                ) {
339                    let redacted =
340                        camel_auth::redact_query_params(req.uri(), &["access_token", "token"]);
341                    tracing::debug!(path = %redacted, "WS upgrade with query token (redacted)");
342                }
343                match sec_ctx
344                    .authenticator
345                    .authenticate_bearer(&extracted.token)
346                    .await
347                {
348                    Ok(principal) => {
349                        let mut exchange = camel_api::Exchange::new(camel_api::Message::new(
350                            camel_api::Body::Empty,
351                        ));
352                        camel_api::store_principal_properties(&mut exchange, &principal);
353                        match sec_ctx.policy.evaluate(&mut exchange).await {
354                            Ok(AuthorizationDecision::Granted { principal: _p }) => {
355                                tracing::debug!(path = %path, subject = %principal.subject, "WS upgrade authorized");
356                                principal_opt = Some(principal);
357                            }
358                            Ok(AuthorizationDecision::Denied { reason, .. }) => {
359                                tracing::warn!(path = %path, reason = %reason, "WS upgrade denied");
360                                return (StatusCode::FORBIDDEN, "Forbidden").into_response();
361                            }
362                            Err(e) => {
363                                tracing::error!(path = %path, error = %e, "Policy evaluation error during WS upgrade");
364                                return (
365                                    StatusCode::INTERNAL_SERVER_ERROR,
366                                    "Internal Server Error",
367                                )
368                                    .into_response();
369                            }
370                        }
371                    }
372                    Err(e) => {
373                        let (status, body) = match &e {
374                            camel_api::CamelError::Unauthenticated(_) => {
375                                (StatusCode::UNAUTHORIZED, "Unauthorized")
376                            }
377                            camel_api::CamelError::ProcessorError(msg)
378                                if msg.contains("auth provider unavailable") =>
379                            {
380                                (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
381                            }
382                            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error"),
383                        };
384                        tracing::warn!(path = %path, error = %e, "WS upgrade authentication failed");
385                        return (status, body).into_response();
386                    }
387                }
388            }
389            None => {
390                tracing::warn!(path = %path, "WS upgrade rejected: no credential found in any source");
391                return (
392                    StatusCode::UNAUTHORIZED,
393                    [("WWW-Authenticate", "Bearer".to_string())],
394                    "Unauthorized",
395                )
396                    .into_response();
397            }
398        }
399    }
400
401    let upgrade_headers: HashMap<String, String> = req
402        .headers()
403        .iter()
404        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
405        .collect();
406
407    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
408        Ok(ws) => ws,
409        Err(_) => {
410            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
411        }
412    };
413
414    ws.on_upgrade(move |socket| {
415        ws_handler(
416            socket,
417            state,
418            path,
419            remote_addr,
420            upgrade_headers,
421            principal_opt,
422        )
423    })
424    .into_response()
425}
426
427#[allow(unused_variables)]
428async fn ws_handler(
429    socket: WebSocket,
430    state: WsAppState,
431    path: String,
432    remote_addr: String,
433    upgrade_headers: HashMap<String, String>,
434    principal: Option<camel_api::security_policy::Principal>,
435) {
436    let connection_key = uuid::Uuid::new_v4().to_string();
437    let path_config = state
438        .path_configs
439        .get(&path)
440        .map(|entry| entry.value().clone())
441        .unwrap_or_default();
442
443    let env_tx = {
444        let table = state.dispatch.read().await;
445        table.get(&path).cloned()
446    };
447    let Some(env_tx) = env_tx else {
448        return;
449    };
450
451    let (mut sink, mut stream) = socket.split();
452    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
453
454    let registry = global_registries();
455    let mut registry_key = None;
456    for entry in registry.iter() {
457        if entry.key().2 == path {
458            entry.value().insert(connection_key.clone(), out_tx.clone());
459            registry_key = Some(entry.key().clone());
460            break;
461        }
462    }
463
464    // Clone for writer closure and subsequent tracing (WS-009)
465    let conn_key_for_writer = connection_key.clone();
466    let path_for_writer = path.clone();
467
468    let writer = tokio::spawn(async move {
469        while let Some(msg) = out_rx.recv().await {
470            if let Err(e) = sink.send(msg).await {
471                tracing::warn!(
472                    connection_key = conn_key_for_writer,
473                    path = path_for_writer,
474                    error = %e,
475                    "WebSocket writer send error — closing connection"
476                );
477                break;
478            }
479        }
480    });
481
482    tracing::info!(
483        connection_key = connection_key,
484        path = path,
485        remote_addr = remote_addr,
486        "WebSocket connection opened"
487    );
488
489    let mut over_limit = false;
490    if let Some(key) = &registry_key
491        && let Some(entry) = registry.get(key)
492        && entry.len() > path_config.max_connections as usize
493    {
494        over_limit = true;
495    }
496    if over_limit {
497        try_send_with_backpressure(
498            &out_tx,
499            WsMessage::Close(Some(CloseFrame {
500                code: CloseCode::from(1013u16),
501                reason: "max connections exceeded".into(),
502            })),
503            "max-connections-close",
504        );
505        if let Some(key) = registry_key.clone()
506            && let Some(entry) = registry.get(&key)
507        {
508            entry.remove(&connection_key);
509        }
510        drop(out_tx);
511        let _ = writer.await;
512        return;
513    }
514
515    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
516        let ping_tx = out_tx.clone();
517        let interval = path_config.heartbeat_interval;
518        Some(tokio::spawn(async move {
519            let mut ticker = tokio::time::interval(interval);
520            loop {
521                ticker.tick().await;
522                let _ = try_send_with_backpressure(
523                    &ping_tx,
524                    WsMessage::Ping(Vec::new().into()),
525                    "heartbeat-ping",
526                );
527            }
528        }))
529    } else {
530        None
531    };
532
533    loop {
534        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
535            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
536                Ok(msg) => msg,
537                Err(_) => {
538                    try_send_with_backpressure(
539                        &out_tx,
540                        WsMessage::Close(Some(CloseFrame {
541                            code: CloseCode::from(1000u16),
542                            reason: "idle timeout".into(),
543                        })),
544                        "idle-timeout-close",
545                    );
546                    break;
547                }
548            }
549        } else {
550            stream.next().await
551        };
552
553        let Some(msg) = next_msg else {
554            break;
555        };
556
557        match msg {
558            Ok(WsMessage::Ping(data)) => {
559                tracing::debug!(
560                    connection_key = connection_key,
561                    path = path,
562                    "WebSocket ping received — sending pong"
563                );
564                let _ = try_send_with_backpressure(
565                    &out_tx,
566                    WsMessage::Pong(data),
567                    "ping-pong-response",
568                );
569            }
570            Ok(WsMessage::Pong(_)) => {
571                tracing::debug!(
572                    connection_key = connection_key,
573                    path = path,
574                    "WebSocket pong received"
575                );
576            }
577            Ok(WsMessage::Text(text)) => {
578                if text.len() > path_config.max_message_size as usize {
579                    try_send_with_backpressure(
580                        &out_tx,
581                        WsMessage::Close(Some(CloseFrame {
582                            code: CloseCode::from(1009u16),
583                            reason: "message too large".into(),
584                        })),
585                        "max-message-size-close-text",
586                    );
587                    break;
588                }
589
590                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
591                message.set_header(
592                    "CamelWsConnectionKey",
593                    serde_json::Value::String(connection_key.clone()),
594                );
595                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
596                message.set_header(
597                    "CamelWsRemoteAddress",
598                    serde_json::Value::String(remote_addr.clone()),
599                );
600
601                #[allow(unused_mut)]
602                let mut exchange = Exchange::new(message);
603                if let Some(ref p) = principal {
604                    camel_api::store_principal_properties(&mut exchange, p);
605                }
606                #[cfg(feature = "otel")]
607                {
608                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
609                }
610                if env_tx
611                    .send(ExchangeEnvelope {
612                        exchange,
613                        reply_tx: None,
614                    })
615                    .await
616                    .is_err()
617                {
618                    break;
619                }
620            }
621            Ok(WsMessage::Binary(data)) => {
622                if data.len() > path_config.max_message_size as usize {
623                    try_send_with_backpressure(
624                        &out_tx,
625                        WsMessage::Close(Some(CloseFrame {
626                            code: CloseCode::from(1009u16),
627                            reason: "message too large".into(),
628                        })),
629                        "max-message-size-close-binary",
630                    );
631                    break;
632                }
633
634                let mut message = CamelMessage::new(CamelBody::Bytes(data));
635                message.set_header(
636                    "CamelWsConnectionKey",
637                    serde_json::Value::String(connection_key.clone()),
638                );
639                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
640                message.set_header(
641                    "CamelWsRemoteAddress",
642                    serde_json::Value::String(remote_addr.clone()),
643                );
644
645                #[allow(unused_mut)]
646                let mut exchange = Exchange::new(message);
647                if let Some(ref p) = principal {
648                    camel_api::store_principal_properties(&mut exchange, p);
649                }
650                #[cfg(feature = "otel")]
651                {
652                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
653                }
654                if env_tx
655                    .send(ExchangeEnvelope {
656                        exchange,
657                        reply_tx: None,
658                    })
659                    .await
660                    .is_err()
661                {
662                    break;
663                }
664            }
665            Ok(WsMessage::Close(cf)) => {
666                let reason = cf
667                    .as_ref()
668                    .map(|f| f.reason.to_string())
669                    .unwrap_or_default();
670                tracing::info!(
671                    connection_key = connection_key,
672                    path = path,
673                    reason = reason,
674                    "WebSocket connection closed by peer"
675                );
676                break;
677            }
678            Err(e) => {
679                tracing::warn!(
680                    connection_key = connection_key,
681                    path = path,
682                    error = %e,
683                    "WebSocket receive error"
684                );
685                break;
686            }
687        }
688    }
689
690    if let Some(task) = heartbeat_task {
691        task.abort();
692    }
693
694    if let Some(key) = registry_key
695        && let Some(entry) = registry.get(&key)
696    {
697        entry.remove(&connection_key);
698    }
699    drop(out_tx);
700    let _ = writer.await;
701
702    tracing::info!(
703        connection_key = connection_key,
704        path = path,
705        "WebSocket connection closed"
706    );
707}
708
709pub struct WsComponent {
710    pub(crate) config: WsConfig,
711}
712
713impl WsComponent {
714    pub fn new() -> Self {
715        Self {
716            config: WsConfig::default(),
717        }
718    }
719
720    pub fn with_config(config: WsConfig) -> Self {
721        Self { config }
722    }
723}
724
725impl Default for WsComponent {
726    fn default() -> Self {
727        Self::new()
728    }
729}
730
731impl Component for WsComponent {
732    fn scheme(&self) -> &str {
733        "ws"
734    }
735
736    fn create_endpoint(
737        &self,
738        uri: &str,
739        ctx: &dyn camel_component_api::ComponentContext,
740    ) -> Result<Box<dyn Endpoint>, CamelError> {
741        self.config.validate()?;
742        let mut cfg = WsEndpointConfig::from_uri(uri)?;
743        if let Some(v) = self.config.max_connections {
744            cfg.max_connections = v;
745        }
746        if let Some(v) = self.config.max_message_size {
747            cfg.max_message_size = v;
748        }
749        if let Some(v) = self.config.heartbeat_interval_ms {
750            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
751        }
752        if let Some(v) = self.config.idle_timeout_ms {
753            cfg.idle_timeout = std::time::Duration::from_millis(v);
754        }
755        if let Some(v) = self.config.connect_timeout_ms {
756            cfg.connect_timeout = std::time::Duration::from_millis(v);
757        }
758        if let Some(v) = self.config.response_timeout_ms {
759            cfg.response_timeout = std::time::Duration::from_millis(v);
760        }
761        if let Some(v) = self.config.send_timeout_ms {
762            cfg.send_timeout = std::time::Duration::from_millis(v);
763        }
764        if let Some(v) = self.config.binary_payload {
765            cfg.binary_payload = v;
766        }
767        if let Some(ref v) = self.config.subprotocols {
768            cfg.subprotocols = v.clone();
769        }
770        let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
771        ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
772        Ok(Box::new(WsEndpoint {
773            uri: uri.to_string(),
774            cfg,
775        }))
776    }
777}
778
779pub struct WssComponent {
780    pub(crate) config: WsConfig,
781}
782
783impl WssComponent {
784    pub fn new() -> Self {
785        Self {
786            config: WsConfig::default(),
787        }
788    }
789
790    pub fn with_config(config: WsConfig) -> Self {
791        Self { config }
792    }
793}
794
795impl Default for WssComponent {
796    fn default() -> Self {
797        Self::new()
798    }
799}
800
801impl Component for WssComponent {
802    fn scheme(&self) -> &str {
803        "wss"
804    }
805
806    fn create_endpoint(
807        &self,
808        uri: &str,
809        ctx: &dyn camel_component_api::ComponentContext,
810    ) -> Result<Box<dyn Endpoint>, CamelError> {
811        self.config.validate()?;
812        let mut cfg = WsEndpointConfig::from_uri(uri)?;
813        if let Some(v) = self.config.max_connections {
814            cfg.max_connections = v;
815        }
816        if let Some(v) = self.config.max_message_size {
817            cfg.max_message_size = v;
818        }
819        if let Some(v) = self.config.heartbeat_interval_ms {
820            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
821        }
822        if let Some(v) = self.config.idle_timeout_ms {
823            cfg.idle_timeout = std::time::Duration::from_millis(v);
824        }
825        if let Some(v) = self.config.connect_timeout_ms {
826            cfg.connect_timeout = std::time::Duration::from_millis(v);
827        }
828        if let Some(v) = self.config.response_timeout_ms {
829            cfg.response_timeout = std::time::Duration::from_millis(v);
830        }
831        if let Some(v) = self.config.send_timeout_ms {
832            cfg.send_timeout = std::time::Duration::from_millis(v);
833        }
834        if let Some(v) = self.config.binary_payload {
835            cfg.binary_payload = v;
836        }
837        if let Some(ref v) = self.config.subprotocols {
838            cfg.subprotocols = v.clone();
839        }
840        let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
841        ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
842        Ok(Box::new(WsEndpoint {
843            uri: uri.to_string(),
844            cfg,
845        }))
846    }
847}
848
849struct WsEndpoint {
850    uri: String,
851    cfg: WsEndpointConfig,
852}
853
854impl Endpoint for WsEndpoint {
855    fn uri(&self) -> &str {
856        &self.uri
857    }
858
859    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
860        Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
861    }
862
863    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
864        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
865    }
866}
867
868pub struct WsConsumer {
869    cfg: WsServerConfig,
870    registry: Arc<WsConnectionRegistry>,
871    server_state: Option<WsAppState>,
872    registry_key: Option<(String, u16, String)>,
873    forward_task: Option<JoinHandle<Result<(), CamelError>>>,
874    security_ctx: Option<camel_component_api::SecurityContext>,
875}
876
877impl WsConsumer {
878    pub fn new(cfg: WsServerConfig) -> Self {
879        Self {
880            cfg,
881            registry: Arc::new(WsConnectionRegistry::new()),
882            server_state: None,
883            registry_key: None,
884            forward_task: None,
885            security_ctx: None,
886        }
887    }
888}
889
890#[async_trait]
891impl Consumer for WsConsumer {
892    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
893        // Reject double-start (WS-006)
894        if self.server_state.is_some() {
895            return Err(CamelError::EndpointCreationFailed(
896                "WebSocket consumer already started".into(),
897            ));
898        }
899
900        tracing::info!(
901            host = self.cfg.inner.host,
902            port = self.cfg.inner.port,
903            path = self.cfg.inner.path,
904            scheme = self.cfg.inner.scheme,
905            "WebSocket consumer starting"
906        );
907
908        let tls_config = if self.cfg.inner.scheme == "wss" {
909            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
910                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
911            })?;
912            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
913                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
914            })?;
915            Some(WsTlsConfig {
916                cert_path,
917                key_path,
918            })
919        } else {
920            None
921        };
922
923        let state = ServerRegistry::global()
924            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
925            .await?;
926
927        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
928        {
929            let mut table = state.dispatch.write().await;
930            table.insert(self.cfg.inner.path.clone(), env_tx);
931        }
932
933        state.path_configs.insert(
934            self.cfg.inner.path.clone(),
935            WsPathConfig {
936                max_connections: self.cfg.inner.max_connections,
937                max_message_size: self.cfg.inner.max_message_size,
938                heartbeat_interval: self.cfg.inner.heartbeat_interval,
939                idle_timeout: self.cfg.inner.idle_timeout,
940                allow_origin: self.cfg.inner.allow_origin.clone(),
941            },
942        );
943
944        if let Some(ref sec_ctx) = self.security_ctx {
945            let path = self.cfg.inner.path.clone();
946            state.path_policies.insert(path, sec_ctx.clone());
947        }
948
949        let registry_key = (
950            self.cfg.inner.canonical_host(),
951            self.cfg.inner.port,
952            self.cfg.inner.path.clone(),
953        );
954        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
955
956        let sender = ctx.sender();
957        let forward_task: JoinHandle<Result<(), CamelError>> = tokio::spawn(async move {
958            while let Some(envelope) = env_rx.recv().await {
959                if sender.send(envelope).await.is_err() {
960                    break;
961                }
962            }
963            Ok(())
964        });
965
966        self.server_state = Some(state);
967        self.registry_key = Some(registry_key);
968        self.forward_task = Some(forward_task);
969        Ok(())
970    }
971
972    async fn stop(&mut self) -> Result<(), CamelError> {
973        tracing::info!(
974            host = self.cfg.inner.host,
975            port = self.cfg.inner.port,
976            path = self.cfg.inner.path,
977            "WebSocket consumer stopping"
978        );
979
980        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
981            code: axum::extract::ws::CloseCode::from(1001u16),
982            reason: "consumer stopping".into(),
983        }));
984        for tx in self.registry.snapshot_senders() {
985            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
986        }
987
988        let mut had_server_error = false;
989
990        if let Some(state) = self.server_state.take() {
991            had_server_error = state.server_error.load(Ordering::Relaxed);
992            state.path_policies.remove(&self.cfg.inner.path);
993            let mut table = state.dispatch.write().await;
994            table.remove(&self.cfg.inner.path);
995            state.path_configs.remove(&self.cfg.inner.path);
996        }
997
998        if let Some(key) = self.registry_key.take() {
999            global_registries().remove(&key);
1000            ServerRegistry::global().release(key.1);
1001        }
1002
1003        if let Some(task) = self.forward_task.take() {
1004            task.abort();
1005        }
1006
1007        tracing::info!(
1008            host = self.cfg.inner.host,
1009            port = self.cfg.inner.port,
1010            path = self.cfg.inner.path,
1011            "WebSocket consumer stopped"
1012        );
1013
1014        if had_server_error {
1015            tracing::warn!(
1016                host = self.cfg.inner.host,
1017                port = self.cfg.inner.port,
1018                path = self.cfg.inner.path,
1019                "WebSocket server had errors during its lifetime"
1020            );
1021            return Err(CamelError::ProcessorError(
1022                "WebSocket server terminated with errors during its lifetime".into(),
1023            ));
1024        }
1025
1026        Ok(())
1027    }
1028
1029    fn concurrency_model(&self) -> ConcurrencyModel {
1030        ConcurrencyModel::Concurrent {
1031            max: Some(self.cfg.inner.max_connections as usize),
1032        }
1033    }
1034
1035    fn background_task_handle(&mut self) -> Option<JoinHandle<Result<(), CamelError>>> {
1036        self.forward_task.take()
1037    }
1038
1039    fn set_security_context(&mut self, ctx: camel_component_api::SecurityContext) {
1040        self.security_ctx = Some(ctx);
1041    }
1042}
1043
1044use std::sync::atomic::{AtomicBool, Ordering};
1045
1046fn new_atomic_false() -> Arc<AtomicBool> {
1047    Arc::new(AtomicBool::new(false))
1048}
1049
1050#[derive(Clone)]
1051pub struct WsProducer {
1052    cfg: WsClientConfig,
1053    /// Shared flag set by the async future when server-send hits backpressure,
1054    /// so that the next `poll_ready` call can return an error. (WS-003)
1055    backpressure_flag: Arc<AtomicBool>,
1056}
1057
1058impl WsProducer {
1059    pub fn new(cfg: WsClientConfig) -> Self {
1060        Self {
1061            cfg,
1062            backpressure_flag: Arc::new(AtomicBool::new(false)),
1063        }
1064    }
1065}
1066
1067impl Service<Exchange> for WsProducer {
1068    type Response = Exchange;
1069    type Error = CamelError;
1070    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
1071
1072    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
1073        // Return error if last server-send hit backpressure (WS-003)
1074        if self.backpressure_flag.swap(false, Ordering::Relaxed) {
1075            return Poll::Ready(Err(CamelError::ProcessorError(
1076                "WebSocket producer backpressure: previous send was dropped due to full channel"
1077                    .into(),
1078            )));
1079        }
1080        Poll::Ready(Ok(()))
1081    }
1082
1083    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
1084        let cfg = self.cfg.clone();
1085        let backpressure_flag = Arc::clone(&self.backpressure_flag);
1086
1087        Box::pin(async move {
1088            let canonical_host = cfg.inner.canonical_host();
1089            let key = (
1090                canonical_host.clone(),
1091                cfg.inner.port,
1092                cfg.inner.path.clone(),
1093            );
1094
1095            let send_to_all = exchange
1096                .input
1097                .header("CamelWsSendToAll")
1098                .and_then(|v| v.as_bool())
1099                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
1100                .unwrap_or(false);
1101
1102            let conn_keys_header = exchange
1103                .input
1104                .header("CamelWsConnectionKey")
1105                .and_then(|v| v.as_str())
1106                .map(str::to_string);
1107
1108            let local_exists = global_registries().contains_key(&key);
1109            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
1110
1111            let message_type = exchange
1112                .input
1113                .header("CamelWsMessageType")
1114                .and_then(|v| v.as_str())
1115                .unwrap_or("text")
1116                .to_ascii_lowercase();
1117
1118            if server_send_mode {
1119                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
1120                let Some(registry) = registry else {
1121                    return Err(CamelError::ProcessorError(format!(
1122                        "WebSocket local consumer not found for {}:{}{}",
1123                        canonical_host, cfg.inner.port, cfg.inner.path
1124                    )));
1125                };
1126
1127                let out_msg = body_to_axum_ws_message(
1128                    std::mem::take(&mut exchange.input.body),
1129                    &message_type,
1130                )
1131                .await?;
1132
1133                let targets = if send_to_all {
1134                    registry.snapshot_senders()
1135                } else if let Some(keys) = conn_keys_header {
1136                    let parsed: Vec<String> = keys
1137                        .split(',')
1138                        .map(str::trim)
1139                        .filter(|k| !k.is_empty())
1140                        .map(|k| k.to_string())
1141                        .collect();
1142                    registry.get_senders_for_keys(&parsed)
1143                } else {
1144                    registry.snapshot_senders()
1145                };
1146
1147                let mut dropped = 0usize;
1148                for tx in &targets {
1149                    if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1150                        dropped += 1;
1151                    }
1152                }
1153
1154                if dropped > 0 {
1155                    tracing::warn!(
1156                        host = canonical_host,
1157                        port = cfg.inner.port,
1158                        path = cfg.inner.path,
1159                        dropped,
1160                        total = targets.len(),
1161                        "WebSocket producer dropped messages due to backpressure"
1162                    );
1163                    exchange.input.set_header(
1164                        "CamelWsDeliveryDropped",
1165                        serde_json::Value::Number(dropped.into()),
1166                    );
1167                    // Signal backpressure for next poll_ready call (WS-003)
1168                    backpressure_flag.store(true, Ordering::Relaxed);
1169                    if dropped == targets.len() {
1170                        return Err(CamelError::ProcessorError(format!(
1171                            "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1172                        )));
1173                    }
1174                }
1175
1176                tracing::debug!(
1177                    host = canonical_host,
1178                    port = cfg.inner.port,
1179                    path = cfg.inner.path,
1180                    targets = targets.len(),
1181                    "WebSocket producer server-send complete"
1182                );
1183
1184                return Ok(exchange);
1185            }
1186
1187            let url = format!(
1188                "{}://{}:{}{}",
1189                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1190            );
1191
1192            tracing::debug!(url = url, "WebSocket producer connecting");
1193
1194            #[allow(unused_mut)]
1195            let mut request = url
1196                .clone()
1197                .into_client_request()
1198                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1199
1200            #[cfg(feature = "otel")]
1201            {
1202                let mut otel_headers = HashMap::new();
1203                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1204                for (k, v) in otel_headers {
1205                    if let (Ok(name), Ok(val)) = (
1206                        http::header::HeaderName::from_bytes(k.as_bytes()),
1207                        http::header::HeaderValue::from_str(&v),
1208                    ) {
1209                        request.headers_mut().insert(name, val);
1210                    }
1211                }
1212            }
1213
1214            // Add Sec-WebSocket-Protocol header if subprotocols configured (WS-007)
1215            if !cfg.inner.subprotocols.is_empty() {
1216                let proto_value = cfg.inner.subprotocols.join(", ");
1217                if let (Ok(name), Ok(val)) = (
1218                    http::header::HeaderName::from_bytes(b"Sec-WebSocket-Protocol"),
1219                    http::header::HeaderValue::from_str(&proto_value),
1220                ) {
1221                    request.headers_mut().insert(name, val);
1222                }
1223            }
1224
1225            // Determine message type: respect binary_payload config (WS-018)
1226            let effective_message_type = if cfg.inner.binary_payload {
1227                "binary"
1228            } else {
1229                &message_type
1230            };
1231
1232            let max_retries = if cfg.inner.reconnect {
1233                cfg.inner.reconnect_max_attempts as usize
1234            } else {
1235                0
1236            };
1237            let mut backoff = BackoffState::new(BackoffConfig {
1238                initial_delay: std::time::Duration::from_millis(cfg.inner.reconnect_delay_ms),
1239                multiplier: 2.0,
1240                max_delay: std::time::Duration::from_secs(30),
1241            });
1242            let mut attempts = 0usize;
1243            let mut ws_stream = loop {
1244                let connect_future = tokio_tungstenite::connect_async(request.clone());
1245                match tokio::time::timeout(cfg.inner.connect_timeout, connect_future).await {
1246                    Ok(Ok((stream, _))) => break stream,
1247                    Ok(Err(e)) => {
1248                        let err = map_connect_error(e, &url);
1249                        let is_transient = err.to_string().contains("connection refused")
1250                            || err.to_string().contains("timeout")
1251                            || err.to_string().contains("connection failed");
1252                        if is_transient && attempts < max_retries {
1253                            attempts += 1;
1254                            tracing::warn!(
1255                                url = url,
1256                                error = %err,
1257                                attempt = attempts,
1258                                max_retries,
1259                                "WebSocket connect failed — retrying"
1260                            );
1261                            tokio::time::sleep(backoff.next_delay()).await;
1262                            continue;
1263                        }
1264                        return Err(err);
1265                    }
1266                    Err(_) => {
1267                        let err = CamelError::ProcessorError(format!(
1268                            "WebSocket connect timeout ({:?}) to {url}",
1269                            cfg.inner.connect_timeout
1270                        ));
1271                        if attempts < max_retries {
1272                            attempts += 1;
1273                            tracing::warn!(
1274                                url = url,
1275                                attempt = attempts,
1276                                max_retries,
1277                                "WebSocket connect timeout — retrying"
1278                            );
1279                            tokio::time::sleep(backoff.next_delay()).await;
1280                            continue;
1281                        }
1282                        return Err(err);
1283                    }
1284                }
1285            };
1286            if attempts > 0 {
1287                tracing::info!(url = url, "WebSocket producer connected after retry");
1288            }
1289
1290            let out_msg = body_to_client_ws_message(
1291                std::mem::take(&mut exchange.input.body),
1292                effective_message_type,
1293            )
1294            .await?;
1295
1296            ws_stream
1297                .send(out_msg)
1298                .await
1299                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1300
1301            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1302                loop {
1303                    match ws_stream.next().await {
1304                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1305                            continue;
1306                        }
1307                        other => break other,
1308                    }
1309                }
1310            })
1311            .await
1312            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1313
1314            match incoming {
1315                Some(Ok(ClientWsMessage::Text(text))) => {
1316                    tracing::debug!(url = url, "WebSocket producer received text response");
1317                    exchange.input.body = CamelBody::Text(text.to_string());
1318                }
1319                Some(Ok(ClientWsMessage::Binary(data))) => {
1320                    tracing::debug!(url = url, "WebSocket producer received binary response");
1321                    exchange.input.body = CamelBody::Bytes(data);
1322                }
1323                Some(Ok(ClientWsMessage::Close(frame))) => {
1324                    let normal = frame
1325                        .as_ref()
1326                        .map(|f| {
1327                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1328                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1329                        })
1330                        .unwrap_or(true);
1331
1332                    if normal {
1333                        tracing::debug!(url = url, "WebSocket producer received normal close");
1334                        exchange.input.body = CamelBody::Empty;
1335                    } else if cfg.inner.reconnect && attempts < max_retries {
1336                        backoff.reset();
1337                        attempts += 1;
1338                        tracing::warn!(
1339                            url = url,
1340                            attempt = attempts,
1341                            max_retries,
1342                            "WebSocket closed by peer — reconnecting"
1343                        );
1344                        tokio::time::sleep(backoff.next_delay()).await;
1345                        return Err(CamelError::ProcessorError(format!(
1346                            "WebSocket reconnect required after close: code {}",
1347                            frame.map(|f| u16::from(f.code)).unwrap_or_default()
1348                        )));
1349                    } else {
1350                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1351                        return Err(CamelError::ProcessorError(format!(
1352                            "WebSocket peer closed: code {code}"
1353                        )));
1354                    }
1355                }
1356                Some(Ok(_)) | None => {
1357                    exchange.input.body = CamelBody::Empty;
1358                }
1359                Some(Err(e)) => {
1360                    return Err(CamelError::ProcessorError(format!(
1361                        "WebSocket receive failed: {e}"
1362                    )));
1363                }
1364            }
1365
1366            let _ = ws_stream.close(None).await;
1367            tracing::debug!(url = url, "WebSocket producer connection closed");
1368            Ok(exchange)
1369        })
1370    }
1371}
1372
1373async fn body_to_axum_ws_message(
1374    body: CamelBody,
1375    message_type: &str,
1376) -> Result<WsMessage, CamelError> {
1377    match message_type {
1378        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1379        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1380    }
1381}
1382
1383async fn body_to_client_ws_message(
1384    body: CamelBody,
1385    message_type: &str,
1386) -> Result<ClientWsMessage, CamelError> {
1387    match message_type {
1388        "binary" => Ok(ClientWsMessage::Binary(
1389            body.into_bytes(10 * 1024 * 1024).await?,
1390        )),
1391        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1392    }
1393}
1394
1395async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1396    Ok(match body {
1397        CamelBody::Empty => String::new(),
1398        CamelBody::Text(s) => s,
1399        CamelBody::Xml(s) => s,
1400        CamelBody::Json(v) => v.to_string(),
1401        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1402        CamelBody::Stream(stream) => {
1403            let bytes = CamelBody::Stream(stream)
1404                .into_bytes(10 * 1024 * 1024)
1405                .await?;
1406            String::from_utf8_lossy(&bytes).to_string()
1407        }
1408    })
1409}
1410
1411fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1412    if allowed_origin == "*" {
1413        return true;
1414    }
1415    request_origin.is_some_and(|origin| origin == allowed_origin)
1416}
1417
1418fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1419    match tx.try_send(msg) {
1420        Ok(()) => true,
1421        Err(error) => {
1422            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1423            false
1424        }
1425    }
1426}
1427
1428fn load_tls_config(
1429    cert_path: &str,
1430    key_path: &str,
1431) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1432    use std::fs::File;
1433    use std::io::BufReader;
1434
1435    let cert_file = File::open(cert_path)
1436        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1437    let key_file = File::open(key_path)
1438        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1439
1440    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1441        .collect::<Result<Vec<_>, _>>()
1442        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1443
1444    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1445        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1446        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1447
1448    tokio_rustls::rustls::ServerConfig::builder()
1449        .with_no_client_auth()
1450        .with_single_cert(certs, key)
1451        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1452}
1453
1454fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1455    match err {
1456        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1457            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1458        }
1459        tungstenite::Error::Tls(_) => {
1460            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1461        }
1462        other => {
1463            let msg = other.to_string();
1464            if msg.to_lowercase().contains("connection refused") {
1465                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1466            } else if msg.to_lowercase().contains("tls") {
1467                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1468            } else {
1469                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1470            }
1471        }
1472    }
1473}
1474
1475#[cfg(test)]
1476mod tests {
1477    use super::*;
1478    use camel_component_api::NoOpComponentContext;
1479    use std::time::Duration;
1480
1481    use tokio::sync::mpsc;
1482    use tokio_tungstenite::connect_async;
1483    use tokio_tungstenite::tungstenite::Message as ClientMessage;
1484    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1485    use tokio_util::sync::CancellationToken;
1486    use tower::ServiceExt;
1487
1488    fn free_port() -> u16 {
1489        std::net::TcpListener::bind("127.0.0.1:0")
1490            .unwrap()
1491            .local_addr()
1492            .unwrap()
1493            .port()
1494    }
1495
1496    #[test]
1497    fn ws_component_scheme_is_ws() {
1498        assert_eq!(WsComponent::new().scheme(), "ws");
1499    }
1500
1501    #[test]
1502    fn wss_component_scheme_is_wss() {
1503        assert_eq!(WssComponent::new().scheme(), "wss");
1504    }
1505
1506    #[test]
1507    fn endpoint_config_defaults_match_spec() {
1508        let cfg = WsEndpointConfig::default();
1509        assert_eq!(cfg.scheme, "ws");
1510        assert_eq!(cfg.host, "0.0.0.0");
1511        assert_eq!(cfg.port, 8080);
1512        assert_eq!(cfg.path, "/");
1513        assert_eq!(cfg.max_connections, 100);
1514        assert_eq!(cfg.max_message_size, 65536);
1515        assert!(!cfg.send_to_all);
1516        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1517        assert_eq!(cfg.idle_timeout, Duration::ZERO);
1518        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1519        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1520        assert_eq!(cfg.allow_origin, "*");
1521        assert_eq!(cfg.tls_cert, None);
1522        assert_eq!(cfg.tls_key, None);
1523        assert!(cfg.reconnect);
1524        assert_eq!(cfg.reconnect_max_attempts, 5);
1525        assert_eq!(cfg.reconnect_delay_ms, 1000);
1526        assert_eq!(cfg.send_timeout, Duration::from_secs(30));
1527        assert!(!cfg.binary_payload);
1528        assert!(cfg.subprotocols.is_empty());
1529    }
1530
1531    #[test]
1532    fn endpoint_config_parses_uri_params() {
1533        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1534        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1535
1536        assert_eq!(cfg.scheme, "ws");
1537        assert_eq!(cfg.host, "localhost");
1538        assert_eq!(cfg.port, 9001);
1539        assert_eq!(cfg.path, "/chat");
1540        assert_eq!(cfg.max_connections, 42);
1541        assert_eq!(cfg.max_message_size, 1024);
1542        assert!(cfg.send_to_all);
1543        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1544        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1545        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1546        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1547        assert_eq!(cfg.allow_origin, "https://example.com");
1548        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1549        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1550        assert!(cfg.reconnect);
1551        assert_eq!(cfg.reconnect_max_attempts, 5);
1552        assert_eq!(cfg.reconnect_delay_ms, 1000);
1553    }
1554
1555    #[test]
1556    fn endpoint_config_parses_reconnect_uri_params() {
1557        let uri =
1558            "ws://localhost:9001/chat?reconnect=false&reconnectMaxAttempts=2&reconnectDelayMs=25";
1559        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1560        assert!(!cfg.reconnect);
1561        assert_eq!(cfg.reconnect_max_attempts, 2);
1562        assert_eq!(cfg.reconnect_delay_ms, 25);
1563    }
1564
1565    #[test]
1566    fn endpoint_config_override_chain_uri_overrides_defaults() {
1567        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1568        assert_eq!(cfg.max_connections, 7);
1569        assert_eq!(cfg.max_message_size, 65536);
1570        assert!(!cfg.send_to_all);
1571        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1572    }
1573
1574    #[test]
1575    fn endpoint_trait_creates_consumer_and_producer() {
1576        let ctx = NoOpComponentContext;
1577        let endpoint = WsComponent::new()
1578            .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1579            .unwrap();
1580
1581        endpoint.create_consumer().unwrap();
1582        endpoint
1583            .create_producer(&ProducerContext::default())
1584            .unwrap();
1585    }
1586
1587    #[test]
1588    fn ws_consumer_concurrency_model_uses_max_connections() {
1589        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1590        let consumer = WsConsumer::new(cfg.server_config());
1591        assert_eq!(
1592            consumer.concurrency_model(),
1593            ConcurrencyModel::Concurrent { max: Some(321) }
1594        );
1595    }
1596
1597    #[tokio::test]
1598    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1599        let registry = WsConnectionRegistry::new();
1600        let (tx1, mut rx1) = mpsc::channel(8);
1601        let (tx2, mut rx2) = mpsc::channel(8);
1602
1603        registry.insert("k1".into(), tx1);
1604        registry.insert("k2".into(), tx2);
1605        assert_eq!(registry.len(), 2);
1606
1607        for tx in registry.snapshot_senders() {
1608            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1609        }
1610
1611        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1612        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1613
1614        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1615        assert_eq!(target.len(), 1);
1616        target[0]
1617            .send(WsMessage::Text("targeted".into()))
1618            .await
1619            .unwrap();
1620
1621        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1622        assert!(
1623            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1624                .await
1625                .is_err()
1626        );
1627
1628        registry.remove("k1");
1629        assert_eq!(registry.len(), 1);
1630    }
1631
1632    #[test]
1633    fn host_canonicalization_maps_local_hosts_to_loopback() {
1634        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1635            .unwrap()
1636            .canonical_host();
1637        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1638            .unwrap()
1639            .canonical_host();
1640        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1641            .unwrap()
1642            .canonical_host();
1643
1644        assert_eq!(c1, "127.0.0.1");
1645        assert_eq!(c2, "127.0.0.1");
1646        assert_eq!(c3, "127.0.0.1");
1647    }
1648
1649    #[tokio::test]
1650    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1651        let port = free_port();
1652        let uri = format!("ws://127.0.0.1:{port}/echo");
1653        let component_ctx = NoOpComponentContext;
1654        let endpoint = WsComponent::new()
1655            .create_endpoint(&uri, &component_ctx)
1656            .unwrap();
1657
1658        let mut consumer = endpoint.create_consumer().unwrap();
1659        let producer = endpoint
1660            .create_producer(&ProducerContext::default())
1661            .unwrap();
1662
1663        let (route_tx, mut route_rx) = mpsc::channel(16);
1664        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1665        consumer.start(ctx).await.unwrap();
1666
1667        let route_task = tokio::spawn(async move {
1668            if let Some(envelope) = route_rx.recv().await {
1669                let payload = envelope
1670                    .exchange
1671                    .input
1672                    .body
1673                    .as_text()
1674                    .unwrap_or_default()
1675                    .to_string();
1676                let key = envelope
1677                    .exchange
1678                    .input
1679                    .header("CamelWsConnectionKey")
1680                    .and_then(|v| v.as_str())
1681                    .unwrap()
1682                    .to_string();
1683
1684                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1685                response
1686                    .input
1687                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1688                producer.oneshot(response).await.unwrap();
1689            }
1690        });
1691
1692        let url = format!("ws://127.0.0.1:{port}/echo");
1693        let (mut client, _) = loop {
1694            match connect_async(&url).await {
1695                Ok(ok) => break ok,
1696                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1697            }
1698        };
1699
1700        client
1701            .send(ClientMessage::Text("hello-ws".into()))
1702            .await
1703            .unwrap();
1704
1705        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1706            loop {
1707                match client.next().await {
1708                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1709                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1710                    Some(Ok(_)) => continue,
1711                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1712                    None => panic!("websocket closed before echo"),
1713                }
1714            }
1715        })
1716        .await
1717        .unwrap();
1718
1719        assert_eq!(incoming, "hello-ws");
1720
1721        consumer.stop().await.unwrap();
1722        route_task.await.unwrap();
1723    }
1724
1725    #[tokio::test]
1726    async fn consumer_stop_sends_close_1001() {
1727        let port = free_port();
1728        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1729        let component_ctx = NoOpComponentContext;
1730        let endpoint = WsComponent::new()
1731            .create_endpoint(&uri, &component_ctx)
1732            .unwrap();
1733
1734        let mut consumer = endpoint.create_consumer().unwrap();
1735        let (route_tx, _route_rx) = mpsc::channel(16);
1736        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1737        consumer.start(ctx).await.unwrap();
1738
1739        let url = format!("ws://127.0.0.1:{port}/shutdown");
1740        let (mut client, _) = loop {
1741            match connect_async(&url).await {
1742                Ok(ok) => break ok,
1743                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1744            }
1745        };
1746
1747        client
1748            .send(ClientMessage::Text("keepalive".into()))
1749            .await
1750            .unwrap();
1751
1752        consumer.stop().await.unwrap();
1753
1754        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1755            loop {
1756                match client.next().await {
1757                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1758                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1759                    Some(Ok(_)) => continue,
1760                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1761                    None => panic!("websocket closed without close frame"),
1762                }
1763            }
1764        })
1765        .await
1766        .unwrap();
1767
1768        assert_eq!(close_code, Some(CloseCode::Away));
1769    }
1770
1771    #[test]
1772    fn wildcard_origin_allows_anything() {
1773        assert!(is_origin_allowed("*", None));
1774        assert!(is_origin_allowed("*", Some("https://example.com")));
1775    }
1776
1777    #[test]
1778    fn exact_origin_requires_match() {
1779        assert!(is_origin_allowed(
1780            "https://example.com",
1781            Some("https://example.com")
1782        ));
1783        assert!(!is_origin_allowed(
1784            "https://example.com",
1785            Some("https://other.com")
1786        ));
1787        assert!(!is_origin_allowed("https://example.com", None));
1788    }
1789
1790    #[test]
1791    fn endpoint_config_rejects_invalid_scheme() {
1792        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1793        assert!(result.is_err());
1794        let msg = result.unwrap_err().to_string();
1795        assert!(
1796            msg.contains("Invalid WebSocket scheme"),
1797            "expected scheme error, got: {msg}"
1798        );
1799    }
1800
1801    #[tokio::test]
1802    async fn wss_consumer_start_fails_without_tls_cert() {
1803        let port = free_port();
1804        let component_ctx = NoOpComponentContext;
1805        let endpoint = WssComponent::new()
1806            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1807            .unwrap();
1808        let mut consumer = endpoint.create_consumer().unwrap();
1809        let (tx, _rx) = mpsc::channel(16);
1810        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1811        let result = consumer.start(ctx).await;
1812        assert!(result.is_err());
1813        let msg = result.unwrap_err().to_string();
1814        assert!(
1815            msg.contains("TLS cert path is required"),
1816            "expected TLS cert error, got: {msg}"
1817        );
1818    }
1819
1820    #[tokio::test]
1821    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1822        let port = free_port();
1823        let component_ctx = NoOpComponentContext;
1824        let endpoint = WssComponent::new()
1825            .create_endpoint(&format!(
1826                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1827            ), &component_ctx)
1828            .unwrap();
1829        let mut consumer = endpoint.create_consumer().unwrap();
1830        let (tx, _rx) = mpsc::channel(16);
1831        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1832        let result = consumer.start(ctx).await;
1833        assert!(result.is_err());
1834        let msg = result.unwrap_err().to_string();
1835        assert!(
1836            msg.contains("TLS cert file error"),
1837            "expected cert file error, got: {msg}"
1838        );
1839    }
1840
1841    #[tokio::test]
1842    async fn server_registry_returns_same_state_for_same_port() {
1843        let port = free_port();
1844        let state1 = ServerRegistry::global()
1845            .get_or_spawn("127.0.0.1", port, None)
1846            .await
1847            .unwrap();
1848        let state2 = ServerRegistry::global()
1849            .get_or_spawn("127.0.0.1", port, None)
1850            .await
1851            .unwrap();
1852        assert!(
1853            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1854            "expected same dispatch table for same port"
1855        );
1856    }
1857
1858    #[tokio::test]
1859    async fn dispatch_handler_returns_404_for_unregistered_path() {
1860        let port = free_port();
1861        let state = ServerRegistry::global()
1862            .get_or_spawn("127.0.0.1", port, None)
1863            .await
1864            .unwrap();
1865        let app = Router::new().fallback(dispatch_handler).with_state(state);
1866        let response = tokio::time::timeout(
1867            Duration::from_secs(2),
1868            tower::ServiceExt::oneshot(
1869                app,
1870                axum::http::Request::builder()
1871                    .uri("/nonexistent")
1872                    .body(Body::empty())
1873                    .unwrap(),
1874            ),
1875        )
1876        .await
1877        .unwrap()
1878        .unwrap();
1879        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1880    }
1881
1882    #[tokio::test]
1883    async fn client_mode_producer_connects_and_echoes() {
1884        let app = Router::new().route(
1885            "/echo",
1886            axum::routing::get(|ws: WebSocketUpgrade| async move {
1887                ws.on_upgrade(|mut socket: WebSocket| async move {
1888                    while let Some(Ok(msg)) = socket.recv().await {
1889                        match msg {
1890                            WsMessage::Text(text) => {
1891                                let _ = socket.send(WsMessage::Text(text)).await;
1892                            }
1893                            WsMessage::Binary(data) => {
1894                                let _ = socket.send(WsMessage::Binary(data)).await;
1895                            }
1896                            WsMessage::Close(_) => break,
1897                            _ => {}
1898                        }
1899                    }
1900                })
1901            }),
1902        );
1903        // Bind to port 0 directly to avoid TOCTOU race with free_port() + re-bind
1904        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1905        let port = listener.local_addr().unwrap().port();
1906        let server_task = tokio::spawn(async move {
1907            let _ = serve(listener, app).await;
1908        });
1909
1910        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1911        let producer = WsProducer::new(cfg.client_config());
1912
1913        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1914        tokio::time::sleep(Duration::from_millis(25)).await;
1915        let result =
1916            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1917                Ok(Ok(r)) => r,
1918                Ok(Err(_)) => panic!("producer call failed"),
1919                Err(_) => panic!("producer call timed out"),
1920            };
1921
1922        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1923
1924        server_task.abort();
1925    }
1926
1927    #[tokio::test]
1928    async fn max_connections_rejects_with_close_1013() {
1929        let port = free_port();
1930        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1931        let component_ctx = NoOpComponentContext;
1932        let endpoint = WsComponent::new()
1933            .create_endpoint(&uri, &component_ctx)
1934            .unwrap();
1935        let mut consumer = endpoint.create_consumer().unwrap();
1936        let (route_tx, _route_rx) = mpsc::channel(16);
1937        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1938        consumer.start(ctx).await.unwrap();
1939
1940        let url = format!("ws://127.0.0.1:{port}/limited");
1941        let (_client1, _) = loop {
1942            match connect_async(&url).await {
1943                Ok(ok) => break ok,
1944                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1945            }
1946        };
1947
1948        tokio::time::sleep(Duration::from_millis(100)).await;
1949
1950        let (mut client2, _) = connect_async(&url).await.unwrap();
1951
1952        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1953            loop {
1954                match client2.next().await {
1955                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1956                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1957                    Some(Ok(ClientMessage::Text(_))) => continue,
1958                    Some(Ok(_)) => continue,
1959                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1960                    None => panic!("client2 closed without close frame"),
1961                }
1962            }
1963        })
1964        .await
1965        .unwrap();
1966
1967        assert_eq!(
1968            close_code,
1969            Some(CloseCode::from(1013u16)),
1970            "expected 1013 (Try Again Later) for max connections"
1971        );
1972
1973        consumer.stop().await.unwrap();
1974    }
1975
1976    #[tokio::test]
1977    async fn max_message_size_rejects_with_close_1009() {
1978        let port = free_port();
1979        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1980        let component_ctx = NoOpComponentContext;
1981        let endpoint = WsComponent::new()
1982            .create_endpoint(&uri, &component_ctx)
1983            .unwrap();
1984        let mut consumer = endpoint.create_consumer().unwrap();
1985        let (route_tx, _route_rx) = mpsc::channel(16);
1986        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1987        consumer.start(ctx).await.unwrap();
1988
1989        let url = format!("ws://127.0.0.1:{port}/sizelimit");
1990        let (mut client, _) = loop {
1991            match connect_async(&url).await {
1992                Ok(ok) => break ok,
1993                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1994            }
1995        };
1996
1997        let oversized = "x".repeat(100);
1998        client
1999            .send(ClientMessage::Text(oversized.into()))
2000            .await
2001            .unwrap();
2002
2003        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
2004            loop {
2005                match client.next().await {
2006                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
2007                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2008                    Some(Ok(_)) => continue,
2009                    Some(Err(e)) => panic!("ws receive failed: {e}"),
2010                    None => panic!("websocket closed without close frame"),
2011                }
2012            }
2013        })
2014        .await
2015        .unwrap();
2016
2017        assert_eq!(
2018            close_code,
2019            Some(CloseCode::from(1009u16)),
2020            "expected 1009 (Message Too Big) for oversized message"
2021        );
2022
2023        consumer.stop().await.unwrap();
2024    }
2025
2026    #[tokio::test]
2027    async fn origin_rejection_returns_403() {
2028        let port = free_port();
2029        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
2030        let component_ctx = NoOpComponentContext;
2031        let endpoint = WsComponent::new()
2032            .create_endpoint(&uri, &component_ctx)
2033            .unwrap();
2034        let mut consumer = endpoint.create_consumer().unwrap();
2035        let (route_tx, _route_rx) = mpsc::channel(16);
2036        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2037        consumer.start(ctx).await.unwrap();
2038
2039        let state = ServerRegistry::global()
2040            .get_or_spawn("127.0.0.1", port, None)
2041            .await
2042            .unwrap();
2043        let app = Router::new().fallback(dispatch_handler).with_state(state);
2044
2045        let response = tokio::time::timeout(
2046            Duration::from_secs(2),
2047            tower::ServiceExt::oneshot(
2048                app,
2049                axum::http::Request::builder()
2050                    .uri("/origintest")
2051                    .header("origin", "https://evil.com")
2052                    .header("upgrade", "websocket")
2053                    .header("connection", "Upgrade")
2054                    .header("sec-websocket-version", "13")
2055                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
2056                    .body(Body::empty())
2057                    .unwrap(),
2058            ),
2059        )
2060        .await
2061        .unwrap()
2062        .unwrap();
2063
2064        assert_eq!(
2065            response.status(),
2066            StatusCode::FORBIDDEN,
2067            "expected 403 for disallowed origin"
2068        );
2069
2070        consumer.stop().await.unwrap();
2071    }
2072
2073    #[tokio::test]
2074    async fn broadcast_sends_to_all_connected_clients() {
2075        let port = free_port();
2076        let uri = format!("ws://127.0.0.1:{port}/bc");
2077        let component_ctx = NoOpComponentContext;
2078        let endpoint = WsComponent::new()
2079            .create_endpoint(&uri, &component_ctx)
2080            .unwrap();
2081        let mut consumer = endpoint.create_consumer().unwrap();
2082        let producer = endpoint
2083            .create_producer(&ProducerContext::default())
2084            .unwrap();
2085
2086        let (route_tx, _route_rx) = mpsc::channel(16);
2087        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2088        consumer.start(ctx).await.unwrap();
2089
2090        let url = format!("ws://127.0.0.1:{port}/bc");
2091
2092        let (mut client1, _) = loop {
2093            match connect_async(&url).await {
2094                Ok(ok) => break ok,
2095                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2096            }
2097        };
2098
2099        let (mut client2, _) = connect_async(&url).await.unwrap();
2100
2101        tokio::time::sleep(Duration::from_millis(100)).await;
2102
2103        let mut response =
2104            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
2105        response
2106            .input
2107            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
2108        producer.oneshot(response).await.unwrap();
2109
2110        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
2111            loop {
2112                match client1.next().await {
2113                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2114                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2115                    _ => panic!("client1 unexpected message or close"),
2116                }
2117            }
2118        })
2119        .await
2120        .unwrap();
2121
2122        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
2123            loop {
2124                match client2.next().await {
2125                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2126                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2127                    _ => panic!("client2 unexpected message or close"),
2128                }
2129            }
2130        })
2131        .await
2132        .unwrap();
2133
2134        assert_eq!(recv1, "broadcast-msg");
2135        assert_eq!(recv2, "broadcast-msg");
2136
2137        consumer.stop().await.unwrap();
2138    }
2139
2140    #[tokio::test]
2141    async fn concurrent_get_or_spawn_returns_same_state() {
2142        let port = free_port();
2143        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
2144            Arc::new(std::sync::Mutex::new(Vec::new()));
2145
2146        let mut handles = Vec::new();
2147        for _ in 0..4 {
2148            let results = results.clone();
2149            handles.push(tokio::spawn(async move {
2150                let state = ServerRegistry::global()
2151                    .get_or_spawn("127.0.0.1", port, None)
2152                    .await
2153                    .unwrap();
2154                results.lock().unwrap().push(state);
2155            }));
2156        }
2157
2158        for h in handles {
2159            h.await.unwrap();
2160        }
2161
2162        let states = results.lock().unwrap();
2163        assert_eq!(states.len(), 4);
2164        for i in 1..states.len() {
2165            assert!(
2166                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
2167                "all concurrent callers should get the same dispatch table"
2168            );
2169        }
2170    }
2171
2172    #[tokio::test]
2173    async fn body_conversion_helpers_cover_text_and_binary_paths() {
2174        let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
2175            .await
2176            .unwrap();
2177        assert!(matches!(text_msg, WsMessage::Text(_)));
2178
2179        let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
2180            .await
2181            .unwrap();
2182        assert!(matches!(bin_msg, WsMessage::Binary(_)));
2183
2184        let client_text =
2185            body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
2186                .await
2187                .unwrap();
2188        assert!(matches!(client_text, ClientWsMessage::Text(_)));
2189
2190        let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
2191            .await
2192            .unwrap();
2193        assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
2194    }
2195
2196    #[tokio::test]
2197    async fn body_to_text_handles_empty_text_json_and_bytes() {
2198        assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2199        assert_eq!(
2200            body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2201            "hello"
2202        );
2203        assert_eq!(
2204            body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2205                .await
2206                .unwrap(),
2207            "{\"n\":1}"
2208        );
2209        assert_eq!(
2210            body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2211                .await
2212                .unwrap(),
2213            "hi"
2214        );
2215    }
2216
2217    #[test]
2218    fn try_send_with_backpressure_returns_false_when_channel_full() {
2219        let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2220        assert!(try_send_with_backpressure(
2221            &tx,
2222            WsMessage::Text("first".into()),
2223            "test"
2224        ));
2225        assert!(!try_send_with_backpressure(
2226            &tx,
2227            WsMessage::Text("second".into()),
2228            "test"
2229        ));
2230    }
2231
2232    #[test]
2233    fn map_connect_error_formats_connection_refused_and_generic_errors() {
2234        let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2235        let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2236        assert!(err.to_string().contains("WebSocket connection refused"));
2237
2238        let generic = map_connect_error(
2239            tungstenite::Error::Protocol(
2240                tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2241            ),
2242            "ws://localhost:2/y",
2243        );
2244        assert!(
2245            generic
2246                .to_string()
2247                .contains("WebSocket connection failed (ws://localhost:2/y)")
2248        );
2249    }
2250
2251    // === Phase B Finding Tests ===
2252
2253    // WS-015: maxConnections=0 must be rejected
2254    #[test]
2255    fn from_uri_rejects_max_connections_zero() {
2256        let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2257        assert!(result.is_err());
2258        let msg = result.unwrap_err().to_string();
2259        assert!(
2260            msg.contains("maxConnections must be >= 1"),
2261            "expected maxConnections validation error, got: {msg}"
2262        );
2263    }
2264
2265    // WS-019: maxMessageSize=0 must be rejected
2266    #[test]
2267    fn from_uri_rejects_max_message_size_zero() {
2268        let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2269        assert!(result.is_err());
2270        let msg = result.unwrap_err().to_string();
2271        assert!(
2272            msg.contains("maxMessageSize must be > 0"),
2273            "expected maxMessageSize validation error, got: {msg}"
2274        );
2275    }
2276
2277    // WS-020: allowOrigin="" must be rejected
2278    #[test]
2279    fn from_uri_rejects_empty_allow_origin() {
2280        let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2281        assert!(result.is_err());
2282        let msg = result.unwrap_err().to_string();
2283        assert!(
2284            msg.contains("allowOrigin must not be empty"),
2285            "expected allowOrigin validation error, got: {msg}"
2286        );
2287    }
2288
2289    // WS-006: Double-start must be rejected
2290    #[tokio::test]
2291    async fn consumer_double_start_returns_error() {
2292        let port = free_port();
2293        let uri = format!("ws://127.0.0.1:{port}/doublestart");
2294        let component_ctx = NoOpComponentContext;
2295        let endpoint = WsComponent::new()
2296            .create_endpoint(&uri, &component_ctx)
2297            .unwrap();
2298
2299        let mut consumer = endpoint.create_consumer().unwrap();
2300        let (route_tx, _route_rx) = mpsc::channel(16);
2301        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2302
2303        // First start should succeed
2304        consumer.start(ctx).await.unwrap();
2305
2306        // Second start should fail
2307        let (route_tx2, _route_rx2) = mpsc::channel(16);
2308        let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2309        let result = consumer.start(ctx2).await;
2310        assert!(result.is_err());
2311        let msg = result.unwrap_err().to_string();
2312        assert!(
2313            msg.contains("already started"),
2314            "expected double-start error, got: {msg}"
2315        );
2316
2317        consumer.stop().await.unwrap();
2318    }
2319
2320    // WS-005: Registry cleanup on stop + port reuse
2321    #[tokio::test]
2322    async fn registry_cleanup_on_consumer_stop() {
2323        let port = free_port();
2324        let uri = format!("ws://127.0.0.1:{port}/cleanup");
2325        let component_ctx = NoOpComponentContext;
2326        let endpoint = WsComponent::new()
2327            .create_endpoint(&uri, &component_ctx)
2328            .unwrap();
2329
2330        let mut consumer = endpoint.create_consumer().unwrap();
2331        let (route_tx, _route_rx) = mpsc::channel(16);
2332        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2333        consumer.start(ctx).await.unwrap();
2334
2335        // Verify registry entry exists
2336        let registries = global_registries();
2337        let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2338        assert!(
2339            registries.contains_key(&key),
2340            "registry should have entry after start"
2341        );
2342
2343        // Stop consumer
2344        consumer.stop().await.unwrap();
2345
2346        // Verify registry entry is removed
2347        assert!(
2348            !registries.contains_key(&key),
2349            "registry should be cleaned up after stop"
2350        );
2351
2352        // Verify server registry reference is released (port can be reused)
2353        // The ServerRegistry should have removed the entry
2354        let server_reg = ServerRegistry::global();
2355        let guard = server_reg.inner.lock().unwrap();
2356        assert!(
2357            !guard.contains_key(&port),
2358            "ServerRegistry should remove port entry after last consumer stops"
2359        );
2360    }
2361
2362    // WS-003 + WS-004: poll_ready backpressure and server-send error handling
2363    #[tokio::test]
2364    async fn producer_server_send_returns_error_when_all_dropped() {
2365        let port = free_port();
2366        let uri = format!("ws://127.0.0.1:{port}/backpressure");
2367        let component_ctx = NoOpComponentContext;
2368        let endpoint = WsComponent::new()
2369            .create_endpoint(&uri, &component_ctx)
2370            .unwrap();
2371
2372        let mut consumer = endpoint.create_consumer().unwrap();
2373        let producer = endpoint
2374            .create_producer(&ProducerContext::default())
2375            .unwrap();
2376
2377        let (route_tx, _route_rx) = mpsc::channel(1); // Tiny channel to force backpressure
2378        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2379        consumer.start(ctx).await.unwrap();
2380
2381        // Connect a client so the registry has an entry
2382        let url = format!("ws://127.0.0.1:{port}/backpressure");
2383        let (mut client, _) = loop {
2384            match connect_async(&url).await {
2385                Ok(ok) => break ok,
2386                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2387            }
2388        };
2389
2390        // Don't consume messages — let the channel fill up
2391        tokio::time::sleep(Duration::from_millis(50)).await;
2392
2393        // Flood the channel to trigger backpressure
2394        let mut all_dropped = false;
2395        for _ in 0..100 {
2396            let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2397            match producer.clone().oneshot(exchange).await {
2398                Ok(_) => {}
2399                Err(e) => {
2400                    if e.to_string().contains("backpressure") {
2401                        all_dropped = true;
2402                        break;
2403                    }
2404                }
2405            }
2406        }
2407
2408        // The producer should eventually return a backpressure error
2409        assert!(
2410            all_dropped,
2411            "producer should return error when all messages are dropped due to backpressure"
2412        );
2413
2414        // Clean up
2415        let _ = client.close(None).await;
2416        consumer.stop().await.unwrap();
2417    }
2418
2419    // WS-012: Ping/pong round-trip in server mode
2420    #[tokio::test]
2421    async fn server_responds_to_client_ping_with_pong() {
2422        let port = free_port();
2423        let uri = format!("ws://127.0.0.1:{port}/pingpong");
2424        let component_ctx = NoOpComponentContext;
2425        let endpoint = WsComponent::new()
2426            .create_endpoint(&uri, &component_ctx)
2427            .unwrap();
2428
2429        let mut consumer = endpoint.create_consumer().unwrap();
2430        let (route_tx, _route_rx) = mpsc::channel(16);
2431        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2432        consumer.start(ctx).await.unwrap();
2433
2434        let url = format!("ws://127.0.0.1:{port}/pingpong");
2435        let (mut client, _) = loop {
2436            match connect_async(&url).await {
2437                Ok(ok) => break ok,
2438                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2439            }
2440        };
2441
2442        // Send a ping
2443        client
2444            .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2445            .await
2446            .unwrap();
2447
2448        // Expect a pong with the same payload
2449        let pong = tokio::time::timeout(Duration::from_secs(2), async {
2450            loop {
2451                match client.next().await {
2452                    Some(Ok(ClientMessage::Pong(data))) => break data,
2453                    Some(Ok(ClientMessage::Ping(_))) => continue,
2454                    Some(Ok(_)) => continue,
2455                    Some(Err(e)) => panic!("ws receive failed: {e}"),
2456                    None => panic!("websocket closed before pong"),
2457                }
2458            }
2459        })
2460        .await
2461        .unwrap();
2462
2463        assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2464
2465        consumer.stop().await.unwrap();
2466    }
2467
2468    // WS-008: Client-side retry on transient connect failures
2469    #[tokio::test]
2470    async fn producer_retries_on_connection_refused() {
2471        // Use a port that nothing is listening on
2472        let port = free_port();
2473        // Ensure nothing is on this port
2474        let cfg = WsEndpointConfig::from_uri(&format!(
2475            "ws://127.0.0.1:{port}/retry?reconnect=true&reconnectMaxAttempts=2&reconnectDelayMs=50"
2476        ))
2477        .unwrap();
2478        let producer = WsProducer::new(cfg.client_config());
2479
2480        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2481
2482        // Should fail after retries (nothing listening)
2483        let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2484        assert!(
2485            result.is_ok(),
2486            "producer should complete (with error) within timeout"
2487        );
2488        let result = result.unwrap();
2489        assert!(
2490            result.is_err(),
2491            "producer should fail when nothing is listening"
2492        );
2493        let msg = result.unwrap_err().to_string();
2494        assert!(
2495            msg.contains("connection refused"),
2496            "expected connection refused error, got: {msg}"
2497        );
2498    }
2499
2500    // WS-001: Server bind error is visible (fake server-start error test)
2501    #[tokio::test]
2502    async fn server_bind_error_is_reported() {
2503        // Bind a port manually to cause a conflict
2504        let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2505        let port = _listener.local_addr().unwrap().port();
2506
2507        // Try to start a consumer on the same port — should succeed since axum binds lazily
2508        // The actual bind error happens when the server task runs
2509        let uri = format!("ws://127.0.0.1:{port}/binderror");
2510        let component_ctx = NoOpComponentContext;
2511        let endpoint = WsComponent::new()
2512            .create_endpoint(&uri, &component_ctx)
2513            .unwrap();
2514
2515        let mut consumer = endpoint.create_consumer().unwrap();
2516        let (route_tx, _route_rx) = mpsc::channel(16);
2517        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2518
2519        // Start should succeed (server spawns, but bind may fail)
2520        let start_result = consumer.start(ctx).await;
2521        // The server may or may not have bound yet — this test verifies no panic
2522        // The actual error is logged by the server task
2523        let _ = start_result;
2524
2525        consumer.stop().await.unwrap();
2526    }
2527
2528    #[test]
2529    fn ws_app_state_server_error_starts_false() {
2530        let state = WsAppState {
2531            dispatch: Arc::new(RwLock::new(HashMap::new())),
2532            path_configs: Arc::new(DashMap::new()),
2533            path_policies: Arc::new(DashMap::new()),
2534            server_error: new_atomic_false(),
2535        };
2536        assert!(
2537            !state.server_error.load(Ordering::Relaxed),
2538            "server_error should start as false"
2539        );
2540    }
2541
2542    #[test]
2543    fn ws_app_state_server_error_can_be_set() {
2544        let state = WsAppState {
2545            dispatch: Arc::new(RwLock::new(HashMap::new())),
2546            path_configs: Arc::new(DashMap::new()),
2547            path_policies: Arc::new(DashMap::new()),
2548            server_error: new_atomic_false(),
2549        };
2550        assert!(!state.server_error.load(Ordering::Relaxed));
2551        state.server_error.store(true, Ordering::Relaxed);
2552        assert!(state.server_error.load(Ordering::Relaxed));
2553    }
2554
2555    #[tokio::test]
2556    async fn consumer_stop_returns_error_when_server_had_errors() {
2557        let port = free_port();
2558        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2559        let mut consumer = WsConsumer::new(cfg.server_config());
2560        let (route_tx, _route_rx) = mpsc::channel(16);
2561        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2562        consumer.start(ctx).await.unwrap();
2563
2564        // Simulate server error by setting the flag directly
2565        if let Some(ref state) = consumer.server_state {
2566            state.server_error.store(true, Ordering::Relaxed);
2567        }
2568
2569        let result = consumer.stop().await;
2570        assert!(
2571            result.is_err(),
2572            "stop should return error when server had errors"
2573        );
2574        let msg = result.unwrap_err().to_string();
2575        assert!(
2576            msg.contains("terminated with errors"),
2577            "expected server error message, got: {msg}"
2578        );
2579    }
2580
2581    #[tokio::test]
2582    async fn consumer_stop_succeeds_when_server_healthy() {
2583        let port = free_port();
2584        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2585        let mut consumer = WsConsumer::new(cfg.server_config());
2586        let (route_tx, _route_rx) = mpsc::channel(16);
2587        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2588        consumer.start(ctx).await.unwrap();
2589
2590        let result = consumer.stop().await;
2591        assert!(
2592            result.is_ok(),
2593            "stop should succeed when server is healthy: {:?}",
2594            result
2595        );
2596    }
2597
2598    // === H-10 Finding Tests ===
2599
2600    // WS-007: subprotocol negotiation support
2601    #[test]
2602    fn endpoint_config_parses_subprotocols() {
2603        let cfg = WsEndpointConfig::from_uri(
2604            "ws://localhost:9001/chat?subprotocols=graphql-ws,graphql-transport-ws",
2605        )
2606        .unwrap();
2607        assert_eq!(cfg.subprotocols, vec!["graphql-ws", "graphql-transport-ws"]);
2608    }
2609
2610    #[test]
2611    fn endpoint_config_default_subprotocols_empty() {
2612        let cfg = WsEndpointConfig::default();
2613        assert!(cfg.subprotocols.is_empty());
2614    }
2615
2616    // WS-017: sendTimeoutMs URI option
2617    #[test]
2618    fn endpoint_config_parses_send_timeout() {
2619        let cfg =
2620            WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=5000").unwrap();
2621        assert_eq!(cfg.send_timeout, Duration::from_millis(5000));
2622    }
2623
2624    #[test]
2625    fn endpoint_config_default_send_timeout() {
2626        let cfg = WsEndpointConfig::default();
2627        assert_eq!(cfg.send_timeout, Duration::from_secs(30));
2628    }
2629
2630    #[test]
2631    fn endpoint_config_rejects_invalid_send_timeout() {
2632        let err =
2633            WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=abc").unwrap_err();
2634        assert!(err.to_string().contains("sendTimeoutMs"));
2635    }
2636
2637    // WS-018: binaryPayload URI option
2638    #[test]
2639    fn endpoint_config_parses_binary_payload() {
2640        let cfg =
2641            WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=true").unwrap();
2642        assert!(cfg.binary_payload);
2643    }
2644
2645    #[test]
2646    fn endpoint_config_default_binary_payload_false() {
2647        let cfg = WsEndpointConfig::default();
2648        assert!(!cfg.binary_payload);
2649    }
2650
2651    #[test]
2652    fn endpoint_config_rejects_invalid_binary_payload() {
2653        let err =
2654            WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=yes").unwrap_err();
2655        assert!(err.to_string().contains("binaryPayload"));
2656    }
2657}