Skip to main content

camel_component_ws/
lib.rs

1pub mod config;
2
3pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex, OnceLock};
7
8use async_trait::async_trait;
9use axum::body::Body;
10use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
11use axum::extract::{FromRequest, Request, State};
12use axum::http::{StatusCode, header};
13use axum::response::IntoResponse;
14use axum::{Router, serve};
15use camel_component_api::{
16    Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
17};
18use camel_component_api::{
19    Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
20    ProducerContext,
21};
22use dashmap::DashMap;
23use futures::{SinkExt, StreamExt};
24use std::future::Future;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use tokio::sync::{OnceCell, RwLock, mpsc};
28use tokio::task::JoinHandle;
29use tokio_tungstenite::tungstenite;
30use tokio_tungstenite::tungstenite::client::IntoClientRequest;
31use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
32use tower::Service;
33
34#[derive(Clone)]
35struct WsPathConfig {
36    max_connections: u32,
37    max_message_size: u32,
38    heartbeat_interval: std::time::Duration,
39    idle_timeout: std::time::Duration,
40    allow_origin: String,
41}
42
43impl Default for WsPathConfig {
44    fn default() -> Self {
45        let cfg = WsEndpointConfig::default();
46        Self {
47            max_connections: cfg.max_connections,
48            max_message_size: cfg.max_message_size,
49            heartbeat_interval: cfg.heartbeat_interval,
50            idle_timeout: cfg.idle_timeout,
51            allow_origin: cfg.allow_origin,
52        }
53    }
54}
55
56#[derive(Clone)]
57struct WsTlsConfig {
58    cert_path: String,
59    key_path: String,
60}
61
62type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
63
64struct ServerHandle {
65    state: WsAppState,
66    is_tls: bool,
67    _task: JoinHandle<()>,
68}
69
70pub struct ServerRegistry {
71    inner: Mutex<HashMap<u16, Arc<OnceCell<ServerHandle>>>>,
72}
73
74impl ServerRegistry {
75    pub fn global() -> &'static Self {
76        static REG: OnceLock<ServerRegistry> = OnceLock::new();
77        REG.get_or_init(|| Self {
78            inner: Mutex::new(HashMap::new()),
79        })
80    }
81
82    pub(crate) async fn get_or_spawn(
83        &'static self,
84        host: &str,
85        port: u16,
86        tls_config: Option<WsTlsConfig>,
87    ) -> Result<WsAppState, CamelError> {
88        let wants_tls = tls_config.is_some();
89        let host_owned = host.to_string();
90
91        let cell = {
92            let mut guard = self.inner.lock().map_err(|_| {
93                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
94            })?;
95            guard
96                .entry(port)
97                .or_insert_with(|| Arc::new(OnceCell::new()))
98                .clone()
99        };
100
101        let handle = cell
102            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
103            .await?;
104
105        if wants_tls != handle.is_tls {
106            return Err(CamelError::EndpointCreationFailed(format!(
107                "Server on port {port} already running with different TLS mode"
108            )));
109        }
110
111        Ok(handle.state.clone())
112    }
113}
114
115async fn spawn_server(
116    host: &str,
117    port: u16,
118    tls_config: Option<WsTlsConfig>,
119) -> Result<ServerHandle, CamelError> {
120    let addr = format!("{host}:{port}");
121    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
122    let path_configs = Arc::new(DashMap::new());
123    let state = WsAppState {
124        dispatch: Arc::clone(&dispatch),
125        path_configs: Arc::clone(&path_configs),
126    };
127    let app = Router::new()
128        .fallback(dispatch_handler)
129        .with_state(state.clone());
130
131    let (task, is_tls) = if let Some(ref tls) = tls_config {
132        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
133        let parsed_addr = addr.parse().map_err(|e| {
134            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
135        })?;
136        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
137        let task = tokio::spawn(async move {
138            let _ = axum_server::bind_rustls(parsed_addr, tls_cfg)
139                .serve(app.into_make_service())
140                .await;
141        });
142        (task, true)
143    } else {
144        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
145            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
146        })?;
147        let task = tokio::spawn(async move {
148            let _ = serve(listener, app).await;
149        });
150        (task, false)
151    };
152
153    Ok(ServerHandle {
154        state,
155        is_tls,
156        _task: task,
157    })
158}
159
160#[derive(Clone)]
161struct WsAppState {
162    dispatch: DispatchTable,
163    path_configs: Arc<DashMap<String, WsPathConfig>>,
164}
165
166pub struct WsConnectionRegistry {
167    connections: DashMap<String, mpsc::Sender<WsMessage>>,
168}
169
170static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
171    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
172> = OnceLock::new();
173
174fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
175    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
176}
177
178impl Default for WsConnectionRegistry {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl WsConnectionRegistry {
185    pub fn new() -> Self {
186        Self {
187            connections: DashMap::new(),
188        }
189    }
190
191    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
192        self.connections.insert(key, tx);
193    }
194
195    pub fn remove(&self, key: &str) {
196        self.connections.remove(key);
197    }
198
199    pub fn len(&self) -> usize {
200        self.connections.len()
201    }
202
203    pub fn is_empty(&self) -> bool {
204        self.connections.is_empty()
205    }
206
207    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
208        self.connections.iter().map(|e| e.value().clone()).collect()
209    }
210
211    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
212        keys.iter()
213            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
214            .collect()
215    }
216}
217
218async fn dispatch_handler(
219    State(state): State<WsAppState>,
220    req: Request<Body>,
221) -> impl IntoResponse {
222    let path = req.uri().path().to_string();
223    let origin = req
224        .headers()
225        .get(header::ORIGIN)
226        .and_then(|value| value.to_str().ok())
227        .map(str::to_string);
228    let remote_addr = req
229        .extensions()
230        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
231        .map(|ci| ci.0.to_string())
232        .unwrap_or_default();
233    let table = state.dispatch.read().await;
234    if !table.contains_key(&path) {
235        return (
236            StatusCode::NOT_FOUND,
237            "no ws endpoint registered for this path",
238        )
239            .into_response();
240    }
241    drop(table);
242
243    let path_config = state
244        .path_configs
245        .get(&path)
246        .map(|entry| entry.value().clone())
247        .unwrap_or_default();
248    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
249        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
250    }
251
252    let upgrade_headers: HashMap<String, String> = req
253        .headers()
254        .iter()
255        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
256        .collect();
257
258    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
259        Ok(ws) => ws,
260        Err(_) => {
261            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
262        }
263    };
264
265    ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
266        .into_response()
267}
268
269#[allow(unused_variables)]
270async fn ws_handler(
271    socket: WebSocket,
272    state: WsAppState,
273    path: String,
274    remote_addr: String,
275    upgrade_headers: HashMap<String, String>,
276) {
277    let connection_key = uuid::Uuid::new_v4().to_string();
278    let path_config = state
279        .path_configs
280        .get(&path)
281        .map(|entry| entry.value().clone())
282        .unwrap_or_default();
283
284    let env_tx = {
285        let table = state.dispatch.read().await;
286        table.get(&path).cloned()
287    };
288    let Some(env_tx) = env_tx else {
289        return;
290    };
291
292    let (mut sink, mut stream) = socket.split();
293    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
294
295    let registry = global_registries();
296    let mut registry_key = None;
297    for entry in registry.iter() {
298        if entry.key().2 == path {
299            entry.value().insert(connection_key.clone(), out_tx.clone());
300            registry_key = Some(entry.key().clone());
301            break;
302        }
303    }
304
305    let writer = tokio::spawn(async move {
306        while let Some(msg) = out_rx.recv().await {
307            let _ = sink.send(msg).await;
308        }
309    });
310
311    let mut over_limit = false;
312    if let Some(key) = &registry_key
313        && let Some(entry) = registry.get(key)
314        && entry.len() > path_config.max_connections as usize
315    {
316        over_limit = true;
317    }
318    if over_limit {
319        try_send_with_backpressure(
320            &out_tx,
321            WsMessage::Close(Some(CloseFrame {
322                code: CloseCode::from(1013u16),
323                reason: "max connections exceeded".into(),
324            })),
325            "max-connections-close",
326        );
327        if let Some(key) = registry_key.clone()
328            && let Some(entry) = registry.get(&key)
329        {
330            entry.remove(&connection_key);
331        }
332        drop(out_tx);
333        let _ = writer.await;
334        return;
335    }
336
337    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
338        let ping_tx = out_tx.clone();
339        let interval = path_config.heartbeat_interval;
340        Some(tokio::spawn(async move {
341            let mut ticker = tokio::time::interval(interval);
342            loop {
343                ticker.tick().await;
344                let _ = try_send_with_backpressure(
345                    &ping_tx,
346                    WsMessage::Ping(Vec::new().into()),
347                    "heartbeat-ping",
348                );
349            }
350        }))
351    } else {
352        None
353    };
354
355    loop {
356        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
357            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
358                Ok(msg) => msg,
359                Err(_) => {
360                    try_send_with_backpressure(
361                        &out_tx,
362                        WsMessage::Close(Some(CloseFrame {
363                            code: CloseCode::from(1000u16),
364                            reason: "idle timeout".into(),
365                        })),
366                        "idle-timeout-close",
367                    );
368                    break;
369                }
370            }
371        } else {
372            stream.next().await
373        };
374
375        let Some(msg) = next_msg else {
376            break;
377        };
378
379        match msg {
380            Ok(WsMessage::Text(text)) => {
381                if text.len() > path_config.max_message_size as usize {
382                    try_send_with_backpressure(
383                        &out_tx,
384                        WsMessage::Close(Some(CloseFrame {
385                            code: CloseCode::from(1009u16),
386                            reason: "message too large".into(),
387                        })),
388                        "max-message-size-close-text",
389                    );
390                    break;
391                }
392
393                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
394                message.set_header(
395                    "CamelWsConnectionKey",
396                    serde_json::Value::String(connection_key.clone()),
397                );
398                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
399                message.set_header(
400                    "CamelWsRemoteAddress",
401                    serde_json::Value::String(remote_addr.clone()),
402                );
403
404                #[allow(unused_mut)]
405                let mut exchange = Exchange::new(message);
406                #[cfg(feature = "otel")]
407                {
408                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
409                }
410                if env_tx
411                    .send(ExchangeEnvelope {
412                        exchange,
413                        reply_tx: None,
414                    })
415                    .await
416                    .is_err()
417                {
418                    break;
419                }
420            }
421            Ok(WsMessage::Binary(data)) => {
422                if data.len() > path_config.max_message_size as usize {
423                    try_send_with_backpressure(
424                        &out_tx,
425                        WsMessage::Close(Some(CloseFrame {
426                            code: CloseCode::from(1009u16),
427                            reason: "message too large".into(),
428                        })),
429                        "max-message-size-close-binary",
430                    );
431                    break;
432                }
433
434                let mut message = CamelMessage::new(CamelBody::Bytes(data));
435                message.set_header(
436                    "CamelWsConnectionKey",
437                    serde_json::Value::String(connection_key.clone()),
438                );
439                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
440                message.set_header(
441                    "CamelWsRemoteAddress",
442                    serde_json::Value::String(remote_addr.clone()),
443                );
444
445                #[allow(unused_mut)]
446                let mut exchange = Exchange::new(message);
447                #[cfg(feature = "otel")]
448                {
449                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
450                }
451                if env_tx
452                    .send(ExchangeEnvelope {
453                        exchange,
454                        reply_tx: None,
455                    })
456                    .await
457                    .is_err()
458                {
459                    break;
460                }
461            }
462            Ok(WsMessage::Close(_)) | Err(_) => break,
463            _ => {}
464        }
465    }
466
467    if let Some(task) = heartbeat_task {
468        task.abort();
469    }
470
471    if let Some(key) = registry_key
472        && let Some(entry) = registry.get(&key)
473    {
474        entry.remove(&connection_key);
475    }
476    drop(out_tx);
477    let _ = writer.await;
478}
479
480pub struct WsComponent;
481
482impl Component for WsComponent {
483    fn scheme(&self) -> &str {
484        "ws"
485    }
486
487    fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
488        let cfg = WsEndpointConfig::from_uri(uri)?;
489        Ok(Box::new(WsEndpoint {
490            uri: uri.to_string(),
491            cfg,
492        }))
493    }
494}
495
496pub struct WssComponent;
497
498impl Component for WssComponent {
499    fn scheme(&self) -> &str {
500        "wss"
501    }
502
503    fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
504        let cfg = WsEndpointConfig::from_uri(uri)?;
505        Ok(Box::new(WsEndpoint {
506            uri: uri.to_string(),
507            cfg,
508        }))
509    }
510}
511
512struct WsEndpoint {
513    uri: String,
514    cfg: WsEndpointConfig,
515}
516
517impl Endpoint for WsEndpoint {
518    fn uri(&self) -> &str {
519        &self.uri
520    }
521
522    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
523        Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
524    }
525
526    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
527        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
528    }
529}
530
531pub struct WsConsumer {
532    cfg: WsServerConfig,
533    registry: Arc<WsConnectionRegistry>,
534    server_state: Option<WsAppState>,
535    registry_key: Option<(String, u16, String)>,
536    forward_task: Option<JoinHandle<()>>,
537}
538
539impl WsConsumer {
540    pub fn new(cfg: WsServerConfig) -> Self {
541        Self {
542            cfg,
543            registry: Arc::new(WsConnectionRegistry::new()),
544            server_state: None,
545            registry_key: None,
546            forward_task: None,
547        }
548    }
549}
550
551#[async_trait]
552impl Consumer for WsConsumer {
553    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
554        let tls_config = if self.cfg.inner.scheme == "wss" {
555            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
556                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
557            })?;
558            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
559                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
560            })?;
561            Some(WsTlsConfig {
562                cert_path,
563                key_path,
564            })
565        } else {
566            None
567        };
568
569        let state = ServerRegistry::global()
570            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
571            .await?;
572
573        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
574        {
575            let mut table = state.dispatch.write().await;
576            table.insert(self.cfg.inner.path.clone(), env_tx);
577        }
578
579        state.path_configs.insert(
580            self.cfg.inner.path.clone(),
581            WsPathConfig {
582                max_connections: self.cfg.inner.max_connections,
583                max_message_size: self.cfg.inner.max_message_size,
584                heartbeat_interval: self.cfg.inner.heartbeat_interval,
585                idle_timeout: self.cfg.inner.idle_timeout,
586                allow_origin: self.cfg.inner.allow_origin.clone(),
587            },
588        );
589
590        let registry_key = (
591            self.cfg.inner.canonical_host(),
592            self.cfg.inner.port,
593            self.cfg.inner.path.clone(),
594        );
595        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
596
597        let sender = ctx.sender();
598        let forward_task = tokio::spawn(async move {
599            while let Some(envelope) = env_rx.recv().await {
600                if sender.send(envelope).await.is_err() {
601                    break;
602                }
603            }
604        });
605
606        self.server_state = Some(state);
607        self.registry_key = Some(registry_key);
608        self.forward_task = Some(forward_task);
609        Ok(())
610    }
611
612    async fn stop(&mut self) -> Result<(), CamelError> {
613        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
614            code: axum::extract::ws::CloseCode::from(1001u16),
615            reason: "consumer stopping".into(),
616        }));
617        for tx in self.registry.snapshot_senders() {
618            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
619        }
620
621        if let Some(state) = self.server_state.take() {
622            let mut table = state.dispatch.write().await;
623            table.remove(&self.cfg.inner.path);
624            state.path_configs.remove(&self.cfg.inner.path);
625        }
626
627        if let Some(key) = self.registry_key.take() {
628            global_registries().remove(&key);
629        }
630
631        if let Some(task) = self.forward_task.take() {
632            task.abort();
633        }
634
635        Ok(())
636    }
637
638    fn concurrency_model(&self) -> ConcurrencyModel {
639        ConcurrencyModel::Concurrent {
640            max: Some(self.cfg.inner.max_connections as usize),
641        }
642    }
643}
644
645#[derive(Clone)]
646pub struct WsProducer {
647    cfg: WsClientConfig,
648}
649
650impl WsProducer {
651    pub fn new(cfg: WsClientConfig) -> Self {
652        Self { cfg }
653    }
654}
655
656impl Service<Exchange> for WsProducer {
657    type Response = Exchange;
658    type Error = CamelError;
659    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
660
661    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
662        Poll::Ready(Ok(()))
663    }
664
665    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
666        let cfg = self.cfg.clone();
667
668        Box::pin(async move {
669            let canonical_host = cfg.inner.canonical_host();
670            let key = (
671                canonical_host.clone(),
672                cfg.inner.port,
673                cfg.inner.path.clone(),
674            );
675
676            let send_to_all = exchange
677                .input
678                .header("CamelWsSendToAll")
679                .and_then(|v| v.as_bool())
680                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
681                .unwrap_or(false);
682
683            let conn_keys_header = exchange
684                .input
685                .header("CamelWsConnectionKey")
686                .and_then(|v| v.as_str())
687                .map(str::to_string);
688
689            let local_exists = global_registries().contains_key(&key);
690            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
691
692            let message_type = exchange
693                .input
694                .header("CamelWsMessageType")
695                .and_then(|v| v.as_str())
696                .unwrap_or("text")
697                .to_ascii_lowercase();
698
699            if server_send_mode {
700                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
701                let Some(registry) = registry else {
702                    return Err(CamelError::ProcessorError(format!(
703                        "WebSocket local consumer not found for {}:{}{}",
704                        canonical_host, cfg.inner.port, cfg.inner.path
705                    )));
706                };
707
708                let out_msg = body_to_axum_ws_message(
709                    std::mem::take(&mut exchange.input.body),
710                    &message_type,
711                )
712                .await?;
713
714                let targets = if send_to_all {
715                    registry.snapshot_senders()
716                } else if let Some(keys) = conn_keys_header {
717                    let parsed: Vec<String> = keys
718                        .split(',')
719                        .map(str::trim)
720                        .filter(|k| !k.is_empty())
721                        .map(str::to_string)
722                        .collect();
723                    registry.get_senders_for_keys(&parsed)
724                } else {
725                    registry.snapshot_senders()
726                };
727
728                for tx in targets {
729                    let _ = try_send_with_backpressure(&tx, out_msg.clone(), "producer-send");
730                }
731
732                return Ok(exchange);
733            }
734
735            let url = format!(
736                "{}://{}:{}{}",
737                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
738            );
739
740            #[allow(unused_mut)]
741            let mut request = url
742                .clone()
743                .into_client_request()
744                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
745
746            #[cfg(feature = "otel")]
747            {
748                let mut otel_headers = HashMap::new();
749                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
750                for (k, v) in otel_headers {
751                    if let (Ok(name), Ok(val)) = (
752                        http::header::HeaderName::from_bytes(k.as_bytes()),
753                        http::header::HeaderValue::from_str(&v),
754                    ) {
755                        request.headers_mut().insert(name, val);
756                    }
757                }
758            }
759
760            let connect_future = tokio_tungstenite::connect_async(request);
761            let (mut ws_stream, _) =
762                tokio::time::timeout(cfg.inner.connect_timeout, connect_future)
763                    .await
764                    .map_err(|_| {
765                        CamelError::ProcessorError(format!(
766                            "WebSocket connect timeout ({:?}) to {url}",
767                            cfg.inner.connect_timeout
768                        ))
769                    })?
770                    .map_err(|e| map_connect_error(e, &url))?;
771
772            let out_msg =
773                body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
774                    .await?;
775
776            ws_stream
777                .send(out_msg)
778                .await
779                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
780
781            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
782                loop {
783                    match ws_stream.next().await {
784                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
785                            continue;
786                        }
787                        other => break other,
788                    }
789                }
790            })
791            .await
792            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
793
794            match incoming {
795                Some(Ok(ClientWsMessage::Text(text))) => {
796                    exchange.input.body = CamelBody::Text(text.to_string());
797                }
798                Some(Ok(ClientWsMessage::Binary(data))) => {
799                    exchange.input.body = CamelBody::Bytes(data);
800                }
801                Some(Ok(ClientWsMessage::Close(frame))) => {
802                    let normal = frame
803                        .as_ref()
804                        .map(|f| {
805                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
806                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
807                        })
808                        .unwrap_or(true);
809
810                    if normal {
811                        exchange.input.body = CamelBody::Empty;
812                    } else {
813                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
814                        return Err(CamelError::ProcessorError(format!(
815                            "WebSocket peer closed: code {code}"
816                        )));
817                    }
818                }
819                Some(Ok(_)) | None => {
820                    exchange.input.body = CamelBody::Empty;
821                }
822                Some(Err(e)) => {
823                    return Err(CamelError::ProcessorError(format!(
824                        "WebSocket receive failed: {e}"
825                    )));
826                }
827            }
828
829            let _ = ws_stream.close(None).await;
830            Ok(exchange)
831        })
832    }
833}
834
835async fn body_to_axum_ws_message(
836    body: CamelBody,
837    message_type: &str,
838) -> Result<WsMessage, CamelError> {
839    match message_type {
840        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
841        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
842    }
843}
844
845async fn body_to_client_ws_message(
846    body: CamelBody,
847    message_type: &str,
848) -> Result<ClientWsMessage, CamelError> {
849    match message_type {
850        "binary" => Ok(ClientWsMessage::Binary(
851            body.into_bytes(10 * 1024 * 1024).await?,
852        )),
853        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
854    }
855}
856
857async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
858    Ok(match body {
859        CamelBody::Empty => String::new(),
860        CamelBody::Text(s) => s,
861        CamelBody::Xml(s) => s,
862        CamelBody::Json(v) => v.to_string(),
863        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
864        CamelBody::Stream(stream) => {
865            let bytes = CamelBody::Stream(stream)
866                .into_bytes(10 * 1024 * 1024)
867                .await?;
868            String::from_utf8_lossy(&bytes).to_string()
869        }
870    })
871}
872
873fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
874    if allowed_origin == "*" {
875        return true;
876    }
877    request_origin.is_some_and(|origin| origin == allowed_origin)
878}
879
880fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
881    match tx.try_send(msg) {
882        Ok(()) => true,
883        Err(error) => {
884            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
885            false
886        }
887    }
888}
889
890fn load_tls_config(
891    cert_path: &str,
892    key_path: &str,
893) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
894    use std::fs::File;
895    use std::io::BufReader;
896
897    let cert_file = File::open(cert_path)
898        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
899    let key_file = File::open(key_path)
900        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
901
902    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
903        .collect::<Result<Vec<_>, _>>()
904        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
905
906    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
907        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
908        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
909
910    tokio_rustls::rustls::ServerConfig::builder()
911        .with_no_client_auth()
912        .with_single_cert(certs, key)
913        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
914}
915
916fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
917    match err {
918        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
919            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
920        }
921        tungstenite::Error::Tls(_) => {
922            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
923        }
924        other => {
925            let msg = other.to_string();
926            if msg.to_lowercase().contains("connection refused") {
927                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
928            } else if msg.to_lowercase().contains("tls") {
929                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
930            } else {
931                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
932            }
933        }
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use std::time::Duration;
941
942    use tokio::sync::mpsc;
943    use tokio_tungstenite::connect_async;
944    use tokio_tungstenite::tungstenite::Message as ClientMessage;
945    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
946    use tokio_util::sync::CancellationToken;
947    use tower::ServiceExt;
948
949    fn free_port() -> u16 {
950        std::net::TcpListener::bind("127.0.0.1:0")
951            .unwrap()
952            .local_addr()
953            .unwrap()
954            .port()
955    }
956
957    #[test]
958    fn ws_component_scheme_is_ws() {
959        assert_eq!(WsComponent.scheme(), "ws");
960    }
961
962    #[test]
963    fn wss_component_scheme_is_wss() {
964        assert_eq!(WssComponent.scheme(), "wss");
965    }
966
967    #[test]
968    fn endpoint_config_defaults_match_spec() {
969        let cfg = WsEndpointConfig::default();
970        assert_eq!(cfg.scheme, "ws");
971        assert_eq!(cfg.host, "0.0.0.0");
972        assert_eq!(cfg.port, 8080);
973        assert_eq!(cfg.path, "/");
974        assert_eq!(cfg.max_connections, 100);
975        assert_eq!(cfg.max_message_size, 65536);
976        assert!(!cfg.send_to_all);
977        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
978        assert_eq!(cfg.idle_timeout, Duration::ZERO);
979        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
980        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
981        assert_eq!(cfg.allow_origin, "*");
982        assert_eq!(cfg.tls_cert, None);
983        assert_eq!(cfg.tls_key, None);
984    }
985
986    #[test]
987    fn endpoint_config_parses_uri_params() {
988        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
989        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
990
991        assert_eq!(cfg.scheme, "ws");
992        assert_eq!(cfg.host, "localhost");
993        assert_eq!(cfg.port, 9001);
994        assert_eq!(cfg.path, "/chat");
995        assert_eq!(cfg.max_connections, 42);
996        assert_eq!(cfg.max_message_size, 1024);
997        assert!(cfg.send_to_all);
998        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
999        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1000        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1001        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1002        assert_eq!(cfg.allow_origin, "https://example.com");
1003        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1004        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1005    }
1006
1007    #[test]
1008    fn endpoint_config_override_chain_uri_overrides_defaults() {
1009        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1010        assert_eq!(cfg.max_connections, 7);
1011        assert_eq!(cfg.max_message_size, 65536);
1012        assert!(!cfg.send_to_all);
1013        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1014    }
1015
1016    #[test]
1017    fn endpoint_trait_creates_consumer_and_producer() {
1018        let endpoint = WsComponent
1019            .create_endpoint("ws://127.0.0.1:9010/trait")
1020            .unwrap();
1021
1022        endpoint.create_consumer().unwrap();
1023        endpoint
1024            .create_producer(&ProducerContext::default())
1025            .unwrap();
1026    }
1027
1028    #[test]
1029    fn ws_consumer_concurrency_model_uses_max_connections() {
1030        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1031        let consumer = WsConsumer::new(cfg.server_config());
1032        assert_eq!(
1033            consumer.concurrency_model(),
1034            ConcurrencyModel::Concurrent { max: Some(321) }
1035        );
1036    }
1037
1038    #[tokio::test]
1039    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1040        let registry = WsConnectionRegistry::new();
1041        let (tx1, mut rx1) = mpsc::channel(8);
1042        let (tx2, mut rx2) = mpsc::channel(8);
1043
1044        registry.insert("k1".into(), tx1);
1045        registry.insert("k2".into(), tx2);
1046        assert_eq!(registry.len(), 2);
1047
1048        for tx in registry.snapshot_senders() {
1049            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1050        }
1051
1052        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1053        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1054
1055        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1056        assert_eq!(target.len(), 1);
1057        target[0]
1058            .send(WsMessage::Text("targeted".into()))
1059            .await
1060            .unwrap();
1061
1062        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1063        assert!(
1064            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1065                .await
1066                .is_err()
1067        );
1068
1069        registry.remove("k1");
1070        assert_eq!(registry.len(), 1);
1071    }
1072
1073    #[test]
1074    fn host_canonicalization_maps_local_hosts_to_loopback() {
1075        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1076            .unwrap()
1077            .canonical_host();
1078        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1079            .unwrap()
1080            .canonical_host();
1081        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1082            .unwrap()
1083            .canonical_host();
1084
1085        assert_eq!(c1, "127.0.0.1");
1086        assert_eq!(c2, "127.0.0.1");
1087        assert_eq!(c3, "127.0.0.1");
1088    }
1089
1090    #[tokio::test]
1091    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1092        let port = free_port();
1093        let uri = format!("ws://127.0.0.1:{port}/echo");
1094        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1095
1096        let mut consumer = endpoint.create_consumer().unwrap();
1097        let producer = endpoint
1098            .create_producer(&ProducerContext::default())
1099            .unwrap();
1100
1101        let (route_tx, mut route_rx) = mpsc::channel(16);
1102        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1103        consumer.start(ctx).await.unwrap();
1104
1105        let route_task = tokio::spawn(async move {
1106            if let Some(envelope) = route_rx.recv().await {
1107                let payload = envelope
1108                    .exchange
1109                    .input
1110                    .body
1111                    .as_text()
1112                    .unwrap_or_default()
1113                    .to_string();
1114                let key = envelope
1115                    .exchange
1116                    .input
1117                    .header("CamelWsConnectionKey")
1118                    .and_then(|v| v.as_str())
1119                    .unwrap()
1120                    .to_string();
1121
1122                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1123                response
1124                    .input
1125                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1126                producer.oneshot(response).await.unwrap();
1127            }
1128        });
1129
1130        let url = format!("ws://127.0.0.1:{port}/echo");
1131        let (mut client, _) = loop {
1132            match connect_async(&url).await {
1133                Ok(ok) => break ok,
1134                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1135            }
1136        };
1137
1138        client
1139            .send(ClientMessage::Text("hello-ws".into()))
1140            .await
1141            .unwrap();
1142
1143        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1144            loop {
1145                match client.next().await {
1146                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1147                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1148                    Some(Ok(_)) => continue,
1149                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1150                    None => panic!("websocket closed before echo"),
1151                }
1152            }
1153        })
1154        .await
1155        .unwrap();
1156
1157        assert_eq!(incoming, "hello-ws");
1158
1159        consumer.stop().await.unwrap();
1160        route_task.await.unwrap();
1161    }
1162
1163    #[tokio::test]
1164    async fn consumer_stop_sends_close_1001() {
1165        let port = free_port();
1166        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1167        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1168
1169        let mut consumer = endpoint.create_consumer().unwrap();
1170        let (route_tx, _route_rx) = mpsc::channel(16);
1171        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1172        consumer.start(ctx).await.unwrap();
1173
1174        let url = format!("ws://127.0.0.1:{port}/shutdown");
1175        let (mut client, _) = loop {
1176            match connect_async(&url).await {
1177                Ok(ok) => break ok,
1178                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1179            }
1180        };
1181
1182        client
1183            .send(ClientMessage::Text("keepalive".into()))
1184            .await
1185            .unwrap();
1186
1187        consumer.stop().await.unwrap();
1188
1189        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1190            loop {
1191                match client.next().await {
1192                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1193                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1194                    Some(Ok(_)) => continue,
1195                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1196                    None => panic!("websocket closed without close frame"),
1197                }
1198            }
1199        })
1200        .await
1201        .unwrap();
1202
1203        assert_eq!(close_code, Some(CloseCode::Away));
1204    }
1205
1206    #[test]
1207    fn wildcard_origin_allows_anything() {
1208        assert!(is_origin_allowed("*", None));
1209        assert!(is_origin_allowed("*", Some("https://example.com")));
1210    }
1211
1212    #[test]
1213    fn exact_origin_requires_match() {
1214        assert!(is_origin_allowed(
1215            "https://example.com",
1216            Some("https://example.com")
1217        ));
1218        assert!(!is_origin_allowed(
1219            "https://example.com",
1220            Some("https://other.com")
1221        ));
1222        assert!(!is_origin_allowed("https://example.com", None));
1223    }
1224
1225    #[test]
1226    fn endpoint_config_rejects_invalid_scheme() {
1227        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1228        assert!(result.is_err());
1229        let msg = result.unwrap_err().to_string();
1230        assert!(
1231            msg.contains("Invalid WebSocket scheme"),
1232            "expected scheme error, got: {msg}"
1233        );
1234    }
1235
1236    #[tokio::test]
1237    async fn wss_consumer_start_fails_without_tls_cert() {
1238        let port = free_port();
1239        let endpoint = WssComponent
1240            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"))
1241            .unwrap();
1242        let mut consumer = endpoint.create_consumer().unwrap();
1243        let (tx, _rx) = mpsc::channel(16);
1244        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1245        let result = consumer.start(ctx).await;
1246        assert!(result.is_err());
1247        let msg = result.unwrap_err().to_string();
1248        assert!(
1249            msg.contains("TLS cert path is required"),
1250            "expected TLS cert error, got: {msg}"
1251        );
1252    }
1253
1254    #[tokio::test]
1255    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1256        let port = free_port();
1257        let endpoint = WssComponent
1258            .create_endpoint(&format!(
1259                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1260            ))
1261            .unwrap();
1262        let mut consumer = endpoint.create_consumer().unwrap();
1263        let (tx, _rx) = mpsc::channel(16);
1264        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1265        let result = consumer.start(ctx).await;
1266        assert!(result.is_err());
1267        let msg = result.unwrap_err().to_string();
1268        assert!(
1269            msg.contains("TLS cert file error"),
1270            "expected cert file error, got: {msg}"
1271        );
1272    }
1273
1274    #[tokio::test]
1275    async fn server_registry_returns_same_state_for_same_port() {
1276        let port = free_port();
1277        let state1 = ServerRegistry::global()
1278            .get_or_spawn("127.0.0.1", port, None)
1279            .await
1280            .unwrap();
1281        let state2 = ServerRegistry::global()
1282            .get_or_spawn("127.0.0.1", port, None)
1283            .await
1284            .unwrap();
1285        assert!(
1286            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1287            "expected same dispatch table for same port"
1288        );
1289    }
1290
1291    #[tokio::test]
1292    async fn dispatch_handler_returns_404_for_unregistered_path() {
1293        let port = free_port();
1294        let state = ServerRegistry::global()
1295            .get_or_spawn("127.0.0.1", port, None)
1296            .await
1297            .unwrap();
1298        let app = Router::new().fallback(dispatch_handler).with_state(state);
1299        let response = tokio::time::timeout(
1300            Duration::from_secs(2),
1301            tower::ServiceExt::oneshot(
1302                app,
1303                axum::http::Request::builder()
1304                    .uri("/nonexistent")
1305                    .body(Body::empty())
1306                    .unwrap(),
1307            ),
1308        )
1309        .await
1310        .unwrap()
1311        .unwrap();
1312        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1313    }
1314
1315    #[tokio::test]
1316    async fn client_mode_producer_connects_and_echoes() {
1317        let port = free_port();
1318
1319        let app = Router::new().route(
1320            "/echo",
1321            axum::routing::get(|ws: WebSocketUpgrade| async move {
1322                ws.on_upgrade(|mut socket: WebSocket| async move {
1323                    while let Some(Ok(msg)) = socket.recv().await {
1324                        match msg {
1325                            WsMessage::Text(text) => {
1326                                let _ = socket.send(WsMessage::Text(text)).await;
1327                            }
1328                            WsMessage::Binary(data) => {
1329                                let _ = socket.send(WsMessage::Binary(data)).await;
1330                            }
1331                            WsMessage::Close(_) => break,
1332                            _ => {}
1333                        }
1334                    }
1335                })
1336            }),
1337        );
1338        let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1339            .await
1340            .unwrap();
1341        let server_task = tokio::spawn(async move {
1342            let _ = serve(listener, app).await;
1343        });
1344
1345        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1346        let producer = WsProducer::new(cfg.client_config());
1347
1348        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1349        tokio::time::sleep(Duration::from_millis(25)).await;
1350        let result =
1351            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1352                Ok(Ok(r)) => r,
1353                Ok(Err(_)) => panic!("producer call failed"),
1354                Err(_) => panic!("producer call timed out"),
1355            };
1356
1357        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1358
1359        server_task.abort();
1360    }
1361
1362    #[tokio::test]
1363    async fn max_connections_rejects_with_close_1013() {
1364        let port = free_port();
1365        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1366        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1367        let mut consumer = endpoint.create_consumer().unwrap();
1368        let (route_tx, _route_rx) = mpsc::channel(16);
1369        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1370        consumer.start(ctx).await.unwrap();
1371
1372        let url = format!("ws://127.0.0.1:{port}/limited");
1373        let (_client1, _) = loop {
1374            match connect_async(&url).await {
1375                Ok(ok) => break ok,
1376                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1377            }
1378        };
1379
1380        tokio::time::sleep(Duration::from_millis(100)).await;
1381
1382        let (mut client2, _) = connect_async(&url).await.unwrap();
1383
1384        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1385            loop {
1386                match client2.next().await {
1387                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1388                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1389                    Some(Ok(ClientMessage::Text(_))) => continue,
1390                    Some(Ok(_)) => continue,
1391                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1392                    None => panic!("client2 closed without close frame"),
1393                }
1394            }
1395        })
1396        .await
1397        .unwrap();
1398
1399        assert_eq!(
1400            close_code,
1401            Some(CloseCode::from(1013u16)),
1402            "expected 1013 (Try Again Later) for max connections"
1403        );
1404
1405        consumer.stop().await.unwrap();
1406    }
1407
1408    #[tokio::test]
1409    async fn max_message_size_rejects_with_close_1009() {
1410        let port = free_port();
1411        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1412        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1413        let mut consumer = endpoint.create_consumer().unwrap();
1414        let (route_tx, _route_rx) = mpsc::channel(16);
1415        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1416        consumer.start(ctx).await.unwrap();
1417
1418        let url = format!("ws://127.0.0.1:{port}/sizelimit");
1419        let (mut client, _) = loop {
1420            match connect_async(&url).await {
1421                Ok(ok) => break ok,
1422                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1423            }
1424        };
1425
1426        let oversized = "x".repeat(100);
1427        client
1428            .send(ClientMessage::Text(oversized.into()))
1429            .await
1430            .unwrap();
1431
1432        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1433            loop {
1434                match client.next().await {
1435                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1436                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1437                    Some(Ok(_)) => continue,
1438                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1439                    None => panic!("websocket closed without close frame"),
1440                }
1441            }
1442        })
1443        .await
1444        .unwrap();
1445
1446        assert_eq!(
1447            close_code,
1448            Some(CloseCode::from(1009u16)),
1449            "expected 1009 (Message Too Big) for oversized message"
1450        );
1451
1452        consumer.stop().await.unwrap();
1453    }
1454
1455    #[tokio::test]
1456    async fn origin_rejection_returns_403() {
1457        let port = free_port();
1458        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1459        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1460        let mut consumer = endpoint.create_consumer().unwrap();
1461        let (route_tx, _route_rx) = mpsc::channel(16);
1462        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1463        consumer.start(ctx).await.unwrap();
1464
1465        let state = ServerRegistry::global()
1466            .get_or_spawn("127.0.0.1", port, None)
1467            .await
1468            .unwrap();
1469        let app = Router::new().fallback(dispatch_handler).with_state(state);
1470
1471        let response = tokio::time::timeout(
1472            Duration::from_secs(2),
1473            tower::ServiceExt::oneshot(
1474                app,
1475                axum::http::Request::builder()
1476                    .uri("/origintest")
1477                    .header("origin", "https://evil.com")
1478                    .header("upgrade", "websocket")
1479                    .header("connection", "Upgrade")
1480                    .header("sec-websocket-version", "13")
1481                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1482                    .body(Body::empty())
1483                    .unwrap(),
1484            ),
1485        )
1486        .await
1487        .unwrap()
1488        .unwrap();
1489
1490        assert_eq!(
1491            response.status(),
1492            StatusCode::FORBIDDEN,
1493            "expected 403 for disallowed origin"
1494        );
1495
1496        consumer.stop().await.unwrap();
1497    }
1498
1499    #[tokio::test]
1500    async fn broadcast_sends_to_all_connected_clients() {
1501        let port = free_port();
1502        let uri = format!("ws://127.0.0.1:{port}/bc");
1503        let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1504        let mut consumer = endpoint.create_consumer().unwrap();
1505        let producer = endpoint
1506            .create_producer(&ProducerContext::default())
1507            .unwrap();
1508
1509        let (route_tx, _route_rx) = mpsc::channel(16);
1510        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1511        consumer.start(ctx).await.unwrap();
1512
1513        let url = format!("ws://127.0.0.1:{port}/bc");
1514
1515        let (mut client1, _) = loop {
1516            match connect_async(&url).await {
1517                Ok(ok) => break ok,
1518                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1519            }
1520        };
1521
1522        let (mut client2, _) = connect_async(&url).await.unwrap();
1523
1524        tokio::time::sleep(Duration::from_millis(100)).await;
1525
1526        let mut response =
1527            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1528        response
1529            .input
1530            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1531        producer.oneshot(response).await.unwrap();
1532
1533        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1534            loop {
1535                match client1.next().await {
1536                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1537                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1538                    _ => panic!("client1 unexpected message or close"),
1539                }
1540            }
1541        })
1542        .await
1543        .unwrap();
1544
1545        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1546            loop {
1547                match client2.next().await {
1548                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1549                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1550                    _ => panic!("client2 unexpected message or close"),
1551                }
1552            }
1553        })
1554        .await
1555        .unwrap();
1556
1557        assert_eq!(recv1, "broadcast-msg");
1558        assert_eq!(recv2, "broadcast-msg");
1559
1560        consumer.stop().await.unwrap();
1561    }
1562
1563    #[tokio::test]
1564    async fn concurrent_get_or_spawn_returns_same_state() {
1565        let port = free_port();
1566        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1567            Arc::new(std::sync::Mutex::new(Vec::new()));
1568
1569        let mut handles = Vec::new();
1570        for _ in 0..4 {
1571            let results = results.clone();
1572            handles.push(tokio::spawn(async move {
1573                let state = ServerRegistry::global()
1574                    .get_or_spawn("127.0.0.1", port, None)
1575                    .await
1576                    .unwrap();
1577                results.lock().unwrap().push(state);
1578            }));
1579        }
1580
1581        for h in handles {
1582            h.await.unwrap();
1583        }
1584
1585        let states = results.lock().unwrap();
1586        assert_eq!(states.len(), 4);
1587        for i in 1..states.len() {
1588            assert!(
1589                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1590                "all concurrent callers should get the same dispatch table"
1591            );
1592        }
1593    }
1594}