Skip to main content

camel_component_ws/
lib.rs

1//! WebSocket component for rust-camel — Axum-based WebSocket server and Tokio-tungstenite client for bidirectional messaging.
2//!
3//! Main types: `WsComponent`, `WsBundle`, `WsConfig`, `WsServerConfig`, `WsClientConfig`, `WsEndpointConfig`.
4//! Main modules: `bundle`, `config`.
5
6pub mod bundle;
7pub mod config;
8
9pub use bundle::WsBundle;
10pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
11
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex, OnceLock};
14
15use async_trait::async_trait;
16use axum::body::Body;
17use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
18use axum::extract::{FromRequest, Request, State};
19use axum::http::{StatusCode, header};
20use axum::response::IntoResponse;
21use axum::{Router, serve};
22use camel_component_api::{
23    Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
24};
25use camel_component_api::{
26    Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
27    ProducerContext,
28};
29use dashmap::DashMap;
30use futures::{SinkExt, StreamExt};
31use std::future::Future;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use tokio::sync::{OnceCell, RwLock, mpsc};
35use tokio::task::JoinHandle;
36use tokio_tungstenite::tungstenite;
37use tokio_tungstenite::tungstenite::client::IntoClientRequest;
38use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
39use tower::Service;
40
41#[derive(Clone)]
42struct WsPathConfig {
43    max_connections: u32,
44    max_message_size: u32,
45    heartbeat_interval: std::time::Duration,
46    idle_timeout: std::time::Duration,
47    allow_origin: String,
48}
49
50impl Default for WsPathConfig {
51    fn default() -> Self {
52        let cfg = WsEndpointConfig::default();
53        Self {
54            max_connections: cfg.max_connections,
55            max_message_size: cfg.max_message_size,
56            heartbeat_interval: cfg.heartbeat_interval,
57            idle_timeout: cfg.idle_timeout,
58            allow_origin: cfg.allow_origin,
59        }
60    }
61}
62
63#[derive(Clone)]
64struct WsTlsConfig {
65    cert_path: String,
66    key_path: String,
67}
68
69type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
70
71struct ServerHandle {
72    state: WsAppState,
73    is_tls: bool,
74    _task: JoinHandle<()>,
75}
76
77struct ServerRegistryInner {
78    cell: Arc<OnceCell<ServerHandle>>,
79    ref_count: usize,
80}
81
82pub struct ServerRegistry {
83    inner: Mutex<HashMap<u16, ServerRegistryInner>>,
84}
85
86impl ServerRegistry {
87    pub fn global() -> &'static Self {
88        static REG: OnceLock<ServerRegistry> = OnceLock::new();
89        REG.get_or_init(|| Self {
90            inner: Mutex::new(HashMap::new()),
91        })
92    }
93
94    pub(crate) async fn get_or_spawn(
95        &'static self,
96        host: &str,
97        port: u16,
98        tls_config: Option<WsTlsConfig>,
99    ) -> Result<WsAppState, CamelError> {
100        let wants_tls = tls_config.is_some();
101        let host_owned = host.to_string();
102
103        let (cell, _is_new) = {
104            let mut guard = self.inner.lock().map_err(|_| {
105                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
106            })?;
107            let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
108                cell: Arc::new(OnceCell::new()),
109                ref_count: 0,
110            });
111            entry.ref_count += 1;
112            (entry.cell.clone(), entry.ref_count == 1)
113        };
114
115        let handle = cell
116            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
117            .await?;
118
119        if wants_tls != handle.is_tls {
120            // Decrement ref count since we're rejecting this caller
121            let mut guard = self.inner.lock().map_err(|_| {
122                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
123            })?;
124            if let Some(entry) = guard.get_mut(&port) {
125                entry.ref_count -= 1;
126                if entry.ref_count == 0 {
127                    guard.remove(&port);
128                }
129            }
130            return Err(CamelError::EndpointCreationFailed(format!(
131                "Server on port {port} already running with different TLS mode"
132            )));
133        }
134
135        Ok(handle.state.clone())
136    }
137
138    /// Release a reference to the server on the given port.
139    /// When the last reference is released, the server task is aborted and the entry removed. (WS-005)
140    pub(crate) fn release(&self, port: u16) {
141        let mut guard = match self.inner.lock() {
142            Ok(g) => g,
143            Err(_) => return,
144        };
145        if let Some(entry) = guard.get_mut(&port) {
146            entry.ref_count = entry.ref_count.saturating_sub(1);
147            if entry.ref_count == 0 {
148                // Abort the server task if it exists (WS-001, WS-005)
149                if let Some(handle) = entry.cell.get() {
150                    handle._task.abort();
151                }
152                guard.remove(&port);
153                tracing::info!(port, "WebSocket server registry entry removed");
154            }
155        }
156    }
157}
158
159async fn spawn_server(
160    host: &str,
161    port: u16,
162    tls_config: Option<WsTlsConfig>,
163) -> Result<ServerHandle, CamelError> {
164    let host_owned = host.to_string();
165    let addr = format!("{host}:{port}");
166    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
167    let path_configs = Arc::new(DashMap::new());
168    let server_error = new_atomic_false();
169    let state = WsAppState {
170        dispatch: Arc::clone(&dispatch),
171        path_configs: Arc::clone(&path_configs),
172        server_error: Arc::clone(&server_error),
173    };
174    let app = Router::new()
175        .fallback(dispatch_handler)
176        .with_state(state.clone());
177
178    let (task, is_tls) = if let Some(ref tls) = tls_config {
179        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
180        let parsed_addr = addr.parse().map_err(|e| {
181            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
182        })?;
183        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
184        let error_flag = Arc::clone(&server_error);
185        let task = tokio::spawn(async move {
186            if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
187                .serve(app.into_make_service())
188                .await
189            {
190                tracing::error!(
191                    host = host_owned,
192                    port = port,
193                    error = %e,
194                    "WebSocket server terminated with error"
195                );
196                error_flag.store(true, Ordering::Relaxed);
197            }
198        });
199        (task, true)
200    } else {
201        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
202            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
203        })?;
204        let error_flag = Arc::clone(&server_error);
205        let task = tokio::spawn(async move {
206            if let Err(e) = serve(listener, app).await {
207                tracing::error!(
208                    host = host_owned,
209                    port = port,
210                    error = %e,
211                    "WebSocket server terminated with error"
212                );
213                error_flag.store(true, Ordering::Relaxed);
214            }
215        });
216        (task, false)
217    };
218
219    tracing::info!(host, port, is_tls, "WebSocket server started");
220
221    Ok(ServerHandle {
222        state,
223        is_tls,
224        _task: task,
225    })
226}
227
228#[derive(Clone)]
229struct WsAppState {
230    dispatch: DispatchTable,
231    path_configs: Arc<DashMap<String, WsPathConfig>>,
232    server_error: Arc<AtomicBool>,
233}
234
235pub struct WsConnectionRegistry {
236    connections: DashMap<String, mpsc::Sender<WsMessage>>,
237}
238
239static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
240    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
241> = OnceLock::new();
242
243fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
244    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
245}
246
247impl Default for WsConnectionRegistry {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253impl WsConnectionRegistry {
254    pub fn new() -> Self {
255        Self {
256            connections: DashMap::new(),
257        }
258    }
259
260    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
261        self.connections.insert(key, tx);
262    }
263
264    pub fn remove(&self, key: &str) {
265        self.connections.remove(key);
266    }
267
268    pub fn len(&self) -> usize {
269        self.connections.len()
270    }
271
272    pub fn is_empty(&self) -> bool {
273        self.connections.is_empty()
274    }
275
276    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
277        self.connections.iter().map(|e| e.value().clone()).collect()
278    }
279
280    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
281        keys.iter()
282            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
283            .collect()
284    }
285}
286
287async fn dispatch_handler(
288    State(state): State<WsAppState>,
289    req: Request<Body>,
290) -> impl IntoResponse {
291    let path = req.uri().path().to_string();
292    let origin = req
293        .headers()
294        .get(header::ORIGIN)
295        .and_then(|value| value.to_str().ok())
296        .map(str::to_string);
297    let remote_addr = req
298        .extensions()
299        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
300        .map(|ci| ci.0.to_string())
301        .unwrap_or_default();
302    let table = state.dispatch.read().await;
303    if !table.contains_key(&path) {
304        return (
305            StatusCode::NOT_FOUND,
306            "no ws endpoint registered for this path",
307        )
308            .into_response();
309    }
310    drop(table);
311
312    let path_config = state
313        .path_configs
314        .get(&path)
315        .map(|entry| entry.value().clone())
316        .unwrap_or_default();
317    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
318        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
319    }
320
321    let upgrade_headers: HashMap<String, String> = req
322        .headers()
323        .iter()
324        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
325        .collect();
326
327    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
328        Ok(ws) => ws,
329        Err(_) => {
330            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
331        }
332    };
333
334    ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
335        .into_response()
336}
337
338#[allow(unused_variables)]
339async fn ws_handler(
340    socket: WebSocket,
341    state: WsAppState,
342    path: String,
343    remote_addr: String,
344    upgrade_headers: HashMap<String, String>,
345) {
346    let connection_key = uuid::Uuid::new_v4().to_string();
347    let path_config = state
348        .path_configs
349        .get(&path)
350        .map(|entry| entry.value().clone())
351        .unwrap_or_default();
352
353    let env_tx = {
354        let table = state.dispatch.read().await;
355        table.get(&path).cloned()
356    };
357    let Some(env_tx) = env_tx else {
358        return;
359    };
360
361    let (mut sink, mut stream) = socket.split();
362    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
363
364    let registry = global_registries();
365    let mut registry_key = None;
366    for entry in registry.iter() {
367        if entry.key().2 == path {
368            entry.value().insert(connection_key.clone(), out_tx.clone());
369            registry_key = Some(entry.key().clone());
370            break;
371        }
372    }
373
374    // Clone for writer closure and subsequent tracing (WS-009)
375    let conn_key_for_writer = connection_key.clone();
376    let path_for_writer = path.clone();
377
378    let writer = tokio::spawn(async move {
379        while let Some(msg) = out_rx.recv().await {
380            if let Err(e) = sink.send(msg).await {
381                tracing::warn!(
382                    connection_key = conn_key_for_writer,
383                    path = path_for_writer,
384                    error = %e,
385                    "WebSocket writer send error — closing connection"
386                );
387                break;
388            }
389        }
390    });
391
392    tracing::info!(
393        connection_key = connection_key,
394        path = path,
395        remote_addr = remote_addr,
396        "WebSocket connection opened"
397    );
398
399    let mut over_limit = false;
400    if let Some(key) = &registry_key
401        && let Some(entry) = registry.get(key)
402        && entry.len() > path_config.max_connections as usize
403    {
404        over_limit = true;
405    }
406    if over_limit {
407        try_send_with_backpressure(
408            &out_tx,
409            WsMessage::Close(Some(CloseFrame {
410                code: CloseCode::from(1013u16),
411                reason: "max connections exceeded".into(),
412            })),
413            "max-connections-close",
414        );
415        if let Some(key) = registry_key.clone()
416            && let Some(entry) = registry.get(&key)
417        {
418            entry.remove(&connection_key);
419        }
420        drop(out_tx);
421        let _ = writer.await;
422        return;
423    }
424
425    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
426        let ping_tx = out_tx.clone();
427        let interval = path_config.heartbeat_interval;
428        Some(tokio::spawn(async move {
429            let mut ticker = tokio::time::interval(interval);
430            loop {
431                ticker.tick().await;
432                let _ = try_send_with_backpressure(
433                    &ping_tx,
434                    WsMessage::Ping(Vec::new().into()),
435                    "heartbeat-ping",
436                );
437            }
438        }))
439    } else {
440        None
441    };
442
443    loop {
444        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
445            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
446                Ok(msg) => msg,
447                Err(_) => {
448                    try_send_with_backpressure(
449                        &out_tx,
450                        WsMessage::Close(Some(CloseFrame {
451                            code: CloseCode::from(1000u16),
452                            reason: "idle timeout".into(),
453                        })),
454                        "idle-timeout-close",
455                    );
456                    break;
457                }
458            }
459        } else {
460            stream.next().await
461        };
462
463        let Some(msg) = next_msg else {
464            break;
465        };
466
467        match msg {
468            Ok(WsMessage::Ping(data)) => {
469                tracing::debug!(
470                    connection_key = connection_key,
471                    path = path,
472                    "WebSocket ping received — sending pong"
473                );
474                let _ = try_send_with_backpressure(
475                    &out_tx,
476                    WsMessage::Pong(data),
477                    "ping-pong-response",
478                );
479            }
480            Ok(WsMessage::Pong(_)) => {
481                tracing::debug!(
482                    connection_key = connection_key,
483                    path = path,
484                    "WebSocket pong received"
485                );
486            }
487            Ok(WsMessage::Text(text)) => {
488                if text.len() > path_config.max_message_size as usize {
489                    try_send_with_backpressure(
490                        &out_tx,
491                        WsMessage::Close(Some(CloseFrame {
492                            code: CloseCode::from(1009u16),
493                            reason: "message too large".into(),
494                        })),
495                        "max-message-size-close-text",
496                    );
497                    break;
498                }
499
500                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
501                message.set_header(
502                    "CamelWsConnectionKey",
503                    serde_json::Value::String(connection_key.clone()),
504                );
505                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
506                message.set_header(
507                    "CamelWsRemoteAddress",
508                    serde_json::Value::String(remote_addr.clone()),
509                );
510
511                #[allow(unused_mut)]
512                let mut exchange = Exchange::new(message);
513                #[cfg(feature = "otel")]
514                {
515                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
516                }
517                if env_tx
518                    .send(ExchangeEnvelope {
519                        exchange,
520                        reply_tx: None,
521                    })
522                    .await
523                    .is_err()
524                {
525                    break;
526                }
527            }
528            Ok(WsMessage::Binary(data)) => {
529                if data.len() > path_config.max_message_size as usize {
530                    try_send_with_backpressure(
531                        &out_tx,
532                        WsMessage::Close(Some(CloseFrame {
533                            code: CloseCode::from(1009u16),
534                            reason: "message too large".into(),
535                        })),
536                        "max-message-size-close-binary",
537                    );
538                    break;
539                }
540
541                let mut message = CamelMessage::new(CamelBody::Bytes(data));
542                message.set_header(
543                    "CamelWsConnectionKey",
544                    serde_json::Value::String(connection_key.clone()),
545                );
546                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
547                message.set_header(
548                    "CamelWsRemoteAddress",
549                    serde_json::Value::String(remote_addr.clone()),
550                );
551
552                #[allow(unused_mut)]
553                let mut exchange = Exchange::new(message);
554                #[cfg(feature = "otel")]
555                {
556                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
557                }
558                if env_tx
559                    .send(ExchangeEnvelope {
560                        exchange,
561                        reply_tx: None,
562                    })
563                    .await
564                    .is_err()
565                {
566                    break;
567                }
568            }
569            Ok(WsMessage::Close(cf)) => {
570                let reason = cf
571                    .as_ref()
572                    .map(|f| f.reason.to_string())
573                    .unwrap_or_default();
574                tracing::info!(
575                    connection_key = connection_key,
576                    path = path,
577                    reason = reason,
578                    "WebSocket connection closed by peer"
579                );
580                break;
581            }
582            Err(e) => {
583                tracing::warn!(
584                    connection_key = connection_key,
585                    path = path,
586                    error = %e,
587                    "WebSocket receive error"
588                );
589                break;
590            }
591        }
592    }
593
594    if let Some(task) = heartbeat_task {
595        task.abort();
596    }
597
598    if let Some(key) = registry_key
599        && let Some(entry) = registry.get(&key)
600    {
601        entry.remove(&connection_key);
602    }
603    drop(out_tx);
604    let _ = writer.await;
605
606    tracing::info!(
607        connection_key = connection_key,
608        path = path,
609        "WebSocket connection closed"
610    );
611}
612
613pub struct WsComponent {
614    pub(crate) config: WsConfig,
615}
616
617impl WsComponent {
618    pub fn new() -> Self {
619        Self {
620            config: WsConfig::default(),
621        }
622    }
623
624    pub fn with_config(config: WsConfig) -> Self {
625        Self { config }
626    }
627}
628
629impl Default for WsComponent {
630    fn default() -> Self {
631        Self::new()
632    }
633}
634
635impl Component for WsComponent {
636    fn scheme(&self) -> &str {
637        "ws"
638    }
639
640    fn create_endpoint(
641        &self,
642        uri: &str,
643        _ctx: &dyn camel_component_api::ComponentContext,
644    ) -> Result<Box<dyn Endpoint>, CamelError> {
645        self.config.validate()?;
646        let mut cfg = WsEndpointConfig::from_uri(uri)?;
647        if let Some(v) = self.config.max_connections {
648            cfg.max_connections = v;
649        }
650        if let Some(v) = self.config.max_message_size {
651            cfg.max_message_size = v;
652        }
653        if let Some(v) = self.config.heartbeat_interval_ms {
654            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
655        }
656        if let Some(v) = self.config.idle_timeout_ms {
657            cfg.idle_timeout = std::time::Duration::from_millis(v);
658        }
659        if let Some(v) = self.config.connect_timeout_ms {
660            cfg.connect_timeout = std::time::Duration::from_millis(v);
661        }
662        if let Some(v) = self.config.response_timeout_ms {
663            cfg.response_timeout = std::time::Duration::from_millis(v);
664        }
665        if let Some(v) = self.config.send_timeout_ms {
666            cfg.send_timeout = std::time::Duration::from_millis(v);
667        }
668        if let Some(v) = self.config.binary_payload {
669            cfg.binary_payload = v;
670        }
671        if let Some(ref v) = self.config.subprotocols {
672            cfg.subprotocols = v.clone();
673        }
674        Ok(Box::new(WsEndpoint {
675            uri: uri.to_string(),
676            cfg,
677        }))
678    }
679}
680
681pub struct WssComponent {
682    pub(crate) config: WsConfig,
683}
684
685impl WssComponent {
686    pub fn new() -> Self {
687        Self {
688            config: WsConfig::default(),
689        }
690    }
691
692    pub fn with_config(config: WsConfig) -> Self {
693        Self { config }
694    }
695}
696
697impl Default for WssComponent {
698    fn default() -> Self {
699        Self::new()
700    }
701}
702
703impl Component for WssComponent {
704    fn scheme(&self) -> &str {
705        "wss"
706    }
707
708    fn create_endpoint(
709        &self,
710        uri: &str,
711        _ctx: &dyn camel_component_api::ComponentContext,
712    ) -> Result<Box<dyn Endpoint>, CamelError> {
713        self.config.validate()?;
714        let mut cfg = WsEndpointConfig::from_uri(uri)?;
715        if let Some(v) = self.config.max_connections {
716            cfg.max_connections = v;
717        }
718        if let Some(v) = self.config.max_message_size {
719            cfg.max_message_size = v;
720        }
721        if let Some(v) = self.config.heartbeat_interval_ms {
722            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
723        }
724        if let Some(v) = self.config.idle_timeout_ms {
725            cfg.idle_timeout = std::time::Duration::from_millis(v);
726        }
727        if let Some(v) = self.config.connect_timeout_ms {
728            cfg.connect_timeout = std::time::Duration::from_millis(v);
729        }
730        if let Some(v) = self.config.response_timeout_ms {
731            cfg.response_timeout = std::time::Duration::from_millis(v);
732        }
733        if let Some(v) = self.config.send_timeout_ms {
734            cfg.send_timeout = std::time::Duration::from_millis(v);
735        }
736        if let Some(v) = self.config.binary_payload {
737            cfg.binary_payload = v;
738        }
739        if let Some(ref v) = self.config.subprotocols {
740            cfg.subprotocols = v.clone();
741        }
742        Ok(Box::new(WsEndpoint {
743            uri: uri.to_string(),
744            cfg,
745        }))
746    }
747}
748
749struct WsEndpoint {
750    uri: String,
751    cfg: WsEndpointConfig,
752}
753
754impl Endpoint for WsEndpoint {
755    fn uri(&self) -> &str {
756        &self.uri
757    }
758
759    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
760        Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
761    }
762
763    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
764        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
765    }
766}
767
768pub struct WsConsumer {
769    cfg: WsServerConfig,
770    registry: Arc<WsConnectionRegistry>,
771    server_state: Option<WsAppState>,
772    registry_key: Option<(String, u16, String)>,
773    forward_task: Option<JoinHandle<()>>,
774}
775
776impl WsConsumer {
777    pub fn new(cfg: WsServerConfig) -> Self {
778        Self {
779            cfg,
780            registry: Arc::new(WsConnectionRegistry::new()),
781            server_state: None,
782            registry_key: None,
783            forward_task: None,
784        }
785    }
786}
787
788#[async_trait]
789impl Consumer for WsConsumer {
790    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
791        // Reject double-start (WS-006)
792        if self.server_state.is_some() {
793            return Err(CamelError::EndpointCreationFailed(
794                "WebSocket consumer already started".into(),
795            ));
796        }
797
798        tracing::info!(
799            host = self.cfg.inner.host,
800            port = self.cfg.inner.port,
801            path = self.cfg.inner.path,
802            scheme = self.cfg.inner.scheme,
803            "WebSocket consumer starting"
804        );
805
806        let tls_config = if self.cfg.inner.scheme == "wss" {
807            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
808                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
809            })?;
810            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
811                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
812            })?;
813            Some(WsTlsConfig {
814                cert_path,
815                key_path,
816            })
817        } else {
818            None
819        };
820
821        let state = ServerRegistry::global()
822            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
823            .await?;
824
825        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
826        {
827            let mut table = state.dispatch.write().await;
828            table.insert(self.cfg.inner.path.clone(), env_tx);
829        }
830
831        state.path_configs.insert(
832            self.cfg.inner.path.clone(),
833            WsPathConfig {
834                max_connections: self.cfg.inner.max_connections,
835                max_message_size: self.cfg.inner.max_message_size,
836                heartbeat_interval: self.cfg.inner.heartbeat_interval,
837                idle_timeout: self.cfg.inner.idle_timeout,
838                allow_origin: self.cfg.inner.allow_origin.clone(),
839            },
840        );
841
842        let registry_key = (
843            self.cfg.inner.canonical_host(),
844            self.cfg.inner.port,
845            self.cfg.inner.path.clone(),
846        );
847        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
848
849        let sender = ctx.sender();
850        let forward_task = tokio::spawn(async move {
851            while let Some(envelope) = env_rx.recv().await {
852                if sender.send(envelope).await.is_err() {
853                    break;
854                }
855            }
856        });
857
858        self.server_state = Some(state);
859        self.registry_key = Some(registry_key);
860        self.forward_task = Some(forward_task);
861        Ok(())
862    }
863
864    async fn stop(&mut self) -> Result<(), CamelError> {
865        tracing::info!(
866            host = self.cfg.inner.host,
867            port = self.cfg.inner.port,
868            path = self.cfg.inner.path,
869            "WebSocket consumer stopping"
870        );
871
872        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
873            code: axum::extract::ws::CloseCode::from(1001u16),
874            reason: "consumer stopping".into(),
875        }));
876        for tx in self.registry.snapshot_senders() {
877            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
878        }
879
880        let mut had_server_error = false;
881
882        if let Some(state) = self.server_state.take() {
883            had_server_error = state.server_error.load(Ordering::Relaxed);
884            let mut table = state.dispatch.write().await;
885            table.remove(&self.cfg.inner.path);
886            state.path_configs.remove(&self.cfg.inner.path);
887        }
888
889        if let Some(key) = self.registry_key.take() {
890            global_registries().remove(&key);
891            ServerRegistry::global().release(key.1);
892        }
893
894        if let Some(task) = self.forward_task.take() {
895            task.abort();
896        }
897
898        tracing::info!(
899            host = self.cfg.inner.host,
900            port = self.cfg.inner.port,
901            path = self.cfg.inner.path,
902            "WebSocket consumer stopped"
903        );
904
905        if had_server_error {
906            tracing::warn!(
907                host = self.cfg.inner.host,
908                port = self.cfg.inner.port,
909                path = self.cfg.inner.path,
910                "WebSocket server had errors during its lifetime"
911            );
912            return Err(CamelError::ProcessorError(
913                "WebSocket server terminated with errors during its lifetime".into(),
914            ));
915        }
916
917        Ok(())
918    }
919
920    fn concurrency_model(&self) -> ConcurrencyModel {
921        ConcurrencyModel::Concurrent {
922            max: Some(self.cfg.inner.max_connections as usize),
923        }
924    }
925}
926
927use std::sync::atomic::{AtomicBool, Ordering};
928
929fn new_atomic_false() -> Arc<AtomicBool> {
930    Arc::new(AtomicBool::new(false))
931}
932
933#[derive(Clone)]
934pub struct WsProducer {
935    cfg: WsClientConfig,
936    /// Shared flag set by the async future when server-send hits backpressure,
937    /// so that the next `poll_ready` call can return an error. (WS-003)
938    backpressure_flag: Arc<AtomicBool>,
939}
940
941impl WsProducer {
942    pub fn new(cfg: WsClientConfig) -> Self {
943        Self {
944            cfg,
945            backpressure_flag: Arc::new(AtomicBool::new(false)),
946        }
947    }
948}
949
950impl Service<Exchange> for WsProducer {
951    type Response = Exchange;
952    type Error = CamelError;
953    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
954
955    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
956        // Return error if last server-send hit backpressure (WS-003)
957        if self.backpressure_flag.swap(false, Ordering::Relaxed) {
958            return Poll::Ready(Err(CamelError::ProcessorError(
959                "WebSocket producer backpressure: previous send was dropped due to full channel"
960                    .into(),
961            )));
962        }
963        Poll::Ready(Ok(()))
964    }
965
966    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
967        let cfg = self.cfg.clone();
968        let backpressure_flag = Arc::clone(&self.backpressure_flag);
969
970        Box::pin(async move {
971            let canonical_host = cfg.inner.canonical_host();
972            let key = (
973                canonical_host.clone(),
974                cfg.inner.port,
975                cfg.inner.path.clone(),
976            );
977
978            let send_to_all = exchange
979                .input
980                .header("CamelWsSendToAll")
981                .and_then(|v| v.as_bool())
982                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
983                .unwrap_or(false);
984
985            let conn_keys_header = exchange
986                .input
987                .header("CamelWsConnectionKey")
988                .and_then(|v| v.as_str())
989                .map(str::to_string);
990
991            let local_exists = global_registries().contains_key(&key);
992            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
993
994            let message_type = exchange
995                .input
996                .header("CamelWsMessageType")
997                .and_then(|v| v.as_str())
998                .unwrap_or("text")
999                .to_ascii_lowercase();
1000
1001            if server_send_mode {
1002                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
1003                let Some(registry) = registry else {
1004                    return Err(CamelError::ProcessorError(format!(
1005                        "WebSocket local consumer not found for {}:{}{}",
1006                        canonical_host, cfg.inner.port, cfg.inner.path
1007                    )));
1008                };
1009
1010                let out_msg = body_to_axum_ws_message(
1011                    std::mem::take(&mut exchange.input.body),
1012                    &message_type,
1013                )
1014                .await?;
1015
1016                let targets = if send_to_all {
1017                    registry.snapshot_senders()
1018                } else if let Some(keys) = conn_keys_header {
1019                    let parsed: Vec<String> = keys
1020                        .split(',')
1021                        .map(str::trim)
1022                        .filter(|k| !k.is_empty())
1023                        .map(|k| k.to_string())
1024                        .collect();
1025                    registry.get_senders_for_keys(&parsed)
1026                } else {
1027                    registry.snapshot_senders()
1028                };
1029
1030                let mut dropped = 0usize;
1031                for tx in &targets {
1032                    if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1033                        dropped += 1;
1034                    }
1035                }
1036
1037                if dropped > 0 {
1038                    tracing::warn!(
1039                        host = canonical_host,
1040                        port = cfg.inner.port,
1041                        path = cfg.inner.path,
1042                        dropped,
1043                        total = targets.len(),
1044                        "WebSocket producer dropped messages due to backpressure"
1045                    );
1046                    exchange.input.set_header(
1047                        "CamelWsDeliveryDropped",
1048                        serde_json::Value::Number(dropped.into()),
1049                    );
1050                    // Signal backpressure for next poll_ready call (WS-003)
1051                    backpressure_flag.store(true, Ordering::Relaxed);
1052                    if dropped == targets.len() {
1053                        return Err(CamelError::ProcessorError(format!(
1054                            "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1055                        )));
1056                    }
1057                }
1058
1059                tracing::debug!(
1060                    host = canonical_host,
1061                    port = cfg.inner.port,
1062                    path = cfg.inner.path,
1063                    targets = targets.len(),
1064                    "WebSocket producer server-send complete"
1065                );
1066
1067                return Ok(exchange);
1068            }
1069
1070            let url = format!(
1071                "{}://{}:{}{}",
1072                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1073            );
1074
1075            tracing::debug!(url = url, "WebSocket producer connecting");
1076
1077            #[allow(unused_mut)]
1078            let mut request = url
1079                .clone()
1080                .into_client_request()
1081                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1082
1083            #[cfg(feature = "otel")]
1084            {
1085                let mut otel_headers = HashMap::new();
1086                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1087                for (k, v) in otel_headers {
1088                    if let (Ok(name), Ok(val)) = (
1089                        http::header::HeaderName::from_bytes(k.as_bytes()),
1090                        http::header::HeaderValue::from_str(&v),
1091                    ) {
1092                        request.headers_mut().insert(name, val);
1093                    }
1094                }
1095            }
1096
1097            // Add Sec-WebSocket-Protocol header if subprotocols configured (WS-007)
1098            if !cfg.inner.subprotocols.is_empty() {
1099                let proto_value = cfg.inner.subprotocols.join(", ");
1100                if let (Ok(name), Ok(val)) = (
1101                    http::header::HeaderName::from_bytes(b"Sec-WebSocket-Protocol"),
1102                    http::header::HeaderValue::from_str(&proto_value),
1103                ) {
1104                    request.headers_mut().insert(name, val);
1105                }
1106            }
1107
1108            // Determine message type: respect binary_payload config (WS-018)
1109            let effective_message_type = if cfg.inner.binary_payload {
1110                "binary"
1111            } else {
1112                &message_type
1113            };
1114
1115            let max_retries = if cfg.inner.reconnect {
1116                cfg.inner.reconnect_max_attempts as usize
1117            } else {
1118                0
1119            };
1120            let mut attempts = 0usize;
1121            let mut ws_stream = loop {
1122                let connect_future = tokio_tungstenite::connect_async(request.clone());
1123                match tokio::time::timeout(cfg.inner.connect_timeout, connect_future).await {
1124                    Ok(Ok((stream, _))) => break stream,
1125                    Ok(Err(e)) => {
1126                        let err = map_connect_error(e, &url);
1127                        let is_transient = err.to_string().contains("connection refused")
1128                            || err.to_string().contains("timeout")
1129                            || err.to_string().contains("connection failed");
1130                        if is_transient && attempts < max_retries {
1131                            attempts += 1;
1132                            tracing::warn!(
1133                                url = url,
1134                                error = %err,
1135                                attempt = attempts,
1136                                max_retries,
1137                                "WebSocket connect failed — retrying"
1138                            );
1139                            tokio::time::sleep(std::time::Duration::from_millis(
1140                                cfg.inner.reconnect_delay_ms,
1141                            ))
1142                            .await;
1143                            continue;
1144                        }
1145                        return Err(err);
1146                    }
1147                    Err(_) => {
1148                        let err = CamelError::ProcessorError(format!(
1149                            "WebSocket connect timeout ({:?}) to {url}",
1150                            cfg.inner.connect_timeout
1151                        ));
1152                        if attempts < max_retries {
1153                            attempts += 1;
1154                            tracing::warn!(
1155                                url = url,
1156                                attempt = attempts,
1157                                max_retries,
1158                                "WebSocket connect timeout — retrying"
1159                            );
1160                            tokio::time::sleep(std::time::Duration::from_millis(
1161                                cfg.inner.reconnect_delay_ms,
1162                            ))
1163                            .await;
1164                            continue;
1165                        }
1166                        return Err(err);
1167                    }
1168                }
1169            };
1170            if attempts > 0 {
1171                tracing::info!(url = url, "WebSocket producer connected after retry");
1172            }
1173
1174            let out_msg = body_to_client_ws_message(
1175                std::mem::take(&mut exchange.input.body),
1176                effective_message_type,
1177            )
1178            .await?;
1179
1180            ws_stream
1181                .send(out_msg)
1182                .await
1183                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1184
1185            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1186                loop {
1187                    match ws_stream.next().await {
1188                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1189                            continue;
1190                        }
1191                        other => break other,
1192                    }
1193                }
1194            })
1195            .await
1196            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1197
1198            match incoming {
1199                Some(Ok(ClientWsMessage::Text(text))) => {
1200                    tracing::debug!(url = url, "WebSocket producer received text response");
1201                    exchange.input.body = CamelBody::Text(text.to_string());
1202                }
1203                Some(Ok(ClientWsMessage::Binary(data))) => {
1204                    tracing::debug!(url = url, "WebSocket producer received binary response");
1205                    exchange.input.body = CamelBody::Bytes(data);
1206                }
1207                Some(Ok(ClientWsMessage::Close(frame))) => {
1208                    let normal = frame
1209                        .as_ref()
1210                        .map(|f| {
1211                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1212                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1213                        })
1214                        .unwrap_or(true);
1215
1216                    if normal {
1217                        tracing::debug!(url = url, "WebSocket producer received normal close");
1218                        exchange.input.body = CamelBody::Empty;
1219                    } else if cfg.inner.reconnect && attempts < max_retries {
1220                        attempts += 1;
1221                        tracing::warn!(
1222                            url = url,
1223                            attempt = attempts,
1224                            max_retries,
1225                            "WebSocket closed by peer — reconnecting"
1226                        );
1227                        tokio::time::sleep(std::time::Duration::from_millis(
1228                            cfg.inner.reconnect_delay_ms,
1229                        ))
1230                        .await;
1231                        return Err(CamelError::ProcessorError(format!(
1232                            "WebSocket reconnect required after close: code {}",
1233                            frame.map(|f| u16::from(f.code)).unwrap_or_default()
1234                        )));
1235                    } else {
1236                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1237                        return Err(CamelError::ProcessorError(format!(
1238                            "WebSocket peer closed: code {code}"
1239                        )));
1240                    }
1241                }
1242                Some(Ok(_)) | None => {
1243                    exchange.input.body = CamelBody::Empty;
1244                }
1245                Some(Err(e)) => {
1246                    return Err(CamelError::ProcessorError(format!(
1247                        "WebSocket receive failed: {e}"
1248                    )));
1249                }
1250            }
1251
1252            let _ = ws_stream.close(None).await;
1253            tracing::debug!(url = url, "WebSocket producer connection closed");
1254            Ok(exchange)
1255        })
1256    }
1257}
1258
1259async fn body_to_axum_ws_message(
1260    body: CamelBody,
1261    message_type: &str,
1262) -> Result<WsMessage, CamelError> {
1263    match message_type {
1264        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1265        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1266    }
1267}
1268
1269async fn body_to_client_ws_message(
1270    body: CamelBody,
1271    message_type: &str,
1272) -> Result<ClientWsMessage, CamelError> {
1273    match message_type {
1274        "binary" => Ok(ClientWsMessage::Binary(
1275            body.into_bytes(10 * 1024 * 1024).await?,
1276        )),
1277        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1278    }
1279}
1280
1281async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1282    Ok(match body {
1283        CamelBody::Empty => String::new(),
1284        CamelBody::Text(s) => s,
1285        CamelBody::Xml(s) => s,
1286        CamelBody::Json(v) => v.to_string(),
1287        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1288        CamelBody::Stream(stream) => {
1289            let bytes = CamelBody::Stream(stream)
1290                .into_bytes(10 * 1024 * 1024)
1291                .await?;
1292            String::from_utf8_lossy(&bytes).to_string()
1293        }
1294    })
1295}
1296
1297fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1298    if allowed_origin == "*" {
1299        return true;
1300    }
1301    request_origin.is_some_and(|origin| origin == allowed_origin)
1302}
1303
1304fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1305    match tx.try_send(msg) {
1306        Ok(()) => true,
1307        Err(error) => {
1308            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1309            false
1310        }
1311    }
1312}
1313
1314fn load_tls_config(
1315    cert_path: &str,
1316    key_path: &str,
1317) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1318    use std::fs::File;
1319    use std::io::BufReader;
1320
1321    let cert_file = File::open(cert_path)
1322        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1323    let key_file = File::open(key_path)
1324        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1325
1326    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1327        .collect::<Result<Vec<_>, _>>()
1328        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1329
1330    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1331        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1332        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1333
1334    tokio_rustls::rustls::ServerConfig::builder()
1335        .with_no_client_auth()
1336        .with_single_cert(certs, key)
1337        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1338}
1339
1340fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1341    match err {
1342        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1343            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1344        }
1345        tungstenite::Error::Tls(_) => {
1346            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1347        }
1348        other => {
1349            let msg = other.to_string();
1350            if msg.to_lowercase().contains("connection refused") {
1351                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1352            } else if msg.to_lowercase().contains("tls") {
1353                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1354            } else {
1355                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1356            }
1357        }
1358    }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363    use super::*;
1364    use camel_component_api::NoOpComponentContext;
1365    use std::time::Duration;
1366
1367    use tokio::sync::mpsc;
1368    use tokio_tungstenite::connect_async;
1369    use tokio_tungstenite::tungstenite::Message as ClientMessage;
1370    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1371    use tokio_util::sync::CancellationToken;
1372    use tower::ServiceExt;
1373
1374    fn free_port() -> u16 {
1375        std::net::TcpListener::bind("127.0.0.1:0")
1376            .unwrap()
1377            .local_addr()
1378            .unwrap()
1379            .port()
1380    }
1381
1382    #[test]
1383    fn ws_component_scheme_is_ws() {
1384        assert_eq!(WsComponent::new().scheme(), "ws");
1385    }
1386
1387    #[test]
1388    fn wss_component_scheme_is_wss() {
1389        assert_eq!(WssComponent::new().scheme(), "wss");
1390    }
1391
1392    #[test]
1393    fn endpoint_config_defaults_match_spec() {
1394        let cfg = WsEndpointConfig::default();
1395        assert_eq!(cfg.scheme, "ws");
1396        assert_eq!(cfg.host, "0.0.0.0");
1397        assert_eq!(cfg.port, 8080);
1398        assert_eq!(cfg.path, "/");
1399        assert_eq!(cfg.max_connections, 100);
1400        assert_eq!(cfg.max_message_size, 65536);
1401        assert!(!cfg.send_to_all);
1402        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1403        assert_eq!(cfg.idle_timeout, Duration::ZERO);
1404        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1405        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1406        assert_eq!(cfg.allow_origin, "*");
1407        assert_eq!(cfg.tls_cert, None);
1408        assert_eq!(cfg.tls_key, None);
1409        assert!(cfg.reconnect);
1410        assert_eq!(cfg.reconnect_max_attempts, 5);
1411        assert_eq!(cfg.reconnect_delay_ms, 1000);
1412        assert_eq!(cfg.send_timeout, Duration::from_secs(30));
1413        assert!(!cfg.binary_payload);
1414        assert!(cfg.subprotocols.is_empty());
1415    }
1416
1417    #[test]
1418    fn endpoint_config_parses_uri_params() {
1419        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1420        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1421
1422        assert_eq!(cfg.scheme, "ws");
1423        assert_eq!(cfg.host, "localhost");
1424        assert_eq!(cfg.port, 9001);
1425        assert_eq!(cfg.path, "/chat");
1426        assert_eq!(cfg.max_connections, 42);
1427        assert_eq!(cfg.max_message_size, 1024);
1428        assert!(cfg.send_to_all);
1429        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1430        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1431        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1432        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1433        assert_eq!(cfg.allow_origin, "https://example.com");
1434        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1435        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1436        assert!(cfg.reconnect);
1437        assert_eq!(cfg.reconnect_max_attempts, 5);
1438        assert_eq!(cfg.reconnect_delay_ms, 1000);
1439    }
1440
1441    #[test]
1442    fn endpoint_config_parses_reconnect_uri_params() {
1443        let uri =
1444            "ws://localhost:9001/chat?reconnect=false&reconnectMaxAttempts=2&reconnectDelayMs=25";
1445        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1446        assert!(!cfg.reconnect);
1447        assert_eq!(cfg.reconnect_max_attempts, 2);
1448        assert_eq!(cfg.reconnect_delay_ms, 25);
1449    }
1450
1451    #[test]
1452    fn endpoint_config_override_chain_uri_overrides_defaults() {
1453        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1454        assert_eq!(cfg.max_connections, 7);
1455        assert_eq!(cfg.max_message_size, 65536);
1456        assert!(!cfg.send_to_all);
1457        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1458    }
1459
1460    #[test]
1461    fn endpoint_trait_creates_consumer_and_producer() {
1462        let ctx = NoOpComponentContext;
1463        let endpoint = WsComponent::new()
1464            .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1465            .unwrap();
1466
1467        endpoint.create_consumer().unwrap();
1468        endpoint
1469            .create_producer(&ProducerContext::default())
1470            .unwrap();
1471    }
1472
1473    #[test]
1474    fn ws_consumer_concurrency_model_uses_max_connections() {
1475        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1476        let consumer = WsConsumer::new(cfg.server_config());
1477        assert_eq!(
1478            consumer.concurrency_model(),
1479            ConcurrencyModel::Concurrent { max: Some(321) }
1480        );
1481    }
1482
1483    #[tokio::test]
1484    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1485        let registry = WsConnectionRegistry::new();
1486        let (tx1, mut rx1) = mpsc::channel(8);
1487        let (tx2, mut rx2) = mpsc::channel(8);
1488
1489        registry.insert("k1".into(), tx1);
1490        registry.insert("k2".into(), tx2);
1491        assert_eq!(registry.len(), 2);
1492
1493        for tx in registry.snapshot_senders() {
1494            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1495        }
1496
1497        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1498        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1499
1500        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1501        assert_eq!(target.len(), 1);
1502        target[0]
1503            .send(WsMessage::Text("targeted".into()))
1504            .await
1505            .unwrap();
1506
1507        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1508        assert!(
1509            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1510                .await
1511                .is_err()
1512        );
1513
1514        registry.remove("k1");
1515        assert_eq!(registry.len(), 1);
1516    }
1517
1518    #[test]
1519    fn host_canonicalization_maps_local_hosts_to_loopback() {
1520        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1521            .unwrap()
1522            .canonical_host();
1523        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1524            .unwrap()
1525            .canonical_host();
1526        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1527            .unwrap()
1528            .canonical_host();
1529
1530        assert_eq!(c1, "127.0.0.1");
1531        assert_eq!(c2, "127.0.0.1");
1532        assert_eq!(c3, "127.0.0.1");
1533    }
1534
1535    #[tokio::test]
1536    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1537        let port = free_port();
1538        let uri = format!("ws://127.0.0.1:{port}/echo");
1539        let component_ctx = NoOpComponentContext;
1540        let endpoint = WsComponent::new()
1541            .create_endpoint(&uri, &component_ctx)
1542            .unwrap();
1543
1544        let mut consumer = endpoint.create_consumer().unwrap();
1545        let producer = endpoint
1546            .create_producer(&ProducerContext::default())
1547            .unwrap();
1548
1549        let (route_tx, mut route_rx) = mpsc::channel(16);
1550        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1551        consumer.start(ctx).await.unwrap();
1552
1553        let route_task = tokio::spawn(async move {
1554            if let Some(envelope) = route_rx.recv().await {
1555                let payload = envelope
1556                    .exchange
1557                    .input
1558                    .body
1559                    .as_text()
1560                    .unwrap_or_default()
1561                    .to_string();
1562                let key = envelope
1563                    .exchange
1564                    .input
1565                    .header("CamelWsConnectionKey")
1566                    .and_then(|v| v.as_str())
1567                    .unwrap()
1568                    .to_string();
1569
1570                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1571                response
1572                    .input
1573                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1574                producer.oneshot(response).await.unwrap();
1575            }
1576        });
1577
1578        let url = format!("ws://127.0.0.1:{port}/echo");
1579        let (mut client, _) = loop {
1580            match connect_async(&url).await {
1581                Ok(ok) => break ok,
1582                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1583            }
1584        };
1585
1586        client
1587            .send(ClientMessage::Text("hello-ws".into()))
1588            .await
1589            .unwrap();
1590
1591        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1592            loop {
1593                match client.next().await {
1594                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1595                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1596                    Some(Ok(_)) => continue,
1597                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1598                    None => panic!("websocket closed before echo"),
1599                }
1600            }
1601        })
1602        .await
1603        .unwrap();
1604
1605        assert_eq!(incoming, "hello-ws");
1606
1607        consumer.stop().await.unwrap();
1608        route_task.await.unwrap();
1609    }
1610
1611    #[tokio::test]
1612    async fn consumer_stop_sends_close_1001() {
1613        let port = free_port();
1614        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1615        let component_ctx = NoOpComponentContext;
1616        let endpoint = WsComponent::new()
1617            .create_endpoint(&uri, &component_ctx)
1618            .unwrap();
1619
1620        let mut consumer = endpoint.create_consumer().unwrap();
1621        let (route_tx, _route_rx) = mpsc::channel(16);
1622        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1623        consumer.start(ctx).await.unwrap();
1624
1625        let url = format!("ws://127.0.0.1:{port}/shutdown");
1626        let (mut client, _) = loop {
1627            match connect_async(&url).await {
1628                Ok(ok) => break ok,
1629                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1630            }
1631        };
1632
1633        client
1634            .send(ClientMessage::Text("keepalive".into()))
1635            .await
1636            .unwrap();
1637
1638        consumer.stop().await.unwrap();
1639
1640        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1641            loop {
1642                match client.next().await {
1643                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1644                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1645                    Some(Ok(_)) => continue,
1646                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1647                    None => panic!("websocket closed without close frame"),
1648                }
1649            }
1650        })
1651        .await
1652        .unwrap();
1653
1654        assert_eq!(close_code, Some(CloseCode::Away));
1655    }
1656
1657    #[test]
1658    fn wildcard_origin_allows_anything() {
1659        assert!(is_origin_allowed("*", None));
1660        assert!(is_origin_allowed("*", Some("https://example.com")));
1661    }
1662
1663    #[test]
1664    fn exact_origin_requires_match() {
1665        assert!(is_origin_allowed(
1666            "https://example.com",
1667            Some("https://example.com")
1668        ));
1669        assert!(!is_origin_allowed(
1670            "https://example.com",
1671            Some("https://other.com")
1672        ));
1673        assert!(!is_origin_allowed("https://example.com", None));
1674    }
1675
1676    #[test]
1677    fn endpoint_config_rejects_invalid_scheme() {
1678        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1679        assert!(result.is_err());
1680        let msg = result.unwrap_err().to_string();
1681        assert!(
1682            msg.contains("Invalid WebSocket scheme"),
1683            "expected scheme error, got: {msg}"
1684        );
1685    }
1686
1687    #[tokio::test]
1688    async fn wss_consumer_start_fails_without_tls_cert() {
1689        let port = free_port();
1690        let component_ctx = NoOpComponentContext;
1691        let endpoint = WssComponent::new()
1692            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1693            .unwrap();
1694        let mut consumer = endpoint.create_consumer().unwrap();
1695        let (tx, _rx) = mpsc::channel(16);
1696        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1697        let result = consumer.start(ctx).await;
1698        assert!(result.is_err());
1699        let msg = result.unwrap_err().to_string();
1700        assert!(
1701            msg.contains("TLS cert path is required"),
1702            "expected TLS cert error, got: {msg}"
1703        );
1704    }
1705
1706    #[tokio::test]
1707    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1708        let port = free_port();
1709        let component_ctx = NoOpComponentContext;
1710        let endpoint = WssComponent::new()
1711            .create_endpoint(&format!(
1712                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1713            ), &component_ctx)
1714            .unwrap();
1715        let mut consumer = endpoint.create_consumer().unwrap();
1716        let (tx, _rx) = mpsc::channel(16);
1717        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1718        let result = consumer.start(ctx).await;
1719        assert!(result.is_err());
1720        let msg = result.unwrap_err().to_string();
1721        assert!(
1722            msg.contains("TLS cert file error"),
1723            "expected cert file error, got: {msg}"
1724        );
1725    }
1726
1727    #[tokio::test]
1728    async fn server_registry_returns_same_state_for_same_port() {
1729        let port = free_port();
1730        let state1 = ServerRegistry::global()
1731            .get_or_spawn("127.0.0.1", port, None)
1732            .await
1733            .unwrap();
1734        let state2 = ServerRegistry::global()
1735            .get_or_spawn("127.0.0.1", port, None)
1736            .await
1737            .unwrap();
1738        assert!(
1739            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1740            "expected same dispatch table for same port"
1741        );
1742    }
1743
1744    #[tokio::test]
1745    async fn dispatch_handler_returns_404_for_unregistered_path() {
1746        let port = free_port();
1747        let state = ServerRegistry::global()
1748            .get_or_spawn("127.0.0.1", port, None)
1749            .await
1750            .unwrap();
1751        let app = Router::new().fallback(dispatch_handler).with_state(state);
1752        let response = tokio::time::timeout(
1753            Duration::from_secs(2),
1754            tower::ServiceExt::oneshot(
1755                app,
1756                axum::http::Request::builder()
1757                    .uri("/nonexistent")
1758                    .body(Body::empty())
1759                    .unwrap(),
1760            ),
1761        )
1762        .await
1763        .unwrap()
1764        .unwrap();
1765        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1766    }
1767
1768    #[tokio::test]
1769    async fn client_mode_producer_connects_and_echoes() {
1770        let app = Router::new().route(
1771            "/echo",
1772            axum::routing::get(|ws: WebSocketUpgrade| async move {
1773                ws.on_upgrade(|mut socket: WebSocket| async move {
1774                    while let Some(Ok(msg)) = socket.recv().await {
1775                        match msg {
1776                            WsMessage::Text(text) => {
1777                                let _ = socket.send(WsMessage::Text(text)).await;
1778                            }
1779                            WsMessage::Binary(data) => {
1780                                let _ = socket.send(WsMessage::Binary(data)).await;
1781                            }
1782                            WsMessage::Close(_) => break,
1783                            _ => {}
1784                        }
1785                    }
1786                })
1787            }),
1788        );
1789        // Bind to port 0 directly to avoid TOCTOU race with free_port() + re-bind
1790        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1791        let port = listener.local_addr().unwrap().port();
1792        let server_task = tokio::spawn(async move {
1793            let _ = serve(listener, app).await;
1794        });
1795
1796        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1797        let producer = WsProducer::new(cfg.client_config());
1798
1799        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1800        tokio::time::sleep(Duration::from_millis(25)).await;
1801        let result =
1802            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1803                Ok(Ok(r)) => r,
1804                Ok(Err(_)) => panic!("producer call failed"),
1805                Err(_) => panic!("producer call timed out"),
1806            };
1807
1808        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1809
1810        server_task.abort();
1811    }
1812
1813    #[tokio::test]
1814    async fn max_connections_rejects_with_close_1013() {
1815        let port = free_port();
1816        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1817        let component_ctx = NoOpComponentContext;
1818        let endpoint = WsComponent::new()
1819            .create_endpoint(&uri, &component_ctx)
1820            .unwrap();
1821        let mut consumer = endpoint.create_consumer().unwrap();
1822        let (route_tx, _route_rx) = mpsc::channel(16);
1823        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1824        consumer.start(ctx).await.unwrap();
1825
1826        let url = format!("ws://127.0.0.1:{port}/limited");
1827        let (_client1, _) = loop {
1828            match connect_async(&url).await {
1829                Ok(ok) => break ok,
1830                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1831            }
1832        };
1833
1834        tokio::time::sleep(Duration::from_millis(100)).await;
1835
1836        let (mut client2, _) = connect_async(&url).await.unwrap();
1837
1838        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1839            loop {
1840                match client2.next().await {
1841                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1842                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1843                    Some(Ok(ClientMessage::Text(_))) => continue,
1844                    Some(Ok(_)) => continue,
1845                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1846                    None => panic!("client2 closed without close frame"),
1847                }
1848            }
1849        })
1850        .await
1851        .unwrap();
1852
1853        assert_eq!(
1854            close_code,
1855            Some(CloseCode::from(1013u16)),
1856            "expected 1013 (Try Again Later) for max connections"
1857        );
1858
1859        consumer.stop().await.unwrap();
1860    }
1861
1862    #[tokio::test]
1863    async fn max_message_size_rejects_with_close_1009() {
1864        let port = free_port();
1865        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1866        let component_ctx = NoOpComponentContext;
1867        let endpoint = WsComponent::new()
1868            .create_endpoint(&uri, &component_ctx)
1869            .unwrap();
1870        let mut consumer = endpoint.create_consumer().unwrap();
1871        let (route_tx, _route_rx) = mpsc::channel(16);
1872        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1873        consumer.start(ctx).await.unwrap();
1874
1875        let url = format!("ws://127.0.0.1:{port}/sizelimit");
1876        let (mut client, _) = loop {
1877            match connect_async(&url).await {
1878                Ok(ok) => break ok,
1879                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1880            }
1881        };
1882
1883        let oversized = "x".repeat(100);
1884        client
1885            .send(ClientMessage::Text(oversized.into()))
1886            .await
1887            .unwrap();
1888
1889        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1890            loop {
1891                match client.next().await {
1892                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1893                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1894                    Some(Ok(_)) => continue,
1895                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1896                    None => panic!("websocket closed without close frame"),
1897                }
1898            }
1899        })
1900        .await
1901        .unwrap();
1902
1903        assert_eq!(
1904            close_code,
1905            Some(CloseCode::from(1009u16)),
1906            "expected 1009 (Message Too Big) for oversized message"
1907        );
1908
1909        consumer.stop().await.unwrap();
1910    }
1911
1912    #[tokio::test]
1913    async fn origin_rejection_returns_403() {
1914        let port = free_port();
1915        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1916        let component_ctx = NoOpComponentContext;
1917        let endpoint = WsComponent::new()
1918            .create_endpoint(&uri, &component_ctx)
1919            .unwrap();
1920        let mut consumer = endpoint.create_consumer().unwrap();
1921        let (route_tx, _route_rx) = mpsc::channel(16);
1922        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1923        consumer.start(ctx).await.unwrap();
1924
1925        let state = ServerRegistry::global()
1926            .get_or_spawn("127.0.0.1", port, None)
1927            .await
1928            .unwrap();
1929        let app = Router::new().fallback(dispatch_handler).with_state(state);
1930
1931        let response = tokio::time::timeout(
1932            Duration::from_secs(2),
1933            tower::ServiceExt::oneshot(
1934                app,
1935                axum::http::Request::builder()
1936                    .uri("/origintest")
1937                    .header("origin", "https://evil.com")
1938                    .header("upgrade", "websocket")
1939                    .header("connection", "Upgrade")
1940                    .header("sec-websocket-version", "13")
1941                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1942                    .body(Body::empty())
1943                    .unwrap(),
1944            ),
1945        )
1946        .await
1947        .unwrap()
1948        .unwrap();
1949
1950        assert_eq!(
1951            response.status(),
1952            StatusCode::FORBIDDEN,
1953            "expected 403 for disallowed origin"
1954        );
1955
1956        consumer.stop().await.unwrap();
1957    }
1958
1959    #[tokio::test]
1960    async fn broadcast_sends_to_all_connected_clients() {
1961        let port = free_port();
1962        let uri = format!("ws://127.0.0.1:{port}/bc");
1963        let component_ctx = NoOpComponentContext;
1964        let endpoint = WsComponent::new()
1965            .create_endpoint(&uri, &component_ctx)
1966            .unwrap();
1967        let mut consumer = endpoint.create_consumer().unwrap();
1968        let producer = endpoint
1969            .create_producer(&ProducerContext::default())
1970            .unwrap();
1971
1972        let (route_tx, _route_rx) = mpsc::channel(16);
1973        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1974        consumer.start(ctx).await.unwrap();
1975
1976        let url = format!("ws://127.0.0.1:{port}/bc");
1977
1978        let (mut client1, _) = loop {
1979            match connect_async(&url).await {
1980                Ok(ok) => break ok,
1981                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1982            }
1983        };
1984
1985        let (mut client2, _) = connect_async(&url).await.unwrap();
1986
1987        tokio::time::sleep(Duration::from_millis(100)).await;
1988
1989        let mut response =
1990            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1991        response
1992            .input
1993            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1994        producer.oneshot(response).await.unwrap();
1995
1996        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1997            loop {
1998                match client1.next().await {
1999                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2000                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2001                    _ => panic!("client1 unexpected message or close"),
2002                }
2003            }
2004        })
2005        .await
2006        .unwrap();
2007
2008        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
2009            loop {
2010                match client2.next().await {
2011                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2012                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2013                    _ => panic!("client2 unexpected message or close"),
2014                }
2015            }
2016        })
2017        .await
2018        .unwrap();
2019
2020        assert_eq!(recv1, "broadcast-msg");
2021        assert_eq!(recv2, "broadcast-msg");
2022
2023        consumer.stop().await.unwrap();
2024    }
2025
2026    #[tokio::test]
2027    async fn concurrent_get_or_spawn_returns_same_state() {
2028        let port = free_port();
2029        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
2030            Arc::new(std::sync::Mutex::new(Vec::new()));
2031
2032        let mut handles = Vec::new();
2033        for _ in 0..4 {
2034            let results = results.clone();
2035            handles.push(tokio::spawn(async move {
2036                let state = ServerRegistry::global()
2037                    .get_or_spawn("127.0.0.1", port, None)
2038                    .await
2039                    .unwrap();
2040                results.lock().unwrap().push(state);
2041            }));
2042        }
2043
2044        for h in handles {
2045            h.await.unwrap();
2046        }
2047
2048        let states = results.lock().unwrap();
2049        assert_eq!(states.len(), 4);
2050        for i in 1..states.len() {
2051            assert!(
2052                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
2053                "all concurrent callers should get the same dispatch table"
2054            );
2055        }
2056    }
2057
2058    #[tokio::test]
2059    async fn body_conversion_helpers_cover_text_and_binary_paths() {
2060        let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
2061            .await
2062            .unwrap();
2063        assert!(matches!(text_msg, WsMessage::Text(_)));
2064
2065        let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
2066            .await
2067            .unwrap();
2068        assert!(matches!(bin_msg, WsMessage::Binary(_)));
2069
2070        let client_text =
2071            body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
2072                .await
2073                .unwrap();
2074        assert!(matches!(client_text, ClientWsMessage::Text(_)));
2075
2076        let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
2077            .await
2078            .unwrap();
2079        assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
2080    }
2081
2082    #[tokio::test]
2083    async fn body_to_text_handles_empty_text_json_and_bytes() {
2084        assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2085        assert_eq!(
2086            body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2087            "hello"
2088        );
2089        assert_eq!(
2090            body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2091                .await
2092                .unwrap(),
2093            "{\"n\":1}"
2094        );
2095        assert_eq!(
2096            body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2097                .await
2098                .unwrap(),
2099            "hi"
2100        );
2101    }
2102
2103    #[test]
2104    fn try_send_with_backpressure_returns_false_when_channel_full() {
2105        let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2106        assert!(try_send_with_backpressure(
2107            &tx,
2108            WsMessage::Text("first".into()),
2109            "test"
2110        ));
2111        assert!(!try_send_with_backpressure(
2112            &tx,
2113            WsMessage::Text("second".into()),
2114            "test"
2115        ));
2116    }
2117
2118    #[test]
2119    fn map_connect_error_formats_connection_refused_and_generic_errors() {
2120        let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2121        let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2122        assert!(err.to_string().contains("WebSocket connection refused"));
2123
2124        let generic = map_connect_error(
2125            tungstenite::Error::Protocol(
2126                tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2127            ),
2128            "ws://localhost:2/y",
2129        );
2130        assert!(
2131            generic
2132                .to_string()
2133                .contains("WebSocket connection failed (ws://localhost:2/y)")
2134        );
2135    }
2136
2137    // === Phase B Finding Tests ===
2138
2139    // WS-015: maxConnections=0 must be rejected
2140    #[test]
2141    fn from_uri_rejects_max_connections_zero() {
2142        let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2143        assert!(result.is_err());
2144        let msg = result.unwrap_err().to_string();
2145        assert!(
2146            msg.contains("maxConnections must be >= 1"),
2147            "expected maxConnections validation error, got: {msg}"
2148        );
2149    }
2150
2151    // WS-019: maxMessageSize=0 must be rejected
2152    #[test]
2153    fn from_uri_rejects_max_message_size_zero() {
2154        let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2155        assert!(result.is_err());
2156        let msg = result.unwrap_err().to_string();
2157        assert!(
2158            msg.contains("maxMessageSize must be > 0"),
2159            "expected maxMessageSize validation error, got: {msg}"
2160        );
2161    }
2162
2163    // WS-020: allowOrigin="" must be rejected
2164    #[test]
2165    fn from_uri_rejects_empty_allow_origin() {
2166        let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2167        assert!(result.is_err());
2168        let msg = result.unwrap_err().to_string();
2169        assert!(
2170            msg.contains("allowOrigin must not be empty"),
2171            "expected allowOrigin validation error, got: {msg}"
2172        );
2173    }
2174
2175    // WS-006: Double-start must be rejected
2176    #[tokio::test]
2177    async fn consumer_double_start_returns_error() {
2178        let port = free_port();
2179        let uri = format!("ws://127.0.0.1:{port}/doublestart");
2180        let component_ctx = NoOpComponentContext;
2181        let endpoint = WsComponent::new()
2182            .create_endpoint(&uri, &component_ctx)
2183            .unwrap();
2184
2185        let mut consumer = endpoint.create_consumer().unwrap();
2186        let (route_tx, _route_rx) = mpsc::channel(16);
2187        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2188
2189        // First start should succeed
2190        consumer.start(ctx).await.unwrap();
2191
2192        // Second start should fail
2193        let (route_tx2, _route_rx2) = mpsc::channel(16);
2194        let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2195        let result = consumer.start(ctx2).await;
2196        assert!(result.is_err());
2197        let msg = result.unwrap_err().to_string();
2198        assert!(
2199            msg.contains("already started"),
2200            "expected double-start error, got: {msg}"
2201        );
2202
2203        consumer.stop().await.unwrap();
2204    }
2205
2206    // WS-005: Registry cleanup on stop + port reuse
2207    #[tokio::test]
2208    async fn registry_cleanup_on_consumer_stop() {
2209        let port = free_port();
2210        let uri = format!("ws://127.0.0.1:{port}/cleanup");
2211        let component_ctx = NoOpComponentContext;
2212        let endpoint = WsComponent::new()
2213            .create_endpoint(&uri, &component_ctx)
2214            .unwrap();
2215
2216        let mut consumer = endpoint.create_consumer().unwrap();
2217        let (route_tx, _route_rx) = mpsc::channel(16);
2218        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2219        consumer.start(ctx).await.unwrap();
2220
2221        // Verify registry entry exists
2222        let registries = global_registries();
2223        let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2224        assert!(
2225            registries.contains_key(&key),
2226            "registry should have entry after start"
2227        );
2228
2229        // Stop consumer
2230        consumer.stop().await.unwrap();
2231
2232        // Verify registry entry is removed
2233        assert!(
2234            !registries.contains_key(&key),
2235            "registry should be cleaned up after stop"
2236        );
2237
2238        // Verify server registry reference is released (port can be reused)
2239        // The ServerRegistry should have removed the entry
2240        let server_reg = ServerRegistry::global();
2241        let guard = server_reg.inner.lock().unwrap();
2242        assert!(
2243            !guard.contains_key(&port),
2244            "ServerRegistry should remove port entry after last consumer stops"
2245        );
2246    }
2247
2248    // WS-003 + WS-004: poll_ready backpressure and server-send error handling
2249    #[tokio::test]
2250    async fn producer_server_send_returns_error_when_all_dropped() {
2251        let port = free_port();
2252        let uri = format!("ws://127.0.0.1:{port}/backpressure");
2253        let component_ctx = NoOpComponentContext;
2254        let endpoint = WsComponent::new()
2255            .create_endpoint(&uri, &component_ctx)
2256            .unwrap();
2257
2258        let mut consumer = endpoint.create_consumer().unwrap();
2259        let producer = endpoint
2260            .create_producer(&ProducerContext::default())
2261            .unwrap();
2262
2263        let (route_tx, _route_rx) = mpsc::channel(1); // Tiny channel to force backpressure
2264        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2265        consumer.start(ctx).await.unwrap();
2266
2267        // Connect a client so the registry has an entry
2268        let url = format!("ws://127.0.0.1:{port}/backpressure");
2269        let (mut client, _) = loop {
2270            match connect_async(&url).await {
2271                Ok(ok) => break ok,
2272                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2273            }
2274        };
2275
2276        // Don't consume messages — let the channel fill up
2277        tokio::time::sleep(Duration::from_millis(50)).await;
2278
2279        // Flood the channel to trigger backpressure
2280        let mut all_dropped = false;
2281        for _ in 0..100 {
2282            let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2283            match producer.clone().oneshot(exchange).await {
2284                Ok(_) => {}
2285                Err(e) => {
2286                    if e.to_string().contains("backpressure") {
2287                        all_dropped = true;
2288                        break;
2289                    }
2290                }
2291            }
2292        }
2293
2294        // The producer should eventually return a backpressure error
2295        assert!(
2296            all_dropped,
2297            "producer should return error when all messages are dropped due to backpressure"
2298        );
2299
2300        // Clean up
2301        let _ = client.close(None).await;
2302        consumer.stop().await.unwrap();
2303    }
2304
2305    // WS-012: Ping/pong round-trip in server mode
2306    #[tokio::test]
2307    async fn server_responds_to_client_ping_with_pong() {
2308        let port = free_port();
2309        let uri = format!("ws://127.0.0.1:{port}/pingpong");
2310        let component_ctx = NoOpComponentContext;
2311        let endpoint = WsComponent::new()
2312            .create_endpoint(&uri, &component_ctx)
2313            .unwrap();
2314
2315        let mut consumer = endpoint.create_consumer().unwrap();
2316        let (route_tx, _route_rx) = mpsc::channel(16);
2317        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2318        consumer.start(ctx).await.unwrap();
2319
2320        let url = format!("ws://127.0.0.1:{port}/pingpong");
2321        let (mut client, _) = loop {
2322            match connect_async(&url).await {
2323                Ok(ok) => break ok,
2324                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2325            }
2326        };
2327
2328        // Send a ping
2329        client
2330            .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2331            .await
2332            .unwrap();
2333
2334        // Expect a pong with the same payload
2335        let pong = tokio::time::timeout(Duration::from_secs(2), async {
2336            loop {
2337                match client.next().await {
2338                    Some(Ok(ClientMessage::Pong(data))) => break data,
2339                    Some(Ok(ClientMessage::Ping(_))) => continue,
2340                    Some(Ok(_)) => continue,
2341                    Some(Err(e)) => panic!("ws receive failed: {e}"),
2342                    None => panic!("websocket closed before pong"),
2343                }
2344            }
2345        })
2346        .await
2347        .unwrap();
2348
2349        assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2350
2351        consumer.stop().await.unwrap();
2352    }
2353
2354    // WS-008: Client-side retry on transient connect failures
2355    #[tokio::test]
2356    async fn producer_retries_on_connection_refused() {
2357        // Use a port that nothing is listening on
2358        let port = free_port();
2359        // Ensure nothing is on this port
2360        let cfg = WsEndpointConfig::from_uri(&format!(
2361            "ws://127.0.0.1:{port}/retry?reconnect=true&reconnectMaxAttempts=2&reconnectDelayMs=50"
2362        ))
2363        .unwrap();
2364        let producer = WsProducer::new(cfg.client_config());
2365
2366        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2367
2368        // Should fail after retries (nothing listening)
2369        let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2370        assert!(
2371            result.is_ok(),
2372            "producer should complete (with error) within timeout"
2373        );
2374        let result = result.unwrap();
2375        assert!(
2376            result.is_err(),
2377            "producer should fail when nothing is listening"
2378        );
2379        let msg = result.unwrap_err().to_string();
2380        assert!(
2381            msg.contains("connection refused"),
2382            "expected connection refused error, got: {msg}"
2383        );
2384    }
2385
2386    // WS-001: Server bind error is visible (fake server-start error test)
2387    #[tokio::test]
2388    async fn server_bind_error_is_reported() {
2389        // Bind a port manually to cause a conflict
2390        let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2391        let port = _listener.local_addr().unwrap().port();
2392
2393        // Try to start a consumer on the same port — should succeed since axum binds lazily
2394        // The actual bind error happens when the server task runs
2395        let uri = format!("ws://127.0.0.1:{port}/binderror");
2396        let component_ctx = NoOpComponentContext;
2397        let endpoint = WsComponent::new()
2398            .create_endpoint(&uri, &component_ctx)
2399            .unwrap();
2400
2401        let mut consumer = endpoint.create_consumer().unwrap();
2402        let (route_tx, _route_rx) = mpsc::channel(16);
2403        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2404
2405        // Start should succeed (server spawns, but bind may fail)
2406        let start_result = consumer.start(ctx).await;
2407        // The server may or may not have bound yet — this test verifies no panic
2408        // The actual error is logged by the server task
2409        let _ = start_result;
2410
2411        consumer.stop().await.unwrap();
2412    }
2413
2414    #[test]
2415    fn ws_app_state_server_error_starts_false() {
2416        let state = WsAppState {
2417            dispatch: Arc::new(RwLock::new(HashMap::new())),
2418            path_configs: Arc::new(DashMap::new()),
2419            server_error: new_atomic_false(),
2420        };
2421        assert!(
2422            !state.server_error.load(Ordering::Relaxed),
2423            "server_error should start as false"
2424        );
2425    }
2426
2427    #[test]
2428    fn ws_app_state_server_error_can_be_set() {
2429        let state = WsAppState {
2430            dispatch: Arc::new(RwLock::new(HashMap::new())),
2431            path_configs: Arc::new(DashMap::new()),
2432            server_error: new_atomic_false(),
2433        };
2434        assert!(!state.server_error.load(Ordering::Relaxed));
2435        state.server_error.store(true, Ordering::Relaxed);
2436        assert!(state.server_error.load(Ordering::Relaxed));
2437    }
2438
2439    #[tokio::test]
2440    async fn consumer_stop_returns_error_when_server_had_errors() {
2441        let port = free_port();
2442        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2443        let mut consumer = WsConsumer::new(cfg.server_config());
2444        let (route_tx, _route_rx) = mpsc::channel(16);
2445        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2446        consumer.start(ctx).await.unwrap();
2447
2448        // Simulate server error by setting the flag directly
2449        if let Some(ref state) = consumer.server_state {
2450            state.server_error.store(true, Ordering::Relaxed);
2451        }
2452
2453        let result = consumer.stop().await;
2454        assert!(
2455            result.is_err(),
2456            "stop should return error when server had errors"
2457        );
2458        let msg = result.unwrap_err().to_string();
2459        assert!(
2460            msg.contains("terminated with errors"),
2461            "expected server error message, got: {msg}"
2462        );
2463    }
2464
2465    #[tokio::test]
2466    async fn consumer_stop_succeeds_when_server_healthy() {
2467        let port = free_port();
2468        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2469        let mut consumer = WsConsumer::new(cfg.server_config());
2470        let (route_tx, _route_rx) = mpsc::channel(16);
2471        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2472        consumer.start(ctx).await.unwrap();
2473
2474        let result = consumer.stop().await;
2475        assert!(
2476            result.is_ok(),
2477            "stop should succeed when server is healthy: {:?}",
2478            result
2479        );
2480    }
2481
2482    // === H-10 Finding Tests ===
2483
2484    // WS-007: subprotocol negotiation support
2485    #[test]
2486    fn endpoint_config_parses_subprotocols() {
2487        let cfg = WsEndpointConfig::from_uri(
2488            "ws://localhost:9001/chat?subprotocols=graphql-ws,graphql-transport-ws",
2489        )
2490        .unwrap();
2491        assert_eq!(cfg.subprotocols, vec!["graphql-ws", "graphql-transport-ws"]);
2492    }
2493
2494    #[test]
2495    fn endpoint_config_default_subprotocols_empty() {
2496        let cfg = WsEndpointConfig::default();
2497        assert!(cfg.subprotocols.is_empty());
2498    }
2499
2500    // WS-017: sendTimeoutMs URI option
2501    #[test]
2502    fn endpoint_config_parses_send_timeout() {
2503        let cfg =
2504            WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=5000").unwrap();
2505        assert_eq!(cfg.send_timeout, Duration::from_millis(5000));
2506    }
2507
2508    #[test]
2509    fn endpoint_config_default_send_timeout() {
2510        let cfg = WsEndpointConfig::default();
2511        assert_eq!(cfg.send_timeout, Duration::from_secs(30));
2512    }
2513
2514    #[test]
2515    fn endpoint_config_rejects_invalid_send_timeout() {
2516        let err =
2517            WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=abc").unwrap_err();
2518        assert!(err.to_string().contains("sendTimeoutMs"));
2519    }
2520
2521    // WS-018: binaryPayload URI option
2522    #[test]
2523    fn endpoint_config_parses_binary_payload() {
2524        let cfg =
2525            WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=true").unwrap();
2526        assert!(cfg.binary_payload);
2527    }
2528
2529    #[test]
2530    fn endpoint_config_default_binary_payload_false() {
2531        let cfg = WsEndpointConfig::default();
2532        assert!(!cfg.binary_payload);
2533    }
2534
2535    #[test]
2536    fn endpoint_config_rejects_invalid_binary_payload() {
2537        let err =
2538            WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=yes").unwrap_err();
2539        assert!(err.to_string().contains("binaryPayload"));
2540    }
2541}