1pub mod config;
2
3pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex, OnceLock};
7
8use async_trait::async_trait;
9use axum::body::Body;
10use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
11use axum::extract::{FromRequest, Request, State};
12use axum::http::{StatusCode, header};
13use axum::response::IntoResponse;
14use axum::{Router, serve};
15use camel_component_api::{
16 Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
17};
18use camel_component_api::{
19 Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
20 ProducerContext,
21};
22use dashmap::DashMap;
23use futures::{SinkExt, StreamExt};
24use std::future::Future;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use tokio::sync::{OnceCell, RwLock, mpsc};
28use tokio::task::JoinHandle;
29use tokio_tungstenite::tungstenite;
30use tokio_tungstenite::tungstenite::client::IntoClientRequest;
31use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
32use tower::Service;
33
34#[derive(Clone)]
35struct WsPathConfig {
36 max_connections: u32,
37 max_message_size: u32,
38 heartbeat_interval: std::time::Duration,
39 idle_timeout: std::time::Duration,
40 allow_origin: String,
41}
42
43impl Default for WsPathConfig {
44 fn default() -> Self {
45 let cfg = WsEndpointConfig::default();
46 Self {
47 max_connections: cfg.max_connections,
48 max_message_size: cfg.max_message_size,
49 heartbeat_interval: cfg.heartbeat_interval,
50 idle_timeout: cfg.idle_timeout,
51 allow_origin: cfg.allow_origin,
52 }
53 }
54}
55
56#[derive(Clone)]
57struct WsTlsConfig {
58 cert_path: String,
59 key_path: String,
60}
61
62type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
63
64struct ServerHandle {
65 state: WsAppState,
66 is_tls: bool,
67 _task: JoinHandle<()>,
68}
69
70pub struct ServerRegistry {
71 inner: Mutex<HashMap<u16, Arc<OnceCell<ServerHandle>>>>,
72}
73
74impl ServerRegistry {
75 pub fn global() -> &'static Self {
76 static REG: OnceLock<ServerRegistry> = OnceLock::new();
77 REG.get_or_init(|| Self {
78 inner: Mutex::new(HashMap::new()),
79 })
80 }
81
82 pub(crate) async fn get_or_spawn(
83 &'static self,
84 host: &str,
85 port: u16,
86 tls_config: Option<WsTlsConfig>,
87 ) -> Result<WsAppState, CamelError> {
88 let wants_tls = tls_config.is_some();
89 let host_owned = host.to_string();
90
91 let cell = {
92 let mut guard = self.inner.lock().map_err(|_| {
93 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
94 })?;
95 guard
96 .entry(port)
97 .or_insert_with(|| Arc::new(OnceCell::new()))
98 .clone()
99 };
100
101 let handle = cell
102 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
103 .await?;
104
105 if wants_tls != handle.is_tls {
106 return Err(CamelError::EndpointCreationFailed(format!(
107 "Server on port {port} already running with different TLS mode"
108 )));
109 }
110
111 Ok(handle.state.clone())
112 }
113}
114
115async fn spawn_server(
116 host: &str,
117 port: u16,
118 tls_config: Option<WsTlsConfig>,
119) -> Result<ServerHandle, CamelError> {
120 let addr = format!("{host}:{port}");
121 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
122 let path_configs = Arc::new(DashMap::new());
123 let state = WsAppState {
124 dispatch: Arc::clone(&dispatch),
125 path_configs: Arc::clone(&path_configs),
126 };
127 let app = Router::new()
128 .fallback(dispatch_handler)
129 .with_state(state.clone());
130
131 let (task, is_tls) = if let Some(ref tls) = tls_config {
132 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
133 let parsed_addr = addr.parse().map_err(|e| {
134 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
135 })?;
136 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
137 let task = tokio::spawn(async move {
138 let _ = axum_server::bind_rustls(parsed_addr, tls_cfg)
139 .serve(app.into_make_service())
140 .await;
141 });
142 (task, true)
143 } else {
144 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
145 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
146 })?;
147 let task = tokio::spawn(async move {
148 let _ = serve(listener, app).await;
149 });
150 (task, false)
151 };
152
153 Ok(ServerHandle {
154 state,
155 is_tls,
156 _task: task,
157 })
158}
159
160#[derive(Clone)]
161struct WsAppState {
162 dispatch: DispatchTable,
163 path_configs: Arc<DashMap<String, WsPathConfig>>,
164}
165
166pub struct WsConnectionRegistry {
167 connections: DashMap<String, mpsc::Sender<WsMessage>>,
168}
169
170static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
171 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
172> = OnceLock::new();
173
174fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
175 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
176}
177
178impl Default for WsConnectionRegistry {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184impl WsConnectionRegistry {
185 pub fn new() -> Self {
186 Self {
187 connections: DashMap::new(),
188 }
189 }
190
191 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
192 self.connections.insert(key, tx);
193 }
194
195 pub fn remove(&self, key: &str) {
196 self.connections.remove(key);
197 }
198
199 pub fn len(&self) -> usize {
200 self.connections.len()
201 }
202
203 pub fn is_empty(&self) -> bool {
204 self.connections.is_empty()
205 }
206
207 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
208 self.connections.iter().map(|e| e.value().clone()).collect()
209 }
210
211 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
212 keys.iter()
213 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
214 .collect()
215 }
216}
217
218async fn dispatch_handler(
219 State(state): State<WsAppState>,
220 req: Request<Body>,
221) -> impl IntoResponse {
222 let path = req.uri().path().to_string();
223 let origin = req
224 .headers()
225 .get(header::ORIGIN)
226 .and_then(|value| value.to_str().ok())
227 .map(str::to_string);
228 let remote_addr = req
229 .extensions()
230 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
231 .map(|ci| ci.0.to_string())
232 .unwrap_or_default();
233 let table = state.dispatch.read().await;
234 if !table.contains_key(&path) {
235 return (
236 StatusCode::NOT_FOUND,
237 "no ws endpoint registered for this path",
238 )
239 .into_response();
240 }
241 drop(table);
242
243 let path_config = state
244 .path_configs
245 .get(&path)
246 .map(|entry| entry.value().clone())
247 .unwrap_or_default();
248 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
249 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
250 }
251
252 let upgrade_headers: HashMap<String, String> = req
253 .headers()
254 .iter()
255 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
256 .collect();
257
258 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
259 Ok(ws) => ws,
260 Err(_) => {
261 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
262 }
263 };
264
265 ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
266 .into_response()
267}
268
269#[allow(unused_variables)]
270async fn ws_handler(
271 socket: WebSocket,
272 state: WsAppState,
273 path: String,
274 remote_addr: String,
275 upgrade_headers: HashMap<String, String>,
276) {
277 let connection_key = uuid::Uuid::new_v4().to_string();
278 let path_config = state
279 .path_configs
280 .get(&path)
281 .map(|entry| entry.value().clone())
282 .unwrap_or_default();
283
284 let env_tx = {
285 let table = state.dispatch.read().await;
286 table.get(&path).cloned()
287 };
288 let Some(env_tx) = env_tx else {
289 return;
290 };
291
292 let (mut sink, mut stream) = socket.split();
293 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
294
295 let registry = global_registries();
296 let mut registry_key = None;
297 for entry in registry.iter() {
298 if entry.key().2 == path {
299 entry.value().insert(connection_key.clone(), out_tx.clone());
300 registry_key = Some(entry.key().clone());
301 break;
302 }
303 }
304
305 let writer = tokio::spawn(async move {
306 while let Some(msg) = out_rx.recv().await {
307 let _ = sink.send(msg).await;
308 }
309 });
310
311 let mut over_limit = false;
312 if let Some(key) = ®istry_key
313 && let Some(entry) = registry.get(key)
314 && entry.len() > path_config.max_connections as usize
315 {
316 over_limit = true;
317 }
318 if over_limit {
319 try_send_with_backpressure(
320 &out_tx,
321 WsMessage::Close(Some(CloseFrame {
322 code: CloseCode::from(1013u16),
323 reason: "max connections exceeded".into(),
324 })),
325 "max-connections-close",
326 );
327 if let Some(key) = registry_key.clone()
328 && let Some(entry) = registry.get(&key)
329 {
330 entry.remove(&connection_key);
331 }
332 drop(out_tx);
333 let _ = writer.await;
334 return;
335 }
336
337 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
338 let ping_tx = out_tx.clone();
339 let interval = path_config.heartbeat_interval;
340 Some(tokio::spawn(async move {
341 let mut ticker = tokio::time::interval(interval);
342 loop {
343 ticker.tick().await;
344 let _ = try_send_with_backpressure(
345 &ping_tx,
346 WsMessage::Ping(Vec::new().into()),
347 "heartbeat-ping",
348 );
349 }
350 }))
351 } else {
352 None
353 };
354
355 loop {
356 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
357 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
358 Ok(msg) => msg,
359 Err(_) => {
360 try_send_with_backpressure(
361 &out_tx,
362 WsMessage::Close(Some(CloseFrame {
363 code: CloseCode::from(1000u16),
364 reason: "idle timeout".into(),
365 })),
366 "idle-timeout-close",
367 );
368 break;
369 }
370 }
371 } else {
372 stream.next().await
373 };
374
375 let Some(msg) = next_msg else {
376 break;
377 };
378
379 match msg {
380 Ok(WsMessage::Text(text)) => {
381 if text.len() > path_config.max_message_size as usize {
382 try_send_with_backpressure(
383 &out_tx,
384 WsMessage::Close(Some(CloseFrame {
385 code: CloseCode::from(1009u16),
386 reason: "message too large".into(),
387 })),
388 "max-message-size-close-text",
389 );
390 break;
391 }
392
393 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
394 message.set_header(
395 "CamelWsConnectionKey",
396 serde_json::Value::String(connection_key.clone()),
397 );
398 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
399 message.set_header(
400 "CamelWsRemoteAddress",
401 serde_json::Value::String(remote_addr.clone()),
402 );
403
404 #[allow(unused_mut)]
405 let mut exchange = Exchange::new(message);
406 #[cfg(feature = "otel")]
407 {
408 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
409 }
410 if env_tx
411 .send(ExchangeEnvelope {
412 exchange,
413 reply_tx: None,
414 })
415 .await
416 .is_err()
417 {
418 break;
419 }
420 }
421 Ok(WsMessage::Binary(data)) => {
422 if data.len() > path_config.max_message_size as usize {
423 try_send_with_backpressure(
424 &out_tx,
425 WsMessage::Close(Some(CloseFrame {
426 code: CloseCode::from(1009u16),
427 reason: "message too large".into(),
428 })),
429 "max-message-size-close-binary",
430 );
431 break;
432 }
433
434 let mut message = CamelMessage::new(CamelBody::Bytes(data));
435 message.set_header(
436 "CamelWsConnectionKey",
437 serde_json::Value::String(connection_key.clone()),
438 );
439 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
440 message.set_header(
441 "CamelWsRemoteAddress",
442 serde_json::Value::String(remote_addr.clone()),
443 );
444
445 #[allow(unused_mut)]
446 let mut exchange = Exchange::new(message);
447 #[cfg(feature = "otel")]
448 {
449 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
450 }
451 if env_tx
452 .send(ExchangeEnvelope {
453 exchange,
454 reply_tx: None,
455 })
456 .await
457 .is_err()
458 {
459 break;
460 }
461 }
462 Ok(WsMessage::Close(_)) | Err(_) => break,
463 _ => {}
464 }
465 }
466
467 if let Some(task) = heartbeat_task {
468 task.abort();
469 }
470
471 if let Some(key) = registry_key
472 && let Some(entry) = registry.get(&key)
473 {
474 entry.remove(&connection_key);
475 }
476 drop(out_tx);
477 let _ = writer.await;
478}
479
480pub struct WsComponent;
481
482impl Component for WsComponent {
483 fn scheme(&self) -> &str {
484 "ws"
485 }
486
487 fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
488 let cfg = WsEndpointConfig::from_uri(uri)?;
489 Ok(Box::new(WsEndpoint {
490 uri: uri.to_string(),
491 cfg,
492 }))
493 }
494}
495
496pub struct WssComponent;
497
498impl Component for WssComponent {
499 fn scheme(&self) -> &str {
500 "wss"
501 }
502
503 fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
504 let cfg = WsEndpointConfig::from_uri(uri)?;
505 Ok(Box::new(WsEndpoint {
506 uri: uri.to_string(),
507 cfg,
508 }))
509 }
510}
511
512struct WsEndpoint {
513 uri: String,
514 cfg: WsEndpointConfig,
515}
516
517impl Endpoint for WsEndpoint {
518 fn uri(&self) -> &str {
519 &self.uri
520 }
521
522 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
523 Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
524 }
525
526 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
527 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
528 }
529}
530
531pub struct WsConsumer {
532 cfg: WsServerConfig,
533 registry: Arc<WsConnectionRegistry>,
534 server_state: Option<WsAppState>,
535 registry_key: Option<(String, u16, String)>,
536 forward_task: Option<JoinHandle<()>>,
537}
538
539impl WsConsumer {
540 pub fn new(cfg: WsServerConfig) -> Self {
541 Self {
542 cfg,
543 registry: Arc::new(WsConnectionRegistry::new()),
544 server_state: None,
545 registry_key: None,
546 forward_task: None,
547 }
548 }
549}
550
551#[async_trait]
552impl Consumer for WsConsumer {
553 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
554 let tls_config = if self.cfg.inner.scheme == "wss" {
555 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
556 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
557 })?;
558 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
559 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
560 })?;
561 Some(WsTlsConfig {
562 cert_path,
563 key_path,
564 })
565 } else {
566 None
567 };
568
569 let state = ServerRegistry::global()
570 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
571 .await?;
572
573 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
574 {
575 let mut table = state.dispatch.write().await;
576 table.insert(self.cfg.inner.path.clone(), env_tx);
577 }
578
579 state.path_configs.insert(
580 self.cfg.inner.path.clone(),
581 WsPathConfig {
582 max_connections: self.cfg.inner.max_connections,
583 max_message_size: self.cfg.inner.max_message_size,
584 heartbeat_interval: self.cfg.inner.heartbeat_interval,
585 idle_timeout: self.cfg.inner.idle_timeout,
586 allow_origin: self.cfg.inner.allow_origin.clone(),
587 },
588 );
589
590 let registry_key = (
591 self.cfg.inner.canonical_host(),
592 self.cfg.inner.port,
593 self.cfg.inner.path.clone(),
594 );
595 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
596
597 let sender = ctx.sender();
598 let forward_task = tokio::spawn(async move {
599 while let Some(envelope) = env_rx.recv().await {
600 if sender.send(envelope).await.is_err() {
601 break;
602 }
603 }
604 });
605
606 self.server_state = Some(state);
607 self.registry_key = Some(registry_key);
608 self.forward_task = Some(forward_task);
609 Ok(())
610 }
611
612 async fn stop(&mut self) -> Result<(), CamelError> {
613 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
614 code: axum::extract::ws::CloseCode::from(1001u16),
615 reason: "consumer stopping".into(),
616 }));
617 for tx in self.registry.snapshot_senders() {
618 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
619 }
620
621 if let Some(state) = self.server_state.take() {
622 let mut table = state.dispatch.write().await;
623 table.remove(&self.cfg.inner.path);
624 state.path_configs.remove(&self.cfg.inner.path);
625 }
626
627 if let Some(key) = self.registry_key.take() {
628 global_registries().remove(&key);
629 }
630
631 if let Some(task) = self.forward_task.take() {
632 task.abort();
633 }
634
635 Ok(())
636 }
637
638 fn concurrency_model(&self) -> ConcurrencyModel {
639 ConcurrencyModel::Concurrent {
640 max: Some(self.cfg.inner.max_connections as usize),
641 }
642 }
643}
644
645#[derive(Clone)]
646pub struct WsProducer {
647 cfg: WsClientConfig,
648}
649
650impl WsProducer {
651 pub fn new(cfg: WsClientConfig) -> Self {
652 Self { cfg }
653 }
654}
655
656impl Service<Exchange> for WsProducer {
657 type Response = Exchange;
658 type Error = CamelError;
659 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
660
661 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
662 Poll::Ready(Ok(()))
663 }
664
665 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
666 let cfg = self.cfg.clone();
667
668 Box::pin(async move {
669 let canonical_host = cfg.inner.canonical_host();
670 let key = (
671 canonical_host.clone(),
672 cfg.inner.port,
673 cfg.inner.path.clone(),
674 );
675
676 let send_to_all = exchange
677 .input
678 .header("CamelWsSendToAll")
679 .and_then(|v| v.as_bool())
680 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
681 .unwrap_or(false);
682
683 let conn_keys_header = exchange
684 .input
685 .header("CamelWsConnectionKey")
686 .and_then(|v| v.as_str())
687 .map(str::to_string);
688
689 let local_exists = global_registries().contains_key(&key);
690 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
691
692 let message_type = exchange
693 .input
694 .header("CamelWsMessageType")
695 .and_then(|v| v.as_str())
696 .unwrap_or("text")
697 .to_ascii_lowercase();
698
699 if server_send_mode {
700 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
701 let Some(registry) = registry else {
702 return Err(CamelError::ProcessorError(format!(
703 "WebSocket local consumer not found for {}:{}{}",
704 canonical_host, cfg.inner.port, cfg.inner.path
705 )));
706 };
707
708 let out_msg = body_to_axum_ws_message(
709 std::mem::take(&mut exchange.input.body),
710 &message_type,
711 )
712 .await?;
713
714 let targets = if send_to_all {
715 registry.snapshot_senders()
716 } else if let Some(keys) = conn_keys_header {
717 let parsed: Vec<String> = keys
718 .split(',')
719 .map(str::trim)
720 .filter(|k| !k.is_empty())
721 .map(str::to_string)
722 .collect();
723 registry.get_senders_for_keys(&parsed)
724 } else {
725 registry.snapshot_senders()
726 };
727
728 for tx in targets {
729 let _ = try_send_with_backpressure(&tx, out_msg.clone(), "producer-send");
730 }
731
732 return Ok(exchange);
733 }
734
735 let url = format!(
736 "{}://{}:{}{}",
737 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
738 );
739
740 #[allow(unused_mut)]
741 let mut request = url
742 .clone()
743 .into_client_request()
744 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
745
746 #[cfg(feature = "otel")]
747 {
748 let mut otel_headers = HashMap::new();
749 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
750 for (k, v) in otel_headers {
751 if let (Ok(name), Ok(val)) = (
752 http::header::HeaderName::from_bytes(k.as_bytes()),
753 http::header::HeaderValue::from_str(&v),
754 ) {
755 request.headers_mut().insert(name, val);
756 }
757 }
758 }
759
760 let connect_future = tokio_tungstenite::connect_async(request);
761 let (mut ws_stream, _) =
762 tokio::time::timeout(cfg.inner.connect_timeout, connect_future)
763 .await
764 .map_err(|_| {
765 CamelError::ProcessorError(format!(
766 "WebSocket connect timeout ({:?}) to {url}",
767 cfg.inner.connect_timeout
768 ))
769 })?
770 .map_err(|e| map_connect_error(e, &url))?;
771
772 let out_msg =
773 body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
774 .await?;
775
776 ws_stream
777 .send(out_msg)
778 .await
779 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
780
781 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
782 loop {
783 match ws_stream.next().await {
784 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
785 continue;
786 }
787 other => break other,
788 }
789 }
790 })
791 .await
792 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
793
794 match incoming {
795 Some(Ok(ClientWsMessage::Text(text))) => {
796 exchange.input.body = CamelBody::Text(text.to_string());
797 }
798 Some(Ok(ClientWsMessage::Binary(data))) => {
799 exchange.input.body = CamelBody::Bytes(data);
800 }
801 Some(Ok(ClientWsMessage::Close(frame))) => {
802 let normal = frame
803 .as_ref()
804 .map(|f| {
805 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
806 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
807 })
808 .unwrap_or(true);
809
810 if normal {
811 exchange.input.body = CamelBody::Empty;
812 } else {
813 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
814 return Err(CamelError::ProcessorError(format!(
815 "WebSocket peer closed: code {code}"
816 )));
817 }
818 }
819 Some(Ok(_)) | None => {
820 exchange.input.body = CamelBody::Empty;
821 }
822 Some(Err(e)) => {
823 return Err(CamelError::ProcessorError(format!(
824 "WebSocket receive failed: {e}"
825 )));
826 }
827 }
828
829 let _ = ws_stream.close(None).await;
830 Ok(exchange)
831 })
832 }
833}
834
835async fn body_to_axum_ws_message(
836 body: CamelBody,
837 message_type: &str,
838) -> Result<WsMessage, CamelError> {
839 match message_type {
840 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
841 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
842 }
843}
844
845async fn body_to_client_ws_message(
846 body: CamelBody,
847 message_type: &str,
848) -> Result<ClientWsMessage, CamelError> {
849 match message_type {
850 "binary" => Ok(ClientWsMessage::Binary(
851 body.into_bytes(10 * 1024 * 1024).await?,
852 )),
853 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
854 }
855}
856
857async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
858 Ok(match body {
859 CamelBody::Empty => String::new(),
860 CamelBody::Text(s) => s,
861 CamelBody::Xml(s) => s,
862 CamelBody::Json(v) => v.to_string(),
863 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
864 CamelBody::Stream(stream) => {
865 let bytes = CamelBody::Stream(stream)
866 .into_bytes(10 * 1024 * 1024)
867 .await?;
868 String::from_utf8_lossy(&bytes).to_string()
869 }
870 })
871}
872
873fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
874 if allowed_origin == "*" {
875 return true;
876 }
877 request_origin.is_some_and(|origin| origin == allowed_origin)
878}
879
880fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
881 match tx.try_send(msg) {
882 Ok(()) => true,
883 Err(error) => {
884 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
885 false
886 }
887 }
888}
889
890fn load_tls_config(
891 cert_path: &str,
892 key_path: &str,
893) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
894 use std::fs::File;
895 use std::io::BufReader;
896
897 let cert_file = File::open(cert_path)
898 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
899 let key_file = File::open(key_path)
900 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
901
902 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
903 .collect::<Result<Vec<_>, _>>()
904 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
905
906 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
907 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
908 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
909
910 tokio_rustls::rustls::ServerConfig::builder()
911 .with_no_client_auth()
912 .with_single_cert(certs, key)
913 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
914}
915
916fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
917 match err {
918 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
919 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
920 }
921 tungstenite::Error::Tls(_) => {
922 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
923 }
924 other => {
925 let msg = other.to_string();
926 if msg.to_lowercase().contains("connection refused") {
927 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
928 } else if msg.to_lowercase().contains("tls") {
929 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
930 } else {
931 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
932 }
933 }
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::*;
940 use std::time::Duration;
941
942 use tokio::sync::mpsc;
943 use tokio_tungstenite::connect_async;
944 use tokio_tungstenite::tungstenite::Message as ClientMessage;
945 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
946 use tokio_util::sync::CancellationToken;
947 use tower::ServiceExt;
948
949 fn free_port() -> u16 {
950 std::net::TcpListener::bind("127.0.0.1:0")
951 .unwrap()
952 .local_addr()
953 .unwrap()
954 .port()
955 }
956
957 #[test]
958 fn ws_component_scheme_is_ws() {
959 assert_eq!(WsComponent.scheme(), "ws");
960 }
961
962 #[test]
963 fn wss_component_scheme_is_wss() {
964 assert_eq!(WssComponent.scheme(), "wss");
965 }
966
967 #[test]
968 fn endpoint_config_defaults_match_spec() {
969 let cfg = WsEndpointConfig::default();
970 assert_eq!(cfg.scheme, "ws");
971 assert_eq!(cfg.host, "0.0.0.0");
972 assert_eq!(cfg.port, 8080);
973 assert_eq!(cfg.path, "/");
974 assert_eq!(cfg.max_connections, 100);
975 assert_eq!(cfg.max_message_size, 65536);
976 assert!(!cfg.send_to_all);
977 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
978 assert_eq!(cfg.idle_timeout, Duration::ZERO);
979 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
980 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
981 assert_eq!(cfg.allow_origin, "*");
982 assert_eq!(cfg.tls_cert, None);
983 assert_eq!(cfg.tls_key, None);
984 }
985
986 #[test]
987 fn endpoint_config_parses_uri_params() {
988 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
989 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
990
991 assert_eq!(cfg.scheme, "ws");
992 assert_eq!(cfg.host, "localhost");
993 assert_eq!(cfg.port, 9001);
994 assert_eq!(cfg.path, "/chat");
995 assert_eq!(cfg.max_connections, 42);
996 assert_eq!(cfg.max_message_size, 1024);
997 assert!(cfg.send_to_all);
998 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
999 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1000 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1001 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1002 assert_eq!(cfg.allow_origin, "https://example.com");
1003 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1004 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1005 }
1006
1007 #[test]
1008 fn endpoint_config_override_chain_uri_overrides_defaults() {
1009 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1010 assert_eq!(cfg.max_connections, 7);
1011 assert_eq!(cfg.max_message_size, 65536);
1012 assert!(!cfg.send_to_all);
1013 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1014 }
1015
1016 #[test]
1017 fn endpoint_trait_creates_consumer_and_producer() {
1018 let endpoint = WsComponent
1019 .create_endpoint("ws://127.0.0.1:9010/trait")
1020 .unwrap();
1021
1022 endpoint.create_consumer().unwrap();
1023 endpoint
1024 .create_producer(&ProducerContext::default())
1025 .unwrap();
1026 }
1027
1028 #[test]
1029 fn ws_consumer_concurrency_model_uses_max_connections() {
1030 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1031 let consumer = WsConsumer::new(cfg.server_config());
1032 assert_eq!(
1033 consumer.concurrency_model(),
1034 ConcurrencyModel::Concurrent { max: Some(321) }
1035 );
1036 }
1037
1038 #[tokio::test]
1039 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1040 let registry = WsConnectionRegistry::new();
1041 let (tx1, mut rx1) = mpsc::channel(8);
1042 let (tx2, mut rx2) = mpsc::channel(8);
1043
1044 registry.insert("k1".into(), tx1);
1045 registry.insert("k2".into(), tx2);
1046 assert_eq!(registry.len(), 2);
1047
1048 for tx in registry.snapshot_senders() {
1049 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1050 }
1051
1052 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1053 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1054
1055 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1056 assert_eq!(target.len(), 1);
1057 target[0]
1058 .send(WsMessage::Text("targeted".into()))
1059 .await
1060 .unwrap();
1061
1062 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1063 assert!(
1064 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1065 .await
1066 .is_err()
1067 );
1068
1069 registry.remove("k1");
1070 assert_eq!(registry.len(), 1);
1071 }
1072
1073 #[test]
1074 fn host_canonicalization_maps_local_hosts_to_loopback() {
1075 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1076 .unwrap()
1077 .canonical_host();
1078 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1079 .unwrap()
1080 .canonical_host();
1081 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1082 .unwrap()
1083 .canonical_host();
1084
1085 assert_eq!(c1, "127.0.0.1");
1086 assert_eq!(c2, "127.0.0.1");
1087 assert_eq!(c3, "127.0.0.1");
1088 }
1089
1090 #[tokio::test]
1091 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1092 let port = free_port();
1093 let uri = format!("ws://127.0.0.1:{port}/echo");
1094 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1095
1096 let mut consumer = endpoint.create_consumer().unwrap();
1097 let producer = endpoint
1098 .create_producer(&ProducerContext::default())
1099 .unwrap();
1100
1101 let (route_tx, mut route_rx) = mpsc::channel(16);
1102 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1103 consumer.start(ctx).await.unwrap();
1104
1105 let route_task = tokio::spawn(async move {
1106 if let Some(envelope) = route_rx.recv().await {
1107 let payload = envelope
1108 .exchange
1109 .input
1110 .body
1111 .as_text()
1112 .unwrap_or_default()
1113 .to_string();
1114 let key = envelope
1115 .exchange
1116 .input
1117 .header("CamelWsConnectionKey")
1118 .and_then(|v| v.as_str())
1119 .unwrap()
1120 .to_string();
1121
1122 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1123 response
1124 .input
1125 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1126 producer.oneshot(response).await.unwrap();
1127 }
1128 });
1129
1130 let url = format!("ws://127.0.0.1:{port}/echo");
1131 let (mut client, _) = loop {
1132 match connect_async(&url).await {
1133 Ok(ok) => break ok,
1134 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1135 }
1136 };
1137
1138 client
1139 .send(ClientMessage::Text("hello-ws".into()))
1140 .await
1141 .unwrap();
1142
1143 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1144 loop {
1145 match client.next().await {
1146 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1147 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1148 Some(Ok(_)) => continue,
1149 Some(Err(e)) => panic!("ws receive failed: {e}"),
1150 None => panic!("websocket closed before echo"),
1151 }
1152 }
1153 })
1154 .await
1155 .unwrap();
1156
1157 assert_eq!(incoming, "hello-ws");
1158
1159 consumer.stop().await.unwrap();
1160 route_task.await.unwrap();
1161 }
1162
1163 #[tokio::test]
1164 async fn consumer_stop_sends_close_1001() {
1165 let port = free_port();
1166 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1167 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1168
1169 let mut consumer = endpoint.create_consumer().unwrap();
1170 let (route_tx, _route_rx) = mpsc::channel(16);
1171 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1172 consumer.start(ctx).await.unwrap();
1173
1174 let url = format!("ws://127.0.0.1:{port}/shutdown");
1175 let (mut client, _) = loop {
1176 match connect_async(&url).await {
1177 Ok(ok) => break ok,
1178 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1179 }
1180 };
1181
1182 client
1183 .send(ClientMessage::Text("keepalive".into()))
1184 .await
1185 .unwrap();
1186
1187 consumer.stop().await.unwrap();
1188
1189 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1190 loop {
1191 match client.next().await {
1192 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1193 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1194 Some(Ok(_)) => continue,
1195 Some(Err(e)) => panic!("ws receive failed: {e}"),
1196 None => panic!("websocket closed without close frame"),
1197 }
1198 }
1199 })
1200 .await
1201 .unwrap();
1202
1203 assert_eq!(close_code, Some(CloseCode::Away));
1204 }
1205
1206 #[test]
1207 fn wildcard_origin_allows_anything() {
1208 assert!(is_origin_allowed("*", None));
1209 assert!(is_origin_allowed("*", Some("https://example.com")));
1210 }
1211
1212 #[test]
1213 fn exact_origin_requires_match() {
1214 assert!(is_origin_allowed(
1215 "https://example.com",
1216 Some("https://example.com")
1217 ));
1218 assert!(!is_origin_allowed(
1219 "https://example.com",
1220 Some("https://other.com")
1221 ));
1222 assert!(!is_origin_allowed("https://example.com", None));
1223 }
1224
1225 #[test]
1226 fn endpoint_config_rejects_invalid_scheme() {
1227 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1228 assert!(result.is_err());
1229 let msg = result.unwrap_err().to_string();
1230 assert!(
1231 msg.contains("Invalid WebSocket scheme"),
1232 "expected scheme error, got: {msg}"
1233 );
1234 }
1235
1236 #[tokio::test]
1237 async fn wss_consumer_start_fails_without_tls_cert() {
1238 let port = free_port();
1239 let endpoint = WssComponent
1240 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"))
1241 .unwrap();
1242 let mut consumer = endpoint.create_consumer().unwrap();
1243 let (tx, _rx) = mpsc::channel(16);
1244 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1245 let result = consumer.start(ctx).await;
1246 assert!(result.is_err());
1247 let msg = result.unwrap_err().to_string();
1248 assert!(
1249 msg.contains("TLS cert path is required"),
1250 "expected TLS cert error, got: {msg}"
1251 );
1252 }
1253
1254 #[tokio::test]
1255 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1256 let port = free_port();
1257 let endpoint = WssComponent
1258 .create_endpoint(&format!(
1259 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1260 ))
1261 .unwrap();
1262 let mut consumer = endpoint.create_consumer().unwrap();
1263 let (tx, _rx) = mpsc::channel(16);
1264 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1265 let result = consumer.start(ctx).await;
1266 assert!(result.is_err());
1267 let msg = result.unwrap_err().to_string();
1268 assert!(
1269 msg.contains("TLS cert file error"),
1270 "expected cert file error, got: {msg}"
1271 );
1272 }
1273
1274 #[tokio::test]
1275 async fn server_registry_returns_same_state_for_same_port() {
1276 let port = free_port();
1277 let state1 = ServerRegistry::global()
1278 .get_or_spawn("127.0.0.1", port, None)
1279 .await
1280 .unwrap();
1281 let state2 = ServerRegistry::global()
1282 .get_or_spawn("127.0.0.1", port, None)
1283 .await
1284 .unwrap();
1285 assert!(
1286 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1287 "expected same dispatch table for same port"
1288 );
1289 }
1290
1291 #[tokio::test]
1292 async fn dispatch_handler_returns_404_for_unregistered_path() {
1293 let port = free_port();
1294 let state = ServerRegistry::global()
1295 .get_or_spawn("127.0.0.1", port, None)
1296 .await
1297 .unwrap();
1298 let app = Router::new().fallback(dispatch_handler).with_state(state);
1299 let response = tokio::time::timeout(
1300 Duration::from_secs(2),
1301 tower::ServiceExt::oneshot(
1302 app,
1303 axum::http::Request::builder()
1304 .uri("/nonexistent")
1305 .body(Body::empty())
1306 .unwrap(),
1307 ),
1308 )
1309 .await
1310 .unwrap()
1311 .unwrap();
1312 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1313 }
1314
1315 #[tokio::test]
1316 async fn client_mode_producer_connects_and_echoes() {
1317 let port = free_port();
1318
1319 let app = Router::new().route(
1320 "/echo",
1321 axum::routing::get(|ws: WebSocketUpgrade| async move {
1322 ws.on_upgrade(|mut socket: WebSocket| async move {
1323 while let Some(Ok(msg)) = socket.recv().await {
1324 match msg {
1325 WsMessage::Text(text) => {
1326 let _ = socket.send(WsMessage::Text(text)).await;
1327 }
1328 WsMessage::Binary(data) => {
1329 let _ = socket.send(WsMessage::Binary(data)).await;
1330 }
1331 WsMessage::Close(_) => break,
1332 _ => {}
1333 }
1334 }
1335 })
1336 }),
1337 );
1338 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1339 .await
1340 .unwrap();
1341 let server_task = tokio::spawn(async move {
1342 let _ = serve(listener, app).await;
1343 });
1344
1345 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1346 let producer = WsProducer::new(cfg.client_config());
1347
1348 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1349 tokio::time::sleep(Duration::from_millis(25)).await;
1350 let result =
1351 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1352 Ok(Ok(r)) => r,
1353 Ok(Err(_)) => panic!("producer call failed"),
1354 Err(_) => panic!("producer call timed out"),
1355 };
1356
1357 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1358
1359 server_task.abort();
1360 }
1361
1362 #[tokio::test]
1363 async fn max_connections_rejects_with_close_1013() {
1364 let port = free_port();
1365 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1366 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1367 let mut consumer = endpoint.create_consumer().unwrap();
1368 let (route_tx, _route_rx) = mpsc::channel(16);
1369 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1370 consumer.start(ctx).await.unwrap();
1371
1372 let url = format!("ws://127.0.0.1:{port}/limited");
1373 let (_client1, _) = loop {
1374 match connect_async(&url).await {
1375 Ok(ok) => break ok,
1376 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1377 }
1378 };
1379
1380 tokio::time::sleep(Duration::from_millis(100)).await;
1381
1382 let (mut client2, _) = connect_async(&url).await.unwrap();
1383
1384 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1385 loop {
1386 match client2.next().await {
1387 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1388 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1389 Some(Ok(ClientMessage::Text(_))) => continue,
1390 Some(Ok(_)) => continue,
1391 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1392 None => panic!("client2 closed without close frame"),
1393 }
1394 }
1395 })
1396 .await
1397 .unwrap();
1398
1399 assert_eq!(
1400 close_code,
1401 Some(CloseCode::from(1013u16)),
1402 "expected 1013 (Try Again Later) for max connections"
1403 );
1404
1405 consumer.stop().await.unwrap();
1406 }
1407
1408 #[tokio::test]
1409 async fn max_message_size_rejects_with_close_1009() {
1410 let port = free_port();
1411 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1412 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1413 let mut consumer = endpoint.create_consumer().unwrap();
1414 let (route_tx, _route_rx) = mpsc::channel(16);
1415 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1416 consumer.start(ctx).await.unwrap();
1417
1418 let url = format!("ws://127.0.0.1:{port}/sizelimit");
1419 let (mut client, _) = loop {
1420 match connect_async(&url).await {
1421 Ok(ok) => break ok,
1422 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1423 }
1424 };
1425
1426 let oversized = "x".repeat(100);
1427 client
1428 .send(ClientMessage::Text(oversized.into()))
1429 .await
1430 .unwrap();
1431
1432 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1433 loop {
1434 match client.next().await {
1435 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1436 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1437 Some(Ok(_)) => continue,
1438 Some(Err(e)) => panic!("ws receive failed: {e}"),
1439 None => panic!("websocket closed without close frame"),
1440 }
1441 }
1442 })
1443 .await
1444 .unwrap();
1445
1446 assert_eq!(
1447 close_code,
1448 Some(CloseCode::from(1009u16)),
1449 "expected 1009 (Message Too Big) for oversized message"
1450 );
1451
1452 consumer.stop().await.unwrap();
1453 }
1454
1455 #[tokio::test]
1456 async fn origin_rejection_returns_403() {
1457 let port = free_port();
1458 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1459 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1460 let mut consumer = endpoint.create_consumer().unwrap();
1461 let (route_tx, _route_rx) = mpsc::channel(16);
1462 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1463 consumer.start(ctx).await.unwrap();
1464
1465 let state = ServerRegistry::global()
1466 .get_or_spawn("127.0.0.1", port, None)
1467 .await
1468 .unwrap();
1469 let app = Router::new().fallback(dispatch_handler).with_state(state);
1470
1471 let response = tokio::time::timeout(
1472 Duration::from_secs(2),
1473 tower::ServiceExt::oneshot(
1474 app,
1475 axum::http::Request::builder()
1476 .uri("/origintest")
1477 .header("origin", "https://evil.com")
1478 .header("upgrade", "websocket")
1479 .header("connection", "Upgrade")
1480 .header("sec-websocket-version", "13")
1481 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1482 .body(Body::empty())
1483 .unwrap(),
1484 ),
1485 )
1486 .await
1487 .unwrap()
1488 .unwrap();
1489
1490 assert_eq!(
1491 response.status(),
1492 StatusCode::FORBIDDEN,
1493 "expected 403 for disallowed origin"
1494 );
1495
1496 consumer.stop().await.unwrap();
1497 }
1498
1499 #[tokio::test]
1500 async fn broadcast_sends_to_all_connected_clients() {
1501 let port = free_port();
1502 let uri = format!("ws://127.0.0.1:{port}/bc");
1503 let endpoint = WsComponent.create_endpoint(&uri).unwrap();
1504 let mut consumer = endpoint.create_consumer().unwrap();
1505 let producer = endpoint
1506 .create_producer(&ProducerContext::default())
1507 .unwrap();
1508
1509 let (route_tx, _route_rx) = mpsc::channel(16);
1510 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1511 consumer.start(ctx).await.unwrap();
1512
1513 let url = format!("ws://127.0.0.1:{port}/bc");
1514
1515 let (mut client1, _) = loop {
1516 match connect_async(&url).await {
1517 Ok(ok) => break ok,
1518 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1519 }
1520 };
1521
1522 let (mut client2, _) = connect_async(&url).await.unwrap();
1523
1524 tokio::time::sleep(Duration::from_millis(100)).await;
1525
1526 let mut response =
1527 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1528 response
1529 .input
1530 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1531 producer.oneshot(response).await.unwrap();
1532
1533 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1534 loop {
1535 match client1.next().await {
1536 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1537 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1538 _ => panic!("client1 unexpected message or close"),
1539 }
1540 }
1541 })
1542 .await
1543 .unwrap();
1544
1545 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1546 loop {
1547 match client2.next().await {
1548 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1549 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1550 _ => panic!("client2 unexpected message or close"),
1551 }
1552 }
1553 })
1554 .await
1555 .unwrap();
1556
1557 assert_eq!(recv1, "broadcast-msg");
1558 assert_eq!(recv2, "broadcast-msg");
1559
1560 consumer.stop().await.unwrap();
1561 }
1562
1563 #[tokio::test]
1564 async fn concurrent_get_or_spawn_returns_same_state() {
1565 let port = free_port();
1566 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1567 Arc::new(std::sync::Mutex::new(Vec::new()));
1568
1569 let mut handles = Vec::new();
1570 for _ in 0..4 {
1571 let results = results.clone();
1572 handles.push(tokio::spawn(async move {
1573 let state = ServerRegistry::global()
1574 .get_or_spawn("127.0.0.1", port, None)
1575 .await
1576 .unwrap();
1577 results.lock().unwrap().push(state);
1578 }));
1579 }
1580
1581 for h in handles {
1582 h.await.unwrap();
1583 }
1584
1585 let states = results.lock().unwrap();
1586 assert_eq!(states.len(), 4);
1587 for i in 1..states.len() {
1588 assert!(
1589 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1590 "all concurrent callers should get the same dispatch table"
1591 );
1592 }
1593 }
1594}