Skip to main content

camel_component_ws/
lib.rs

1pub mod bundle;
2pub mod config;
3
4pub use bundle::WsBundle;
5pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use async_trait::async_trait;
11use axum::body::Body;
12use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
13use axum::extract::{FromRequest, Request, State};
14use axum::http::{StatusCode, header};
15use axum::response::IntoResponse;
16use axum::{Router, serve};
17use camel_component_api::{
18    Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
19};
20use camel_component_api::{
21    Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
22    ProducerContext,
23};
24use dashmap::DashMap;
25use futures::{SinkExt, StreamExt};
26use std::future::Future;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29use tokio::sync::{OnceCell, RwLock, mpsc};
30use tokio::task::JoinHandle;
31use tokio_tungstenite::tungstenite;
32use tokio_tungstenite::tungstenite::client::IntoClientRequest;
33use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
34use tower::Service;
35
36#[derive(Clone)]
37struct WsPathConfig {
38    max_connections: u32,
39    max_message_size: u32,
40    heartbeat_interval: std::time::Duration,
41    idle_timeout: std::time::Duration,
42    allow_origin: String,
43}
44
45impl Default for WsPathConfig {
46    fn default() -> Self {
47        let cfg = WsEndpointConfig::default();
48        Self {
49            max_connections: cfg.max_connections,
50            max_message_size: cfg.max_message_size,
51            heartbeat_interval: cfg.heartbeat_interval,
52            idle_timeout: cfg.idle_timeout,
53            allow_origin: cfg.allow_origin,
54        }
55    }
56}
57
58#[derive(Clone)]
59struct WsTlsConfig {
60    cert_path: String,
61    key_path: String,
62}
63
64type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
65
66struct ServerHandle {
67    state: WsAppState,
68    is_tls: bool,
69    _task: JoinHandle<()>,
70}
71
72pub struct ServerRegistry {
73    inner: Mutex<HashMap<u16, Arc<OnceCell<ServerHandle>>>>,
74}
75
76impl ServerRegistry {
77    pub fn global() -> &'static Self {
78        static REG: OnceLock<ServerRegistry> = OnceLock::new();
79        REG.get_or_init(|| Self {
80            inner: Mutex::new(HashMap::new()),
81        })
82    }
83
84    pub(crate) async fn get_or_spawn(
85        &'static self,
86        host: &str,
87        port: u16,
88        tls_config: Option<WsTlsConfig>,
89    ) -> Result<WsAppState, CamelError> {
90        let wants_tls = tls_config.is_some();
91        let host_owned = host.to_string();
92
93        let cell = {
94            let mut guard = self.inner.lock().map_err(|_| {
95                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
96            })?;
97            guard
98                .entry(port)
99                .or_insert_with(|| Arc::new(OnceCell::new()))
100                .clone()
101        };
102
103        let handle = cell
104            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
105            .await?;
106
107        if wants_tls != handle.is_tls {
108            return Err(CamelError::EndpointCreationFailed(format!(
109                "Server on port {port} already running with different TLS mode"
110            )));
111        }
112
113        Ok(handle.state.clone())
114    }
115}
116
117async fn spawn_server(
118    host: &str,
119    port: u16,
120    tls_config: Option<WsTlsConfig>,
121) -> Result<ServerHandle, CamelError> {
122    let addr = format!("{host}:{port}");
123    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
124    let path_configs = Arc::new(DashMap::new());
125    let state = WsAppState {
126        dispatch: Arc::clone(&dispatch),
127        path_configs: Arc::clone(&path_configs),
128    };
129    let app = Router::new()
130        .fallback(dispatch_handler)
131        .with_state(state.clone());
132
133    let (task, is_tls) = if let Some(ref tls) = tls_config {
134        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
135        let parsed_addr = addr.parse().map_err(|e| {
136            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
137        })?;
138        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
139        let task = tokio::spawn(async move {
140            let _ = axum_server::bind_rustls(parsed_addr, tls_cfg)
141                .serve(app.into_make_service())
142                .await;
143        });
144        (task, true)
145    } else {
146        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
147            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
148        })?;
149        let task = tokio::spawn(async move {
150            let _ = serve(listener, app).await;
151        });
152        (task, false)
153    };
154
155    Ok(ServerHandle {
156        state,
157        is_tls,
158        _task: task,
159    })
160}
161
162#[derive(Clone)]
163struct WsAppState {
164    dispatch: DispatchTable,
165    path_configs: Arc<DashMap<String, WsPathConfig>>,
166}
167
168pub struct WsConnectionRegistry {
169    connections: DashMap<String, mpsc::Sender<WsMessage>>,
170}
171
172static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
173    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
174> = OnceLock::new();
175
176fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
177    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
178}
179
180impl Default for WsConnectionRegistry {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186impl WsConnectionRegistry {
187    pub fn new() -> Self {
188        Self {
189            connections: DashMap::new(),
190        }
191    }
192
193    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
194        self.connections.insert(key, tx);
195    }
196
197    pub fn remove(&self, key: &str) {
198        self.connections.remove(key);
199    }
200
201    pub fn len(&self) -> usize {
202        self.connections.len()
203    }
204
205    pub fn is_empty(&self) -> bool {
206        self.connections.is_empty()
207    }
208
209    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
210        self.connections.iter().map(|e| e.value().clone()).collect()
211    }
212
213    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
214        keys.iter()
215            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
216            .collect()
217    }
218}
219
220async fn dispatch_handler(
221    State(state): State<WsAppState>,
222    req: Request<Body>,
223) -> impl IntoResponse {
224    let path = req.uri().path().to_string();
225    let origin = req
226        .headers()
227        .get(header::ORIGIN)
228        .and_then(|value| value.to_str().ok())
229        .map(str::to_string);
230    let remote_addr = req
231        .extensions()
232        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
233        .map(|ci| ci.0.to_string())
234        .unwrap_or_default();
235    let table = state.dispatch.read().await;
236    if !table.contains_key(&path) {
237        return (
238            StatusCode::NOT_FOUND,
239            "no ws endpoint registered for this path",
240        )
241            .into_response();
242    }
243    drop(table);
244
245    let path_config = state
246        .path_configs
247        .get(&path)
248        .map(|entry| entry.value().clone())
249        .unwrap_or_default();
250    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
251        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
252    }
253
254    let upgrade_headers: HashMap<String, String> = req
255        .headers()
256        .iter()
257        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
258        .collect();
259
260    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
261        Ok(ws) => ws,
262        Err(_) => {
263            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
264        }
265    };
266
267    ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
268        .into_response()
269}
270
271#[allow(unused_variables)]
272async fn ws_handler(
273    socket: WebSocket,
274    state: WsAppState,
275    path: String,
276    remote_addr: String,
277    upgrade_headers: HashMap<String, String>,
278) {
279    let connection_key = uuid::Uuid::new_v4().to_string();
280    let path_config = state
281        .path_configs
282        .get(&path)
283        .map(|entry| entry.value().clone())
284        .unwrap_or_default();
285
286    let env_tx = {
287        let table = state.dispatch.read().await;
288        table.get(&path).cloned()
289    };
290    let Some(env_tx) = env_tx else {
291        return;
292    };
293
294    let (mut sink, mut stream) = socket.split();
295    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
296
297    let registry = global_registries();
298    let mut registry_key = None;
299    for entry in registry.iter() {
300        if entry.key().2 == path {
301            entry.value().insert(connection_key.clone(), out_tx.clone());
302            registry_key = Some(entry.key().clone());
303            break;
304        }
305    }
306
307    let writer = tokio::spawn(async move {
308        while let Some(msg) = out_rx.recv().await {
309            let _ = sink.send(msg).await;
310        }
311    });
312
313    let mut over_limit = false;
314    if let Some(key) = &registry_key
315        && let Some(entry) = registry.get(key)
316        && entry.len() > path_config.max_connections as usize
317    {
318        over_limit = true;
319    }
320    if over_limit {
321        try_send_with_backpressure(
322            &out_tx,
323            WsMessage::Close(Some(CloseFrame {
324                code: CloseCode::from(1013u16),
325                reason: "max connections exceeded".into(),
326            })),
327            "max-connections-close",
328        );
329        if let Some(key) = registry_key.clone()
330            && let Some(entry) = registry.get(&key)
331        {
332            entry.remove(&connection_key);
333        }
334        drop(out_tx);
335        let _ = writer.await;
336        return;
337    }
338
339    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
340        let ping_tx = out_tx.clone();
341        let interval = path_config.heartbeat_interval;
342        Some(tokio::spawn(async move {
343            let mut ticker = tokio::time::interval(interval);
344            loop {
345                ticker.tick().await;
346                let _ = try_send_with_backpressure(
347                    &ping_tx,
348                    WsMessage::Ping(Vec::new().into()),
349                    "heartbeat-ping",
350                );
351            }
352        }))
353    } else {
354        None
355    };
356
357    loop {
358        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
359            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
360                Ok(msg) => msg,
361                Err(_) => {
362                    try_send_with_backpressure(
363                        &out_tx,
364                        WsMessage::Close(Some(CloseFrame {
365                            code: CloseCode::from(1000u16),
366                            reason: "idle timeout".into(),
367                        })),
368                        "idle-timeout-close",
369                    );
370                    break;
371                }
372            }
373        } else {
374            stream.next().await
375        };
376
377        let Some(msg) = next_msg else {
378            break;
379        };
380
381        match msg {
382            Ok(WsMessage::Text(text)) => {
383                if text.len() > path_config.max_message_size as usize {
384                    try_send_with_backpressure(
385                        &out_tx,
386                        WsMessage::Close(Some(CloseFrame {
387                            code: CloseCode::from(1009u16),
388                            reason: "message too large".into(),
389                        })),
390                        "max-message-size-close-text",
391                    );
392                    break;
393                }
394
395                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
396                message.set_header(
397                    "CamelWsConnectionKey",
398                    serde_json::Value::String(connection_key.clone()),
399                );
400                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
401                message.set_header(
402                    "CamelWsRemoteAddress",
403                    serde_json::Value::String(remote_addr.clone()),
404                );
405
406                #[allow(unused_mut)]
407                let mut exchange = Exchange::new(message);
408                #[cfg(feature = "otel")]
409                {
410                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
411                }
412                if env_tx
413                    .send(ExchangeEnvelope {
414                        exchange,
415                        reply_tx: None,
416                    })
417                    .await
418                    .is_err()
419                {
420                    break;
421                }
422            }
423            Ok(WsMessage::Binary(data)) => {
424                if data.len() > path_config.max_message_size as usize {
425                    try_send_with_backpressure(
426                        &out_tx,
427                        WsMessage::Close(Some(CloseFrame {
428                            code: CloseCode::from(1009u16),
429                            reason: "message too large".into(),
430                        })),
431                        "max-message-size-close-binary",
432                    );
433                    break;
434                }
435
436                let mut message = CamelMessage::new(CamelBody::Bytes(data));
437                message.set_header(
438                    "CamelWsConnectionKey",
439                    serde_json::Value::String(connection_key.clone()),
440                );
441                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
442                message.set_header(
443                    "CamelWsRemoteAddress",
444                    serde_json::Value::String(remote_addr.clone()),
445                );
446
447                #[allow(unused_mut)]
448                let mut exchange = Exchange::new(message);
449                #[cfg(feature = "otel")]
450                {
451                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
452                }
453                if env_tx
454                    .send(ExchangeEnvelope {
455                        exchange,
456                        reply_tx: None,
457                    })
458                    .await
459                    .is_err()
460                {
461                    break;
462                }
463            }
464            Ok(WsMessage::Close(_)) | Err(_) => break,
465            _ => {}
466        }
467    }
468
469    if let Some(task) = heartbeat_task {
470        task.abort();
471    }
472
473    if let Some(key) = registry_key
474        && let Some(entry) = registry.get(&key)
475    {
476        entry.remove(&connection_key);
477    }
478    drop(out_tx);
479    let _ = writer.await;
480}
481
482pub struct WsComponent {
483    pub(crate) config: WsConfig,
484}
485
486impl WsComponent {
487    pub fn new() -> Self {
488        Self {
489            config: WsConfig::default(),
490        }
491    }
492
493    pub fn with_config(config: WsConfig) -> Self {
494        Self { config }
495    }
496}
497
498impl Default for WsComponent {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504impl Component for WsComponent {
505    fn scheme(&self) -> &str {
506        "ws"
507    }
508
509    fn create_endpoint(
510        &self,
511        uri: &str,
512        _ctx: &dyn camel_component_api::ComponentContext,
513    ) -> Result<Box<dyn Endpoint>, CamelError> {
514        let mut cfg = WsEndpointConfig::from_uri(uri)?;
515        if let Some(v) = self.config.max_connections {
516            cfg.max_connections = v;
517        }
518        if let Some(v) = self.config.max_message_size {
519            cfg.max_message_size = v;
520        }
521        if let Some(v) = self.config.heartbeat_interval_ms {
522            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
523        }
524        if let Some(v) = self.config.idle_timeout_ms {
525            cfg.idle_timeout = std::time::Duration::from_millis(v);
526        }
527        if let Some(v) = self.config.connect_timeout_ms {
528            cfg.connect_timeout = std::time::Duration::from_millis(v);
529        }
530        if let Some(v) = self.config.response_timeout_ms {
531            cfg.response_timeout = std::time::Duration::from_millis(v);
532        }
533        Ok(Box::new(WsEndpoint {
534            uri: uri.to_string(),
535            cfg,
536        }))
537    }
538}
539
540pub struct WssComponent {
541    pub(crate) config: WsConfig,
542}
543
544impl WssComponent {
545    pub fn new() -> Self {
546        Self {
547            config: WsConfig::default(),
548        }
549    }
550
551    pub fn with_config(config: WsConfig) -> Self {
552        Self { config }
553    }
554}
555
556impl Default for WssComponent {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562impl Component for WssComponent {
563    fn scheme(&self) -> &str {
564        "wss"
565    }
566
567    fn create_endpoint(
568        &self,
569        uri: &str,
570        _ctx: &dyn camel_component_api::ComponentContext,
571    ) -> Result<Box<dyn Endpoint>, CamelError> {
572        let mut cfg = WsEndpointConfig::from_uri(uri)?;
573        if let Some(v) = self.config.max_connections {
574            cfg.max_connections = v;
575        }
576        if let Some(v) = self.config.max_message_size {
577            cfg.max_message_size = v;
578        }
579        if let Some(v) = self.config.heartbeat_interval_ms {
580            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
581        }
582        if let Some(v) = self.config.idle_timeout_ms {
583            cfg.idle_timeout = std::time::Duration::from_millis(v);
584        }
585        if let Some(v) = self.config.connect_timeout_ms {
586            cfg.connect_timeout = std::time::Duration::from_millis(v);
587        }
588        if let Some(v) = self.config.response_timeout_ms {
589            cfg.response_timeout = std::time::Duration::from_millis(v);
590        }
591        Ok(Box::new(WsEndpoint {
592            uri: uri.to_string(),
593            cfg,
594        }))
595    }
596}
597
598struct WsEndpoint {
599    uri: String,
600    cfg: WsEndpointConfig,
601}
602
603impl Endpoint for WsEndpoint {
604    fn uri(&self) -> &str {
605        &self.uri
606    }
607
608    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
609        Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
610    }
611
612    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
613        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
614    }
615}
616
617pub struct WsConsumer {
618    cfg: WsServerConfig,
619    registry: Arc<WsConnectionRegistry>,
620    server_state: Option<WsAppState>,
621    registry_key: Option<(String, u16, String)>,
622    forward_task: Option<JoinHandle<()>>,
623}
624
625impl WsConsumer {
626    pub fn new(cfg: WsServerConfig) -> Self {
627        Self {
628            cfg,
629            registry: Arc::new(WsConnectionRegistry::new()),
630            server_state: None,
631            registry_key: None,
632            forward_task: None,
633        }
634    }
635}
636
637#[async_trait]
638impl Consumer for WsConsumer {
639    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
640        let tls_config = if self.cfg.inner.scheme == "wss" {
641            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
642                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
643            })?;
644            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
645                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
646            })?;
647            Some(WsTlsConfig {
648                cert_path,
649                key_path,
650            })
651        } else {
652            None
653        };
654
655        let state = ServerRegistry::global()
656            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
657            .await?;
658
659        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
660        {
661            let mut table = state.dispatch.write().await;
662            table.insert(self.cfg.inner.path.clone(), env_tx);
663        }
664
665        state.path_configs.insert(
666            self.cfg.inner.path.clone(),
667            WsPathConfig {
668                max_connections: self.cfg.inner.max_connections,
669                max_message_size: self.cfg.inner.max_message_size,
670                heartbeat_interval: self.cfg.inner.heartbeat_interval,
671                idle_timeout: self.cfg.inner.idle_timeout,
672                allow_origin: self.cfg.inner.allow_origin.clone(),
673            },
674        );
675
676        let registry_key = (
677            self.cfg.inner.canonical_host(),
678            self.cfg.inner.port,
679            self.cfg.inner.path.clone(),
680        );
681        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
682
683        let sender = ctx.sender();
684        let forward_task = tokio::spawn(async move {
685            while let Some(envelope) = env_rx.recv().await {
686                if sender.send(envelope).await.is_err() {
687                    break;
688                }
689            }
690        });
691
692        self.server_state = Some(state);
693        self.registry_key = Some(registry_key);
694        self.forward_task = Some(forward_task);
695        Ok(())
696    }
697
698    async fn stop(&mut self) -> Result<(), CamelError> {
699        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
700            code: axum::extract::ws::CloseCode::from(1001u16),
701            reason: "consumer stopping".into(),
702        }));
703        for tx in self.registry.snapshot_senders() {
704            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
705        }
706
707        if let Some(state) = self.server_state.take() {
708            let mut table = state.dispatch.write().await;
709            table.remove(&self.cfg.inner.path);
710            state.path_configs.remove(&self.cfg.inner.path);
711        }
712
713        if let Some(key) = self.registry_key.take() {
714            global_registries().remove(&key);
715        }
716
717        if let Some(task) = self.forward_task.take() {
718            task.abort();
719        }
720
721        Ok(())
722    }
723
724    fn concurrency_model(&self) -> ConcurrencyModel {
725        ConcurrencyModel::Concurrent {
726            max: Some(self.cfg.inner.max_connections as usize),
727        }
728    }
729}
730
731#[derive(Clone)]
732pub struct WsProducer {
733    cfg: WsClientConfig,
734}
735
736impl WsProducer {
737    pub fn new(cfg: WsClientConfig) -> Self {
738        Self { cfg }
739    }
740}
741
742impl Service<Exchange> for WsProducer {
743    type Response = Exchange;
744    type Error = CamelError;
745    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
746
747    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
748        Poll::Ready(Ok(()))
749    }
750
751    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
752        let cfg = self.cfg.clone();
753
754        Box::pin(async move {
755            let canonical_host = cfg.inner.canonical_host();
756            let key = (
757                canonical_host.clone(),
758                cfg.inner.port,
759                cfg.inner.path.clone(),
760            );
761
762            let send_to_all = exchange
763                .input
764                .header("CamelWsSendToAll")
765                .and_then(|v| v.as_bool())
766                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
767                .unwrap_or(false);
768
769            let conn_keys_header = exchange
770                .input
771                .header("CamelWsConnectionKey")
772                .and_then(|v| v.as_str())
773                .map(str::to_string);
774
775            let local_exists = global_registries().contains_key(&key);
776            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
777
778            let message_type = exchange
779                .input
780                .header("CamelWsMessageType")
781                .and_then(|v| v.as_str())
782                .unwrap_or("text")
783                .to_ascii_lowercase();
784
785            if server_send_mode {
786                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
787                let Some(registry) = registry else {
788                    return Err(CamelError::ProcessorError(format!(
789                        "WebSocket local consumer not found for {}:{}{}",
790                        canonical_host, cfg.inner.port, cfg.inner.path
791                    )));
792                };
793
794                let out_msg = body_to_axum_ws_message(
795                    std::mem::take(&mut exchange.input.body),
796                    &message_type,
797                )
798                .await?;
799
800                let targets = if send_to_all {
801                    registry.snapshot_senders()
802                } else if let Some(keys) = conn_keys_header {
803                    let parsed: Vec<String> = keys
804                        .split(',')
805                        .map(str::trim)
806                        .filter(|k| !k.is_empty())
807                        .map(str::to_string)
808                        .collect();
809                    registry.get_senders_for_keys(&parsed)
810                } else {
811                    registry.snapshot_senders()
812                };
813
814                for tx in targets {
815                    let _ = try_send_with_backpressure(&tx, out_msg.clone(), "producer-send");
816                }
817
818                return Ok(exchange);
819            }
820
821            let url = format!(
822                "{}://{}:{}{}",
823                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
824            );
825
826            #[allow(unused_mut)]
827            let mut request = url
828                .clone()
829                .into_client_request()
830                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
831
832            #[cfg(feature = "otel")]
833            {
834                let mut otel_headers = HashMap::new();
835                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
836                for (k, v) in otel_headers {
837                    if let (Ok(name), Ok(val)) = (
838                        http::header::HeaderName::from_bytes(k.as_bytes()),
839                        http::header::HeaderValue::from_str(&v),
840                    ) {
841                        request.headers_mut().insert(name, val);
842                    }
843                }
844            }
845
846            let connect_future = tokio_tungstenite::connect_async(request);
847            let (mut ws_stream, _) =
848                tokio::time::timeout(cfg.inner.connect_timeout, connect_future)
849                    .await
850                    .map_err(|_| {
851                        CamelError::ProcessorError(format!(
852                            "WebSocket connect timeout ({:?}) to {url}",
853                            cfg.inner.connect_timeout
854                        ))
855                    })?
856                    .map_err(|e| map_connect_error(e, &url))?;
857
858            let out_msg =
859                body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
860                    .await?;
861
862            ws_stream
863                .send(out_msg)
864                .await
865                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
866
867            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
868                loop {
869                    match ws_stream.next().await {
870                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
871                            continue;
872                        }
873                        other => break other,
874                    }
875                }
876            })
877            .await
878            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
879
880            match incoming {
881                Some(Ok(ClientWsMessage::Text(text))) => {
882                    exchange.input.body = CamelBody::Text(text.to_string());
883                }
884                Some(Ok(ClientWsMessage::Binary(data))) => {
885                    exchange.input.body = CamelBody::Bytes(data);
886                }
887                Some(Ok(ClientWsMessage::Close(frame))) => {
888                    let normal = frame
889                        .as_ref()
890                        .map(|f| {
891                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
892                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
893                        })
894                        .unwrap_or(true);
895
896                    if normal {
897                        exchange.input.body = CamelBody::Empty;
898                    } else {
899                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
900                        return Err(CamelError::ProcessorError(format!(
901                            "WebSocket peer closed: code {code}"
902                        )));
903                    }
904                }
905                Some(Ok(_)) | None => {
906                    exchange.input.body = CamelBody::Empty;
907                }
908                Some(Err(e)) => {
909                    return Err(CamelError::ProcessorError(format!(
910                        "WebSocket receive failed: {e}"
911                    )));
912                }
913            }
914
915            let _ = ws_stream.close(None).await;
916            Ok(exchange)
917        })
918    }
919}
920
921async fn body_to_axum_ws_message(
922    body: CamelBody,
923    message_type: &str,
924) -> Result<WsMessage, CamelError> {
925    match message_type {
926        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
927        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
928    }
929}
930
931async fn body_to_client_ws_message(
932    body: CamelBody,
933    message_type: &str,
934) -> Result<ClientWsMessage, CamelError> {
935    match message_type {
936        "binary" => Ok(ClientWsMessage::Binary(
937            body.into_bytes(10 * 1024 * 1024).await?,
938        )),
939        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
940    }
941}
942
943async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
944    Ok(match body {
945        CamelBody::Empty => String::new(),
946        CamelBody::Text(s) => s,
947        CamelBody::Xml(s) => s,
948        CamelBody::Json(v) => v.to_string(),
949        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
950        CamelBody::Stream(stream) => {
951            let bytes = CamelBody::Stream(stream)
952                .into_bytes(10 * 1024 * 1024)
953                .await?;
954            String::from_utf8_lossy(&bytes).to_string()
955        }
956    })
957}
958
959fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
960    if allowed_origin == "*" {
961        return true;
962    }
963    request_origin.is_some_and(|origin| origin == allowed_origin)
964}
965
966fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
967    match tx.try_send(msg) {
968        Ok(()) => true,
969        Err(error) => {
970            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
971            false
972        }
973    }
974}
975
976fn load_tls_config(
977    cert_path: &str,
978    key_path: &str,
979) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
980    use std::fs::File;
981    use std::io::BufReader;
982
983    let cert_file = File::open(cert_path)
984        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
985    let key_file = File::open(key_path)
986        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
987
988    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
989        .collect::<Result<Vec<_>, _>>()
990        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
991
992    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
993        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
994        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
995
996    tokio_rustls::rustls::ServerConfig::builder()
997        .with_no_client_auth()
998        .with_single_cert(certs, key)
999        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1000}
1001
1002fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1003    match err {
1004        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1005            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1006        }
1007        tungstenite::Error::Tls(_) => {
1008            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1009        }
1010        other => {
1011            let msg = other.to_string();
1012            if msg.to_lowercase().contains("connection refused") {
1013                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1014            } else if msg.to_lowercase().contains("tls") {
1015                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1016            } else {
1017                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1018            }
1019        }
1020    }
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025    use super::*;
1026    use camel_component_api::NoOpComponentContext;
1027    use std::time::Duration;
1028
1029    use tokio::sync::mpsc;
1030    use tokio_tungstenite::connect_async;
1031    use tokio_tungstenite::tungstenite::Message as ClientMessage;
1032    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1033    use tokio_util::sync::CancellationToken;
1034    use tower::ServiceExt;
1035
1036    fn free_port() -> u16 {
1037        std::net::TcpListener::bind("127.0.0.1:0")
1038            .unwrap()
1039            .local_addr()
1040            .unwrap()
1041            .port()
1042    }
1043
1044    #[test]
1045    fn ws_component_scheme_is_ws() {
1046        assert_eq!(WsComponent::new().scheme(), "ws");
1047    }
1048
1049    #[test]
1050    fn wss_component_scheme_is_wss() {
1051        assert_eq!(WssComponent::new().scheme(), "wss");
1052    }
1053
1054    #[test]
1055    fn endpoint_config_defaults_match_spec() {
1056        let cfg = WsEndpointConfig::default();
1057        assert_eq!(cfg.scheme, "ws");
1058        assert_eq!(cfg.host, "0.0.0.0");
1059        assert_eq!(cfg.port, 8080);
1060        assert_eq!(cfg.path, "/");
1061        assert_eq!(cfg.max_connections, 100);
1062        assert_eq!(cfg.max_message_size, 65536);
1063        assert!(!cfg.send_to_all);
1064        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1065        assert_eq!(cfg.idle_timeout, Duration::ZERO);
1066        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1067        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1068        assert_eq!(cfg.allow_origin, "*");
1069        assert_eq!(cfg.tls_cert, None);
1070        assert_eq!(cfg.tls_key, None);
1071    }
1072
1073    #[test]
1074    fn endpoint_config_parses_uri_params() {
1075        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1076        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1077
1078        assert_eq!(cfg.scheme, "ws");
1079        assert_eq!(cfg.host, "localhost");
1080        assert_eq!(cfg.port, 9001);
1081        assert_eq!(cfg.path, "/chat");
1082        assert_eq!(cfg.max_connections, 42);
1083        assert_eq!(cfg.max_message_size, 1024);
1084        assert!(cfg.send_to_all);
1085        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1086        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1087        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1088        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1089        assert_eq!(cfg.allow_origin, "https://example.com");
1090        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1091        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1092    }
1093
1094    #[test]
1095    fn endpoint_config_override_chain_uri_overrides_defaults() {
1096        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1097        assert_eq!(cfg.max_connections, 7);
1098        assert_eq!(cfg.max_message_size, 65536);
1099        assert!(!cfg.send_to_all);
1100        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1101    }
1102
1103    #[test]
1104    fn endpoint_trait_creates_consumer_and_producer() {
1105        let ctx = NoOpComponentContext;
1106        let endpoint = WsComponent::new()
1107            .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1108            .unwrap();
1109
1110        endpoint.create_consumer().unwrap();
1111        endpoint
1112            .create_producer(&ProducerContext::default())
1113            .unwrap();
1114    }
1115
1116    #[test]
1117    fn ws_consumer_concurrency_model_uses_max_connections() {
1118        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1119        let consumer = WsConsumer::new(cfg.server_config());
1120        assert_eq!(
1121            consumer.concurrency_model(),
1122            ConcurrencyModel::Concurrent { max: Some(321) }
1123        );
1124    }
1125
1126    #[tokio::test]
1127    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1128        let registry = WsConnectionRegistry::new();
1129        let (tx1, mut rx1) = mpsc::channel(8);
1130        let (tx2, mut rx2) = mpsc::channel(8);
1131
1132        registry.insert("k1".into(), tx1);
1133        registry.insert("k2".into(), tx2);
1134        assert_eq!(registry.len(), 2);
1135
1136        for tx in registry.snapshot_senders() {
1137            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1138        }
1139
1140        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1141        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1142
1143        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1144        assert_eq!(target.len(), 1);
1145        target[0]
1146            .send(WsMessage::Text("targeted".into()))
1147            .await
1148            .unwrap();
1149
1150        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1151        assert!(
1152            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1153                .await
1154                .is_err()
1155        );
1156
1157        registry.remove("k1");
1158        assert_eq!(registry.len(), 1);
1159    }
1160
1161    #[test]
1162    fn host_canonicalization_maps_local_hosts_to_loopback() {
1163        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1164            .unwrap()
1165            .canonical_host();
1166        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1167            .unwrap()
1168            .canonical_host();
1169        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1170            .unwrap()
1171            .canonical_host();
1172
1173        assert_eq!(c1, "127.0.0.1");
1174        assert_eq!(c2, "127.0.0.1");
1175        assert_eq!(c3, "127.0.0.1");
1176    }
1177
1178    #[tokio::test]
1179    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1180        let port = free_port();
1181        let uri = format!("ws://127.0.0.1:{port}/echo");
1182        let component_ctx = NoOpComponentContext;
1183        let endpoint = WsComponent::new()
1184            .create_endpoint(&uri, &component_ctx)
1185            .unwrap();
1186
1187        let mut consumer = endpoint.create_consumer().unwrap();
1188        let producer = endpoint
1189            .create_producer(&ProducerContext::default())
1190            .unwrap();
1191
1192        let (route_tx, mut route_rx) = mpsc::channel(16);
1193        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1194        consumer.start(ctx).await.unwrap();
1195
1196        let route_task = tokio::spawn(async move {
1197            if let Some(envelope) = route_rx.recv().await {
1198                let payload = envelope
1199                    .exchange
1200                    .input
1201                    .body
1202                    .as_text()
1203                    .unwrap_or_default()
1204                    .to_string();
1205                let key = envelope
1206                    .exchange
1207                    .input
1208                    .header("CamelWsConnectionKey")
1209                    .and_then(|v| v.as_str())
1210                    .unwrap()
1211                    .to_string();
1212
1213                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1214                response
1215                    .input
1216                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1217                producer.oneshot(response).await.unwrap();
1218            }
1219        });
1220
1221        let url = format!("ws://127.0.0.1:{port}/echo");
1222        let (mut client, _) = loop {
1223            match connect_async(&url).await {
1224                Ok(ok) => break ok,
1225                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1226            }
1227        };
1228
1229        client
1230            .send(ClientMessage::Text("hello-ws".into()))
1231            .await
1232            .unwrap();
1233
1234        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1235            loop {
1236                match client.next().await {
1237                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1238                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1239                    Some(Ok(_)) => continue,
1240                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1241                    None => panic!("websocket closed before echo"),
1242                }
1243            }
1244        })
1245        .await
1246        .unwrap();
1247
1248        assert_eq!(incoming, "hello-ws");
1249
1250        consumer.stop().await.unwrap();
1251        route_task.await.unwrap();
1252    }
1253
1254    #[tokio::test]
1255    async fn consumer_stop_sends_close_1001() {
1256        let port = free_port();
1257        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1258        let component_ctx = NoOpComponentContext;
1259        let endpoint = WsComponent::new()
1260            .create_endpoint(&uri, &component_ctx)
1261            .unwrap();
1262
1263        let mut consumer = endpoint.create_consumer().unwrap();
1264        let (route_tx, _route_rx) = mpsc::channel(16);
1265        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1266        consumer.start(ctx).await.unwrap();
1267
1268        let url = format!("ws://127.0.0.1:{port}/shutdown");
1269        let (mut client, _) = loop {
1270            match connect_async(&url).await {
1271                Ok(ok) => break ok,
1272                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1273            }
1274        };
1275
1276        client
1277            .send(ClientMessage::Text("keepalive".into()))
1278            .await
1279            .unwrap();
1280
1281        consumer.stop().await.unwrap();
1282
1283        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1284            loop {
1285                match client.next().await {
1286                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1287                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1288                    Some(Ok(_)) => continue,
1289                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1290                    None => panic!("websocket closed without close frame"),
1291                }
1292            }
1293        })
1294        .await
1295        .unwrap();
1296
1297        assert_eq!(close_code, Some(CloseCode::Away));
1298    }
1299
1300    #[test]
1301    fn wildcard_origin_allows_anything() {
1302        assert!(is_origin_allowed("*", None));
1303        assert!(is_origin_allowed("*", Some("https://example.com")));
1304    }
1305
1306    #[test]
1307    fn exact_origin_requires_match() {
1308        assert!(is_origin_allowed(
1309            "https://example.com",
1310            Some("https://example.com")
1311        ));
1312        assert!(!is_origin_allowed(
1313            "https://example.com",
1314            Some("https://other.com")
1315        ));
1316        assert!(!is_origin_allowed("https://example.com", None));
1317    }
1318
1319    #[test]
1320    fn endpoint_config_rejects_invalid_scheme() {
1321        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1322        assert!(result.is_err());
1323        let msg = result.unwrap_err().to_string();
1324        assert!(
1325            msg.contains("Invalid WebSocket scheme"),
1326            "expected scheme error, got: {msg}"
1327        );
1328    }
1329
1330    #[tokio::test]
1331    async fn wss_consumer_start_fails_without_tls_cert() {
1332        let port = free_port();
1333        let component_ctx = NoOpComponentContext;
1334        let endpoint = WssComponent::new()
1335            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1336            .unwrap();
1337        let mut consumer = endpoint.create_consumer().unwrap();
1338        let (tx, _rx) = mpsc::channel(16);
1339        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1340        let result = consumer.start(ctx).await;
1341        assert!(result.is_err());
1342        let msg = result.unwrap_err().to_string();
1343        assert!(
1344            msg.contains("TLS cert path is required"),
1345            "expected TLS cert error, got: {msg}"
1346        );
1347    }
1348
1349    #[tokio::test]
1350    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1351        let port = free_port();
1352        let component_ctx = NoOpComponentContext;
1353        let endpoint = WssComponent::new()
1354            .create_endpoint(&format!(
1355                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1356            ), &component_ctx)
1357            .unwrap();
1358        let mut consumer = endpoint.create_consumer().unwrap();
1359        let (tx, _rx) = mpsc::channel(16);
1360        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1361        let result = consumer.start(ctx).await;
1362        assert!(result.is_err());
1363        let msg = result.unwrap_err().to_string();
1364        assert!(
1365            msg.contains("TLS cert file error"),
1366            "expected cert file error, got: {msg}"
1367        );
1368    }
1369
1370    #[tokio::test]
1371    async fn server_registry_returns_same_state_for_same_port() {
1372        let port = free_port();
1373        let state1 = ServerRegistry::global()
1374            .get_or_spawn("127.0.0.1", port, None)
1375            .await
1376            .unwrap();
1377        let state2 = ServerRegistry::global()
1378            .get_or_spawn("127.0.0.1", port, None)
1379            .await
1380            .unwrap();
1381        assert!(
1382            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1383            "expected same dispatch table for same port"
1384        );
1385    }
1386
1387    #[tokio::test]
1388    async fn dispatch_handler_returns_404_for_unregistered_path() {
1389        let port = free_port();
1390        let state = ServerRegistry::global()
1391            .get_or_spawn("127.0.0.1", port, None)
1392            .await
1393            .unwrap();
1394        let app = Router::new().fallback(dispatch_handler).with_state(state);
1395        let response = tokio::time::timeout(
1396            Duration::from_secs(2),
1397            tower::ServiceExt::oneshot(
1398                app,
1399                axum::http::Request::builder()
1400                    .uri("/nonexistent")
1401                    .body(Body::empty())
1402                    .unwrap(),
1403            ),
1404        )
1405        .await
1406        .unwrap()
1407        .unwrap();
1408        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1409    }
1410
1411    #[tokio::test]
1412    async fn client_mode_producer_connects_and_echoes() {
1413        let port = free_port();
1414
1415        let app = Router::new().route(
1416            "/echo",
1417            axum::routing::get(|ws: WebSocketUpgrade| async move {
1418                ws.on_upgrade(|mut socket: WebSocket| async move {
1419                    while let Some(Ok(msg)) = socket.recv().await {
1420                        match msg {
1421                            WsMessage::Text(text) => {
1422                                let _ = socket.send(WsMessage::Text(text)).await;
1423                            }
1424                            WsMessage::Binary(data) => {
1425                                let _ = socket.send(WsMessage::Binary(data)).await;
1426                            }
1427                            WsMessage::Close(_) => break,
1428                            _ => {}
1429                        }
1430                    }
1431                })
1432            }),
1433        );
1434        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1435            .await
1436            .unwrap();
1437        let server_task = tokio::spawn(async move {
1438            let _ = serve(listener, app).await;
1439        });
1440
1441        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1442        let producer = WsProducer::new(cfg.client_config());
1443
1444        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1445        tokio::time::sleep(Duration::from_millis(25)).await;
1446        let result =
1447            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1448                Ok(Ok(r)) => r,
1449                Ok(Err(_)) => panic!("producer call failed"),
1450                Err(_) => panic!("producer call timed out"),
1451            };
1452
1453        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1454
1455        server_task.abort();
1456    }
1457
1458    #[tokio::test]
1459    async fn max_connections_rejects_with_close_1013() {
1460        let port = free_port();
1461        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1462        let component_ctx = NoOpComponentContext;
1463        let endpoint = WsComponent::new()
1464            .create_endpoint(&uri, &component_ctx)
1465            .unwrap();
1466        let mut consumer = endpoint.create_consumer().unwrap();
1467        let (route_tx, _route_rx) = mpsc::channel(16);
1468        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1469        consumer.start(ctx).await.unwrap();
1470
1471        let url = format!("ws://127.0.0.1:{port}/limited");
1472        let (_client1, _) = loop {
1473            match connect_async(&url).await {
1474                Ok(ok) => break ok,
1475                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1476            }
1477        };
1478
1479        tokio::time::sleep(Duration::from_millis(100)).await;
1480
1481        let (mut client2, _) = connect_async(&url).await.unwrap();
1482
1483        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1484            loop {
1485                match client2.next().await {
1486                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1487                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1488                    Some(Ok(ClientMessage::Text(_))) => continue,
1489                    Some(Ok(_)) => continue,
1490                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1491                    None => panic!("client2 closed without close frame"),
1492                }
1493            }
1494        })
1495        .await
1496        .unwrap();
1497
1498        assert_eq!(
1499            close_code,
1500            Some(CloseCode::from(1013u16)),
1501            "expected 1013 (Try Again Later) for max connections"
1502        );
1503
1504        consumer.stop().await.unwrap();
1505    }
1506
1507    #[tokio::test]
1508    async fn max_message_size_rejects_with_close_1009() {
1509        let port = free_port();
1510        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1511        let component_ctx = NoOpComponentContext;
1512        let endpoint = WsComponent::new()
1513            .create_endpoint(&uri, &component_ctx)
1514            .unwrap();
1515        let mut consumer = endpoint.create_consumer().unwrap();
1516        let (route_tx, _route_rx) = mpsc::channel(16);
1517        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1518        consumer.start(ctx).await.unwrap();
1519
1520        let url = format!("ws://127.0.0.1:{port}/sizelimit");
1521        let (mut client, _) = loop {
1522            match connect_async(&url).await {
1523                Ok(ok) => break ok,
1524                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1525            }
1526        };
1527
1528        let oversized = "x".repeat(100);
1529        client
1530            .send(ClientMessage::Text(oversized.into()))
1531            .await
1532            .unwrap();
1533
1534        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1535            loop {
1536                match client.next().await {
1537                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1538                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1539                    Some(Ok(_)) => continue,
1540                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1541                    None => panic!("websocket closed without close frame"),
1542                }
1543            }
1544        })
1545        .await
1546        .unwrap();
1547
1548        assert_eq!(
1549            close_code,
1550            Some(CloseCode::from(1009u16)),
1551            "expected 1009 (Message Too Big) for oversized message"
1552        );
1553
1554        consumer.stop().await.unwrap();
1555    }
1556
1557    #[tokio::test]
1558    async fn origin_rejection_returns_403() {
1559        let port = free_port();
1560        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1561        let component_ctx = NoOpComponentContext;
1562        let endpoint = WsComponent::new()
1563            .create_endpoint(&uri, &component_ctx)
1564            .unwrap();
1565        let mut consumer = endpoint.create_consumer().unwrap();
1566        let (route_tx, _route_rx) = mpsc::channel(16);
1567        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1568        consumer.start(ctx).await.unwrap();
1569
1570        let state = ServerRegistry::global()
1571            .get_or_spawn("127.0.0.1", port, None)
1572            .await
1573            .unwrap();
1574        let app = Router::new().fallback(dispatch_handler).with_state(state);
1575
1576        let response = tokio::time::timeout(
1577            Duration::from_secs(2),
1578            tower::ServiceExt::oneshot(
1579                app,
1580                axum::http::Request::builder()
1581                    .uri("/origintest")
1582                    .header("origin", "https://evil.com")
1583                    .header("upgrade", "websocket")
1584                    .header("connection", "Upgrade")
1585                    .header("sec-websocket-version", "13")
1586                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1587                    .body(Body::empty())
1588                    .unwrap(),
1589            ),
1590        )
1591        .await
1592        .unwrap()
1593        .unwrap();
1594
1595        assert_eq!(
1596            response.status(),
1597            StatusCode::FORBIDDEN,
1598            "expected 403 for disallowed origin"
1599        );
1600
1601        consumer.stop().await.unwrap();
1602    }
1603
1604    #[tokio::test]
1605    async fn broadcast_sends_to_all_connected_clients() {
1606        let port = free_port();
1607        let uri = format!("ws://127.0.0.1:{port}/bc");
1608        let component_ctx = NoOpComponentContext;
1609        let endpoint = WsComponent::new()
1610            .create_endpoint(&uri, &component_ctx)
1611            .unwrap();
1612        let mut consumer = endpoint.create_consumer().unwrap();
1613        let producer = endpoint
1614            .create_producer(&ProducerContext::default())
1615            .unwrap();
1616
1617        let (route_tx, _route_rx) = mpsc::channel(16);
1618        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1619        consumer.start(ctx).await.unwrap();
1620
1621        let url = format!("ws://127.0.0.1:{port}/bc");
1622
1623        let (mut client1, _) = loop {
1624            match connect_async(&url).await {
1625                Ok(ok) => break ok,
1626                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1627            }
1628        };
1629
1630        let (mut client2, _) = connect_async(&url).await.unwrap();
1631
1632        tokio::time::sleep(Duration::from_millis(100)).await;
1633
1634        let mut response =
1635            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1636        response
1637            .input
1638            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1639        producer.oneshot(response).await.unwrap();
1640
1641        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1642            loop {
1643                match client1.next().await {
1644                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1645                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1646                    _ => panic!("client1 unexpected message or close"),
1647                }
1648            }
1649        })
1650        .await
1651        .unwrap();
1652
1653        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1654            loop {
1655                match client2.next().await {
1656                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1657                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1658                    _ => panic!("client2 unexpected message or close"),
1659                }
1660            }
1661        })
1662        .await
1663        .unwrap();
1664
1665        assert_eq!(recv1, "broadcast-msg");
1666        assert_eq!(recv2, "broadcast-msg");
1667
1668        consumer.stop().await.unwrap();
1669    }
1670
1671    #[tokio::test]
1672    async fn concurrent_get_or_spawn_returns_same_state() {
1673        let port = free_port();
1674        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1675            Arc::new(std::sync::Mutex::new(Vec::new()));
1676
1677        let mut handles = Vec::new();
1678        for _ in 0..4 {
1679            let results = results.clone();
1680            handles.push(tokio::spawn(async move {
1681                let state = ServerRegistry::global()
1682                    .get_or_spawn("127.0.0.1", port, None)
1683                    .await
1684                    .unwrap();
1685                results.lock().unwrap().push(state);
1686            }));
1687        }
1688
1689        for h in handles {
1690            h.await.unwrap();
1691        }
1692
1693        let states = results.lock().unwrap();
1694        assert_eq!(states.len(), 4);
1695        for i in 1..states.len() {
1696            assert!(
1697                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1698                "all concurrent callers should get the same dispatch table"
1699            );
1700        }
1701    }
1702}