Skip to main content

camel_component_ws/
lib.rs

1//! WebSocket component for rust-camel — Axum-based WebSocket server and Tokio-tungstenite client for bidirectional messaging.
2//!
3//! Main types: `WsComponent`, `WsBundle`, `WsConfig`, `WsServerConfig`, `WsClientConfig`, `WsEndpointConfig`.
4//! Main modules: `bundle`, `config`.
5
6pub mod bundle;
7pub mod config;
8
9pub use bundle::WsBundle;
10pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
11
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex, OnceLock};
14
15use async_trait::async_trait;
16use axum::body::Body;
17use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
18use axum::extract::{FromRequest, Request, State};
19use axum::http::{StatusCode, header};
20use axum::response::IntoResponse;
21use axum::{Router, serve};
22use camel_component_api::{
23    Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
24};
25use camel_component_api::{
26    Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
27    ProducerContext,
28};
29use dashmap::DashMap;
30use futures::{SinkExt, StreamExt};
31use std::future::Future;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use tokio::sync::{OnceCell, RwLock, mpsc};
35use tokio::task::JoinHandle;
36use tokio_tungstenite::tungstenite;
37use tokio_tungstenite::tungstenite::client::IntoClientRequest;
38use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
39use tower::Service;
40
41#[derive(Clone)]
42struct WsPathConfig {
43    max_connections: u32,
44    max_message_size: u32,
45    heartbeat_interval: std::time::Duration,
46    idle_timeout: std::time::Duration,
47    allow_origin: String,
48}
49
50impl Default for WsPathConfig {
51    fn default() -> Self {
52        let cfg = WsEndpointConfig::default();
53        Self {
54            max_connections: cfg.max_connections,
55            max_message_size: cfg.max_message_size,
56            heartbeat_interval: cfg.heartbeat_interval,
57            idle_timeout: cfg.idle_timeout,
58            allow_origin: cfg.allow_origin,
59        }
60    }
61}
62
63#[derive(Clone)]
64struct WsTlsConfig {
65    cert_path: String,
66    key_path: String,
67}
68
69type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
70
71struct ServerHandle {
72    state: WsAppState,
73    is_tls: bool,
74    _task: JoinHandle<()>,
75}
76
77struct ServerRegistryInner {
78    cell: Arc<OnceCell<ServerHandle>>,
79    ref_count: usize,
80}
81
82pub struct ServerRegistry {
83    inner: Mutex<HashMap<u16, ServerRegistryInner>>,
84}
85
86impl ServerRegistry {
87    pub fn global() -> &'static Self {
88        static REG: OnceLock<ServerRegistry> = OnceLock::new();
89        REG.get_or_init(|| Self {
90            inner: Mutex::new(HashMap::new()),
91        })
92    }
93
94    pub(crate) async fn get_or_spawn(
95        &'static self,
96        host: &str,
97        port: u16,
98        tls_config: Option<WsTlsConfig>,
99    ) -> Result<WsAppState, CamelError> {
100        let wants_tls = tls_config.is_some();
101        let host_owned = host.to_string();
102
103        let (cell, _is_new) = {
104            let mut guard = self.inner.lock().map_err(|_| {
105                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
106            })?;
107            let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
108                cell: Arc::new(OnceCell::new()),
109                ref_count: 0,
110            });
111            entry.ref_count += 1;
112            (entry.cell.clone(), entry.ref_count == 1)
113        };
114
115        let handle = cell
116            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
117            .await?;
118
119        if wants_tls != handle.is_tls {
120            // Decrement ref count since we're rejecting this caller
121            let mut guard = self.inner.lock().map_err(|_| {
122                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
123            })?;
124            if let Some(entry) = guard.get_mut(&port) {
125                entry.ref_count -= 1;
126                if entry.ref_count == 0 {
127                    guard.remove(&port);
128                }
129            }
130            return Err(CamelError::EndpointCreationFailed(format!(
131                "Server on port {port} already running with different TLS mode"
132            )));
133        }
134
135        Ok(handle.state.clone())
136    }
137
138    /// Release a reference to the server on the given port.
139    /// When the last reference is released, the server task is aborted and the entry removed. (WS-005)
140    pub(crate) fn release(&self, port: u16) {
141        let mut guard = match self.inner.lock() {
142            Ok(g) => g,
143            Err(_) => return,
144        };
145        if let Some(entry) = guard.get_mut(&port) {
146            entry.ref_count = entry.ref_count.saturating_sub(1);
147            if entry.ref_count == 0 {
148                // Abort the server task if it exists (WS-001, WS-005)
149                if let Some(handle) = entry.cell.get() {
150                    handle._task.abort();
151                }
152                guard.remove(&port);
153                tracing::info!(port, "WebSocket server registry entry removed");
154            }
155        }
156    }
157}
158
159async fn spawn_server(
160    host: &str,
161    port: u16,
162    tls_config: Option<WsTlsConfig>,
163) -> Result<ServerHandle, CamelError> {
164    let host_owned = host.to_string();
165    let addr = format!("{host}:{port}");
166    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
167    let path_configs = Arc::new(DashMap::new());
168    let server_error = new_atomic_false();
169    let state = WsAppState {
170        dispatch: Arc::clone(&dispatch),
171        path_configs: Arc::clone(&path_configs),
172        server_error: Arc::clone(&server_error),
173    };
174    let app = Router::new()
175        .fallback(dispatch_handler)
176        .with_state(state.clone());
177
178    let (task, is_tls) = if let Some(ref tls) = tls_config {
179        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
180        let parsed_addr = addr.parse().map_err(|e| {
181            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
182        })?;
183        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
184        let error_flag = Arc::clone(&server_error);
185        let task = tokio::spawn(async move {
186            if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
187                .serve(app.into_make_service())
188                .await
189            {
190                tracing::error!(
191                    host = host_owned,
192                    port = port,
193                    error = %e,
194                    "WebSocket server terminated with error"
195                );
196                error_flag.store(true, Ordering::Relaxed);
197            }
198        });
199        (task, true)
200    } else {
201        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
202            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
203        })?;
204        let error_flag = Arc::clone(&server_error);
205        let task = tokio::spawn(async move {
206            if let Err(e) = serve(listener, app).await {
207                tracing::error!(
208                    host = host_owned,
209                    port = port,
210                    error = %e,
211                    "WebSocket server terminated with error"
212                );
213                error_flag.store(true, Ordering::Relaxed);
214            }
215        });
216        (task, false)
217    };
218
219    tracing::info!(host, port, is_tls, "WebSocket server started");
220
221    Ok(ServerHandle {
222        state,
223        is_tls,
224        _task: task,
225    })
226}
227
228#[derive(Clone)]
229struct WsAppState {
230    dispatch: DispatchTable,
231    path_configs: Arc<DashMap<String, WsPathConfig>>,
232    server_error: Arc<AtomicBool>,
233}
234
235pub struct WsConnectionRegistry {
236    connections: DashMap<String, mpsc::Sender<WsMessage>>,
237}
238
239static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
240    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
241> = OnceLock::new();
242
243fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
244    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
245}
246
247impl Default for WsConnectionRegistry {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253impl WsConnectionRegistry {
254    pub fn new() -> Self {
255        Self {
256            connections: DashMap::new(),
257        }
258    }
259
260    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
261        self.connections.insert(key, tx);
262    }
263
264    pub fn remove(&self, key: &str) {
265        self.connections.remove(key);
266    }
267
268    pub fn len(&self) -> usize {
269        self.connections.len()
270    }
271
272    pub fn is_empty(&self) -> bool {
273        self.connections.is_empty()
274    }
275
276    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
277        self.connections.iter().map(|e| e.value().clone()).collect()
278    }
279
280    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
281        keys.iter()
282            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
283            .collect()
284    }
285}
286
287async fn dispatch_handler(
288    State(state): State<WsAppState>,
289    req: Request<Body>,
290) -> impl IntoResponse {
291    let path = req.uri().path().to_string();
292    let origin = req
293        .headers()
294        .get(header::ORIGIN)
295        .and_then(|value| value.to_str().ok())
296        .map(str::to_string);
297    let remote_addr = req
298        .extensions()
299        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
300        .map(|ci| ci.0.to_string())
301        .unwrap_or_default();
302    let table = state.dispatch.read().await;
303    if !table.contains_key(&path) {
304        return (
305            StatusCode::NOT_FOUND,
306            "no ws endpoint registered for this path",
307        )
308            .into_response();
309    }
310    drop(table);
311
312    let path_config = state
313        .path_configs
314        .get(&path)
315        .map(|entry| entry.value().clone())
316        .unwrap_or_default();
317    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
318        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
319    }
320
321    let upgrade_headers: HashMap<String, String> = req
322        .headers()
323        .iter()
324        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
325        .collect();
326
327    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
328        Ok(ws) => ws,
329        Err(_) => {
330            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
331        }
332    };
333
334    ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
335        .into_response()
336}
337
338#[allow(unused_variables)]
339async fn ws_handler(
340    socket: WebSocket,
341    state: WsAppState,
342    path: String,
343    remote_addr: String,
344    upgrade_headers: HashMap<String, String>,
345) {
346    let connection_key = uuid::Uuid::new_v4().to_string();
347    let path_config = state
348        .path_configs
349        .get(&path)
350        .map(|entry| entry.value().clone())
351        .unwrap_or_default();
352
353    let env_tx = {
354        let table = state.dispatch.read().await;
355        table.get(&path).cloned()
356    };
357    let Some(env_tx) = env_tx else {
358        return;
359    };
360
361    let (mut sink, mut stream) = socket.split();
362    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
363
364    let registry = global_registries();
365    let mut registry_key = None;
366    for entry in registry.iter() {
367        if entry.key().2 == path {
368            entry.value().insert(connection_key.clone(), out_tx.clone());
369            registry_key = Some(entry.key().clone());
370            break;
371        }
372    }
373
374    // Clone for writer closure and subsequent tracing (WS-009)
375    let conn_key_for_writer = connection_key.clone();
376    let path_for_writer = path.clone();
377
378    let writer = tokio::spawn(async move {
379        while let Some(msg) = out_rx.recv().await {
380            if let Err(e) = sink.send(msg).await {
381                tracing::warn!(
382                    connection_key = conn_key_for_writer,
383                    path = path_for_writer,
384                    error = %e,
385                    "WebSocket writer send error — closing connection"
386                );
387                break;
388            }
389        }
390    });
391
392    tracing::info!(
393        connection_key = connection_key,
394        path = path,
395        remote_addr = remote_addr,
396        "WebSocket connection opened"
397    );
398
399    let mut over_limit = false;
400    if let Some(key) = &registry_key
401        && let Some(entry) = registry.get(key)
402        && entry.len() > path_config.max_connections as usize
403    {
404        over_limit = true;
405    }
406    if over_limit {
407        try_send_with_backpressure(
408            &out_tx,
409            WsMessage::Close(Some(CloseFrame {
410                code: CloseCode::from(1013u16),
411                reason: "max connections exceeded".into(),
412            })),
413            "max-connections-close",
414        );
415        if let Some(key) = registry_key.clone()
416            && let Some(entry) = registry.get(&key)
417        {
418            entry.remove(&connection_key);
419        }
420        drop(out_tx);
421        let _ = writer.await;
422        return;
423    }
424
425    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
426        let ping_tx = out_tx.clone();
427        let interval = path_config.heartbeat_interval;
428        Some(tokio::spawn(async move {
429            let mut ticker = tokio::time::interval(interval);
430            loop {
431                ticker.tick().await;
432                let _ = try_send_with_backpressure(
433                    &ping_tx,
434                    WsMessage::Ping(Vec::new().into()),
435                    "heartbeat-ping",
436                );
437            }
438        }))
439    } else {
440        None
441    };
442
443    loop {
444        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
445            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
446                Ok(msg) => msg,
447                Err(_) => {
448                    try_send_with_backpressure(
449                        &out_tx,
450                        WsMessage::Close(Some(CloseFrame {
451                            code: CloseCode::from(1000u16),
452                            reason: "idle timeout".into(),
453                        })),
454                        "idle-timeout-close",
455                    );
456                    break;
457                }
458            }
459        } else {
460            stream.next().await
461        };
462
463        let Some(msg) = next_msg else {
464            break;
465        };
466
467        match msg {
468            Ok(WsMessage::Ping(data)) => {
469                tracing::debug!(
470                    connection_key = connection_key,
471                    path = path,
472                    "WebSocket ping received — sending pong"
473                );
474                let _ = try_send_with_backpressure(
475                    &out_tx,
476                    WsMessage::Pong(data),
477                    "ping-pong-response",
478                );
479            }
480            Ok(WsMessage::Pong(_)) => {
481                tracing::debug!(
482                    connection_key = connection_key,
483                    path = path,
484                    "WebSocket pong received"
485                );
486            }
487            Ok(WsMessage::Text(text)) => {
488                if text.len() > path_config.max_message_size as usize {
489                    try_send_with_backpressure(
490                        &out_tx,
491                        WsMessage::Close(Some(CloseFrame {
492                            code: CloseCode::from(1009u16),
493                            reason: "message too large".into(),
494                        })),
495                        "max-message-size-close-text",
496                    );
497                    break;
498                }
499
500                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
501                message.set_header(
502                    "CamelWsConnectionKey",
503                    serde_json::Value::String(connection_key.clone()),
504                );
505                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
506                message.set_header(
507                    "CamelWsRemoteAddress",
508                    serde_json::Value::String(remote_addr.clone()),
509                );
510
511                #[allow(unused_mut)]
512                let mut exchange = Exchange::new(message);
513                #[cfg(feature = "otel")]
514                {
515                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
516                }
517                if env_tx
518                    .send(ExchangeEnvelope {
519                        exchange,
520                        reply_tx: None,
521                    })
522                    .await
523                    .is_err()
524                {
525                    break;
526                }
527            }
528            Ok(WsMessage::Binary(data)) => {
529                if data.len() > path_config.max_message_size as usize {
530                    try_send_with_backpressure(
531                        &out_tx,
532                        WsMessage::Close(Some(CloseFrame {
533                            code: CloseCode::from(1009u16),
534                            reason: "message too large".into(),
535                        })),
536                        "max-message-size-close-binary",
537                    );
538                    break;
539                }
540
541                let mut message = CamelMessage::new(CamelBody::Bytes(data));
542                message.set_header(
543                    "CamelWsConnectionKey",
544                    serde_json::Value::String(connection_key.clone()),
545                );
546                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
547                message.set_header(
548                    "CamelWsRemoteAddress",
549                    serde_json::Value::String(remote_addr.clone()),
550                );
551
552                #[allow(unused_mut)]
553                let mut exchange = Exchange::new(message);
554                #[cfg(feature = "otel")]
555                {
556                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
557                }
558                if env_tx
559                    .send(ExchangeEnvelope {
560                        exchange,
561                        reply_tx: None,
562                    })
563                    .await
564                    .is_err()
565                {
566                    break;
567                }
568            }
569            Ok(WsMessage::Close(cf)) => {
570                let reason = cf
571                    .as_ref()
572                    .map(|f| f.reason.to_string())
573                    .unwrap_or_default();
574                tracing::info!(
575                    connection_key = connection_key,
576                    path = path,
577                    reason = reason,
578                    "WebSocket connection closed by peer"
579                );
580                break;
581            }
582            Err(e) => {
583                tracing::warn!(
584                    connection_key = connection_key,
585                    path = path,
586                    error = %e,
587                    "WebSocket receive error"
588                );
589                break;
590            }
591        }
592    }
593
594    if let Some(task) = heartbeat_task {
595        task.abort();
596    }
597
598    if let Some(key) = registry_key
599        && let Some(entry) = registry.get(&key)
600    {
601        entry.remove(&connection_key);
602    }
603    drop(out_tx);
604    let _ = writer.await;
605
606    tracing::info!(
607        connection_key = connection_key,
608        path = path,
609        "WebSocket connection closed"
610    );
611}
612
613pub struct WsComponent {
614    pub(crate) config: WsConfig,
615}
616
617impl WsComponent {
618    pub fn new() -> Self {
619        Self {
620            config: WsConfig::default(),
621        }
622    }
623
624    pub fn with_config(config: WsConfig) -> Self {
625        Self { config }
626    }
627}
628
629impl Default for WsComponent {
630    fn default() -> Self {
631        Self::new()
632    }
633}
634
635impl Component for WsComponent {
636    fn scheme(&self) -> &str {
637        "ws"
638    }
639
640    fn create_endpoint(
641        &self,
642        uri: &str,
643        _ctx: &dyn camel_component_api::ComponentContext,
644    ) -> Result<Box<dyn Endpoint>, CamelError> {
645        let mut cfg = WsEndpointConfig::from_uri(uri)?;
646        if let Some(v) = self.config.max_connections {
647            cfg.max_connections = v;
648        }
649        if let Some(v) = self.config.max_message_size {
650            cfg.max_message_size = v;
651        }
652        if let Some(v) = self.config.heartbeat_interval_ms {
653            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
654        }
655        if let Some(v) = self.config.idle_timeout_ms {
656            cfg.idle_timeout = std::time::Duration::from_millis(v);
657        }
658        if let Some(v) = self.config.connect_timeout_ms {
659            cfg.connect_timeout = std::time::Duration::from_millis(v);
660        }
661        if let Some(v) = self.config.response_timeout_ms {
662            cfg.response_timeout = std::time::Duration::from_millis(v);
663        }
664        Ok(Box::new(WsEndpoint {
665            uri: uri.to_string(),
666            cfg,
667        }))
668    }
669}
670
671pub struct WssComponent {
672    pub(crate) config: WsConfig,
673}
674
675impl WssComponent {
676    pub fn new() -> Self {
677        Self {
678            config: WsConfig::default(),
679        }
680    }
681
682    pub fn with_config(config: WsConfig) -> Self {
683        Self { config }
684    }
685}
686
687impl Default for WssComponent {
688    fn default() -> Self {
689        Self::new()
690    }
691}
692
693impl Component for WssComponent {
694    fn scheme(&self) -> &str {
695        "wss"
696    }
697
698    fn create_endpoint(
699        &self,
700        uri: &str,
701        _ctx: &dyn camel_component_api::ComponentContext,
702    ) -> Result<Box<dyn Endpoint>, CamelError> {
703        let mut cfg = WsEndpointConfig::from_uri(uri)?;
704        if let Some(v) = self.config.max_connections {
705            cfg.max_connections = v;
706        }
707        if let Some(v) = self.config.max_message_size {
708            cfg.max_message_size = v;
709        }
710        if let Some(v) = self.config.heartbeat_interval_ms {
711            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
712        }
713        if let Some(v) = self.config.idle_timeout_ms {
714            cfg.idle_timeout = std::time::Duration::from_millis(v);
715        }
716        if let Some(v) = self.config.connect_timeout_ms {
717            cfg.connect_timeout = std::time::Duration::from_millis(v);
718        }
719        if let Some(v) = self.config.response_timeout_ms {
720            cfg.response_timeout = std::time::Duration::from_millis(v);
721        }
722        Ok(Box::new(WsEndpoint {
723            uri: uri.to_string(),
724            cfg,
725        }))
726    }
727}
728
729struct WsEndpoint {
730    uri: String,
731    cfg: WsEndpointConfig,
732}
733
734impl Endpoint for WsEndpoint {
735    fn uri(&self) -> &str {
736        &self.uri
737    }
738
739    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
740        Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
741    }
742
743    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
744        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
745    }
746}
747
748pub struct WsConsumer {
749    cfg: WsServerConfig,
750    registry: Arc<WsConnectionRegistry>,
751    server_state: Option<WsAppState>,
752    registry_key: Option<(String, u16, String)>,
753    forward_task: Option<JoinHandle<()>>,
754}
755
756impl WsConsumer {
757    pub fn new(cfg: WsServerConfig) -> Self {
758        Self {
759            cfg,
760            registry: Arc::new(WsConnectionRegistry::new()),
761            server_state: None,
762            registry_key: None,
763            forward_task: None,
764        }
765    }
766}
767
768#[async_trait]
769impl Consumer for WsConsumer {
770    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
771        // Reject double-start (WS-006)
772        if self.server_state.is_some() {
773            return Err(CamelError::EndpointCreationFailed(
774                "WebSocket consumer already started".into(),
775            ));
776        }
777
778        tracing::info!(
779            host = self.cfg.inner.host,
780            port = self.cfg.inner.port,
781            path = self.cfg.inner.path,
782            scheme = self.cfg.inner.scheme,
783            "WebSocket consumer starting"
784        );
785
786        let tls_config = if self.cfg.inner.scheme == "wss" {
787            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
788                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
789            })?;
790            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
791                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
792            })?;
793            Some(WsTlsConfig {
794                cert_path,
795                key_path,
796            })
797        } else {
798            None
799        };
800
801        let state = ServerRegistry::global()
802            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
803            .await?;
804
805        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
806        {
807            let mut table = state.dispatch.write().await;
808            table.insert(self.cfg.inner.path.clone(), env_tx);
809        }
810
811        state.path_configs.insert(
812            self.cfg.inner.path.clone(),
813            WsPathConfig {
814                max_connections: self.cfg.inner.max_connections,
815                max_message_size: self.cfg.inner.max_message_size,
816                heartbeat_interval: self.cfg.inner.heartbeat_interval,
817                idle_timeout: self.cfg.inner.idle_timeout,
818                allow_origin: self.cfg.inner.allow_origin.clone(),
819            },
820        );
821
822        let registry_key = (
823            self.cfg.inner.canonical_host(),
824            self.cfg.inner.port,
825            self.cfg.inner.path.clone(),
826        );
827        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
828
829        let sender = ctx.sender();
830        let forward_task = tokio::spawn(async move {
831            while let Some(envelope) = env_rx.recv().await {
832                if sender.send(envelope).await.is_err() {
833                    break;
834                }
835            }
836        });
837
838        self.server_state = Some(state);
839        self.registry_key = Some(registry_key);
840        self.forward_task = Some(forward_task);
841        Ok(())
842    }
843
844    async fn stop(&mut self) -> Result<(), CamelError> {
845        tracing::info!(
846            host = self.cfg.inner.host,
847            port = self.cfg.inner.port,
848            path = self.cfg.inner.path,
849            "WebSocket consumer stopping"
850        );
851
852        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
853            code: axum::extract::ws::CloseCode::from(1001u16),
854            reason: "consumer stopping".into(),
855        }));
856        for tx in self.registry.snapshot_senders() {
857            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
858        }
859
860        let mut had_server_error = false;
861
862        if let Some(state) = self.server_state.take() {
863            had_server_error = state.server_error.load(Ordering::Relaxed);
864            let mut table = state.dispatch.write().await;
865            table.remove(&self.cfg.inner.path);
866            state.path_configs.remove(&self.cfg.inner.path);
867        }
868
869        if let Some(key) = self.registry_key.take() {
870            global_registries().remove(&key);
871            ServerRegistry::global().release(key.1);
872        }
873
874        if let Some(task) = self.forward_task.take() {
875            task.abort();
876        }
877
878        tracing::info!(
879            host = self.cfg.inner.host,
880            port = self.cfg.inner.port,
881            path = self.cfg.inner.path,
882            "WebSocket consumer stopped"
883        );
884
885        if had_server_error {
886            tracing::warn!(
887                host = self.cfg.inner.host,
888                port = self.cfg.inner.port,
889                path = self.cfg.inner.path,
890                "WebSocket server had errors during its lifetime"
891            );
892            return Err(CamelError::ProcessorError(
893                "WebSocket server terminated with errors during its lifetime".into(),
894            ));
895        }
896
897        Ok(())
898    }
899
900    fn concurrency_model(&self) -> ConcurrencyModel {
901        ConcurrencyModel::Concurrent {
902            max: Some(self.cfg.inner.max_connections as usize),
903        }
904    }
905}
906
907use std::sync::atomic::{AtomicBool, Ordering};
908
909fn new_atomic_false() -> Arc<AtomicBool> {
910    Arc::new(AtomicBool::new(false))
911}
912
913#[derive(Clone)]
914pub struct WsProducer {
915    cfg: WsClientConfig,
916    /// Shared flag set by the async future when server-send hits backpressure,
917    /// so that the next `poll_ready` call can return an error. (WS-003)
918    backpressure_flag: Arc<AtomicBool>,
919}
920
921impl WsProducer {
922    pub fn new(cfg: WsClientConfig) -> Self {
923        Self {
924            cfg,
925            backpressure_flag: Arc::new(AtomicBool::new(false)),
926        }
927    }
928}
929
930impl Service<Exchange> for WsProducer {
931    type Response = Exchange;
932    type Error = CamelError;
933    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
934
935    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
936        // Return error if last server-send hit backpressure (WS-003)
937        if self.backpressure_flag.swap(false, Ordering::Relaxed) {
938            return Poll::Ready(Err(CamelError::ProcessorError(
939                "WebSocket producer backpressure: previous send was dropped due to full channel"
940                    .into(),
941            )));
942        }
943        Poll::Ready(Ok(()))
944    }
945
946    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
947        let cfg = self.cfg.clone();
948        let backpressure_flag = Arc::clone(&self.backpressure_flag);
949
950        Box::pin(async move {
951            let canonical_host = cfg.inner.canonical_host();
952            let key = (
953                canonical_host.clone(),
954                cfg.inner.port,
955                cfg.inner.path.clone(),
956            );
957
958            let send_to_all = exchange
959                .input
960                .header("CamelWsSendToAll")
961                .and_then(|v| v.as_bool())
962                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
963                .unwrap_or(false);
964
965            let conn_keys_header = exchange
966                .input
967                .header("CamelWsConnectionKey")
968                .and_then(|v| v.as_str())
969                .map(str::to_string);
970
971            let local_exists = global_registries().contains_key(&key);
972            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
973
974            let message_type = exchange
975                .input
976                .header("CamelWsMessageType")
977                .and_then(|v| v.as_str())
978                .unwrap_or("text")
979                .to_ascii_lowercase();
980
981            if server_send_mode {
982                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
983                let Some(registry) = registry else {
984                    return Err(CamelError::ProcessorError(format!(
985                        "WebSocket local consumer not found for {}:{}{}",
986                        canonical_host, cfg.inner.port, cfg.inner.path
987                    )));
988                };
989
990                let out_msg = body_to_axum_ws_message(
991                    std::mem::take(&mut exchange.input.body),
992                    &message_type,
993                )
994                .await?;
995
996                let targets = if send_to_all {
997                    registry.snapshot_senders()
998                } else if let Some(keys) = conn_keys_header {
999                    let parsed: Vec<String> = keys
1000                        .split(',')
1001                        .map(str::trim)
1002                        .filter(|k| !k.is_empty())
1003                        .map(|k| k.to_string())
1004                        .collect();
1005                    registry.get_senders_for_keys(&parsed)
1006                } else {
1007                    registry.snapshot_senders()
1008                };
1009
1010                let mut dropped = 0usize;
1011                for tx in &targets {
1012                    if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1013                        dropped += 1;
1014                    }
1015                }
1016
1017                if dropped > 0 {
1018                    tracing::warn!(
1019                        host = canonical_host,
1020                        port = cfg.inner.port,
1021                        path = cfg.inner.path,
1022                        dropped,
1023                        total = targets.len(),
1024                        "WebSocket producer dropped messages due to backpressure"
1025                    );
1026                    exchange.input.set_header(
1027                        "CamelWsDeliveryDropped",
1028                        serde_json::Value::Number(dropped.into()),
1029                    );
1030                    // Signal backpressure for next poll_ready call (WS-003)
1031                    backpressure_flag.store(true, Ordering::Relaxed);
1032                    if dropped == targets.len() {
1033                        return Err(CamelError::ProcessorError(format!(
1034                            "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1035                        )));
1036                    }
1037                }
1038
1039                tracing::debug!(
1040                    host = canonical_host,
1041                    port = cfg.inner.port,
1042                    path = cfg.inner.path,
1043                    targets = targets.len(),
1044                    "WebSocket producer server-send complete"
1045                );
1046
1047                return Ok(exchange);
1048            }
1049
1050            let url = format!(
1051                "{}://{}:{}{}",
1052                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1053            );
1054
1055            tracing::debug!(url = url, "WebSocket producer connecting");
1056
1057            #[allow(unused_mut)]
1058            let mut request = url
1059                .clone()
1060                .into_client_request()
1061                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1062
1063            #[cfg(feature = "otel")]
1064            {
1065                let mut otel_headers = HashMap::new();
1066                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1067                for (k, v) in otel_headers {
1068                    if let (Ok(name), Ok(val)) = (
1069                        http::header::HeaderName::from_bytes(k.as_bytes()),
1070                        http::header::HeaderValue::from_str(&v),
1071                    ) {
1072                        request.headers_mut().insert(name, val);
1073                    }
1074                }
1075            }
1076
1077            // Retry transient connect/open failures before write occurs (WS-008)
1078            let max_retries = 3usize;
1079            let mut retries_left = max_retries;
1080            let mut last_err: Option<CamelError> = None;
1081            let mut ws_stream = loop {
1082                let connect_future = tokio_tungstenite::connect_async(request.clone());
1083                match tokio::time::timeout(cfg.inner.connect_timeout, connect_future).await {
1084                    Ok(Ok((stream, _))) => break stream,
1085                    Ok(Err(e)) => {
1086                        let err = map_connect_error(e, &url);
1087                        // Only retry transient connect failures (connection refused, timeout)
1088                        let is_transient = err.to_string().contains("connection refused")
1089                            || err.to_string().contains("timeout");
1090                        if retries_left > 0 && is_transient {
1091                            tracing::warn!(
1092                                url = url,
1093                                error = %err,
1094                                retries_left,
1095                                "WebSocket connect failed — retrying"
1096                            );
1097                            last_err = Some(err);
1098                            retries_left -= 1;
1099                            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1100                            continue;
1101                        }
1102                        return Err(err);
1103                    }
1104                    Err(_) => {
1105                        let err = CamelError::ProcessorError(format!(
1106                            "WebSocket connect timeout ({:?}) to {url}",
1107                            cfg.inner.connect_timeout
1108                        ));
1109                        if retries_left > 0 {
1110                            tracing::warn!(
1111                                url = url,
1112                                retries_left,
1113                                "WebSocket connect timeout — retrying"
1114                            );
1115                            last_err = Some(err);
1116                            retries_left -= 1;
1117                            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1118                            continue;
1119                        }
1120                        return Err(err);
1121                    }
1122                }
1123            };
1124            if let Some(ref _err) = last_err {
1125                tracing::info!(url = url, "WebSocket producer connected after retry");
1126            }
1127
1128            let out_msg =
1129                body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
1130                    .await?;
1131
1132            ws_stream
1133                .send(out_msg)
1134                .await
1135                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1136
1137            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1138                loop {
1139                    match ws_stream.next().await {
1140                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1141                            continue;
1142                        }
1143                        other => break other,
1144                    }
1145                }
1146            })
1147            .await
1148            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1149
1150            match incoming {
1151                Some(Ok(ClientWsMessage::Text(text))) => {
1152                    tracing::debug!(url = url, "WebSocket producer received text response");
1153                    exchange.input.body = CamelBody::Text(text.to_string());
1154                }
1155                Some(Ok(ClientWsMessage::Binary(data))) => {
1156                    tracing::debug!(url = url, "WebSocket producer received binary response");
1157                    exchange.input.body = CamelBody::Bytes(data);
1158                }
1159                Some(Ok(ClientWsMessage::Close(frame))) => {
1160                    let normal = frame
1161                        .as_ref()
1162                        .map(|f| {
1163                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1164                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1165                        })
1166                        .unwrap_or(true);
1167
1168                    if normal {
1169                        tracing::debug!(url = url, "WebSocket producer received normal close");
1170                        exchange.input.body = CamelBody::Empty;
1171                    } else {
1172                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1173                        return Err(CamelError::ProcessorError(format!(
1174                            "WebSocket peer closed: code {code}"
1175                        )));
1176                    }
1177                }
1178                Some(Ok(_)) | None => {
1179                    exchange.input.body = CamelBody::Empty;
1180                }
1181                Some(Err(e)) => {
1182                    return Err(CamelError::ProcessorError(format!(
1183                        "WebSocket receive failed: {e}"
1184                    )));
1185                }
1186            }
1187
1188            let _ = ws_stream.close(None).await;
1189            tracing::debug!(url = url, "WebSocket producer connection closed");
1190            Ok(exchange)
1191        })
1192    }
1193}
1194
1195async fn body_to_axum_ws_message(
1196    body: CamelBody,
1197    message_type: &str,
1198) -> Result<WsMessage, CamelError> {
1199    match message_type {
1200        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1201        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1202    }
1203}
1204
1205async fn body_to_client_ws_message(
1206    body: CamelBody,
1207    message_type: &str,
1208) -> Result<ClientWsMessage, CamelError> {
1209    match message_type {
1210        "binary" => Ok(ClientWsMessage::Binary(
1211            body.into_bytes(10 * 1024 * 1024).await?,
1212        )),
1213        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1214    }
1215}
1216
1217async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1218    Ok(match body {
1219        CamelBody::Empty => String::new(),
1220        CamelBody::Text(s) => s,
1221        CamelBody::Xml(s) => s,
1222        CamelBody::Json(v) => v.to_string(),
1223        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1224        CamelBody::Stream(stream) => {
1225            let bytes = CamelBody::Stream(stream)
1226                .into_bytes(10 * 1024 * 1024)
1227                .await?;
1228            String::from_utf8_lossy(&bytes).to_string()
1229        }
1230    })
1231}
1232
1233fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1234    if allowed_origin == "*" {
1235        return true;
1236    }
1237    request_origin.is_some_and(|origin| origin == allowed_origin)
1238}
1239
1240fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1241    match tx.try_send(msg) {
1242        Ok(()) => true,
1243        Err(error) => {
1244            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1245            false
1246        }
1247    }
1248}
1249
1250fn load_tls_config(
1251    cert_path: &str,
1252    key_path: &str,
1253) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1254    use std::fs::File;
1255    use std::io::BufReader;
1256
1257    let cert_file = File::open(cert_path)
1258        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1259    let key_file = File::open(key_path)
1260        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1261
1262    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1263        .collect::<Result<Vec<_>, _>>()
1264        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1265
1266    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1267        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1268        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1269
1270    tokio_rustls::rustls::ServerConfig::builder()
1271        .with_no_client_auth()
1272        .with_single_cert(certs, key)
1273        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1274}
1275
1276fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1277    match err {
1278        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1279            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1280        }
1281        tungstenite::Error::Tls(_) => {
1282            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1283        }
1284        other => {
1285            let msg = other.to_string();
1286            if msg.to_lowercase().contains("connection refused") {
1287                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1288            } else if msg.to_lowercase().contains("tls") {
1289                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1290            } else {
1291                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1292            }
1293        }
1294    }
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299    use super::*;
1300    use camel_component_api::NoOpComponentContext;
1301    use std::time::Duration;
1302
1303    use tokio::sync::mpsc;
1304    use tokio_tungstenite::connect_async;
1305    use tokio_tungstenite::tungstenite::Message as ClientMessage;
1306    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1307    use tokio_util::sync::CancellationToken;
1308    use tower::ServiceExt;
1309
1310    fn free_port() -> u16 {
1311        std::net::TcpListener::bind("127.0.0.1:0")
1312            .unwrap()
1313            .local_addr()
1314            .unwrap()
1315            .port()
1316    }
1317
1318    #[test]
1319    fn ws_component_scheme_is_ws() {
1320        assert_eq!(WsComponent::new().scheme(), "ws");
1321    }
1322
1323    #[test]
1324    fn wss_component_scheme_is_wss() {
1325        assert_eq!(WssComponent::new().scheme(), "wss");
1326    }
1327
1328    #[test]
1329    fn endpoint_config_defaults_match_spec() {
1330        let cfg = WsEndpointConfig::default();
1331        assert_eq!(cfg.scheme, "ws");
1332        assert_eq!(cfg.host, "0.0.0.0");
1333        assert_eq!(cfg.port, 8080);
1334        assert_eq!(cfg.path, "/");
1335        assert_eq!(cfg.max_connections, 100);
1336        assert_eq!(cfg.max_message_size, 65536);
1337        assert!(!cfg.send_to_all);
1338        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1339        assert_eq!(cfg.idle_timeout, Duration::ZERO);
1340        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1341        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1342        assert_eq!(cfg.allow_origin, "*");
1343        assert_eq!(cfg.tls_cert, None);
1344        assert_eq!(cfg.tls_key, None);
1345    }
1346
1347    #[test]
1348    fn endpoint_config_parses_uri_params() {
1349        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1350        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1351
1352        assert_eq!(cfg.scheme, "ws");
1353        assert_eq!(cfg.host, "localhost");
1354        assert_eq!(cfg.port, 9001);
1355        assert_eq!(cfg.path, "/chat");
1356        assert_eq!(cfg.max_connections, 42);
1357        assert_eq!(cfg.max_message_size, 1024);
1358        assert!(cfg.send_to_all);
1359        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1360        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1361        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1362        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1363        assert_eq!(cfg.allow_origin, "https://example.com");
1364        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1365        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1366    }
1367
1368    #[test]
1369    fn endpoint_config_override_chain_uri_overrides_defaults() {
1370        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1371        assert_eq!(cfg.max_connections, 7);
1372        assert_eq!(cfg.max_message_size, 65536);
1373        assert!(!cfg.send_to_all);
1374        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1375    }
1376
1377    #[test]
1378    fn endpoint_trait_creates_consumer_and_producer() {
1379        let ctx = NoOpComponentContext;
1380        let endpoint = WsComponent::new()
1381            .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1382            .unwrap();
1383
1384        endpoint.create_consumer().unwrap();
1385        endpoint
1386            .create_producer(&ProducerContext::default())
1387            .unwrap();
1388    }
1389
1390    #[test]
1391    fn ws_consumer_concurrency_model_uses_max_connections() {
1392        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1393        let consumer = WsConsumer::new(cfg.server_config());
1394        assert_eq!(
1395            consumer.concurrency_model(),
1396            ConcurrencyModel::Concurrent { max: Some(321) }
1397        );
1398    }
1399
1400    #[tokio::test]
1401    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1402        let registry = WsConnectionRegistry::new();
1403        let (tx1, mut rx1) = mpsc::channel(8);
1404        let (tx2, mut rx2) = mpsc::channel(8);
1405
1406        registry.insert("k1".into(), tx1);
1407        registry.insert("k2".into(), tx2);
1408        assert_eq!(registry.len(), 2);
1409
1410        for tx in registry.snapshot_senders() {
1411            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1412        }
1413
1414        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1415        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1416
1417        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1418        assert_eq!(target.len(), 1);
1419        target[0]
1420            .send(WsMessage::Text("targeted".into()))
1421            .await
1422            .unwrap();
1423
1424        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1425        assert!(
1426            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1427                .await
1428                .is_err()
1429        );
1430
1431        registry.remove("k1");
1432        assert_eq!(registry.len(), 1);
1433    }
1434
1435    #[test]
1436    fn host_canonicalization_maps_local_hosts_to_loopback() {
1437        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1438            .unwrap()
1439            .canonical_host();
1440        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1441            .unwrap()
1442            .canonical_host();
1443        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1444            .unwrap()
1445            .canonical_host();
1446
1447        assert_eq!(c1, "127.0.0.1");
1448        assert_eq!(c2, "127.0.0.1");
1449        assert_eq!(c3, "127.0.0.1");
1450    }
1451
1452    #[tokio::test]
1453    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1454        let port = free_port();
1455        let uri = format!("ws://127.0.0.1:{port}/echo");
1456        let component_ctx = NoOpComponentContext;
1457        let endpoint = WsComponent::new()
1458            .create_endpoint(&uri, &component_ctx)
1459            .unwrap();
1460
1461        let mut consumer = endpoint.create_consumer().unwrap();
1462        let producer = endpoint
1463            .create_producer(&ProducerContext::default())
1464            .unwrap();
1465
1466        let (route_tx, mut route_rx) = mpsc::channel(16);
1467        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1468        consumer.start(ctx).await.unwrap();
1469
1470        let route_task = tokio::spawn(async move {
1471            if let Some(envelope) = route_rx.recv().await {
1472                let payload = envelope
1473                    .exchange
1474                    .input
1475                    .body
1476                    .as_text()
1477                    .unwrap_or_default()
1478                    .to_string();
1479                let key = envelope
1480                    .exchange
1481                    .input
1482                    .header("CamelWsConnectionKey")
1483                    .and_then(|v| v.as_str())
1484                    .unwrap()
1485                    .to_string();
1486
1487                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1488                response
1489                    .input
1490                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1491                producer.oneshot(response).await.unwrap();
1492            }
1493        });
1494
1495        let url = format!("ws://127.0.0.1:{port}/echo");
1496        let (mut client, _) = loop {
1497            match connect_async(&url).await {
1498                Ok(ok) => break ok,
1499                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1500            }
1501        };
1502
1503        client
1504            .send(ClientMessage::Text("hello-ws".into()))
1505            .await
1506            .unwrap();
1507
1508        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1509            loop {
1510                match client.next().await {
1511                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1512                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1513                    Some(Ok(_)) => continue,
1514                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1515                    None => panic!("websocket closed before echo"),
1516                }
1517            }
1518        })
1519        .await
1520        .unwrap();
1521
1522        assert_eq!(incoming, "hello-ws");
1523
1524        consumer.stop().await.unwrap();
1525        route_task.await.unwrap();
1526    }
1527
1528    #[tokio::test]
1529    async fn consumer_stop_sends_close_1001() {
1530        let port = free_port();
1531        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1532        let component_ctx = NoOpComponentContext;
1533        let endpoint = WsComponent::new()
1534            .create_endpoint(&uri, &component_ctx)
1535            .unwrap();
1536
1537        let mut consumer = endpoint.create_consumer().unwrap();
1538        let (route_tx, _route_rx) = mpsc::channel(16);
1539        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1540        consumer.start(ctx).await.unwrap();
1541
1542        let url = format!("ws://127.0.0.1:{port}/shutdown");
1543        let (mut client, _) = loop {
1544            match connect_async(&url).await {
1545                Ok(ok) => break ok,
1546                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1547            }
1548        };
1549
1550        client
1551            .send(ClientMessage::Text("keepalive".into()))
1552            .await
1553            .unwrap();
1554
1555        consumer.stop().await.unwrap();
1556
1557        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1558            loop {
1559                match client.next().await {
1560                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1561                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1562                    Some(Ok(_)) => continue,
1563                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1564                    None => panic!("websocket closed without close frame"),
1565                }
1566            }
1567        })
1568        .await
1569        .unwrap();
1570
1571        assert_eq!(close_code, Some(CloseCode::Away));
1572    }
1573
1574    #[test]
1575    fn wildcard_origin_allows_anything() {
1576        assert!(is_origin_allowed("*", None));
1577        assert!(is_origin_allowed("*", Some("https://example.com")));
1578    }
1579
1580    #[test]
1581    fn exact_origin_requires_match() {
1582        assert!(is_origin_allowed(
1583            "https://example.com",
1584            Some("https://example.com")
1585        ));
1586        assert!(!is_origin_allowed(
1587            "https://example.com",
1588            Some("https://other.com")
1589        ));
1590        assert!(!is_origin_allowed("https://example.com", None));
1591    }
1592
1593    #[test]
1594    fn endpoint_config_rejects_invalid_scheme() {
1595        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1596        assert!(result.is_err());
1597        let msg = result.unwrap_err().to_string();
1598        assert!(
1599            msg.contains("Invalid WebSocket scheme"),
1600            "expected scheme error, got: {msg}"
1601        );
1602    }
1603
1604    #[tokio::test]
1605    async fn wss_consumer_start_fails_without_tls_cert() {
1606        let port = free_port();
1607        let component_ctx = NoOpComponentContext;
1608        let endpoint = WssComponent::new()
1609            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1610            .unwrap();
1611        let mut consumer = endpoint.create_consumer().unwrap();
1612        let (tx, _rx) = mpsc::channel(16);
1613        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1614        let result = consumer.start(ctx).await;
1615        assert!(result.is_err());
1616        let msg = result.unwrap_err().to_string();
1617        assert!(
1618            msg.contains("TLS cert path is required"),
1619            "expected TLS cert error, got: {msg}"
1620        );
1621    }
1622
1623    #[tokio::test]
1624    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1625        let port = free_port();
1626        let component_ctx = NoOpComponentContext;
1627        let endpoint = WssComponent::new()
1628            .create_endpoint(&format!(
1629                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1630            ), &component_ctx)
1631            .unwrap();
1632        let mut consumer = endpoint.create_consumer().unwrap();
1633        let (tx, _rx) = mpsc::channel(16);
1634        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1635        let result = consumer.start(ctx).await;
1636        assert!(result.is_err());
1637        let msg = result.unwrap_err().to_string();
1638        assert!(
1639            msg.contains("TLS cert file error"),
1640            "expected cert file error, got: {msg}"
1641        );
1642    }
1643
1644    #[tokio::test]
1645    async fn server_registry_returns_same_state_for_same_port() {
1646        let port = free_port();
1647        let state1 = ServerRegistry::global()
1648            .get_or_spawn("127.0.0.1", port, None)
1649            .await
1650            .unwrap();
1651        let state2 = ServerRegistry::global()
1652            .get_or_spawn("127.0.0.1", port, None)
1653            .await
1654            .unwrap();
1655        assert!(
1656            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1657            "expected same dispatch table for same port"
1658        );
1659    }
1660
1661    #[tokio::test]
1662    async fn dispatch_handler_returns_404_for_unregistered_path() {
1663        let port = free_port();
1664        let state = ServerRegistry::global()
1665            .get_or_spawn("127.0.0.1", port, None)
1666            .await
1667            .unwrap();
1668        let app = Router::new().fallback(dispatch_handler).with_state(state);
1669        let response = tokio::time::timeout(
1670            Duration::from_secs(2),
1671            tower::ServiceExt::oneshot(
1672                app,
1673                axum::http::Request::builder()
1674                    .uri("/nonexistent")
1675                    .body(Body::empty())
1676                    .unwrap(),
1677            ),
1678        )
1679        .await
1680        .unwrap()
1681        .unwrap();
1682        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1683    }
1684
1685    #[tokio::test]
1686    async fn client_mode_producer_connects_and_echoes() {
1687        let port = free_port();
1688
1689        let app = Router::new().route(
1690            "/echo",
1691            axum::routing::get(|ws: WebSocketUpgrade| async move {
1692                ws.on_upgrade(|mut socket: WebSocket| async move {
1693                    while let Some(Ok(msg)) = socket.recv().await {
1694                        match msg {
1695                            WsMessage::Text(text) => {
1696                                let _ = socket.send(WsMessage::Text(text)).await;
1697                            }
1698                            WsMessage::Binary(data) => {
1699                                let _ = socket.send(WsMessage::Binary(data)).await;
1700                            }
1701                            WsMessage::Close(_) => break,
1702                            _ => {}
1703                        }
1704                    }
1705                })
1706            }),
1707        );
1708        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1709            .await
1710            .unwrap();
1711        let server_task = tokio::spawn(async move {
1712            let _ = serve(listener, app).await;
1713        });
1714
1715        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1716        let producer = WsProducer::new(cfg.client_config());
1717
1718        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1719        tokio::time::sleep(Duration::from_millis(25)).await;
1720        let result =
1721            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1722                Ok(Ok(r)) => r,
1723                Ok(Err(_)) => panic!("producer call failed"),
1724                Err(_) => panic!("producer call timed out"),
1725            };
1726
1727        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1728
1729        server_task.abort();
1730    }
1731
1732    #[tokio::test]
1733    async fn max_connections_rejects_with_close_1013() {
1734        let port = free_port();
1735        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1736        let component_ctx = NoOpComponentContext;
1737        let endpoint = WsComponent::new()
1738            .create_endpoint(&uri, &component_ctx)
1739            .unwrap();
1740        let mut consumer = endpoint.create_consumer().unwrap();
1741        let (route_tx, _route_rx) = mpsc::channel(16);
1742        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1743        consumer.start(ctx).await.unwrap();
1744
1745        let url = format!("ws://127.0.0.1:{port}/limited");
1746        let (_client1, _) = loop {
1747            match connect_async(&url).await {
1748                Ok(ok) => break ok,
1749                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1750            }
1751        };
1752
1753        tokio::time::sleep(Duration::from_millis(100)).await;
1754
1755        let (mut client2, _) = connect_async(&url).await.unwrap();
1756
1757        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1758            loop {
1759                match client2.next().await {
1760                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1761                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1762                    Some(Ok(ClientMessage::Text(_))) => continue,
1763                    Some(Ok(_)) => continue,
1764                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1765                    None => panic!("client2 closed without close frame"),
1766                }
1767            }
1768        })
1769        .await
1770        .unwrap();
1771
1772        assert_eq!(
1773            close_code,
1774            Some(CloseCode::from(1013u16)),
1775            "expected 1013 (Try Again Later) for max connections"
1776        );
1777
1778        consumer.stop().await.unwrap();
1779    }
1780
1781    #[tokio::test]
1782    async fn max_message_size_rejects_with_close_1009() {
1783        let port = free_port();
1784        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1785        let component_ctx = NoOpComponentContext;
1786        let endpoint = WsComponent::new()
1787            .create_endpoint(&uri, &component_ctx)
1788            .unwrap();
1789        let mut consumer = endpoint.create_consumer().unwrap();
1790        let (route_tx, _route_rx) = mpsc::channel(16);
1791        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1792        consumer.start(ctx).await.unwrap();
1793
1794        let url = format!("ws://127.0.0.1:{port}/sizelimit");
1795        let (mut client, _) = loop {
1796            match connect_async(&url).await {
1797                Ok(ok) => break ok,
1798                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1799            }
1800        };
1801
1802        let oversized = "x".repeat(100);
1803        client
1804            .send(ClientMessage::Text(oversized.into()))
1805            .await
1806            .unwrap();
1807
1808        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1809            loop {
1810                match client.next().await {
1811                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1812                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1813                    Some(Ok(_)) => continue,
1814                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1815                    None => panic!("websocket closed without close frame"),
1816                }
1817            }
1818        })
1819        .await
1820        .unwrap();
1821
1822        assert_eq!(
1823            close_code,
1824            Some(CloseCode::from(1009u16)),
1825            "expected 1009 (Message Too Big) for oversized message"
1826        );
1827
1828        consumer.stop().await.unwrap();
1829    }
1830
1831    #[tokio::test]
1832    async fn origin_rejection_returns_403() {
1833        let port = free_port();
1834        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1835        let component_ctx = NoOpComponentContext;
1836        let endpoint = WsComponent::new()
1837            .create_endpoint(&uri, &component_ctx)
1838            .unwrap();
1839        let mut consumer = endpoint.create_consumer().unwrap();
1840        let (route_tx, _route_rx) = mpsc::channel(16);
1841        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1842        consumer.start(ctx).await.unwrap();
1843
1844        let state = ServerRegistry::global()
1845            .get_or_spawn("127.0.0.1", port, None)
1846            .await
1847            .unwrap();
1848        let app = Router::new().fallback(dispatch_handler).with_state(state);
1849
1850        let response = tokio::time::timeout(
1851            Duration::from_secs(2),
1852            tower::ServiceExt::oneshot(
1853                app,
1854                axum::http::Request::builder()
1855                    .uri("/origintest")
1856                    .header("origin", "https://evil.com")
1857                    .header("upgrade", "websocket")
1858                    .header("connection", "Upgrade")
1859                    .header("sec-websocket-version", "13")
1860                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1861                    .body(Body::empty())
1862                    .unwrap(),
1863            ),
1864        )
1865        .await
1866        .unwrap()
1867        .unwrap();
1868
1869        assert_eq!(
1870            response.status(),
1871            StatusCode::FORBIDDEN,
1872            "expected 403 for disallowed origin"
1873        );
1874
1875        consumer.stop().await.unwrap();
1876    }
1877
1878    #[tokio::test]
1879    async fn broadcast_sends_to_all_connected_clients() {
1880        let port = free_port();
1881        let uri = format!("ws://127.0.0.1:{port}/bc");
1882        let component_ctx = NoOpComponentContext;
1883        let endpoint = WsComponent::new()
1884            .create_endpoint(&uri, &component_ctx)
1885            .unwrap();
1886        let mut consumer = endpoint.create_consumer().unwrap();
1887        let producer = endpoint
1888            .create_producer(&ProducerContext::default())
1889            .unwrap();
1890
1891        let (route_tx, _route_rx) = mpsc::channel(16);
1892        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1893        consumer.start(ctx).await.unwrap();
1894
1895        let url = format!("ws://127.0.0.1:{port}/bc");
1896
1897        let (mut client1, _) = loop {
1898            match connect_async(&url).await {
1899                Ok(ok) => break ok,
1900                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1901            }
1902        };
1903
1904        let (mut client2, _) = connect_async(&url).await.unwrap();
1905
1906        tokio::time::sleep(Duration::from_millis(100)).await;
1907
1908        let mut response =
1909            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1910        response
1911            .input
1912            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1913        producer.oneshot(response).await.unwrap();
1914
1915        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1916            loop {
1917                match client1.next().await {
1918                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1919                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1920                    _ => panic!("client1 unexpected message or close"),
1921                }
1922            }
1923        })
1924        .await
1925        .unwrap();
1926
1927        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1928            loop {
1929                match client2.next().await {
1930                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1931                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1932                    _ => panic!("client2 unexpected message or close"),
1933                }
1934            }
1935        })
1936        .await
1937        .unwrap();
1938
1939        assert_eq!(recv1, "broadcast-msg");
1940        assert_eq!(recv2, "broadcast-msg");
1941
1942        consumer.stop().await.unwrap();
1943    }
1944
1945    #[tokio::test]
1946    async fn concurrent_get_or_spawn_returns_same_state() {
1947        let port = free_port();
1948        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1949            Arc::new(std::sync::Mutex::new(Vec::new()));
1950
1951        let mut handles = Vec::new();
1952        for _ in 0..4 {
1953            let results = results.clone();
1954            handles.push(tokio::spawn(async move {
1955                let state = ServerRegistry::global()
1956                    .get_or_spawn("127.0.0.1", port, None)
1957                    .await
1958                    .unwrap();
1959                results.lock().unwrap().push(state);
1960            }));
1961        }
1962
1963        for h in handles {
1964            h.await.unwrap();
1965        }
1966
1967        let states = results.lock().unwrap();
1968        assert_eq!(states.len(), 4);
1969        for i in 1..states.len() {
1970            assert!(
1971                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1972                "all concurrent callers should get the same dispatch table"
1973            );
1974        }
1975    }
1976
1977    #[tokio::test]
1978    async fn body_conversion_helpers_cover_text_and_binary_paths() {
1979        let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
1980            .await
1981            .unwrap();
1982        assert!(matches!(text_msg, WsMessage::Text(_)));
1983
1984        let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
1985            .await
1986            .unwrap();
1987        assert!(matches!(bin_msg, WsMessage::Binary(_)));
1988
1989        let client_text =
1990            body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
1991                .await
1992                .unwrap();
1993        assert!(matches!(client_text, ClientWsMessage::Text(_)));
1994
1995        let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
1996            .await
1997            .unwrap();
1998        assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
1999    }
2000
2001    #[tokio::test]
2002    async fn body_to_text_handles_empty_text_json_and_bytes() {
2003        assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2004        assert_eq!(
2005            body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2006            "hello"
2007        );
2008        assert_eq!(
2009            body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2010                .await
2011                .unwrap(),
2012            "{\"n\":1}"
2013        );
2014        assert_eq!(
2015            body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2016                .await
2017                .unwrap(),
2018            "hi"
2019        );
2020    }
2021
2022    #[test]
2023    fn try_send_with_backpressure_returns_false_when_channel_full() {
2024        let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2025        assert!(try_send_with_backpressure(
2026            &tx,
2027            WsMessage::Text("first".into()),
2028            "test"
2029        ));
2030        assert!(!try_send_with_backpressure(
2031            &tx,
2032            WsMessage::Text("second".into()),
2033            "test"
2034        ));
2035    }
2036
2037    #[test]
2038    fn map_connect_error_formats_connection_refused_and_generic_errors() {
2039        let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2040        let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2041        assert!(err.to_string().contains("WebSocket connection refused"));
2042
2043        let generic = map_connect_error(
2044            tungstenite::Error::Protocol(
2045                tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2046            ),
2047            "ws://localhost:2/y",
2048        );
2049        assert!(
2050            generic
2051                .to_string()
2052                .contains("WebSocket connection failed (ws://localhost:2/y)")
2053        );
2054    }
2055
2056    // === Phase B Finding Tests ===
2057
2058    // WS-015: maxConnections=0 must be rejected
2059    #[test]
2060    fn from_uri_rejects_max_connections_zero() {
2061        let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2062        assert!(result.is_err());
2063        let msg = result.unwrap_err().to_string();
2064        assert!(
2065            msg.contains("maxConnections must be >= 1"),
2066            "expected maxConnections validation error, got: {msg}"
2067        );
2068    }
2069
2070    // WS-019: maxMessageSize=0 must be rejected
2071    #[test]
2072    fn from_uri_rejects_max_message_size_zero() {
2073        let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2074        assert!(result.is_err());
2075        let msg = result.unwrap_err().to_string();
2076        assert!(
2077            msg.contains("maxMessageSize must be > 0"),
2078            "expected maxMessageSize validation error, got: {msg}"
2079        );
2080    }
2081
2082    // WS-020: allowOrigin="" must be rejected
2083    #[test]
2084    fn from_uri_rejects_empty_allow_origin() {
2085        let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2086        assert!(result.is_err());
2087        let msg = result.unwrap_err().to_string();
2088        assert!(
2089            msg.contains("allowOrigin must not be empty"),
2090            "expected allowOrigin validation error, got: {msg}"
2091        );
2092    }
2093
2094    // WS-006: Double-start must be rejected
2095    #[tokio::test]
2096    async fn consumer_double_start_returns_error() {
2097        let port = free_port();
2098        let uri = format!("ws://127.0.0.1:{port}/doublestart");
2099        let component_ctx = NoOpComponentContext;
2100        let endpoint = WsComponent::new()
2101            .create_endpoint(&uri, &component_ctx)
2102            .unwrap();
2103
2104        let mut consumer = endpoint.create_consumer().unwrap();
2105        let (route_tx, _route_rx) = mpsc::channel(16);
2106        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2107
2108        // First start should succeed
2109        consumer.start(ctx).await.unwrap();
2110
2111        // Second start should fail
2112        let (route_tx2, _route_rx2) = mpsc::channel(16);
2113        let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2114        let result = consumer.start(ctx2).await;
2115        assert!(result.is_err());
2116        let msg = result.unwrap_err().to_string();
2117        assert!(
2118            msg.contains("already started"),
2119            "expected double-start error, got: {msg}"
2120        );
2121
2122        consumer.stop().await.unwrap();
2123    }
2124
2125    // WS-005: Registry cleanup on stop + port reuse
2126    #[tokio::test]
2127    async fn registry_cleanup_on_consumer_stop() {
2128        let port = free_port();
2129        let uri = format!("ws://127.0.0.1:{port}/cleanup");
2130        let component_ctx = NoOpComponentContext;
2131        let endpoint = WsComponent::new()
2132            .create_endpoint(&uri, &component_ctx)
2133            .unwrap();
2134
2135        let mut consumer = endpoint.create_consumer().unwrap();
2136        let (route_tx, _route_rx) = mpsc::channel(16);
2137        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2138        consumer.start(ctx).await.unwrap();
2139
2140        // Verify registry entry exists
2141        let registries = global_registries();
2142        let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2143        assert!(
2144            registries.contains_key(&key),
2145            "registry should have entry after start"
2146        );
2147
2148        // Stop consumer
2149        consumer.stop().await.unwrap();
2150
2151        // Verify registry entry is removed
2152        assert!(
2153            !registries.contains_key(&key),
2154            "registry should be cleaned up after stop"
2155        );
2156
2157        // Verify server registry reference is released (port can be reused)
2158        // The ServerRegistry should have removed the entry
2159        let server_reg = ServerRegistry::global();
2160        let guard = server_reg.inner.lock().unwrap();
2161        assert!(
2162            !guard.contains_key(&port),
2163            "ServerRegistry should remove port entry after last consumer stops"
2164        );
2165    }
2166
2167    // WS-003 + WS-004: poll_ready backpressure and server-send error handling
2168    #[tokio::test]
2169    async fn producer_server_send_returns_error_when_all_dropped() {
2170        let port = free_port();
2171        let uri = format!("ws://127.0.0.1:{port}/backpressure");
2172        let component_ctx = NoOpComponentContext;
2173        let endpoint = WsComponent::new()
2174            .create_endpoint(&uri, &component_ctx)
2175            .unwrap();
2176
2177        let mut consumer = endpoint.create_consumer().unwrap();
2178        let producer = endpoint
2179            .create_producer(&ProducerContext::default())
2180            .unwrap();
2181
2182        let (route_tx, _route_rx) = mpsc::channel(1); // Tiny channel to force backpressure
2183        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2184        consumer.start(ctx).await.unwrap();
2185
2186        // Connect a client so the registry has an entry
2187        let url = format!("ws://127.0.0.1:{port}/backpressure");
2188        let (mut client, _) = loop {
2189            match connect_async(&url).await {
2190                Ok(ok) => break ok,
2191                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2192            }
2193        };
2194
2195        // Don't consume messages — let the channel fill up
2196        tokio::time::sleep(Duration::from_millis(50)).await;
2197
2198        // Flood the channel to trigger backpressure
2199        let mut all_dropped = false;
2200        for _ in 0..100 {
2201            let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2202            match producer.clone().oneshot(exchange).await {
2203                Ok(_) => {}
2204                Err(e) => {
2205                    if e.to_string().contains("backpressure") {
2206                        all_dropped = true;
2207                        break;
2208                    }
2209                }
2210            }
2211        }
2212
2213        // The producer should eventually return a backpressure error
2214        assert!(
2215            all_dropped,
2216            "producer should return error when all messages are dropped due to backpressure"
2217        );
2218
2219        // Clean up
2220        let _ = client.close(None).await;
2221        consumer.stop().await.unwrap();
2222    }
2223
2224    // WS-012: Ping/pong round-trip in server mode
2225    #[tokio::test]
2226    async fn server_responds_to_client_ping_with_pong() {
2227        let port = free_port();
2228        let uri = format!("ws://127.0.0.1:{port}/pingpong");
2229        let component_ctx = NoOpComponentContext;
2230        let endpoint = WsComponent::new()
2231            .create_endpoint(&uri, &component_ctx)
2232            .unwrap();
2233
2234        let mut consumer = endpoint.create_consumer().unwrap();
2235        let (route_tx, _route_rx) = mpsc::channel(16);
2236        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2237        consumer.start(ctx).await.unwrap();
2238
2239        let url = format!("ws://127.0.0.1:{port}/pingpong");
2240        let (mut client, _) = loop {
2241            match connect_async(&url).await {
2242                Ok(ok) => break ok,
2243                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2244            }
2245        };
2246
2247        // Send a ping
2248        client
2249            .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2250            .await
2251            .unwrap();
2252
2253        // Expect a pong with the same payload
2254        let pong = tokio::time::timeout(Duration::from_secs(2), async {
2255            loop {
2256                match client.next().await {
2257                    Some(Ok(ClientMessage::Pong(data))) => break data,
2258                    Some(Ok(ClientMessage::Ping(_))) => continue,
2259                    Some(Ok(_)) => continue,
2260                    Some(Err(e)) => panic!("ws receive failed: {e}"),
2261                    None => panic!("websocket closed before pong"),
2262                }
2263            }
2264        })
2265        .await
2266        .unwrap();
2267
2268        assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2269
2270        consumer.stop().await.unwrap();
2271    }
2272
2273    // WS-008: Client-side retry on transient connect failures
2274    #[tokio::test]
2275    async fn producer_retries_on_connection_refused() {
2276        // Use a port that nothing is listening on
2277        let port = free_port();
2278        // Ensure nothing is on this port
2279        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/retry")).unwrap();
2280        let producer = WsProducer::new(cfg.client_config());
2281
2282        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2283
2284        // Should fail after retries (nothing listening)
2285        let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2286        assert!(
2287            result.is_ok(),
2288            "producer should complete (with error) within timeout"
2289        );
2290        let result = result.unwrap();
2291        assert!(
2292            result.is_err(),
2293            "producer should fail when nothing is listening"
2294        );
2295        let msg = result.unwrap_err().to_string();
2296        assert!(
2297            msg.contains("connection refused"),
2298            "expected connection refused error, got: {msg}"
2299        );
2300    }
2301
2302    // WS-001: Server bind error is visible (fake server-start error test)
2303    #[tokio::test]
2304    async fn server_bind_error_is_reported() {
2305        // Bind a port manually to cause a conflict
2306        let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2307        let port = _listener.local_addr().unwrap().port();
2308
2309        // Try to start a consumer on the same port — should succeed since axum binds lazily
2310        // The actual bind error happens when the server task runs
2311        let uri = format!("ws://127.0.0.1:{port}/binderror");
2312        let component_ctx = NoOpComponentContext;
2313        let endpoint = WsComponent::new()
2314            .create_endpoint(&uri, &component_ctx)
2315            .unwrap();
2316
2317        let mut consumer = endpoint.create_consumer().unwrap();
2318        let (route_tx, _route_rx) = mpsc::channel(16);
2319        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2320
2321        // Start should succeed (server spawns, but bind may fail)
2322        let start_result = consumer.start(ctx).await;
2323        // The server may or may not have bound yet — this test verifies no panic
2324        // The actual error is logged by the server task
2325        let _ = start_result;
2326
2327        consumer.stop().await.unwrap();
2328    }
2329
2330    #[test]
2331    fn ws_app_state_server_error_starts_false() {
2332        let state = WsAppState {
2333            dispatch: Arc::new(RwLock::new(HashMap::new())),
2334            path_configs: Arc::new(DashMap::new()),
2335            server_error: new_atomic_false(),
2336        };
2337        assert!(
2338            !state.server_error.load(Ordering::Relaxed),
2339            "server_error should start as false"
2340        );
2341    }
2342
2343    #[test]
2344    fn ws_app_state_server_error_can_be_set() {
2345        let state = WsAppState {
2346            dispatch: Arc::new(RwLock::new(HashMap::new())),
2347            path_configs: Arc::new(DashMap::new()),
2348            server_error: new_atomic_false(),
2349        };
2350        assert!(!state.server_error.load(Ordering::Relaxed));
2351        state.server_error.store(true, Ordering::Relaxed);
2352        assert!(state.server_error.load(Ordering::Relaxed));
2353    }
2354
2355    #[tokio::test]
2356    async fn consumer_stop_returns_error_when_server_had_errors() {
2357        let port = free_port();
2358        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2359        let mut consumer = WsConsumer::new(cfg.server_config());
2360        let (route_tx, _route_rx) = mpsc::channel(16);
2361        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2362        consumer.start(ctx).await.unwrap();
2363
2364        // Simulate server error by setting the flag directly
2365        if let Some(ref state) = consumer.server_state {
2366            state.server_error.store(true, Ordering::Relaxed);
2367        }
2368
2369        let result = consumer.stop().await;
2370        assert!(
2371            result.is_err(),
2372            "stop should return error when server had errors"
2373        );
2374        let msg = result.unwrap_err().to_string();
2375        assert!(
2376            msg.contains("terminated with errors"),
2377            "expected server error message, got: {msg}"
2378        );
2379    }
2380
2381    #[tokio::test]
2382    async fn consumer_stop_succeeds_when_server_healthy() {
2383        let port = free_port();
2384        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2385        let mut consumer = WsConsumer::new(cfg.server_config());
2386        let (route_tx, _route_rx) = mpsc::channel(16);
2387        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2388        consumer.start(ctx).await.unwrap();
2389
2390        let result = consumer.stop().await;
2391        assert!(
2392            result.is_ok(),
2393            "stop should succeed when server is healthy: {:?}",
2394            result
2395        );
2396    }
2397}