Skip to main content

camel_component_ws/
lib.rs

1pub mod config;
2
3pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex, OnceLock};
7
8use async_trait::async_trait;
9use axum::body::Body;
10use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
11use axum::extract::{FromRequest, Request, State};
12use axum::http::{StatusCode, header};
13use axum::response::IntoResponse;
14use axum::{Router, serve};
15use camel_api::{Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage};
16use camel_component::{
17    Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
18    ProducerContext,
19};
20use dashmap::DashMap;
21use futures::{SinkExt, StreamExt};
22use std::future::Future;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25use tokio::sync::{OnceCell, RwLock, mpsc};
26use tokio::task::JoinHandle;
27use tokio_tungstenite::tungstenite;
28use tokio_tungstenite::tungstenite::client::IntoClientRequest;
29use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
30use tower::Service;
31
32#[derive(Clone)]
33struct WsPathConfig {
34    max_connections: u32,
35    max_message_size: u32,
36    heartbeat_interval: std::time::Duration,
37    idle_timeout: std::time::Duration,
38    allow_origin: String,
39}
40
41impl Default for WsPathConfig {
42    fn default() -> Self {
43        let cfg = WsEndpointConfig::default();
44        Self {
45            max_connections: cfg.max_connections,
46            max_message_size: cfg.max_message_size,
47            heartbeat_interval: cfg.heartbeat_interval,
48            idle_timeout: cfg.idle_timeout,
49            allow_origin: cfg.allow_origin,
50        }
51    }
52}
53
54#[derive(Clone)]
55struct WsTlsConfig {
56    cert_path: String,
57    key_path: String,
58}
59
60type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
61
62struct ServerHandle {
63    state: WsAppState,
64    is_tls: bool,
65    _task: JoinHandle<()>,
66}
67
68pub struct ServerRegistry {
69    inner: Mutex<HashMap<u16, Arc<OnceCell<ServerHandle>>>>,
70}
71
72impl ServerRegistry {
73    pub fn global() -> &'static Self {
74        static REG: OnceLock<ServerRegistry> = OnceLock::new();
75        REG.get_or_init(|| Self {
76            inner: Mutex::new(HashMap::new()),
77        })
78    }
79
80    pub(crate) async fn get_or_spawn(
81        &'static self,
82        host: &str,
83        port: u16,
84        tls_config: Option<WsTlsConfig>,
85    ) -> Result<WsAppState, CamelError> {
86        let wants_tls = tls_config.is_some();
87        let host_owned = host.to_string();
88
89        let cell = {
90            let mut guard = self.inner.lock().map_err(|_| {
91                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
92            })?;
93            guard
94                .entry(port)
95                .or_insert_with(|| Arc::new(OnceCell::new()))
96                .clone()
97        };
98
99        let handle = cell
100            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
101            .await?;
102
103        if wants_tls != handle.is_tls {
104            return Err(CamelError::EndpointCreationFailed(format!(
105                "Server on port {port} already running with different TLS mode"
106            )));
107        }
108
109        Ok(handle.state.clone())
110    }
111}
112
113async fn spawn_server(
114    host: &str,
115    port: u16,
116    tls_config: Option<WsTlsConfig>,
117) -> Result<ServerHandle, CamelError> {
118    let addr = format!("{host}:{port}");
119    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
120    let path_configs = Arc::new(DashMap::new());
121    let state = WsAppState {
122        dispatch: Arc::clone(&dispatch),
123        path_configs: Arc::clone(&path_configs),
124    };
125    let app = Router::new()
126        .fallback(dispatch_handler)
127        .with_state(state.clone());
128
129    let (task, is_tls) = if let Some(ref tls) = tls_config {
130        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
131        let parsed_addr = addr.parse().map_err(|e| {
132            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
133        })?;
134        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
135        let task = tokio::spawn(async move {
136            let _ = axum_server::bind_rustls(parsed_addr, tls_cfg)
137                .serve(app.into_make_service())
138                .await;
139        });
140        (task, true)
141    } else {
142        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
143            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
144        })?;
145        let task = tokio::spawn(async move {
146            let _ = serve(listener, app).await;
147        });
148        (task, false)
149    };
150
151    Ok(ServerHandle {
152        state,
153        is_tls,
154        _task: task,
155    })
156}
157
158#[derive(Clone)]
159struct WsAppState {
160    dispatch: DispatchTable,
161    path_configs: Arc<DashMap<String, WsPathConfig>>,
162}
163
164pub struct WsConnectionRegistry {
165    connections: DashMap<String, mpsc::Sender<WsMessage>>,
166}
167
168static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
169    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
170> = OnceLock::new();
171
172fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
173    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
174}
175
176impl Default for WsConnectionRegistry {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl WsConnectionRegistry {
183    pub fn new() -> Self {
184        Self {
185            connections: DashMap::new(),
186        }
187    }
188
189    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
190        self.connections.insert(key, tx);
191    }
192
193    pub fn remove(&self, key: &str) {
194        self.connections.remove(key);
195    }
196
197    pub fn len(&self) -> usize {
198        self.connections.len()
199    }
200
201    pub fn is_empty(&self) -> bool {
202        self.connections.is_empty()
203    }
204
205    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
206        self.connections.iter().map(|e| e.value().clone()).collect()
207    }
208
209    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
210        keys.iter()
211            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
212            .collect()
213    }
214}
215
216async fn dispatch_handler(
217    State(state): State<WsAppState>,
218    req: Request<Body>,
219) -> impl IntoResponse {
220    let path = req.uri().path().to_string();
221    let origin = req
222        .headers()
223        .get(header::ORIGIN)
224        .and_then(|value| value.to_str().ok())
225        .map(str::to_string);
226    let remote_addr = req
227        .extensions()
228        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
229        .map(|ci| ci.0.to_string())
230        .unwrap_or_default();
231    let table = state.dispatch.read().await;
232    if !table.contains_key(&path) {
233        return (
234            StatusCode::NOT_FOUND,
235            "no ws endpoint registered for this path",
236        )
237            .into_response();
238    }
239    drop(table);
240
241    let path_config = state
242        .path_configs
243        .get(&path)
244        .map(|entry| entry.value().clone())
245        .unwrap_or_default();
246    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
247        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
248    }
249
250    let upgrade_headers: HashMap<String, String> = req
251        .headers()
252        .iter()
253        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
254        .collect();
255
256    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
257        Ok(ws) => ws,
258        Err(_) => {
259            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
260        }
261    };
262
263    ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
264        .into_response()
265}
266
267#[allow(unused_variables)]
268async fn ws_handler(
269    socket: WebSocket,
270    state: WsAppState,
271    path: String,
272    remote_addr: String,
273    upgrade_headers: HashMap<String, String>,
274) {
275    let connection_key = uuid::Uuid::new_v4().to_string();
276    let path_config = state
277        .path_configs
278        .get(&path)
279        .map(|entry| entry.value().clone())
280        .unwrap_or_default();
281
282    let env_tx = {
283        let table = state.dispatch.read().await;
284        table.get(&path).cloned()
285    };
286    let Some(env_tx) = env_tx else {
287        return;
288    };
289
290    let (mut sink, mut stream) = socket.split();
291    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
292
293    let registry = global_registries();
294    let mut registry_key = None;
295    for entry in registry.iter() {
296        if entry.key().2 == path {
297            entry.value().insert(connection_key.clone(), out_tx.clone());
298            registry_key = Some(entry.key().clone());
299            break;
300        }
301    }
302
303    let writer = tokio::spawn(async move {
304        while let Some(msg) = out_rx.recv().await {
305            let _ = sink.send(msg).await;
306        }
307    });
308
309    let mut over_limit = false;
310    if let Some(key) = &registry_key
311        && let Some(entry) = registry.get(key)
312        && entry.len() > path_config.max_connections as usize
313    {
314        over_limit = true;
315    }
316    if over_limit {
317        try_send_with_backpressure(
318            &out_tx,
319            WsMessage::Close(Some(CloseFrame {
320                code: CloseCode::from(1013u16),
321                reason: "max connections exceeded".into(),
322            })),
323            "max-connections-close",
324        );
325        if let Some(key) = registry_key.clone()
326            && let Some(entry) = registry.get(&key)
327        {
328            entry.remove(&connection_key);
329        }
330        drop(out_tx);
331        let _ = writer.await;
332        return;
333    }
334
335    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
336        let ping_tx = out_tx.clone();
337        let interval = path_config.heartbeat_interval;
338        Some(tokio::spawn(async move {
339            let mut ticker = tokio::time::interval(interval);
340            loop {
341                ticker.tick().await;
342                let _ = try_send_with_backpressure(
343                    &ping_tx,
344                    WsMessage::Ping(Vec::new().into()),
345                    "heartbeat-ping",
346                );
347            }
348        }))
349    } else {
350        None
351    };
352
353    loop {
354        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
355            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
356                Ok(msg) => msg,
357                Err(_) => {
358                    try_send_with_backpressure(
359                        &out_tx,
360                        WsMessage::Close(Some(CloseFrame {
361                            code: CloseCode::from(1000u16),
362                            reason: "idle timeout".into(),
363                        })),
364                        "idle-timeout-close",
365                    );
366                    break;
367                }
368            }
369        } else {
370            stream.next().await
371        };
372
373        let Some(msg) = next_msg else {
374            break;
375        };
376
377        match msg {
378            Ok(WsMessage::Text(text)) => {
379                if text.len() > path_config.max_message_size as usize {
380                    try_send_with_backpressure(
381                        &out_tx,
382                        WsMessage::Close(Some(CloseFrame {
383                            code: CloseCode::from(1009u16),
384                            reason: "message too large".into(),
385                        })),
386                        "max-message-size-close-text",
387                    );
388                    break;
389                }
390
391                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
392                message.set_header(
393                    "CamelWsConnectionKey",
394                    serde_json::Value::String(connection_key.clone()),
395                );
396                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
397                message.set_header(
398                    "CamelWsRemoteAddress",
399                    serde_json::Value::String(remote_addr.clone()),
400                );
401
402                #[allow(unused_mut)]
403                let mut exchange = Exchange::new(message);
404                #[cfg(feature = "otel")]
405                {
406                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
407                }
408                if env_tx
409                    .send(ExchangeEnvelope {
410                        exchange,
411                        reply_tx: None,
412                    })
413                    .await
414                    .is_err()
415                {
416                    break;
417                }
418            }
419            Ok(WsMessage::Binary(data)) => {
420                if data.len() > path_config.max_message_size as usize {
421                    try_send_with_backpressure(
422                        &out_tx,
423                        WsMessage::Close(Some(CloseFrame {
424                            code: CloseCode::from(1009u16),
425                            reason: "message too large".into(),
426                        })),
427                        "max-message-size-close-binary",
428                    );
429                    break;
430                }
431
432                let mut message = CamelMessage::new(CamelBody::Bytes(data));
433                message.set_header(
434                    "CamelWsConnectionKey",
435                    serde_json::Value::String(connection_key.clone()),
436                );
437                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
438                message.set_header(
439                    "CamelWsRemoteAddress",
440                    serde_json::Value::String(remote_addr.clone()),
441                );
442
443                #[allow(unused_mut)]
444                let mut exchange = Exchange::new(message);
445                #[cfg(feature = "otel")]
446                {
447                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
448                }
449                if env_tx
450                    .send(ExchangeEnvelope {
451                        exchange,
452                        reply_tx: None,
453                    })
454                    .await
455                    .is_err()
456                {
457                    break;
458                }
459            }
460            Ok(WsMessage::Close(_)) | Err(_) => break,
461            _ => {}
462        }
463    }
464
465    if let Some(task) = heartbeat_task {
466        task.abort();
467    }
468
469    if let Some(key) = registry_key
470        && let Some(entry) = registry.get(&key)
471    {
472        entry.remove(&connection_key);
473    }
474    drop(out_tx);
475    let _ = writer.await;
476}
477
478pub struct WsComponent;
479
480impl Component for WsComponent {
481    fn scheme(&self) -> &str {
482        "ws"
483    }
484
485    fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
486        let cfg = WsEndpointConfig::from_uri(uri)?;
487        Ok(Box::new(WsEndpoint {
488            uri: uri.to_string(),
489            cfg,
490        }))
491    }
492}
493
494pub struct WssComponent;
495
496impl Component for WssComponent {
497    fn scheme(&self) -> &str {
498        "wss"
499    }
500
501    fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
502        let cfg = WsEndpointConfig::from_uri(uri)?;
503        Ok(Box::new(WsEndpoint {
504            uri: uri.to_string(),
505            cfg,
506        }))
507    }
508}
509
510struct WsEndpoint {
511    uri: String,
512    cfg: WsEndpointConfig,
513}
514
515impl Endpoint for WsEndpoint {
516    fn uri(&self) -> &str {
517        &self.uri
518    }
519
520    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
521        Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
522    }
523
524    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
525        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
526    }
527}
528
529pub struct WsConsumer {
530    cfg: WsServerConfig,
531    registry: Arc<WsConnectionRegistry>,
532    server_state: Option<WsAppState>,
533    registry_key: Option<(String, u16, String)>,
534    forward_task: Option<JoinHandle<()>>,
535}
536
537impl WsConsumer {
538    pub fn new(cfg: WsServerConfig) -> Self {
539        Self {
540            cfg,
541            registry: Arc::new(WsConnectionRegistry::new()),
542            server_state: None,
543            registry_key: None,
544            forward_task: None,
545        }
546    }
547}
548
549#[async_trait]
550impl Consumer for WsConsumer {
551    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
552        let tls_config = if self.cfg.inner.scheme == "wss" {
553            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
554                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
555            })?;
556            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
557                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
558            })?;
559            Some(WsTlsConfig {
560                cert_path,
561                key_path,
562            })
563        } else {
564            None
565        };
566
567        let state = ServerRegistry::global()
568            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
569            .await?;
570
571        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
572        {
573            let mut table = state.dispatch.write().await;
574            table.insert(self.cfg.inner.path.clone(), env_tx);
575        }
576
577        state.path_configs.insert(
578            self.cfg.inner.path.clone(),
579            WsPathConfig {
580                max_connections: self.cfg.inner.max_connections,
581                max_message_size: self.cfg.inner.max_message_size,
582                heartbeat_interval: self.cfg.inner.heartbeat_interval,
583                idle_timeout: self.cfg.inner.idle_timeout,
584                allow_origin: self.cfg.inner.allow_origin.clone(),
585            },
586        );
587
588        let registry_key = (
589            self.cfg.inner.canonical_host(),
590            self.cfg.inner.port,
591            self.cfg.inner.path.clone(),
592        );
593        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
594
595        let sender = ctx.sender();
596        let forward_task = tokio::spawn(async move {
597            while let Some(envelope) = env_rx.recv().await {
598                if sender.send(envelope).await.is_err() {
599                    break;
600                }
601            }
602        });
603
604        self.server_state = Some(state);
605        self.registry_key = Some(registry_key);
606        self.forward_task = Some(forward_task);
607        Ok(())
608    }
609
610    async fn stop(&mut self) -> Result<(), CamelError> {
611        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
612            code: axum::extract::ws::CloseCode::from(1001u16),
613            reason: "consumer stopping".into(),
614        }));
615        for tx in self.registry.snapshot_senders() {
616            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
617        }
618
619        if let Some(state) = self.server_state.take() {
620            let mut table = state.dispatch.write().await;
621            table.remove(&self.cfg.inner.path);
622            state.path_configs.remove(&self.cfg.inner.path);
623        }
624
625        if let Some(key) = self.registry_key.take() {
626            global_registries().remove(&key);
627        }
628
629        if let Some(task) = self.forward_task.take() {
630            task.abort();
631        }
632
633        Ok(())
634    }
635
636    fn concurrency_model(&self) -> ConcurrencyModel {
637        ConcurrencyModel::Concurrent {
638            max: Some(self.cfg.inner.max_connections as usize),
639        }
640    }
641}
642
643#[derive(Clone)]
644pub struct WsProducer {
645    cfg: WsClientConfig,
646}
647
648impl WsProducer {
649    pub fn new(cfg: WsClientConfig) -> Self {
650        Self { cfg }
651    }
652}
653
654impl Service<Exchange> for WsProducer {
655    type Response = Exchange;
656    type Error = CamelError;
657    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
658
659    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
660        Poll::Ready(Ok(()))
661    }
662
663    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
664        let cfg = self.cfg.clone();
665
666        Box::pin(async move {
667            let canonical_host = cfg.inner.canonical_host();
668            let key = (
669                canonical_host.clone(),
670                cfg.inner.port,
671                cfg.inner.path.clone(),
672            );
673
674            let send_to_all = exchange
675                .input
676                .header("CamelWsSendToAll")
677                .and_then(|v| v.as_bool())
678                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
679                .unwrap_or(false);
680
681            let conn_keys_header = exchange
682                .input
683                .header("CamelWsConnectionKey")
684                .and_then(|v| v.as_str())
685                .map(str::to_string);
686
687            let local_exists = global_registries().contains_key(&key);
688            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
689
690            let message_type = exchange
691                .input
692                .header("CamelWsMessageType")
693                .and_then(|v| v.as_str())
694                .unwrap_or("text")
695                .to_ascii_lowercase();
696
697            if server_send_mode {
698                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
699                let Some(registry) = registry else {
700                    return Err(CamelError::ProcessorError(format!(
701                        "WebSocket local consumer not found for {}:{}{}",
702                        canonical_host, cfg.inner.port, cfg.inner.path
703                    )));
704                };
705
706                let out_msg = body_to_axum_ws_message(
707                    std::mem::take(&mut exchange.input.body),
708                    &message_type,
709                )
710                .await?;
711
712                let targets = if send_to_all {
713                    registry.snapshot_senders()
714                } else if let Some(keys) = conn_keys_header {
715                    let parsed: Vec<String> = keys
716                        .split(',')
717                        .map(str::trim)
718                        .filter(|k| !k.is_empty())
719                        .map(str::to_string)
720                        .collect();
721                    registry.get_senders_for_keys(&parsed)
722                } else {
723                    registry.snapshot_senders()
724                };
725
726                for tx in targets {
727                    let _ = try_send_with_backpressure(&tx, out_msg.clone(), "producer-send");
728                }
729
730                return Ok(exchange);
731            }
732
733            let url = format!(
734                "{}://{}:{}{}",
735                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
736            );
737
738            #[allow(unused_mut)]
739            let mut request = url
740                .clone()
741                .into_client_request()
742                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
743
744            #[cfg(feature = "otel")]
745            {
746                let mut otel_headers = HashMap::new();
747                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
748                for (k, v) in otel_headers {
749                    if let (Ok(name), Ok(val)) = (
750                        http::header::HeaderName::from_bytes(k.as_bytes()),
751                        http::header::HeaderValue::from_str(&v),
752                    ) {
753                        request.headers_mut().insert(name, val);
754                    }
755                }
756            }
757
758            let connect_future = tokio_tungstenite::connect_async(request);
759            let (mut ws_stream, _) =
760                tokio::time::timeout(cfg.inner.connect_timeout, connect_future)
761                    .await
762                    .map_err(|_| {
763                        CamelError::ProcessorError(format!(
764                            "WebSocket connect timeout ({:?}) to {url}",
765                            cfg.inner.connect_timeout
766                        ))
767                    })?
768                    .map_err(|e| map_connect_error(e, &url))?;
769
770            let out_msg =
771                body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
772                    .await?;
773
774            ws_stream
775                .send(out_msg)
776                .await
777                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
778
779            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
780                loop {
781                    match ws_stream.next().await {
782                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
783                            continue;
784                        }
785                        other => break other,
786                    }
787                }
788            })
789            .await
790            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
791
792            match incoming {
793                Some(Ok(ClientWsMessage::Text(text))) => {
794                    exchange.input.body = CamelBody::Text(text.to_string());
795                }
796                Some(Ok(ClientWsMessage::Binary(data))) => {
797                    exchange.input.body = CamelBody::Bytes(data);
798                }
799                Some(Ok(ClientWsMessage::Close(frame))) => {
800                    let normal = frame
801                        .as_ref()
802                        .map(|f| {
803                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
804                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
805                        })
806                        .unwrap_or(true);
807
808                    if normal {
809                        exchange.input.body = CamelBody::Empty;
810                    } else {
811                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
812                        return Err(CamelError::ProcessorError(format!(
813                            "WebSocket peer closed: code {code}"
814                        )));
815                    }
816                }
817                Some(Ok(_)) | None => {
818                    exchange.input.body = CamelBody::Empty;
819                }
820                Some(Err(e)) => {
821                    return Err(CamelError::ProcessorError(format!(
822                        "WebSocket receive failed: {e}"
823                    )));
824                }
825            }
826
827            let _ = ws_stream.close(None).await;
828            Ok(exchange)
829        })
830    }
831}
832
833async fn body_to_axum_ws_message(
834    body: CamelBody,
835    message_type: &str,
836) -> Result<WsMessage, CamelError> {
837    match message_type {
838        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
839        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
840    }
841}
842
843async fn body_to_client_ws_message(
844    body: CamelBody,
845    message_type: &str,
846) -> Result<ClientWsMessage, CamelError> {
847    match message_type {
848        "binary" => Ok(ClientWsMessage::Binary(
849            body.into_bytes(10 * 1024 * 1024).await?,
850        )),
851        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
852    }
853}
854
855async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
856    Ok(match body {
857        CamelBody::Empty => String::new(),
858        CamelBody::Text(s) => s,
859        CamelBody::Xml(s) => s,
860        CamelBody::Json(v) => v.to_string(),
861        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
862        CamelBody::Stream(stream) => {
863            let bytes = CamelBody::Stream(stream)
864                .into_bytes(10 * 1024 * 1024)
865                .await?;
866            String::from_utf8_lossy(&bytes).to_string()
867        }
868    })
869}
870
871fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
872    if allowed_origin == "*" {
873        return true;
874    }
875    request_origin.is_some_and(|origin| origin == allowed_origin)
876}
877
878fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
879    match tx.try_send(msg) {
880        Ok(()) => true,
881        Err(error) => {
882            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
883            false
884        }
885    }
886}
887
888fn load_tls_config(
889    cert_path: &str,
890    key_path: &str,
891) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
892    use std::fs::File;
893    use std::io::BufReader;
894
895    let cert_file = File::open(cert_path)
896        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
897    let key_file = File::open(key_path)
898        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
899
900    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
901        .collect::<Result<Vec<_>, _>>()
902        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
903
904    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
905        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
906        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
907
908    tokio_rustls::rustls::ServerConfig::builder()
909        .with_no_client_auth()
910        .with_single_cert(certs, key)
911        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
912}
913
914fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
915    match err {
916        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
917            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
918        }
919        tungstenite::Error::Tls(_) => {
920            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
921        }
922        other => {
923            let msg = other.to_string();
924            if msg.to_lowercase().contains("connection refused") {
925                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
926            } else if msg.to_lowercase().contains("tls") {
927                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
928            } else {
929                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
930            }
931        }
932    }
933}
934
935#[cfg(test)]
936mod tests {
937    use super::*;
938    use std::time::Duration;
939
940    use tokio::sync::mpsc;
941    use tokio_tungstenite::connect_async;
942    use tokio_tungstenite::tungstenite::Message as ClientMessage;
943    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
944    use tokio_util::sync::CancellationToken;
945    use tower::ServiceExt;
946
947    fn free_port() -> u16 {
948        std::net::TcpListener::bind("127.0.0.1:0")
949            .unwrap()
950            .local_addr()
951            .unwrap()
952            .port()
953    }
954
955    #[test]
956    fn ws_component_scheme_is_ws() {
957        assert_eq!(WsComponent.scheme(), "ws");
958    }
959
960    #[test]
961    fn wss_component_scheme_is_wss() {
962        assert_eq!(WssComponent.scheme(), "wss");
963    }
964
965    #[test]
966    fn endpoint_config_defaults_match_spec() {
967        let cfg = WsEndpointConfig::default();
968        assert_eq!(cfg.scheme, "ws");
969        assert_eq!(cfg.host, "0.0.0.0");
970        assert_eq!(cfg.port, 8080);
971        assert_eq!(cfg.path, "/");
972        assert_eq!(cfg.max_connections, 100);
973        assert_eq!(cfg.max_message_size, 65536);
974        assert!(!cfg.send_to_all);
975        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
976        assert_eq!(cfg.idle_timeout, Duration::ZERO);
977        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
978        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
979        assert_eq!(cfg.allow_origin, "*");
980        assert_eq!(cfg.tls_cert, None);
981        assert_eq!(cfg.tls_key, None);
982    }
983
984    #[test]
985    fn endpoint_config_parses_uri_params() {
986        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
987        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
988
989        assert_eq!(cfg.scheme, "ws");
990        assert_eq!(cfg.host, "localhost");
991        assert_eq!(cfg.port, 9001);
992        assert_eq!(cfg.path, "/chat");
993        assert_eq!(cfg.max_connections, 42);
994        assert_eq!(cfg.max_message_size, 1024);
995        assert!(cfg.send_to_all);
996        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
997        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
998        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
999        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1000        assert_eq!(cfg.allow_origin, "https://example.com");
1001        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1002        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1003    }
1004
1005    #[test]
1006    fn endpoint_config_override_chain_uri_overrides_defaults() {
1007        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1008        assert_eq!(cfg.max_connections, 7);
1009        assert_eq!(cfg.max_message_size, 65536);
1010        assert!(!cfg.send_to_all);
1011        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1012    }
1013
1014    #[test]
1015    fn endpoint_trait_creates_consumer_and_producer() {
1016        let endpoint = WsComponent
1017            .create_endpoint("ws://127.0.0.1:9010/trait")
1018            .unwrap();
1019
1020        endpoint.create_consumer().unwrap();
1021        endpoint
1022            .create_producer(&ProducerContext::default())
1023            .unwrap();
1024    }
1025
1026    #[test]
1027    fn ws_consumer_concurrency_model_uses_max_connections() {
1028        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1029        let consumer = WsConsumer::new(cfg.server_config());
1030        assert_eq!(
1031            consumer.concurrency_model(),
1032            ConcurrencyModel::Concurrent { max: Some(321) }
1033        );
1034    }
1035
1036    #[tokio::test]
1037    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1038        let registry = WsConnectionRegistry::new();
1039        let (tx1, mut rx1) = mpsc::channel(8);
1040        let (tx2, mut rx2) = mpsc::channel(8);
1041
1042        registry.insert("k1".into(), tx1);
1043        registry.insert("k2".into(), tx2);
1044        assert_eq!(registry.len(), 2);
1045
1046        for tx in registry.snapshot_senders() {
1047            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1048        }
1049
1050        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1051        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1052
1053        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1054        assert_eq!(target.len(), 1);
1055        target[0]
1056            .send(WsMessage::Text("targeted".into()))
1057            .await
1058            .unwrap();
1059
1060        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1061        assert!(
1062            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1063                .await
1064                .is_err()
1065        );
1066
1067        registry.remove("k1");
1068        assert_eq!(registry.len(), 1);
1069    }
1070
1071    #[test]
1072    fn host_canonicalization_maps_local_hosts_to_loopback() {
1073        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1074            .unwrap()
1075            .canonical_host();
1076        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1077            .unwrap()
1078            .canonical_host();
1079        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1080            .unwrap()
1081            .canonical_host();
1082
1083        assert_eq!(c1, "127.0.0.1");
1084        assert_eq!(c2, "127.0.0.1");
1085        assert_eq!(c3, "127.0.0.1");
1086    }
1087
1088    #[tokio::test]
1089    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1090        let port = free_port();
1091        let uri = format!("ws://127.0.0.1:{port}/echo");
1092        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1093
1094        let mut consumer = endpoint.create_consumer().unwrap();
1095        let producer = endpoint
1096            .create_producer(&ProducerContext::default())
1097            .unwrap();
1098
1099        let (route_tx, mut route_rx) = mpsc::channel(16);
1100        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1101        consumer.start(ctx).await.unwrap();
1102
1103        let route_task = tokio::spawn(async move {
1104            if let Some(envelope) = route_rx.recv().await {
1105                let payload = envelope
1106                    .exchange
1107                    .input
1108                    .body
1109                    .as_text()
1110                    .unwrap_or_default()
1111                    .to_string();
1112                let key = envelope
1113                    .exchange
1114                    .input
1115                    .header("CamelWsConnectionKey")
1116                    .and_then(|v| v.as_str())
1117                    .unwrap()
1118                    .to_string();
1119
1120                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1121                response
1122                    .input
1123                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1124                producer.oneshot(response).await.unwrap();
1125            }
1126        });
1127
1128        let url = format!("ws://127.0.0.1:{port}/echo");
1129        let (mut client, _) = loop {
1130            match connect_async(&url).await {
1131                Ok(ok) => break ok,
1132                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1133            }
1134        };
1135
1136        client
1137            .send(ClientMessage::Text("hello-ws".into()))
1138            .await
1139            .unwrap();
1140
1141        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1142            loop {
1143                match client.next().await {
1144                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1145                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1146                    Some(Ok(_)) => continue,
1147                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1148                    None => panic!("websocket closed before echo"),
1149                }
1150            }
1151        })
1152        .await
1153        .unwrap();
1154
1155        assert_eq!(incoming, "hello-ws");
1156
1157        consumer.stop().await.unwrap();
1158        route_task.await.unwrap();
1159    }
1160
1161    #[tokio::test]
1162    async fn consumer_stop_sends_close_1001() {
1163        let port = free_port();
1164        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1165        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1166
1167        let mut consumer = endpoint.create_consumer().unwrap();
1168        let (route_tx, _route_rx) = mpsc::channel(16);
1169        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1170        consumer.start(ctx).await.unwrap();
1171
1172        let url = format!("ws://127.0.0.1:{port}/shutdown");
1173        let (mut client, _) = loop {
1174            match connect_async(&url).await {
1175                Ok(ok) => break ok,
1176                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1177            }
1178        };
1179
1180        client
1181            .send(ClientMessage::Text("keepalive".into()))
1182            .await
1183            .unwrap();
1184
1185        consumer.stop().await.unwrap();
1186
1187        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1188            loop {
1189                match client.next().await {
1190                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1191                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1192                    Some(Ok(_)) => continue,
1193                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1194                    None => panic!("websocket closed without close frame"),
1195                }
1196            }
1197        })
1198        .await
1199        .unwrap();
1200
1201        assert_eq!(close_code, Some(CloseCode::Away));
1202    }
1203
1204    #[test]
1205    fn wildcard_origin_allows_anything() {
1206        assert!(is_origin_allowed("*", None));
1207        assert!(is_origin_allowed("*", Some("https://example.com")));
1208    }
1209
1210    #[test]
1211    fn exact_origin_requires_match() {
1212        assert!(is_origin_allowed(
1213            "https://example.com",
1214            Some("https://example.com")
1215        ));
1216        assert!(!is_origin_allowed(
1217            "https://example.com",
1218            Some("https://other.com")
1219        ));
1220        assert!(!is_origin_allowed("https://example.com", None));
1221    }
1222
1223    #[test]
1224    fn endpoint_config_rejects_invalid_scheme() {
1225        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1226        assert!(result.is_err());
1227        let msg = result.unwrap_err().to_string();
1228        assert!(
1229            msg.contains("Invalid WebSocket scheme"),
1230            "expected scheme error, got: {msg}"
1231        );
1232    }
1233
1234    #[tokio::test]
1235    async fn wss_consumer_start_fails_without_tls_cert() {
1236        let port = free_port();
1237        let endpoint = WssComponent
1238            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"))
1239            .unwrap();
1240        let mut consumer = endpoint.create_consumer().unwrap();
1241        let (tx, _rx) = mpsc::channel(16);
1242        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1243        let result = consumer.start(ctx).await;
1244        assert!(result.is_err());
1245        let msg = result.unwrap_err().to_string();
1246        assert!(
1247            msg.contains("TLS cert path is required"),
1248            "expected TLS cert error, got: {msg}"
1249        );
1250    }
1251
1252    #[tokio::test]
1253    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1254        let port = free_port();
1255        let endpoint = WssComponent
1256            .create_endpoint(&format!(
1257                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1258            ))
1259            .unwrap();
1260        let mut consumer = endpoint.create_consumer().unwrap();
1261        let (tx, _rx) = mpsc::channel(16);
1262        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1263        let result = consumer.start(ctx).await;
1264        assert!(result.is_err());
1265        let msg = result.unwrap_err().to_string();
1266        assert!(
1267            msg.contains("TLS cert file error"),
1268            "expected cert file error, got: {msg}"
1269        );
1270    }
1271
1272    #[tokio::test]
1273    async fn server_registry_returns_same_state_for_same_port() {
1274        let port = free_port();
1275        let state1 = ServerRegistry::global()
1276            .get_or_spawn("127.0.0.1", port, None)
1277            .await
1278            .unwrap();
1279        let state2 = ServerRegistry::global()
1280            .get_or_spawn("127.0.0.1", port, None)
1281            .await
1282            .unwrap();
1283        assert!(
1284            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1285            "expected same dispatch table for same port"
1286        );
1287    }
1288
1289    #[tokio::test]
1290    async fn dispatch_handler_returns_404_for_unregistered_path() {
1291        let port = free_port();
1292        let state = ServerRegistry::global()
1293            .get_or_spawn("127.0.0.1", port, None)
1294            .await
1295            .unwrap();
1296        let app = Router::new().fallback(dispatch_handler).with_state(state);
1297        let response = tokio::time::timeout(
1298            Duration::from_secs(2),
1299            tower::ServiceExt::oneshot(
1300                app,
1301                axum::http::Request::builder()
1302                    .uri("/nonexistent")
1303                    .body(Body::empty())
1304                    .unwrap(),
1305            ),
1306        )
1307        .await
1308        .unwrap()
1309        .unwrap();
1310        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1311    }
1312
1313    #[tokio::test]
1314    async fn client_mode_producer_connects_and_echoes() {
1315        let port = free_port();
1316
1317        let app = Router::new().route(
1318            "/echo",
1319            axum::routing::get(|ws: WebSocketUpgrade| async move {
1320                ws.on_upgrade(|mut socket: WebSocket| async move {
1321                    while let Some(Ok(msg)) = socket.recv().await {
1322                        match msg {
1323                            WsMessage::Text(text) => {
1324                                let _ = socket.send(WsMessage::Text(text)).await;
1325                            }
1326                            WsMessage::Binary(data) => {
1327                                let _ = socket.send(WsMessage::Binary(data)).await;
1328                            }
1329                            WsMessage::Close(_) => break,
1330                            _ => {}
1331                        }
1332                    }
1333                })
1334            }),
1335        );
1336        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1337            .await
1338            .unwrap();
1339        let server_task = tokio::spawn(async move {
1340            let _ = serve(listener, app).await;
1341        });
1342
1343        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1344        let producer = WsProducer::new(cfg.client_config());
1345
1346        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1347        let result = loop {
1348            tokio::time::sleep(Duration::from_millis(25)).await;
1349            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1350                Ok(Ok(r)) => break r,
1351                Ok(Err(_)) => panic!("producer call failed"),
1352                Err(_) => panic!("producer call timed out"),
1353            }
1354        };
1355
1356        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1357
1358        server_task.abort();
1359    }
1360
1361    #[tokio::test]
1362    async fn max_connections_rejects_with_close_1013() {
1363        let port = free_port();
1364        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1365        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1366        let mut consumer = endpoint.create_consumer().unwrap();
1367        let (route_tx, _route_rx) = mpsc::channel(16);
1368        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1369        consumer.start(ctx).await.unwrap();
1370
1371        let url = format!("ws://127.0.0.1:{port}/limited");
1372        let (_client1, _) = loop {
1373            match connect_async(&url).await {
1374                Ok(ok) => break ok,
1375                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1376            }
1377        };
1378
1379        tokio::time::sleep(Duration::from_millis(100)).await;
1380
1381        let (mut client2, _) = connect_async(&url).await.unwrap();
1382
1383        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1384            loop {
1385                match client2.next().await {
1386                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1387                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1388                    Some(Ok(ClientMessage::Text(_))) => continue,
1389                    Some(Ok(_)) => continue,
1390                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1391                    None => panic!("client2 closed without close frame"),
1392                }
1393            }
1394        })
1395        .await
1396        .unwrap();
1397
1398        assert_eq!(
1399            close_code,
1400            Some(CloseCode::from(1013u16)),
1401            "expected 1013 (Try Again Later) for max connections"
1402        );
1403
1404        consumer.stop().await.unwrap();
1405    }
1406
1407    #[tokio::test]
1408    async fn max_message_size_rejects_with_close_1009() {
1409        let port = free_port();
1410        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1411        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1412        let mut consumer = endpoint.create_consumer().unwrap();
1413        let (route_tx, _route_rx) = mpsc::channel(16);
1414        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1415        consumer.start(ctx).await.unwrap();
1416
1417        let url = format!("ws://127.0.0.1:{port}/sizelimit");
1418        let (mut client, _) = loop {
1419            match connect_async(&url).await {
1420                Ok(ok) => break ok,
1421                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1422            }
1423        };
1424
1425        let oversized = "x".repeat(100);
1426        client
1427            .send(ClientMessage::Text(oversized.into()))
1428            .await
1429            .unwrap();
1430
1431        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1432            loop {
1433                match client.next().await {
1434                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1435                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1436                    Some(Ok(_)) => continue,
1437                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1438                    None => panic!("websocket closed without close frame"),
1439                }
1440            }
1441        })
1442        .await
1443        .unwrap();
1444
1445        assert_eq!(
1446            close_code,
1447            Some(CloseCode::from(1009u16)),
1448            "expected 1009 (Message Too Big) for oversized message"
1449        );
1450
1451        consumer.stop().await.unwrap();
1452    }
1453
1454    #[tokio::test]
1455    async fn origin_rejection_returns_403() {
1456        let port = free_port();
1457        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1458        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1459        let mut consumer = endpoint.create_consumer().unwrap();
1460        let (route_tx, _route_rx) = mpsc::channel(16);
1461        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1462        consumer.start(ctx).await.unwrap();
1463
1464        let state = ServerRegistry::global()
1465            .get_or_spawn("127.0.0.1", port, None)
1466            .await
1467            .unwrap();
1468        let app = Router::new().fallback(dispatch_handler).with_state(state);
1469
1470        let response = tokio::time::timeout(
1471            Duration::from_secs(2),
1472            tower::ServiceExt::oneshot(
1473                app,
1474                axum::http::Request::builder()
1475                    .uri("/origintest")
1476                    .header("origin", "https://evil.com")
1477                    .header("upgrade", "websocket")
1478                    .header("connection", "Upgrade")
1479                    .header("sec-websocket-version", "13")
1480                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1481                    .body(Body::empty())
1482                    .unwrap(),
1483            ),
1484        )
1485        .await
1486        .unwrap()
1487        .unwrap();
1488
1489        assert_eq!(
1490            response.status(),
1491            StatusCode::FORBIDDEN,
1492            "expected 403 for disallowed origin"
1493        );
1494
1495        consumer.stop().await.unwrap();
1496    }
1497
1498    #[tokio::test]
1499    async fn broadcast_sends_to_all_connected_clients() {
1500        let port = free_port();
1501        let uri = format!("ws://127.0.0.1:{port}/bc");
1502        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1503        let mut consumer = endpoint.create_consumer().unwrap();
1504        let producer = endpoint
1505            .create_producer(&ProducerContext::default())
1506            .unwrap();
1507
1508        let (route_tx, _route_rx) = mpsc::channel(16);
1509        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1510        consumer.start(ctx).await.unwrap();
1511
1512        let url = format!("ws://127.0.0.1:{port}/bc");
1513
1514        let (mut client1, _) = loop {
1515            match connect_async(&url).await {
1516                Ok(ok) => break ok,
1517                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1518            }
1519        };
1520
1521        let (mut client2, _) = connect_async(&url).await.unwrap();
1522
1523        tokio::time::sleep(Duration::from_millis(100)).await;
1524
1525        let mut response =
1526            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1527        response
1528            .input
1529            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1530        producer.oneshot(response).await.unwrap();
1531
1532        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1533            loop {
1534                match client1.next().await {
1535                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1536                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1537                    _ => panic!("client1 unexpected message or close"),
1538                }
1539            }
1540        })
1541        .await
1542        .unwrap();
1543
1544        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1545            loop {
1546                match client2.next().await {
1547                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1548                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1549                    _ => panic!("client2 unexpected message or close"),
1550                }
1551            }
1552        })
1553        .await
1554        .unwrap();
1555
1556        assert_eq!(recv1, "broadcast-msg");
1557        assert_eq!(recv2, "broadcast-msg");
1558
1559        consumer.stop().await.unwrap();
1560    }
1561
1562    #[tokio::test]
1563    async fn concurrent_get_or_spawn_returns_same_state() {
1564        let port = free_port();
1565        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1566            Arc::new(std::sync::Mutex::new(Vec::new()));
1567
1568        let mut handles = Vec::new();
1569        for _ in 0..4 {
1570            let results = results.clone();
1571            handles.push(tokio::spawn(async move {
1572                let state = ServerRegistry::global()
1573                    .get_or_spawn("127.0.0.1", port, None)
1574                    .await
1575                    .unwrap();
1576                results.lock().unwrap().push(state);
1577            }));
1578        }
1579
1580        for h in handles {
1581            h.await.unwrap();
1582        }
1583
1584        let states = results.lock().unwrap();
1585        assert_eq!(states.len(), 4);
1586        for i in 1..states.len() {
1587            assert!(
1588                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1589                "all concurrent callers should get the same dispatch table"
1590            );
1591        }
1592    }
1593}