1pub mod config;
2
3pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex, OnceLock};
7
8use async_trait::async_trait;
9use axum::body::Body;
10use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
11use axum::extract::{FromRequest, Request, State};
12use axum::http::{StatusCode, header};
13use axum::response::IntoResponse;
14use axum::{Router, serve};
15use camel_api::{Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage};
16use camel_component::{
17 Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
18 ProducerContext,
19};
20use dashmap::DashMap;
21use futures::{SinkExt, StreamExt};
22use std::future::Future;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25use tokio::sync::{OnceCell, RwLock, mpsc};
26use tokio::task::JoinHandle;
27use tokio_tungstenite::tungstenite;
28use tokio_tungstenite::tungstenite::client::IntoClientRequest;
29use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
30use tower::Service;
31
32#[derive(Clone)]
33struct WsPathConfig {
34 max_connections: u32,
35 max_message_size: u32,
36 heartbeat_interval: std::time::Duration,
37 idle_timeout: std::time::Duration,
38 allow_origin: String,
39}
40
41impl Default for WsPathConfig {
42 fn default() -> Self {
43 let cfg = WsEndpointConfig::default();
44 Self {
45 max_connections: cfg.max_connections,
46 max_message_size: cfg.max_message_size,
47 heartbeat_interval: cfg.heartbeat_interval,
48 idle_timeout: cfg.idle_timeout,
49 allow_origin: cfg.allow_origin,
50 }
51 }
52}
53
54#[derive(Clone)]
55struct WsTlsConfig {
56 cert_path: String,
57 key_path: String,
58}
59
60type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
61
62struct ServerHandle {
63 state: WsAppState,
64 is_tls: bool,
65 _task: JoinHandle<()>,
66}
67
68pub struct ServerRegistry {
69 inner: Mutex<HashMap<u16, Arc<OnceCell<ServerHandle>>>>,
70}
71
72impl ServerRegistry {
73 pub fn global() -> &'static Self {
74 static REG: OnceLock<ServerRegistry> = OnceLock::new();
75 REG.get_or_init(|| Self {
76 inner: Mutex::new(HashMap::new()),
77 })
78 }
79
80 pub(crate) async fn get_or_spawn(
81 &'static self,
82 host: &str,
83 port: u16,
84 tls_config: Option<WsTlsConfig>,
85 ) -> Result<WsAppState, CamelError> {
86 let wants_tls = tls_config.is_some();
87 let host_owned = host.to_string();
88
89 let cell = {
90 let mut guard = self.inner.lock().map_err(|_| {
91 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
92 })?;
93 guard
94 .entry(port)
95 .or_insert_with(|| Arc::new(OnceCell::new()))
96 .clone()
97 };
98
99 let handle = cell
100 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
101 .await?;
102
103 if wants_tls != handle.is_tls {
104 return Err(CamelError::EndpointCreationFailed(format!(
105 "Server on port {port} already running with different TLS mode"
106 )));
107 }
108
109 Ok(handle.state.clone())
110 }
111}
112
113async fn spawn_server(
114 host: &str,
115 port: u16,
116 tls_config: Option<WsTlsConfig>,
117) -> Result<ServerHandle, CamelError> {
118 let addr = format!("{host}:{port}");
119 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
120 let path_configs = Arc::new(DashMap::new());
121 let state = WsAppState {
122 dispatch: Arc::clone(&dispatch),
123 path_configs: Arc::clone(&path_configs),
124 };
125 let app = Router::new()
126 .fallback(dispatch_handler)
127 .with_state(state.clone());
128
129 let (task, is_tls) = if let Some(ref tls) = tls_config {
130 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
131 let parsed_addr = addr.parse().map_err(|e| {
132 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
133 })?;
134 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
135 let task = tokio::spawn(async move {
136 let _ = axum_server::bind_rustls(parsed_addr, tls_cfg)
137 .serve(app.into_make_service())
138 .await;
139 });
140 (task, true)
141 } else {
142 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
143 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
144 })?;
145 let task = tokio::spawn(async move {
146 let _ = serve(listener, app).await;
147 });
148 (task, false)
149 };
150
151 Ok(ServerHandle {
152 state,
153 is_tls,
154 _task: task,
155 })
156}
157
158#[derive(Clone)]
159struct WsAppState {
160 dispatch: DispatchTable,
161 path_configs: Arc<DashMap<String, WsPathConfig>>,
162}
163
164pub struct WsConnectionRegistry {
165 connections: DashMap<String, mpsc::Sender<WsMessage>>,
166}
167
168static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
169 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
170> = OnceLock::new();
171
172fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
173 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
174}
175
176impl Default for WsConnectionRegistry {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl WsConnectionRegistry {
183 pub fn new() -> Self {
184 Self {
185 connections: DashMap::new(),
186 }
187 }
188
189 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
190 self.connections.insert(key, tx);
191 }
192
193 pub fn remove(&self, key: &str) {
194 self.connections.remove(key);
195 }
196
197 pub fn len(&self) -> usize {
198 self.connections.len()
199 }
200
201 pub fn is_empty(&self) -> bool {
202 self.connections.is_empty()
203 }
204
205 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
206 self.connections.iter().map(|e| e.value().clone()).collect()
207 }
208
209 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
210 keys.iter()
211 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
212 .collect()
213 }
214}
215
216async fn dispatch_handler(
217 State(state): State<WsAppState>,
218 req: Request<Body>,
219) -> impl IntoResponse {
220 let path = req.uri().path().to_string();
221 let origin = req
222 .headers()
223 .get(header::ORIGIN)
224 .and_then(|value| value.to_str().ok())
225 .map(str::to_string);
226 let remote_addr = req
227 .extensions()
228 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
229 .map(|ci| ci.0.to_string())
230 .unwrap_or_default();
231 let table = state.dispatch.read().await;
232 if !table.contains_key(&path) {
233 return (
234 StatusCode::NOT_FOUND,
235 "no ws endpoint registered for this path",
236 )
237 .into_response();
238 }
239 drop(table);
240
241 let path_config = state
242 .path_configs
243 .get(&path)
244 .map(|entry| entry.value().clone())
245 .unwrap_or_default();
246 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
247 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
248 }
249
250 let upgrade_headers: HashMap<String, String> = req
251 .headers()
252 .iter()
253 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
254 .collect();
255
256 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
257 Ok(ws) => ws,
258 Err(_) => {
259 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
260 }
261 };
262
263 ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
264 .into_response()
265}
266
267#[allow(unused_variables)]
268async fn ws_handler(
269 socket: WebSocket,
270 state: WsAppState,
271 path: String,
272 remote_addr: String,
273 upgrade_headers: HashMap<String, String>,
274) {
275 let connection_key = uuid::Uuid::new_v4().to_string();
276 let path_config = state
277 .path_configs
278 .get(&path)
279 .map(|entry| entry.value().clone())
280 .unwrap_or_default();
281
282 let env_tx = {
283 let table = state.dispatch.read().await;
284 table.get(&path).cloned()
285 };
286 let Some(env_tx) = env_tx else {
287 return;
288 };
289
290 let (mut sink, mut stream) = socket.split();
291 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
292
293 let registry = global_registries();
294 let mut registry_key = None;
295 for entry in registry.iter() {
296 if entry.key().2 == path {
297 entry.value().insert(connection_key.clone(), out_tx.clone());
298 registry_key = Some(entry.key().clone());
299 break;
300 }
301 }
302
303 let writer = tokio::spawn(async move {
304 while let Some(msg) = out_rx.recv().await {
305 let _ = sink.send(msg).await;
306 }
307 });
308
309 let mut over_limit = false;
310 if let Some(key) = ®istry_key
311 && let Some(entry) = registry.get(key)
312 && entry.len() > path_config.max_connections as usize
313 {
314 over_limit = true;
315 }
316 if over_limit {
317 try_send_with_backpressure(
318 &out_tx,
319 WsMessage::Close(Some(CloseFrame {
320 code: CloseCode::from(1013u16),
321 reason: "max connections exceeded".into(),
322 })),
323 "max-connections-close",
324 );
325 if let Some(key) = registry_key.clone()
326 && let Some(entry) = registry.get(&key)
327 {
328 entry.remove(&connection_key);
329 }
330 drop(out_tx);
331 let _ = writer.await;
332 return;
333 }
334
335 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
336 let ping_tx = out_tx.clone();
337 let interval = path_config.heartbeat_interval;
338 Some(tokio::spawn(async move {
339 let mut ticker = tokio::time::interval(interval);
340 loop {
341 ticker.tick().await;
342 let _ = try_send_with_backpressure(
343 &ping_tx,
344 WsMessage::Ping(Vec::new().into()),
345 "heartbeat-ping",
346 );
347 }
348 }))
349 } else {
350 None
351 };
352
353 loop {
354 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
355 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
356 Ok(msg) => msg,
357 Err(_) => {
358 try_send_with_backpressure(
359 &out_tx,
360 WsMessage::Close(Some(CloseFrame {
361 code: CloseCode::from(1000u16),
362 reason: "idle timeout".into(),
363 })),
364 "idle-timeout-close",
365 );
366 break;
367 }
368 }
369 } else {
370 stream.next().await
371 };
372
373 let Some(msg) = next_msg else {
374 break;
375 };
376
377 match msg {
378 Ok(WsMessage::Text(text)) => {
379 if text.len() > path_config.max_message_size as usize {
380 try_send_with_backpressure(
381 &out_tx,
382 WsMessage::Close(Some(CloseFrame {
383 code: CloseCode::from(1009u16),
384 reason: "message too large".into(),
385 })),
386 "max-message-size-close-text",
387 );
388 break;
389 }
390
391 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
392 message.set_header(
393 "CamelWsConnectionKey",
394 serde_json::Value::String(connection_key.clone()),
395 );
396 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
397 message.set_header(
398 "CamelWsRemoteAddress",
399 serde_json::Value::String(remote_addr.clone()),
400 );
401
402 #[allow(unused_mut)]
403 let mut exchange = Exchange::new(message);
404 #[cfg(feature = "otel")]
405 {
406 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
407 }
408 if env_tx
409 .send(ExchangeEnvelope {
410 exchange,
411 reply_tx: None,
412 })
413 .await
414 .is_err()
415 {
416 break;
417 }
418 }
419 Ok(WsMessage::Binary(data)) => {
420 if data.len() > path_config.max_message_size as usize {
421 try_send_with_backpressure(
422 &out_tx,
423 WsMessage::Close(Some(CloseFrame {
424 code: CloseCode::from(1009u16),
425 reason: "message too large".into(),
426 })),
427 "max-message-size-close-binary",
428 );
429 break;
430 }
431
432 let mut message = CamelMessage::new(CamelBody::Bytes(data));
433 message.set_header(
434 "CamelWsConnectionKey",
435 serde_json::Value::String(connection_key.clone()),
436 );
437 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
438 message.set_header(
439 "CamelWsRemoteAddress",
440 serde_json::Value::String(remote_addr.clone()),
441 );
442
443 #[allow(unused_mut)]
444 let mut exchange = Exchange::new(message);
445 #[cfg(feature = "otel")]
446 {
447 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
448 }
449 if env_tx
450 .send(ExchangeEnvelope {
451 exchange,
452 reply_tx: None,
453 })
454 .await
455 .is_err()
456 {
457 break;
458 }
459 }
460 Ok(WsMessage::Close(_)) | Err(_) => break,
461 _ => {}
462 }
463 }
464
465 if let Some(task) = heartbeat_task {
466 task.abort();
467 }
468
469 if let Some(key) = registry_key
470 && let Some(entry) = registry.get(&key)
471 {
472 entry.remove(&connection_key);
473 }
474 drop(out_tx);
475 let _ = writer.await;
476}
477
478pub struct WsComponent;
479
480impl Component for WsComponent {
481 fn scheme(&self) -> &str {
482 "ws"
483 }
484
485 fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
486 let cfg = WsEndpointConfig::from_uri(uri)?;
487 Ok(Box::new(WsEndpoint {
488 uri: uri.to_string(),
489 cfg,
490 }))
491 }
492}
493
494pub struct WssComponent;
495
496impl Component for WssComponent {
497 fn scheme(&self) -> &str {
498 "wss"
499 }
500
501 fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
502 let cfg = WsEndpointConfig::from_uri(uri)?;
503 Ok(Box::new(WsEndpoint {
504 uri: uri.to_string(),
505 cfg,
506 }))
507 }
508}
509
510struct WsEndpoint {
511 uri: String,
512 cfg: WsEndpointConfig,
513}
514
515impl Endpoint for WsEndpoint {
516 fn uri(&self) -> &str {
517 &self.uri
518 }
519
520 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
521 Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
522 }
523
524 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
525 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
526 }
527}
528
529pub struct WsConsumer {
530 cfg: WsServerConfig,
531 registry: Arc<WsConnectionRegistry>,
532 server_state: Option<WsAppState>,
533 registry_key: Option<(String, u16, String)>,
534 forward_task: Option<JoinHandle<()>>,
535}
536
537impl WsConsumer {
538 pub fn new(cfg: WsServerConfig) -> Self {
539 Self {
540 cfg,
541 registry: Arc::new(WsConnectionRegistry::new()),
542 server_state: None,
543 registry_key: None,
544 forward_task: None,
545 }
546 }
547}
548
549#[async_trait]
550impl Consumer for WsConsumer {
551 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
552 let tls_config = if self.cfg.inner.scheme == "wss" {
553 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
554 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
555 })?;
556 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
557 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
558 })?;
559 Some(WsTlsConfig {
560 cert_path,
561 key_path,
562 })
563 } else {
564 None
565 };
566
567 let state = ServerRegistry::global()
568 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
569 .await?;
570
571 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
572 {
573 let mut table = state.dispatch.write().await;
574 table.insert(self.cfg.inner.path.clone(), env_tx);
575 }
576
577 state.path_configs.insert(
578 self.cfg.inner.path.clone(),
579 WsPathConfig {
580 max_connections: self.cfg.inner.max_connections,
581 max_message_size: self.cfg.inner.max_message_size,
582 heartbeat_interval: self.cfg.inner.heartbeat_interval,
583 idle_timeout: self.cfg.inner.idle_timeout,
584 allow_origin: self.cfg.inner.allow_origin.clone(),
585 },
586 );
587
588 let registry_key = (
589 self.cfg.inner.canonical_host(),
590 self.cfg.inner.port,
591 self.cfg.inner.path.clone(),
592 );
593 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
594
595 let sender = ctx.sender();
596 let forward_task = tokio::spawn(async move {
597 while let Some(envelope) = env_rx.recv().await {
598 if sender.send(envelope).await.is_err() {
599 break;
600 }
601 }
602 });
603
604 self.server_state = Some(state);
605 self.registry_key = Some(registry_key);
606 self.forward_task = Some(forward_task);
607 Ok(())
608 }
609
610 async fn stop(&mut self) -> Result<(), CamelError> {
611 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
612 code: axum::extract::ws::CloseCode::from(1001u16),
613 reason: "consumer stopping".into(),
614 }));
615 for tx in self.registry.snapshot_senders() {
616 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
617 }
618
619 if let Some(state) = self.server_state.take() {
620 let mut table = state.dispatch.write().await;
621 table.remove(&self.cfg.inner.path);
622 state.path_configs.remove(&self.cfg.inner.path);
623 }
624
625 if let Some(key) = self.registry_key.take() {
626 global_registries().remove(&key);
627 }
628
629 if let Some(task) = self.forward_task.take() {
630 task.abort();
631 }
632
633 Ok(())
634 }
635
636 fn concurrency_model(&self) -> ConcurrencyModel {
637 ConcurrencyModel::Concurrent {
638 max: Some(self.cfg.inner.max_connections as usize),
639 }
640 }
641}
642
643#[derive(Clone)]
644pub struct WsProducer {
645 cfg: WsClientConfig,
646}
647
648impl WsProducer {
649 pub fn new(cfg: WsClientConfig) -> Self {
650 Self { cfg }
651 }
652}
653
654impl Service<Exchange> for WsProducer {
655 type Response = Exchange;
656 type Error = CamelError;
657 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
658
659 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
660 Poll::Ready(Ok(()))
661 }
662
663 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
664 let cfg = self.cfg.clone();
665
666 Box::pin(async move {
667 let canonical_host = cfg.inner.canonical_host();
668 let key = (
669 canonical_host.clone(),
670 cfg.inner.port,
671 cfg.inner.path.clone(),
672 );
673
674 let send_to_all = exchange
675 .input
676 .header("CamelWsSendToAll")
677 .and_then(|v| v.as_bool())
678 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
679 .unwrap_or(false);
680
681 let conn_keys_header = exchange
682 .input
683 .header("CamelWsConnectionKey")
684 .and_then(|v| v.as_str())
685 .map(str::to_string);
686
687 let local_exists = global_registries().contains_key(&key);
688 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
689
690 let message_type = exchange
691 .input
692 .header("CamelWsMessageType")
693 .and_then(|v| v.as_str())
694 .unwrap_or("text")
695 .to_ascii_lowercase();
696
697 if server_send_mode {
698 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
699 let Some(registry) = registry else {
700 return Err(CamelError::ProcessorError(format!(
701 "WebSocket local consumer not found for {}:{}{}",
702 canonical_host, cfg.inner.port, cfg.inner.path
703 )));
704 };
705
706 let out_msg = body_to_axum_ws_message(
707 std::mem::take(&mut exchange.input.body),
708 &message_type,
709 )
710 .await?;
711
712 let targets = if send_to_all {
713 registry.snapshot_senders()
714 } else if let Some(keys) = conn_keys_header {
715 let parsed: Vec<String> = keys
716 .split(',')
717 .map(str::trim)
718 .filter(|k| !k.is_empty())
719 .map(str::to_string)
720 .collect();
721 registry.get_senders_for_keys(&parsed)
722 } else {
723 registry.snapshot_senders()
724 };
725
726 for tx in targets {
727 let _ = try_send_with_backpressure(&tx, out_msg.clone(), "producer-send");
728 }
729
730 return Ok(exchange);
731 }
732
733 let url = format!(
734 "{}://{}:{}{}",
735 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
736 );
737
738 #[allow(unused_mut)]
739 let mut request = url
740 .clone()
741 .into_client_request()
742 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
743
744 #[cfg(feature = "otel")]
745 {
746 let mut otel_headers = HashMap::new();
747 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
748 for (k, v) in otel_headers {
749 if let (Ok(name), Ok(val)) = (
750 http::header::HeaderName::from_bytes(k.as_bytes()),
751 http::header::HeaderValue::from_str(&v),
752 ) {
753 request.headers_mut().insert(name, val);
754 }
755 }
756 }
757
758 let connect_future = tokio_tungstenite::connect_async(request);
759 let (mut ws_stream, _) =
760 tokio::time::timeout(cfg.inner.connect_timeout, connect_future)
761 .await
762 .map_err(|_| {
763 CamelError::ProcessorError(format!(
764 "WebSocket connect timeout ({:?}) to {url}",
765 cfg.inner.connect_timeout
766 ))
767 })?
768 .map_err(|e| map_connect_error(e, &url))?;
769
770 let out_msg =
771 body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
772 .await?;
773
774 ws_stream
775 .send(out_msg)
776 .await
777 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
778
779 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
780 loop {
781 match ws_stream.next().await {
782 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
783 continue;
784 }
785 other => break other,
786 }
787 }
788 })
789 .await
790 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
791
792 match incoming {
793 Some(Ok(ClientWsMessage::Text(text))) => {
794 exchange.input.body = CamelBody::Text(text.to_string());
795 }
796 Some(Ok(ClientWsMessage::Binary(data))) => {
797 exchange.input.body = CamelBody::Bytes(data);
798 }
799 Some(Ok(ClientWsMessage::Close(frame))) => {
800 let normal = frame
801 .as_ref()
802 .map(|f| {
803 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
804 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
805 })
806 .unwrap_or(true);
807
808 if normal {
809 exchange.input.body = CamelBody::Empty;
810 } else {
811 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
812 return Err(CamelError::ProcessorError(format!(
813 "WebSocket peer closed: code {code}"
814 )));
815 }
816 }
817 Some(Ok(_)) | None => {
818 exchange.input.body = CamelBody::Empty;
819 }
820 Some(Err(e)) => {
821 return Err(CamelError::ProcessorError(format!(
822 "WebSocket receive failed: {e}"
823 )));
824 }
825 }
826
827 let _ = ws_stream.close(None).await;
828 Ok(exchange)
829 })
830 }
831}
832
833async fn body_to_axum_ws_message(
834 body: CamelBody,
835 message_type: &str,
836) -> Result<WsMessage, CamelError> {
837 match message_type {
838 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
839 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
840 }
841}
842
843async fn body_to_client_ws_message(
844 body: CamelBody,
845 message_type: &str,
846) -> Result<ClientWsMessage, CamelError> {
847 match message_type {
848 "binary" => Ok(ClientWsMessage::Binary(
849 body.into_bytes(10 * 1024 * 1024).await?,
850 )),
851 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
852 }
853}
854
855async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
856 Ok(match body {
857 CamelBody::Empty => String::new(),
858 CamelBody::Text(s) => s,
859 CamelBody::Xml(s) => s,
860 CamelBody::Json(v) => v.to_string(),
861 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
862 CamelBody::Stream(stream) => {
863 let bytes = CamelBody::Stream(stream)
864 .into_bytes(10 * 1024 * 1024)
865 .await?;
866 String::from_utf8_lossy(&bytes).to_string()
867 }
868 })
869}
870
871fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
872 if allowed_origin == "*" {
873 return true;
874 }
875 request_origin.is_some_and(|origin| origin == allowed_origin)
876}
877
878fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
879 match tx.try_send(msg) {
880 Ok(()) => true,
881 Err(error) => {
882 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
883 false
884 }
885 }
886}
887
888fn load_tls_config(
889 cert_path: &str,
890 key_path: &str,
891) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
892 use std::fs::File;
893 use std::io::BufReader;
894
895 let cert_file = File::open(cert_path)
896 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
897 let key_file = File::open(key_path)
898 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
899
900 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
901 .collect::<Result<Vec<_>, _>>()
902 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
903
904 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
905 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
906 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
907
908 tokio_rustls::rustls::ServerConfig::builder()
909 .with_no_client_auth()
910 .with_single_cert(certs, key)
911 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
912}
913
914fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
915 match err {
916 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
917 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
918 }
919 tungstenite::Error::Tls(_) => {
920 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
921 }
922 other => {
923 let msg = other.to_string();
924 if msg.to_lowercase().contains("connection refused") {
925 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
926 } else if msg.to_lowercase().contains("tls") {
927 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
928 } else {
929 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
930 }
931 }
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938 use std::time::Duration;
939
940 use tokio::sync::mpsc;
941 use tokio_tungstenite::connect_async;
942 use tokio_tungstenite::tungstenite::Message as ClientMessage;
943 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
944 use tokio_util::sync::CancellationToken;
945 use tower::ServiceExt;
946
947 fn free_port() -> u16 {
948 std::net::TcpListener::bind("127.0.0.1:0")
949 .unwrap()
950 .local_addr()
951 .unwrap()
952 .port()
953 }
954
955 #[test]
956 fn ws_component_scheme_is_ws() {
957 assert_eq!(WsComponent.scheme(), "ws");
958 }
959
960 #[test]
961 fn wss_component_scheme_is_wss() {
962 assert_eq!(WssComponent.scheme(), "wss");
963 }
964
965 #[test]
966 fn endpoint_config_defaults_match_spec() {
967 let cfg = WsEndpointConfig::default();
968 assert_eq!(cfg.scheme, "ws");
969 assert_eq!(cfg.host, "0.0.0.0");
970 assert_eq!(cfg.port, 8080);
971 assert_eq!(cfg.path, "/");
972 assert_eq!(cfg.max_connections, 100);
973 assert_eq!(cfg.max_message_size, 65536);
974 assert!(!cfg.send_to_all);
975 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
976 assert_eq!(cfg.idle_timeout, Duration::ZERO);
977 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
978 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
979 assert_eq!(cfg.allow_origin, "*");
980 assert_eq!(cfg.tls_cert, None);
981 assert_eq!(cfg.tls_key, None);
982 }
983
984 #[test]
985 fn endpoint_config_parses_uri_params() {
986 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
987 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
988
989 assert_eq!(cfg.scheme, "ws");
990 assert_eq!(cfg.host, "localhost");
991 assert_eq!(cfg.port, 9001);
992 assert_eq!(cfg.path, "/chat");
993 assert_eq!(cfg.max_connections, 42);
994 assert_eq!(cfg.max_message_size, 1024);
995 assert!(cfg.send_to_all);
996 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
997 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
998 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
999 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1000 assert_eq!(cfg.allow_origin, "https://example.com");
1001 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1002 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1003 }
1004
1005 #[test]
1006 fn endpoint_config_override_chain_uri_overrides_defaults() {
1007 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1008 assert_eq!(cfg.max_connections, 7);
1009 assert_eq!(cfg.max_message_size, 65536);
1010 assert!(!cfg.send_to_all);
1011 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1012 }
1013
1014 #[test]
1015 fn endpoint_trait_creates_consumer_and_producer() {
1016 let endpoint = WsComponent
1017 .create_endpoint("ws://127.0.0.1:9010/trait")
1018 .unwrap();
1019
1020 endpoint.create_consumer().unwrap();
1021 endpoint
1022 .create_producer(&ProducerContext::default())
1023 .unwrap();
1024 }
1025
1026 #[test]
1027 fn ws_consumer_concurrency_model_uses_max_connections() {
1028 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1029 let consumer = WsConsumer::new(cfg.server_config());
1030 assert_eq!(
1031 consumer.concurrency_model(),
1032 ConcurrencyModel::Concurrent { max: Some(321) }
1033 );
1034 }
1035
1036 #[tokio::test]
1037 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1038 let registry = WsConnectionRegistry::new();
1039 let (tx1, mut rx1) = mpsc::channel(8);
1040 let (tx2, mut rx2) = mpsc::channel(8);
1041
1042 registry.insert("k1".into(), tx1);
1043 registry.insert("k2".into(), tx2);
1044 assert_eq!(registry.len(), 2);
1045
1046 for tx in registry.snapshot_senders() {
1047 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1048 }
1049
1050 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1051 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1052
1053 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1054 assert_eq!(target.len(), 1);
1055 target[0]
1056 .send(WsMessage::Text("targeted".into()))
1057 .await
1058 .unwrap();
1059
1060 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1061 assert!(
1062 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1063 .await
1064 .is_err()
1065 );
1066
1067 registry.remove("k1");
1068 assert_eq!(registry.len(), 1);
1069 }
1070
1071 #[test]
1072 fn host_canonicalization_maps_local_hosts_to_loopback() {
1073 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1074 .unwrap()
1075 .canonical_host();
1076 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1077 .unwrap()
1078 .canonical_host();
1079 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1080 .unwrap()
1081 .canonical_host();
1082
1083 assert_eq!(c1, "127.0.0.1");
1084 assert_eq!(c2, "127.0.0.1");
1085 assert_eq!(c3, "127.0.0.1");
1086 }
1087
1088 #[tokio::test]
1089 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1090 let port = free_port();
1091 let uri = format!("ws://127.0.0.1:{port}/echo");
1092 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1093
1094 let mut consumer = endpoint.create_consumer().unwrap();
1095 let producer = endpoint
1096 .create_producer(&ProducerContext::default())
1097 .unwrap();
1098
1099 let (route_tx, mut route_rx) = mpsc::channel(16);
1100 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1101 consumer.start(ctx).await.unwrap();
1102
1103 let route_task = tokio::spawn(async move {
1104 if let Some(envelope) = route_rx.recv().await {
1105 let payload = envelope
1106 .exchange
1107 .input
1108 .body
1109 .as_text()
1110 .unwrap_or_default()
1111 .to_string();
1112 let key = envelope
1113 .exchange
1114 .input
1115 .header("CamelWsConnectionKey")
1116 .and_then(|v| v.as_str())
1117 .unwrap()
1118 .to_string();
1119
1120 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1121 response
1122 .input
1123 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1124 producer.oneshot(response).await.unwrap();
1125 }
1126 });
1127
1128 let url = format!("ws://127.0.0.1:{port}/echo");
1129 let (mut client, _) = loop {
1130 match connect_async(&url).await {
1131 Ok(ok) => break ok,
1132 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1133 }
1134 };
1135
1136 client
1137 .send(ClientMessage::Text("hello-ws".into()))
1138 .await
1139 .unwrap();
1140
1141 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1142 loop {
1143 match client.next().await {
1144 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1145 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1146 Some(Ok(_)) => continue,
1147 Some(Err(e)) => panic!("ws receive failed: {e}"),
1148 None => panic!("websocket closed before echo"),
1149 }
1150 }
1151 })
1152 .await
1153 .unwrap();
1154
1155 assert_eq!(incoming, "hello-ws");
1156
1157 consumer.stop().await.unwrap();
1158 route_task.await.unwrap();
1159 }
1160
1161 #[tokio::test]
1162 async fn consumer_stop_sends_close_1001() {
1163 let port = free_port();
1164 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1165 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1166
1167 let mut consumer = endpoint.create_consumer().unwrap();
1168 let (route_tx, _route_rx) = mpsc::channel(16);
1169 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1170 consumer.start(ctx).await.unwrap();
1171
1172 let url = format!("ws://127.0.0.1:{port}/shutdown");
1173 let (mut client, _) = loop {
1174 match connect_async(&url).await {
1175 Ok(ok) => break ok,
1176 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1177 }
1178 };
1179
1180 client
1181 .send(ClientMessage::Text("keepalive".into()))
1182 .await
1183 .unwrap();
1184
1185 consumer.stop().await.unwrap();
1186
1187 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1188 loop {
1189 match client.next().await {
1190 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1191 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1192 Some(Ok(_)) => continue,
1193 Some(Err(e)) => panic!("ws receive failed: {e}"),
1194 None => panic!("websocket closed without close frame"),
1195 }
1196 }
1197 })
1198 .await
1199 .unwrap();
1200
1201 assert_eq!(close_code, Some(CloseCode::Away));
1202 }
1203
1204 #[test]
1205 fn wildcard_origin_allows_anything() {
1206 assert!(is_origin_allowed("*", None));
1207 assert!(is_origin_allowed("*", Some("https://example.com")));
1208 }
1209
1210 #[test]
1211 fn exact_origin_requires_match() {
1212 assert!(is_origin_allowed(
1213 "https://example.com",
1214 Some("https://example.com")
1215 ));
1216 assert!(!is_origin_allowed(
1217 "https://example.com",
1218 Some("https://other.com")
1219 ));
1220 assert!(!is_origin_allowed("https://example.com", None));
1221 }
1222
1223 #[test]
1224 fn endpoint_config_rejects_invalid_scheme() {
1225 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1226 assert!(result.is_err());
1227 let msg = result.unwrap_err().to_string();
1228 assert!(
1229 msg.contains("Invalid WebSocket scheme"),
1230 "expected scheme error, got: {msg}"
1231 );
1232 }
1233
1234 #[tokio::test]
1235 async fn wss_consumer_start_fails_without_tls_cert() {
1236 let port = free_port();
1237 let endpoint = WssComponent
1238 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"))
1239 .unwrap();
1240 let mut consumer = endpoint.create_consumer().unwrap();
1241 let (tx, _rx) = mpsc::channel(16);
1242 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1243 let result = consumer.start(ctx).await;
1244 assert!(result.is_err());
1245 let msg = result.unwrap_err().to_string();
1246 assert!(
1247 msg.contains("TLS cert path is required"),
1248 "expected TLS cert error, got: {msg}"
1249 );
1250 }
1251
1252 #[tokio::test]
1253 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1254 let port = free_port();
1255 let endpoint = WssComponent
1256 .create_endpoint(&format!(
1257 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1258 ))
1259 .unwrap();
1260 let mut consumer = endpoint.create_consumer().unwrap();
1261 let (tx, _rx) = mpsc::channel(16);
1262 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1263 let result = consumer.start(ctx).await;
1264 assert!(result.is_err());
1265 let msg = result.unwrap_err().to_string();
1266 assert!(
1267 msg.contains("TLS cert file error"),
1268 "expected cert file error, got: {msg}"
1269 );
1270 }
1271
1272 #[tokio::test]
1273 async fn server_registry_returns_same_state_for_same_port() {
1274 let port = free_port();
1275 let state1 = ServerRegistry::global()
1276 .get_or_spawn("127.0.0.1", port, None)
1277 .await
1278 .unwrap();
1279 let state2 = ServerRegistry::global()
1280 .get_or_spawn("127.0.0.1", port, None)
1281 .await
1282 .unwrap();
1283 assert!(
1284 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1285 "expected same dispatch table for same port"
1286 );
1287 }
1288
1289 #[tokio::test]
1290 async fn dispatch_handler_returns_404_for_unregistered_path() {
1291 let port = free_port();
1292 let state = ServerRegistry::global()
1293 .get_or_spawn("127.0.0.1", port, None)
1294 .await
1295 .unwrap();
1296 let app = Router::new().fallback(dispatch_handler).with_state(state);
1297 let response = tokio::time::timeout(
1298 Duration::from_secs(2),
1299 tower::ServiceExt::oneshot(
1300 app,
1301 axum::http::Request::builder()
1302 .uri("/nonexistent")
1303 .body(Body::empty())
1304 .unwrap(),
1305 ),
1306 )
1307 .await
1308 .unwrap()
1309 .unwrap();
1310 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1311 }
1312
1313 #[tokio::test]
1314 async fn client_mode_producer_connects_and_echoes() {
1315 let port = free_port();
1316
1317 let app = Router::new().route(
1318 "/echo",
1319 axum::routing::get(|ws: WebSocketUpgrade| async move {
1320 ws.on_upgrade(|mut socket: WebSocket| async move {
1321 while let Some(Ok(msg)) = socket.recv().await {
1322 match msg {
1323 WsMessage::Text(text) => {
1324 let _ = socket.send(WsMessage::Text(text)).await;
1325 }
1326 WsMessage::Binary(data) => {
1327 let _ = socket.send(WsMessage::Binary(data)).await;
1328 }
1329 WsMessage::Close(_) => break,
1330 _ => {}
1331 }
1332 }
1333 })
1334 }),
1335 );
1336 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1337 .await
1338 .unwrap();
1339 let server_task = tokio::spawn(async move {
1340 let _ = serve(listener, app).await;
1341 });
1342
1343 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1344 let producer = WsProducer::new(cfg.client_config());
1345
1346 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1347 let result = loop {
1348 tokio::time::sleep(Duration::from_millis(25)).await;
1349 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1350 Ok(Ok(r)) => break r,
1351 Ok(Err(_)) => panic!("producer call failed"),
1352 Err(_) => panic!("producer call timed out"),
1353 }
1354 };
1355
1356 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1357
1358 server_task.abort();
1359 }
1360
1361 #[tokio::test]
1362 async fn max_connections_rejects_with_close_1013() {
1363 let port = free_port();
1364 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1365 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1366 let mut consumer = endpoint.create_consumer().unwrap();
1367 let (route_tx, _route_rx) = mpsc::channel(16);
1368 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1369 consumer.start(ctx).await.unwrap();
1370
1371 let url = format!("ws://127.0.0.1:{port}/limited");
1372 let (_client1, _) = loop {
1373 match connect_async(&url).await {
1374 Ok(ok) => break ok,
1375 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1376 }
1377 };
1378
1379 tokio::time::sleep(Duration::from_millis(100)).await;
1380
1381 let (mut client2, _) = connect_async(&url).await.unwrap();
1382
1383 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1384 loop {
1385 match client2.next().await {
1386 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1387 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1388 Some(Ok(ClientMessage::Text(_))) => continue,
1389 Some(Ok(_)) => continue,
1390 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1391 None => panic!("client2 closed without close frame"),
1392 }
1393 }
1394 })
1395 .await
1396 .unwrap();
1397
1398 assert_eq!(
1399 close_code,
1400 Some(CloseCode::from(1013u16)),
1401 "expected 1013 (Try Again Later) for max connections"
1402 );
1403
1404 consumer.stop().await.unwrap();
1405 }
1406
1407 #[tokio::test]
1408 async fn max_message_size_rejects_with_close_1009() {
1409 let port = free_port();
1410 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1411 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1412 let mut consumer = endpoint.create_consumer().unwrap();
1413 let (route_tx, _route_rx) = mpsc::channel(16);
1414 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1415 consumer.start(ctx).await.unwrap();
1416
1417 let url = format!("ws://127.0.0.1:{port}/sizelimit");
1418 let (mut client, _) = loop {
1419 match connect_async(&url).await {
1420 Ok(ok) => break ok,
1421 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1422 }
1423 };
1424
1425 let oversized = "x".repeat(100);
1426 client
1427 .send(ClientMessage::Text(oversized.into()))
1428 .await
1429 .unwrap();
1430
1431 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1432 loop {
1433 match client.next().await {
1434 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1435 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1436 Some(Ok(_)) => continue,
1437 Some(Err(e)) => panic!("ws receive failed: {e}"),
1438 None => panic!("websocket closed without close frame"),
1439 }
1440 }
1441 })
1442 .await
1443 .unwrap();
1444
1445 assert_eq!(
1446 close_code,
1447 Some(CloseCode::from(1009u16)),
1448 "expected 1009 (Message Too Big) for oversized message"
1449 );
1450
1451 consumer.stop().await.unwrap();
1452 }
1453
1454 #[tokio::test]
1455 async fn origin_rejection_returns_403() {
1456 let port = free_port();
1457 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1458 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1459 let mut consumer = endpoint.create_consumer().unwrap();
1460 let (route_tx, _route_rx) = mpsc::channel(16);
1461 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1462 consumer.start(ctx).await.unwrap();
1463
1464 let state = ServerRegistry::global()
1465 .get_or_spawn("127.0.0.1", port, None)
1466 .await
1467 .unwrap();
1468 let app = Router::new().fallback(dispatch_handler).with_state(state);
1469
1470 let response = tokio::time::timeout(
1471 Duration::from_secs(2),
1472 tower::ServiceExt::oneshot(
1473 app,
1474 axum::http::Request::builder()
1475 .uri("/origintest")
1476 .header("origin", "https://evil.com")
1477 .header("upgrade", "websocket")
1478 .header("connection", "Upgrade")
1479 .header("sec-websocket-version", "13")
1480 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1481 .body(Body::empty())
1482 .unwrap(),
1483 ),
1484 )
1485 .await
1486 .unwrap()
1487 .unwrap();
1488
1489 assert_eq!(
1490 response.status(),
1491 StatusCode::FORBIDDEN,
1492 "expected 403 for disallowed origin"
1493 );
1494
1495 consumer.stop().await.unwrap();
1496 }
1497
1498 #[tokio::test]
1499 async fn broadcast_sends_to_all_connected_clients() {
1500 let port = free_port();
1501 let uri = format!("ws://127.0.0.1:{port}/bc");
1502 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1503 let mut consumer = endpoint.create_consumer().unwrap();
1504 let producer = endpoint
1505 .create_producer(&ProducerContext::default())
1506 .unwrap();
1507
1508 let (route_tx, _route_rx) = mpsc::channel(16);
1509 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1510 consumer.start(ctx).await.unwrap();
1511
1512 let url = format!("ws://127.0.0.1:{port}/bc");
1513
1514 let (mut client1, _) = loop {
1515 match connect_async(&url).await {
1516 Ok(ok) => break ok,
1517 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1518 }
1519 };
1520
1521 let (mut client2, _) = connect_async(&url).await.unwrap();
1522
1523 tokio::time::sleep(Duration::from_millis(100)).await;
1524
1525 let mut response =
1526 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1527 response
1528 .input
1529 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1530 producer.oneshot(response).await.unwrap();
1531
1532 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1533 loop {
1534 match client1.next().await {
1535 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1536 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1537 _ => panic!("client1 unexpected message or close"),
1538 }
1539 }
1540 })
1541 .await
1542 .unwrap();
1543
1544 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1545 loop {
1546 match client2.next().await {
1547 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1548 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1549 _ => panic!("client2 unexpected message or close"),
1550 }
1551 }
1552 })
1553 .await
1554 .unwrap();
1555
1556 assert_eq!(recv1, "broadcast-msg");
1557 assert_eq!(recv2, "broadcast-msg");
1558
1559 consumer.stop().await.unwrap();
1560 }
1561
1562 #[tokio::test]
1563 async fn concurrent_get_or_spawn_returns_same_state() {
1564 let port = free_port();
1565 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1566 Arc::new(std::sync::Mutex::new(Vec::new()));
1567
1568 let mut handles = Vec::new();
1569 for _ in 0..4 {
1570 let results = results.clone();
1571 handles.push(tokio::spawn(async move {
1572 let state = ServerRegistry::global()
1573 .get_or_spawn("127.0.0.1", port, None)
1574 .await
1575 .unwrap();
1576 results.lock().unwrap().push(state);
1577 }));
1578 }
1579
1580 for h in handles {
1581 h.await.unwrap();
1582 }
1583
1584 let states = results.lock().unwrap();
1585 assert_eq!(states.len(), 4);
1586 for i in 1..states.len() {
1587 assert!(
1588 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1589 "all concurrent callers should get the same dispatch table"
1590 );
1591 }
1592 }
1593}