1pub mod bundle;
7pub mod config;
8
9pub use bundle::WsBundle;
10pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
11
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex, OnceLock};
14
15use async_trait::async_trait;
16use axum::body::Body;
17use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
18use axum::extract::{FromRequest, Request, State};
19use axum::http::{StatusCode, header};
20use axum::response::IntoResponse;
21use axum::{Router, serve};
22use camel_component_api::{
23 Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
24};
25use camel_component_api::{
26 Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
27 ProducerContext,
28};
29use dashmap::DashMap;
30use futures::{SinkExt, StreamExt};
31use std::future::Future;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use tokio::sync::{OnceCell, RwLock, mpsc};
35use tokio::task::JoinHandle;
36use tokio_tungstenite::tungstenite;
37use tokio_tungstenite::tungstenite::client::IntoClientRequest;
38use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
39use tower::Service;
40
41#[derive(Clone)]
42struct WsPathConfig {
43 max_connections: u32,
44 max_message_size: u32,
45 heartbeat_interval: std::time::Duration,
46 idle_timeout: std::time::Duration,
47 allow_origin: String,
48}
49
50impl Default for WsPathConfig {
51 fn default() -> Self {
52 let cfg = WsEndpointConfig::default();
53 Self {
54 max_connections: cfg.max_connections,
55 max_message_size: cfg.max_message_size,
56 heartbeat_interval: cfg.heartbeat_interval,
57 idle_timeout: cfg.idle_timeout,
58 allow_origin: cfg.allow_origin,
59 }
60 }
61}
62
63#[derive(Clone)]
64struct WsTlsConfig {
65 cert_path: String,
66 key_path: String,
67}
68
69type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
70
71struct ServerHandle {
72 state: WsAppState,
73 is_tls: bool,
74 _task: JoinHandle<()>,
75}
76
77struct ServerRegistryInner {
78 cell: Arc<OnceCell<ServerHandle>>,
79 ref_count: usize,
80}
81
82pub struct ServerRegistry {
83 inner: Mutex<HashMap<u16, ServerRegistryInner>>,
84}
85
86impl ServerRegistry {
87 pub fn global() -> &'static Self {
88 static REG: OnceLock<ServerRegistry> = OnceLock::new();
89 REG.get_or_init(|| Self {
90 inner: Mutex::new(HashMap::new()),
91 })
92 }
93
94 pub(crate) async fn get_or_spawn(
95 &'static self,
96 host: &str,
97 port: u16,
98 tls_config: Option<WsTlsConfig>,
99 ) -> Result<WsAppState, CamelError> {
100 let wants_tls = tls_config.is_some();
101 let host_owned = host.to_string();
102
103 let (cell, _is_new) = {
104 let mut guard = self.inner.lock().map_err(|_| {
105 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
106 })?;
107 let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
108 cell: Arc::new(OnceCell::new()),
109 ref_count: 0,
110 });
111 entry.ref_count += 1;
112 (entry.cell.clone(), entry.ref_count == 1)
113 };
114
115 let handle = cell
116 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
117 .await?;
118
119 if wants_tls != handle.is_tls {
120 let mut guard = self.inner.lock().map_err(|_| {
122 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
123 })?;
124 if let Some(entry) = guard.get_mut(&port) {
125 entry.ref_count -= 1;
126 if entry.ref_count == 0 {
127 guard.remove(&port);
128 }
129 }
130 return Err(CamelError::EndpointCreationFailed(format!(
131 "Server on port {port} already running with different TLS mode"
132 )));
133 }
134
135 Ok(handle.state.clone())
136 }
137
138 pub(crate) fn release(&self, port: u16) {
141 let mut guard = match self.inner.lock() {
142 Ok(g) => g,
143 Err(_) => return,
144 };
145 if let Some(entry) = guard.get_mut(&port) {
146 entry.ref_count = entry.ref_count.saturating_sub(1);
147 if entry.ref_count == 0 {
148 if let Some(handle) = entry.cell.get() {
150 handle._task.abort();
151 }
152 guard.remove(&port);
153 tracing::info!(port, "WebSocket server registry entry removed");
154 }
155 }
156 }
157}
158
159async fn spawn_server(
160 host: &str,
161 port: u16,
162 tls_config: Option<WsTlsConfig>,
163) -> Result<ServerHandle, CamelError> {
164 let host_owned = host.to_string();
165 let addr = format!("{host}:{port}");
166 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
167 let path_configs = Arc::new(DashMap::new());
168 let server_error = new_atomic_false();
169 let state = WsAppState {
170 dispatch: Arc::clone(&dispatch),
171 path_configs: Arc::clone(&path_configs),
172 server_error: Arc::clone(&server_error),
173 };
174 let app = Router::new()
175 .fallback(dispatch_handler)
176 .with_state(state.clone());
177
178 let (task, is_tls) = if let Some(ref tls) = tls_config {
179 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
180 let parsed_addr = addr.parse().map_err(|e| {
181 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
182 })?;
183 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
184 let error_flag = Arc::clone(&server_error);
185 let task = tokio::spawn(async move {
186 if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
187 .serve(app.into_make_service())
188 .await
189 {
190 tracing::error!(
191 host = host_owned,
192 port = port,
193 error = %e,
194 "WebSocket server terminated with error"
195 );
196 error_flag.store(true, Ordering::Relaxed);
197 }
198 });
199 (task, true)
200 } else {
201 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
202 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
203 })?;
204 let error_flag = Arc::clone(&server_error);
205 let task = tokio::spawn(async move {
206 if let Err(e) = serve(listener, app).await {
207 tracing::error!(
208 host = host_owned,
209 port = port,
210 error = %e,
211 "WebSocket server terminated with error"
212 );
213 error_flag.store(true, Ordering::Relaxed);
214 }
215 });
216 (task, false)
217 };
218
219 tracing::info!(host, port, is_tls, "WebSocket server started");
220
221 Ok(ServerHandle {
222 state,
223 is_tls,
224 _task: task,
225 })
226}
227
228#[derive(Clone)]
229struct WsAppState {
230 dispatch: DispatchTable,
231 path_configs: Arc<DashMap<String, WsPathConfig>>,
232 server_error: Arc<AtomicBool>,
233}
234
235pub struct WsConnectionRegistry {
236 connections: DashMap<String, mpsc::Sender<WsMessage>>,
237}
238
239static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
240 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
241> = OnceLock::new();
242
243fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
244 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
245}
246
247impl Default for WsConnectionRegistry {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253impl WsConnectionRegistry {
254 pub fn new() -> Self {
255 Self {
256 connections: DashMap::new(),
257 }
258 }
259
260 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
261 self.connections.insert(key, tx);
262 }
263
264 pub fn remove(&self, key: &str) {
265 self.connections.remove(key);
266 }
267
268 pub fn len(&self) -> usize {
269 self.connections.len()
270 }
271
272 pub fn is_empty(&self) -> bool {
273 self.connections.is_empty()
274 }
275
276 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
277 self.connections.iter().map(|e| e.value().clone()).collect()
278 }
279
280 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
281 keys.iter()
282 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
283 .collect()
284 }
285}
286
287async fn dispatch_handler(
288 State(state): State<WsAppState>,
289 req: Request<Body>,
290) -> impl IntoResponse {
291 let path = req.uri().path().to_string();
292 let origin = req
293 .headers()
294 .get(header::ORIGIN)
295 .and_then(|value| value.to_str().ok())
296 .map(str::to_string);
297 let remote_addr = req
298 .extensions()
299 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
300 .map(|ci| ci.0.to_string())
301 .unwrap_or_default();
302 let table = state.dispatch.read().await;
303 if !table.contains_key(&path) {
304 return (
305 StatusCode::NOT_FOUND,
306 "no ws endpoint registered for this path",
307 )
308 .into_response();
309 }
310 drop(table);
311
312 let path_config = state
313 .path_configs
314 .get(&path)
315 .map(|entry| entry.value().clone())
316 .unwrap_or_default();
317 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
318 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
319 }
320
321 let upgrade_headers: HashMap<String, String> = req
322 .headers()
323 .iter()
324 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
325 .collect();
326
327 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
328 Ok(ws) => ws,
329 Err(_) => {
330 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
331 }
332 };
333
334 ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
335 .into_response()
336}
337
338#[allow(unused_variables)]
339async fn ws_handler(
340 socket: WebSocket,
341 state: WsAppState,
342 path: String,
343 remote_addr: String,
344 upgrade_headers: HashMap<String, String>,
345) {
346 let connection_key = uuid::Uuid::new_v4().to_string();
347 let path_config = state
348 .path_configs
349 .get(&path)
350 .map(|entry| entry.value().clone())
351 .unwrap_or_default();
352
353 let env_tx = {
354 let table = state.dispatch.read().await;
355 table.get(&path).cloned()
356 };
357 let Some(env_tx) = env_tx else {
358 return;
359 };
360
361 let (mut sink, mut stream) = socket.split();
362 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
363
364 let registry = global_registries();
365 let mut registry_key = None;
366 for entry in registry.iter() {
367 if entry.key().2 == path {
368 entry.value().insert(connection_key.clone(), out_tx.clone());
369 registry_key = Some(entry.key().clone());
370 break;
371 }
372 }
373
374 let conn_key_for_writer = connection_key.clone();
376 let path_for_writer = path.clone();
377
378 let writer = tokio::spawn(async move {
379 while let Some(msg) = out_rx.recv().await {
380 if let Err(e) = sink.send(msg).await {
381 tracing::warn!(
382 connection_key = conn_key_for_writer,
383 path = path_for_writer,
384 error = %e,
385 "WebSocket writer send error — closing connection"
386 );
387 break;
388 }
389 }
390 });
391
392 tracing::info!(
393 connection_key = connection_key,
394 path = path,
395 remote_addr = remote_addr,
396 "WebSocket connection opened"
397 );
398
399 let mut over_limit = false;
400 if let Some(key) = ®istry_key
401 && let Some(entry) = registry.get(key)
402 && entry.len() > path_config.max_connections as usize
403 {
404 over_limit = true;
405 }
406 if over_limit {
407 try_send_with_backpressure(
408 &out_tx,
409 WsMessage::Close(Some(CloseFrame {
410 code: CloseCode::from(1013u16),
411 reason: "max connections exceeded".into(),
412 })),
413 "max-connections-close",
414 );
415 if let Some(key) = registry_key.clone()
416 && let Some(entry) = registry.get(&key)
417 {
418 entry.remove(&connection_key);
419 }
420 drop(out_tx);
421 let _ = writer.await;
422 return;
423 }
424
425 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
426 let ping_tx = out_tx.clone();
427 let interval = path_config.heartbeat_interval;
428 Some(tokio::spawn(async move {
429 let mut ticker = tokio::time::interval(interval);
430 loop {
431 ticker.tick().await;
432 let _ = try_send_with_backpressure(
433 &ping_tx,
434 WsMessage::Ping(Vec::new().into()),
435 "heartbeat-ping",
436 );
437 }
438 }))
439 } else {
440 None
441 };
442
443 loop {
444 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
445 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
446 Ok(msg) => msg,
447 Err(_) => {
448 try_send_with_backpressure(
449 &out_tx,
450 WsMessage::Close(Some(CloseFrame {
451 code: CloseCode::from(1000u16),
452 reason: "idle timeout".into(),
453 })),
454 "idle-timeout-close",
455 );
456 break;
457 }
458 }
459 } else {
460 stream.next().await
461 };
462
463 let Some(msg) = next_msg else {
464 break;
465 };
466
467 match msg {
468 Ok(WsMessage::Ping(data)) => {
469 tracing::debug!(
470 connection_key = connection_key,
471 path = path,
472 "WebSocket ping received — sending pong"
473 );
474 let _ = try_send_with_backpressure(
475 &out_tx,
476 WsMessage::Pong(data),
477 "ping-pong-response",
478 );
479 }
480 Ok(WsMessage::Pong(_)) => {
481 tracing::debug!(
482 connection_key = connection_key,
483 path = path,
484 "WebSocket pong received"
485 );
486 }
487 Ok(WsMessage::Text(text)) => {
488 if text.len() > path_config.max_message_size as usize {
489 try_send_with_backpressure(
490 &out_tx,
491 WsMessage::Close(Some(CloseFrame {
492 code: CloseCode::from(1009u16),
493 reason: "message too large".into(),
494 })),
495 "max-message-size-close-text",
496 );
497 break;
498 }
499
500 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
501 message.set_header(
502 "CamelWsConnectionKey",
503 serde_json::Value::String(connection_key.clone()),
504 );
505 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
506 message.set_header(
507 "CamelWsRemoteAddress",
508 serde_json::Value::String(remote_addr.clone()),
509 );
510
511 #[allow(unused_mut)]
512 let mut exchange = Exchange::new(message);
513 #[cfg(feature = "otel")]
514 {
515 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
516 }
517 if env_tx
518 .send(ExchangeEnvelope {
519 exchange,
520 reply_tx: None,
521 })
522 .await
523 .is_err()
524 {
525 break;
526 }
527 }
528 Ok(WsMessage::Binary(data)) => {
529 if data.len() > path_config.max_message_size as usize {
530 try_send_with_backpressure(
531 &out_tx,
532 WsMessage::Close(Some(CloseFrame {
533 code: CloseCode::from(1009u16),
534 reason: "message too large".into(),
535 })),
536 "max-message-size-close-binary",
537 );
538 break;
539 }
540
541 let mut message = CamelMessage::new(CamelBody::Bytes(data));
542 message.set_header(
543 "CamelWsConnectionKey",
544 serde_json::Value::String(connection_key.clone()),
545 );
546 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
547 message.set_header(
548 "CamelWsRemoteAddress",
549 serde_json::Value::String(remote_addr.clone()),
550 );
551
552 #[allow(unused_mut)]
553 let mut exchange = Exchange::new(message);
554 #[cfg(feature = "otel")]
555 {
556 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
557 }
558 if env_tx
559 .send(ExchangeEnvelope {
560 exchange,
561 reply_tx: None,
562 })
563 .await
564 .is_err()
565 {
566 break;
567 }
568 }
569 Ok(WsMessage::Close(cf)) => {
570 let reason = cf
571 .as_ref()
572 .map(|f| f.reason.to_string())
573 .unwrap_or_default();
574 tracing::info!(
575 connection_key = connection_key,
576 path = path,
577 reason = reason,
578 "WebSocket connection closed by peer"
579 );
580 break;
581 }
582 Err(e) => {
583 tracing::warn!(
584 connection_key = connection_key,
585 path = path,
586 error = %e,
587 "WebSocket receive error"
588 );
589 break;
590 }
591 }
592 }
593
594 if let Some(task) = heartbeat_task {
595 task.abort();
596 }
597
598 if let Some(key) = registry_key
599 && let Some(entry) = registry.get(&key)
600 {
601 entry.remove(&connection_key);
602 }
603 drop(out_tx);
604 let _ = writer.await;
605
606 tracing::info!(
607 connection_key = connection_key,
608 path = path,
609 "WebSocket connection closed"
610 );
611}
612
613pub struct WsComponent {
614 pub(crate) config: WsConfig,
615}
616
617impl WsComponent {
618 pub fn new() -> Self {
619 Self {
620 config: WsConfig::default(),
621 }
622 }
623
624 pub fn with_config(config: WsConfig) -> Self {
625 Self { config }
626 }
627}
628
629impl Default for WsComponent {
630 fn default() -> Self {
631 Self::new()
632 }
633}
634
635impl Component for WsComponent {
636 fn scheme(&self) -> &str {
637 "ws"
638 }
639
640 fn create_endpoint(
641 &self,
642 uri: &str,
643 _ctx: &dyn camel_component_api::ComponentContext,
644 ) -> Result<Box<dyn Endpoint>, CamelError> {
645 self.config.validate()?;
646 let mut cfg = WsEndpointConfig::from_uri(uri)?;
647 if let Some(v) = self.config.max_connections {
648 cfg.max_connections = v;
649 }
650 if let Some(v) = self.config.max_message_size {
651 cfg.max_message_size = v;
652 }
653 if let Some(v) = self.config.heartbeat_interval_ms {
654 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
655 }
656 if let Some(v) = self.config.idle_timeout_ms {
657 cfg.idle_timeout = std::time::Duration::from_millis(v);
658 }
659 if let Some(v) = self.config.connect_timeout_ms {
660 cfg.connect_timeout = std::time::Duration::from_millis(v);
661 }
662 if let Some(v) = self.config.response_timeout_ms {
663 cfg.response_timeout = std::time::Duration::from_millis(v);
664 }
665 if let Some(v) = self.config.send_timeout_ms {
666 cfg.send_timeout = std::time::Duration::from_millis(v);
667 }
668 if let Some(v) = self.config.binary_payload {
669 cfg.binary_payload = v;
670 }
671 if let Some(ref v) = self.config.subprotocols {
672 cfg.subprotocols = v.clone();
673 }
674 Ok(Box::new(WsEndpoint {
675 uri: uri.to_string(),
676 cfg,
677 }))
678 }
679}
680
681pub struct WssComponent {
682 pub(crate) config: WsConfig,
683}
684
685impl WssComponent {
686 pub fn new() -> Self {
687 Self {
688 config: WsConfig::default(),
689 }
690 }
691
692 pub fn with_config(config: WsConfig) -> Self {
693 Self { config }
694 }
695}
696
697impl Default for WssComponent {
698 fn default() -> Self {
699 Self::new()
700 }
701}
702
703impl Component for WssComponent {
704 fn scheme(&self) -> &str {
705 "wss"
706 }
707
708 fn create_endpoint(
709 &self,
710 uri: &str,
711 _ctx: &dyn camel_component_api::ComponentContext,
712 ) -> Result<Box<dyn Endpoint>, CamelError> {
713 self.config.validate()?;
714 let mut cfg = WsEndpointConfig::from_uri(uri)?;
715 if let Some(v) = self.config.max_connections {
716 cfg.max_connections = v;
717 }
718 if let Some(v) = self.config.max_message_size {
719 cfg.max_message_size = v;
720 }
721 if let Some(v) = self.config.heartbeat_interval_ms {
722 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
723 }
724 if let Some(v) = self.config.idle_timeout_ms {
725 cfg.idle_timeout = std::time::Duration::from_millis(v);
726 }
727 if let Some(v) = self.config.connect_timeout_ms {
728 cfg.connect_timeout = std::time::Duration::from_millis(v);
729 }
730 if let Some(v) = self.config.response_timeout_ms {
731 cfg.response_timeout = std::time::Duration::from_millis(v);
732 }
733 if let Some(v) = self.config.send_timeout_ms {
734 cfg.send_timeout = std::time::Duration::from_millis(v);
735 }
736 if let Some(v) = self.config.binary_payload {
737 cfg.binary_payload = v;
738 }
739 if let Some(ref v) = self.config.subprotocols {
740 cfg.subprotocols = v.clone();
741 }
742 Ok(Box::new(WsEndpoint {
743 uri: uri.to_string(),
744 cfg,
745 }))
746 }
747}
748
749struct WsEndpoint {
750 uri: String,
751 cfg: WsEndpointConfig,
752}
753
754impl Endpoint for WsEndpoint {
755 fn uri(&self) -> &str {
756 &self.uri
757 }
758
759 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
760 Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
761 }
762
763 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
764 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
765 }
766}
767
768pub struct WsConsumer {
769 cfg: WsServerConfig,
770 registry: Arc<WsConnectionRegistry>,
771 server_state: Option<WsAppState>,
772 registry_key: Option<(String, u16, String)>,
773 forward_task: Option<JoinHandle<()>>,
774}
775
776impl WsConsumer {
777 pub fn new(cfg: WsServerConfig) -> Self {
778 Self {
779 cfg,
780 registry: Arc::new(WsConnectionRegistry::new()),
781 server_state: None,
782 registry_key: None,
783 forward_task: None,
784 }
785 }
786}
787
788#[async_trait]
789impl Consumer for WsConsumer {
790 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
791 if self.server_state.is_some() {
793 return Err(CamelError::EndpointCreationFailed(
794 "WebSocket consumer already started".into(),
795 ));
796 }
797
798 tracing::info!(
799 host = self.cfg.inner.host,
800 port = self.cfg.inner.port,
801 path = self.cfg.inner.path,
802 scheme = self.cfg.inner.scheme,
803 "WebSocket consumer starting"
804 );
805
806 let tls_config = if self.cfg.inner.scheme == "wss" {
807 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
808 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
809 })?;
810 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
811 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
812 })?;
813 Some(WsTlsConfig {
814 cert_path,
815 key_path,
816 })
817 } else {
818 None
819 };
820
821 let state = ServerRegistry::global()
822 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
823 .await?;
824
825 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
826 {
827 let mut table = state.dispatch.write().await;
828 table.insert(self.cfg.inner.path.clone(), env_tx);
829 }
830
831 state.path_configs.insert(
832 self.cfg.inner.path.clone(),
833 WsPathConfig {
834 max_connections: self.cfg.inner.max_connections,
835 max_message_size: self.cfg.inner.max_message_size,
836 heartbeat_interval: self.cfg.inner.heartbeat_interval,
837 idle_timeout: self.cfg.inner.idle_timeout,
838 allow_origin: self.cfg.inner.allow_origin.clone(),
839 },
840 );
841
842 let registry_key = (
843 self.cfg.inner.canonical_host(),
844 self.cfg.inner.port,
845 self.cfg.inner.path.clone(),
846 );
847 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
848
849 let sender = ctx.sender();
850 let forward_task = tokio::spawn(async move {
851 while let Some(envelope) = env_rx.recv().await {
852 if sender.send(envelope).await.is_err() {
853 break;
854 }
855 }
856 });
857
858 self.server_state = Some(state);
859 self.registry_key = Some(registry_key);
860 self.forward_task = Some(forward_task);
861 Ok(())
862 }
863
864 async fn stop(&mut self) -> Result<(), CamelError> {
865 tracing::info!(
866 host = self.cfg.inner.host,
867 port = self.cfg.inner.port,
868 path = self.cfg.inner.path,
869 "WebSocket consumer stopping"
870 );
871
872 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
873 code: axum::extract::ws::CloseCode::from(1001u16),
874 reason: "consumer stopping".into(),
875 }));
876 for tx in self.registry.snapshot_senders() {
877 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
878 }
879
880 let mut had_server_error = false;
881
882 if let Some(state) = self.server_state.take() {
883 had_server_error = state.server_error.load(Ordering::Relaxed);
884 let mut table = state.dispatch.write().await;
885 table.remove(&self.cfg.inner.path);
886 state.path_configs.remove(&self.cfg.inner.path);
887 }
888
889 if let Some(key) = self.registry_key.take() {
890 global_registries().remove(&key);
891 ServerRegistry::global().release(key.1);
892 }
893
894 if let Some(task) = self.forward_task.take() {
895 task.abort();
896 }
897
898 tracing::info!(
899 host = self.cfg.inner.host,
900 port = self.cfg.inner.port,
901 path = self.cfg.inner.path,
902 "WebSocket consumer stopped"
903 );
904
905 if had_server_error {
906 tracing::warn!(
907 host = self.cfg.inner.host,
908 port = self.cfg.inner.port,
909 path = self.cfg.inner.path,
910 "WebSocket server had errors during its lifetime"
911 );
912 return Err(CamelError::ProcessorError(
913 "WebSocket server terminated with errors during its lifetime".into(),
914 ));
915 }
916
917 Ok(())
918 }
919
920 fn concurrency_model(&self) -> ConcurrencyModel {
921 ConcurrencyModel::Concurrent {
922 max: Some(self.cfg.inner.max_connections as usize),
923 }
924 }
925}
926
927use std::sync::atomic::{AtomicBool, Ordering};
928
929fn new_atomic_false() -> Arc<AtomicBool> {
930 Arc::new(AtomicBool::new(false))
931}
932
933#[derive(Clone)]
934pub struct WsProducer {
935 cfg: WsClientConfig,
936 backpressure_flag: Arc<AtomicBool>,
939}
940
941impl WsProducer {
942 pub fn new(cfg: WsClientConfig) -> Self {
943 Self {
944 cfg,
945 backpressure_flag: Arc::new(AtomicBool::new(false)),
946 }
947 }
948}
949
950impl Service<Exchange> for WsProducer {
951 type Response = Exchange;
952 type Error = CamelError;
953 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
954
955 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
956 if self.backpressure_flag.swap(false, Ordering::Relaxed) {
958 return Poll::Ready(Err(CamelError::ProcessorError(
959 "WebSocket producer backpressure: previous send was dropped due to full channel"
960 .into(),
961 )));
962 }
963 Poll::Ready(Ok(()))
964 }
965
966 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
967 let cfg = self.cfg.clone();
968 let backpressure_flag = Arc::clone(&self.backpressure_flag);
969
970 Box::pin(async move {
971 let canonical_host = cfg.inner.canonical_host();
972 let key = (
973 canonical_host.clone(),
974 cfg.inner.port,
975 cfg.inner.path.clone(),
976 );
977
978 let send_to_all = exchange
979 .input
980 .header("CamelWsSendToAll")
981 .and_then(|v| v.as_bool())
982 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
983 .unwrap_or(false);
984
985 let conn_keys_header = exchange
986 .input
987 .header("CamelWsConnectionKey")
988 .and_then(|v| v.as_str())
989 .map(str::to_string);
990
991 let local_exists = global_registries().contains_key(&key);
992 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
993
994 let message_type = exchange
995 .input
996 .header("CamelWsMessageType")
997 .and_then(|v| v.as_str())
998 .unwrap_or("text")
999 .to_ascii_lowercase();
1000
1001 if server_send_mode {
1002 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
1003 let Some(registry) = registry else {
1004 return Err(CamelError::ProcessorError(format!(
1005 "WebSocket local consumer not found for {}:{}{}",
1006 canonical_host, cfg.inner.port, cfg.inner.path
1007 )));
1008 };
1009
1010 let out_msg = body_to_axum_ws_message(
1011 std::mem::take(&mut exchange.input.body),
1012 &message_type,
1013 )
1014 .await?;
1015
1016 let targets = if send_to_all {
1017 registry.snapshot_senders()
1018 } else if let Some(keys) = conn_keys_header {
1019 let parsed: Vec<String> = keys
1020 .split(',')
1021 .map(str::trim)
1022 .filter(|k| !k.is_empty())
1023 .map(|k| k.to_string())
1024 .collect();
1025 registry.get_senders_for_keys(&parsed)
1026 } else {
1027 registry.snapshot_senders()
1028 };
1029
1030 let mut dropped = 0usize;
1031 for tx in &targets {
1032 if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1033 dropped += 1;
1034 }
1035 }
1036
1037 if dropped > 0 {
1038 tracing::warn!(
1039 host = canonical_host,
1040 port = cfg.inner.port,
1041 path = cfg.inner.path,
1042 dropped,
1043 total = targets.len(),
1044 "WebSocket producer dropped messages due to backpressure"
1045 );
1046 exchange.input.set_header(
1047 "CamelWsDeliveryDropped",
1048 serde_json::Value::Number(dropped.into()),
1049 );
1050 backpressure_flag.store(true, Ordering::Relaxed);
1052 if dropped == targets.len() {
1053 return Err(CamelError::ProcessorError(format!(
1054 "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1055 )));
1056 }
1057 }
1058
1059 tracing::debug!(
1060 host = canonical_host,
1061 port = cfg.inner.port,
1062 path = cfg.inner.path,
1063 targets = targets.len(),
1064 "WebSocket producer server-send complete"
1065 );
1066
1067 return Ok(exchange);
1068 }
1069
1070 let url = format!(
1071 "{}://{}:{}{}",
1072 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1073 );
1074
1075 tracing::debug!(url = url, "WebSocket producer connecting");
1076
1077 #[allow(unused_mut)]
1078 let mut request = url
1079 .clone()
1080 .into_client_request()
1081 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1082
1083 #[cfg(feature = "otel")]
1084 {
1085 let mut otel_headers = HashMap::new();
1086 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1087 for (k, v) in otel_headers {
1088 if let (Ok(name), Ok(val)) = (
1089 http::header::HeaderName::from_bytes(k.as_bytes()),
1090 http::header::HeaderValue::from_str(&v),
1091 ) {
1092 request.headers_mut().insert(name, val);
1093 }
1094 }
1095 }
1096
1097 if !cfg.inner.subprotocols.is_empty() {
1099 let proto_value = cfg.inner.subprotocols.join(", ");
1100 if let (Ok(name), Ok(val)) = (
1101 http::header::HeaderName::from_bytes(b"Sec-WebSocket-Protocol"),
1102 http::header::HeaderValue::from_str(&proto_value),
1103 ) {
1104 request.headers_mut().insert(name, val);
1105 }
1106 }
1107
1108 let effective_message_type = if cfg.inner.binary_payload {
1110 "binary"
1111 } else {
1112 &message_type
1113 };
1114
1115 let max_retries = if cfg.inner.reconnect {
1116 cfg.inner.reconnect_max_attempts as usize
1117 } else {
1118 0
1119 };
1120 let mut attempts = 0usize;
1121 let mut ws_stream = loop {
1122 let connect_future = tokio_tungstenite::connect_async(request.clone());
1123 match tokio::time::timeout(cfg.inner.connect_timeout, connect_future).await {
1124 Ok(Ok((stream, _))) => break stream,
1125 Ok(Err(e)) => {
1126 let err = map_connect_error(e, &url);
1127 let is_transient = err.to_string().contains("connection refused")
1128 || err.to_string().contains("timeout")
1129 || err.to_string().contains("connection failed");
1130 if is_transient && attempts < max_retries {
1131 attempts += 1;
1132 tracing::warn!(
1133 url = url,
1134 error = %err,
1135 attempt = attempts,
1136 max_retries,
1137 "WebSocket connect failed — retrying"
1138 );
1139 tokio::time::sleep(std::time::Duration::from_millis(
1140 cfg.inner.reconnect_delay_ms,
1141 ))
1142 .await;
1143 continue;
1144 }
1145 return Err(err);
1146 }
1147 Err(_) => {
1148 let err = CamelError::ProcessorError(format!(
1149 "WebSocket connect timeout ({:?}) to {url}",
1150 cfg.inner.connect_timeout
1151 ));
1152 if attempts < max_retries {
1153 attempts += 1;
1154 tracing::warn!(
1155 url = url,
1156 attempt = attempts,
1157 max_retries,
1158 "WebSocket connect timeout — retrying"
1159 );
1160 tokio::time::sleep(std::time::Duration::from_millis(
1161 cfg.inner.reconnect_delay_ms,
1162 ))
1163 .await;
1164 continue;
1165 }
1166 return Err(err);
1167 }
1168 }
1169 };
1170 if attempts > 0 {
1171 tracing::info!(url = url, "WebSocket producer connected after retry");
1172 }
1173
1174 let out_msg = body_to_client_ws_message(
1175 std::mem::take(&mut exchange.input.body),
1176 effective_message_type,
1177 )
1178 .await?;
1179
1180 ws_stream
1181 .send(out_msg)
1182 .await
1183 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1184
1185 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1186 loop {
1187 match ws_stream.next().await {
1188 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1189 continue;
1190 }
1191 other => break other,
1192 }
1193 }
1194 })
1195 .await
1196 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1197
1198 match incoming {
1199 Some(Ok(ClientWsMessage::Text(text))) => {
1200 tracing::debug!(url = url, "WebSocket producer received text response");
1201 exchange.input.body = CamelBody::Text(text.to_string());
1202 }
1203 Some(Ok(ClientWsMessage::Binary(data))) => {
1204 tracing::debug!(url = url, "WebSocket producer received binary response");
1205 exchange.input.body = CamelBody::Bytes(data);
1206 }
1207 Some(Ok(ClientWsMessage::Close(frame))) => {
1208 let normal = frame
1209 .as_ref()
1210 .map(|f| {
1211 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1212 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1213 })
1214 .unwrap_or(true);
1215
1216 if normal {
1217 tracing::debug!(url = url, "WebSocket producer received normal close");
1218 exchange.input.body = CamelBody::Empty;
1219 } else if cfg.inner.reconnect && attempts < max_retries {
1220 attempts += 1;
1221 tracing::warn!(
1222 url = url,
1223 attempt = attempts,
1224 max_retries,
1225 "WebSocket closed by peer — reconnecting"
1226 );
1227 tokio::time::sleep(std::time::Duration::from_millis(
1228 cfg.inner.reconnect_delay_ms,
1229 ))
1230 .await;
1231 return Err(CamelError::ProcessorError(format!(
1232 "WebSocket reconnect required after close: code {}",
1233 frame.map(|f| u16::from(f.code)).unwrap_or_default()
1234 )));
1235 } else {
1236 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1237 return Err(CamelError::ProcessorError(format!(
1238 "WebSocket peer closed: code {code}"
1239 )));
1240 }
1241 }
1242 Some(Ok(_)) | None => {
1243 exchange.input.body = CamelBody::Empty;
1244 }
1245 Some(Err(e)) => {
1246 return Err(CamelError::ProcessorError(format!(
1247 "WebSocket receive failed: {e}"
1248 )));
1249 }
1250 }
1251
1252 let _ = ws_stream.close(None).await;
1253 tracing::debug!(url = url, "WebSocket producer connection closed");
1254 Ok(exchange)
1255 })
1256 }
1257}
1258
1259async fn body_to_axum_ws_message(
1260 body: CamelBody,
1261 message_type: &str,
1262) -> Result<WsMessage, CamelError> {
1263 match message_type {
1264 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1265 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1266 }
1267}
1268
1269async fn body_to_client_ws_message(
1270 body: CamelBody,
1271 message_type: &str,
1272) -> Result<ClientWsMessage, CamelError> {
1273 match message_type {
1274 "binary" => Ok(ClientWsMessage::Binary(
1275 body.into_bytes(10 * 1024 * 1024).await?,
1276 )),
1277 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1278 }
1279}
1280
1281async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1282 Ok(match body {
1283 CamelBody::Empty => String::new(),
1284 CamelBody::Text(s) => s,
1285 CamelBody::Xml(s) => s,
1286 CamelBody::Json(v) => v.to_string(),
1287 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1288 CamelBody::Stream(stream) => {
1289 let bytes = CamelBody::Stream(stream)
1290 .into_bytes(10 * 1024 * 1024)
1291 .await?;
1292 String::from_utf8_lossy(&bytes).to_string()
1293 }
1294 })
1295}
1296
1297fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1298 if allowed_origin == "*" {
1299 return true;
1300 }
1301 request_origin.is_some_and(|origin| origin == allowed_origin)
1302}
1303
1304fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1305 match tx.try_send(msg) {
1306 Ok(()) => true,
1307 Err(error) => {
1308 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1309 false
1310 }
1311 }
1312}
1313
1314fn load_tls_config(
1315 cert_path: &str,
1316 key_path: &str,
1317) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1318 use std::fs::File;
1319 use std::io::BufReader;
1320
1321 let cert_file = File::open(cert_path)
1322 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1323 let key_file = File::open(key_path)
1324 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1325
1326 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1327 .collect::<Result<Vec<_>, _>>()
1328 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1329
1330 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1331 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1332 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1333
1334 tokio_rustls::rustls::ServerConfig::builder()
1335 .with_no_client_auth()
1336 .with_single_cert(certs, key)
1337 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1338}
1339
1340fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1341 match err {
1342 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1343 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1344 }
1345 tungstenite::Error::Tls(_) => {
1346 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1347 }
1348 other => {
1349 let msg = other.to_string();
1350 if msg.to_lowercase().contains("connection refused") {
1351 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1352 } else if msg.to_lowercase().contains("tls") {
1353 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1354 } else {
1355 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1356 }
1357 }
1358 }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363 use super::*;
1364 use camel_component_api::NoOpComponentContext;
1365 use std::time::Duration;
1366
1367 use tokio::sync::mpsc;
1368 use tokio_tungstenite::connect_async;
1369 use tokio_tungstenite::tungstenite::Message as ClientMessage;
1370 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1371 use tokio_util::sync::CancellationToken;
1372 use tower::ServiceExt;
1373
1374 fn free_port() -> u16 {
1375 std::net::TcpListener::bind("127.0.0.1:0")
1376 .unwrap()
1377 .local_addr()
1378 .unwrap()
1379 .port()
1380 }
1381
1382 #[test]
1383 fn ws_component_scheme_is_ws() {
1384 assert_eq!(WsComponent::new().scheme(), "ws");
1385 }
1386
1387 #[test]
1388 fn wss_component_scheme_is_wss() {
1389 assert_eq!(WssComponent::new().scheme(), "wss");
1390 }
1391
1392 #[test]
1393 fn endpoint_config_defaults_match_spec() {
1394 let cfg = WsEndpointConfig::default();
1395 assert_eq!(cfg.scheme, "ws");
1396 assert_eq!(cfg.host, "0.0.0.0");
1397 assert_eq!(cfg.port, 8080);
1398 assert_eq!(cfg.path, "/");
1399 assert_eq!(cfg.max_connections, 100);
1400 assert_eq!(cfg.max_message_size, 65536);
1401 assert!(!cfg.send_to_all);
1402 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1403 assert_eq!(cfg.idle_timeout, Duration::ZERO);
1404 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1405 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1406 assert_eq!(cfg.allow_origin, "*");
1407 assert_eq!(cfg.tls_cert, None);
1408 assert_eq!(cfg.tls_key, None);
1409 assert!(cfg.reconnect);
1410 assert_eq!(cfg.reconnect_max_attempts, 5);
1411 assert_eq!(cfg.reconnect_delay_ms, 1000);
1412 assert_eq!(cfg.send_timeout, Duration::from_secs(30));
1413 assert!(!cfg.binary_payload);
1414 assert!(cfg.subprotocols.is_empty());
1415 }
1416
1417 #[test]
1418 fn endpoint_config_parses_uri_params() {
1419 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1420 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1421
1422 assert_eq!(cfg.scheme, "ws");
1423 assert_eq!(cfg.host, "localhost");
1424 assert_eq!(cfg.port, 9001);
1425 assert_eq!(cfg.path, "/chat");
1426 assert_eq!(cfg.max_connections, 42);
1427 assert_eq!(cfg.max_message_size, 1024);
1428 assert!(cfg.send_to_all);
1429 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1430 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1431 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1432 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1433 assert_eq!(cfg.allow_origin, "https://example.com");
1434 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1435 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1436 assert!(cfg.reconnect);
1437 assert_eq!(cfg.reconnect_max_attempts, 5);
1438 assert_eq!(cfg.reconnect_delay_ms, 1000);
1439 }
1440
1441 #[test]
1442 fn endpoint_config_parses_reconnect_uri_params() {
1443 let uri =
1444 "ws://localhost:9001/chat?reconnect=false&reconnectMaxAttempts=2&reconnectDelayMs=25";
1445 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1446 assert!(!cfg.reconnect);
1447 assert_eq!(cfg.reconnect_max_attempts, 2);
1448 assert_eq!(cfg.reconnect_delay_ms, 25);
1449 }
1450
1451 #[test]
1452 fn endpoint_config_override_chain_uri_overrides_defaults() {
1453 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1454 assert_eq!(cfg.max_connections, 7);
1455 assert_eq!(cfg.max_message_size, 65536);
1456 assert!(!cfg.send_to_all);
1457 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1458 }
1459
1460 #[test]
1461 fn endpoint_trait_creates_consumer_and_producer() {
1462 let ctx = NoOpComponentContext;
1463 let endpoint = WsComponent::new()
1464 .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1465 .unwrap();
1466
1467 endpoint.create_consumer().unwrap();
1468 endpoint
1469 .create_producer(&ProducerContext::default())
1470 .unwrap();
1471 }
1472
1473 #[test]
1474 fn ws_consumer_concurrency_model_uses_max_connections() {
1475 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1476 let consumer = WsConsumer::new(cfg.server_config());
1477 assert_eq!(
1478 consumer.concurrency_model(),
1479 ConcurrencyModel::Concurrent { max: Some(321) }
1480 );
1481 }
1482
1483 #[tokio::test]
1484 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1485 let registry = WsConnectionRegistry::new();
1486 let (tx1, mut rx1) = mpsc::channel(8);
1487 let (tx2, mut rx2) = mpsc::channel(8);
1488
1489 registry.insert("k1".into(), tx1);
1490 registry.insert("k2".into(), tx2);
1491 assert_eq!(registry.len(), 2);
1492
1493 for tx in registry.snapshot_senders() {
1494 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1495 }
1496
1497 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1498 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1499
1500 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1501 assert_eq!(target.len(), 1);
1502 target[0]
1503 .send(WsMessage::Text("targeted".into()))
1504 .await
1505 .unwrap();
1506
1507 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1508 assert!(
1509 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1510 .await
1511 .is_err()
1512 );
1513
1514 registry.remove("k1");
1515 assert_eq!(registry.len(), 1);
1516 }
1517
1518 #[test]
1519 fn host_canonicalization_maps_local_hosts_to_loopback() {
1520 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1521 .unwrap()
1522 .canonical_host();
1523 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1524 .unwrap()
1525 .canonical_host();
1526 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1527 .unwrap()
1528 .canonical_host();
1529
1530 assert_eq!(c1, "127.0.0.1");
1531 assert_eq!(c2, "127.0.0.1");
1532 assert_eq!(c3, "127.0.0.1");
1533 }
1534
1535 #[tokio::test]
1536 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1537 let port = free_port();
1538 let uri = format!("ws://127.0.0.1:{port}/echo");
1539 let component_ctx = NoOpComponentContext;
1540 let endpoint = WsComponent::new()
1541 .create_endpoint(&uri, &component_ctx)
1542 .unwrap();
1543
1544 let mut consumer = endpoint.create_consumer().unwrap();
1545 let producer = endpoint
1546 .create_producer(&ProducerContext::default())
1547 .unwrap();
1548
1549 let (route_tx, mut route_rx) = mpsc::channel(16);
1550 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1551 consumer.start(ctx).await.unwrap();
1552
1553 let route_task = tokio::spawn(async move {
1554 if let Some(envelope) = route_rx.recv().await {
1555 let payload = envelope
1556 .exchange
1557 .input
1558 .body
1559 .as_text()
1560 .unwrap_or_default()
1561 .to_string();
1562 let key = envelope
1563 .exchange
1564 .input
1565 .header("CamelWsConnectionKey")
1566 .and_then(|v| v.as_str())
1567 .unwrap()
1568 .to_string();
1569
1570 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1571 response
1572 .input
1573 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1574 producer.oneshot(response).await.unwrap();
1575 }
1576 });
1577
1578 let url = format!("ws://127.0.0.1:{port}/echo");
1579 let (mut client, _) = loop {
1580 match connect_async(&url).await {
1581 Ok(ok) => break ok,
1582 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1583 }
1584 };
1585
1586 client
1587 .send(ClientMessage::Text("hello-ws".into()))
1588 .await
1589 .unwrap();
1590
1591 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1592 loop {
1593 match client.next().await {
1594 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1595 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1596 Some(Ok(_)) => continue,
1597 Some(Err(e)) => panic!("ws receive failed: {e}"),
1598 None => panic!("websocket closed before echo"),
1599 }
1600 }
1601 })
1602 .await
1603 .unwrap();
1604
1605 assert_eq!(incoming, "hello-ws");
1606
1607 consumer.stop().await.unwrap();
1608 route_task.await.unwrap();
1609 }
1610
1611 #[tokio::test]
1612 async fn consumer_stop_sends_close_1001() {
1613 let port = free_port();
1614 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1615 let component_ctx = NoOpComponentContext;
1616 let endpoint = WsComponent::new()
1617 .create_endpoint(&uri, &component_ctx)
1618 .unwrap();
1619
1620 let mut consumer = endpoint.create_consumer().unwrap();
1621 let (route_tx, _route_rx) = mpsc::channel(16);
1622 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1623 consumer.start(ctx).await.unwrap();
1624
1625 let url = format!("ws://127.0.0.1:{port}/shutdown");
1626 let (mut client, _) = loop {
1627 match connect_async(&url).await {
1628 Ok(ok) => break ok,
1629 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1630 }
1631 };
1632
1633 client
1634 .send(ClientMessage::Text("keepalive".into()))
1635 .await
1636 .unwrap();
1637
1638 consumer.stop().await.unwrap();
1639
1640 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1641 loop {
1642 match client.next().await {
1643 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1644 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1645 Some(Ok(_)) => continue,
1646 Some(Err(e)) => panic!("ws receive failed: {e}"),
1647 None => panic!("websocket closed without close frame"),
1648 }
1649 }
1650 })
1651 .await
1652 .unwrap();
1653
1654 assert_eq!(close_code, Some(CloseCode::Away));
1655 }
1656
1657 #[test]
1658 fn wildcard_origin_allows_anything() {
1659 assert!(is_origin_allowed("*", None));
1660 assert!(is_origin_allowed("*", Some("https://example.com")));
1661 }
1662
1663 #[test]
1664 fn exact_origin_requires_match() {
1665 assert!(is_origin_allowed(
1666 "https://example.com",
1667 Some("https://example.com")
1668 ));
1669 assert!(!is_origin_allowed(
1670 "https://example.com",
1671 Some("https://other.com")
1672 ));
1673 assert!(!is_origin_allowed("https://example.com", None));
1674 }
1675
1676 #[test]
1677 fn endpoint_config_rejects_invalid_scheme() {
1678 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1679 assert!(result.is_err());
1680 let msg = result.unwrap_err().to_string();
1681 assert!(
1682 msg.contains("Invalid WebSocket scheme"),
1683 "expected scheme error, got: {msg}"
1684 );
1685 }
1686
1687 #[tokio::test]
1688 async fn wss_consumer_start_fails_without_tls_cert() {
1689 let port = free_port();
1690 let component_ctx = NoOpComponentContext;
1691 let endpoint = WssComponent::new()
1692 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1693 .unwrap();
1694 let mut consumer = endpoint.create_consumer().unwrap();
1695 let (tx, _rx) = mpsc::channel(16);
1696 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1697 let result = consumer.start(ctx).await;
1698 assert!(result.is_err());
1699 let msg = result.unwrap_err().to_string();
1700 assert!(
1701 msg.contains("TLS cert path is required"),
1702 "expected TLS cert error, got: {msg}"
1703 );
1704 }
1705
1706 #[tokio::test]
1707 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1708 let port = free_port();
1709 let component_ctx = NoOpComponentContext;
1710 let endpoint = WssComponent::new()
1711 .create_endpoint(&format!(
1712 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1713 ), &component_ctx)
1714 .unwrap();
1715 let mut consumer = endpoint.create_consumer().unwrap();
1716 let (tx, _rx) = mpsc::channel(16);
1717 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1718 let result = consumer.start(ctx).await;
1719 assert!(result.is_err());
1720 let msg = result.unwrap_err().to_string();
1721 assert!(
1722 msg.contains("TLS cert file error"),
1723 "expected cert file error, got: {msg}"
1724 );
1725 }
1726
1727 #[tokio::test]
1728 async fn server_registry_returns_same_state_for_same_port() {
1729 let port = free_port();
1730 let state1 = ServerRegistry::global()
1731 .get_or_spawn("127.0.0.1", port, None)
1732 .await
1733 .unwrap();
1734 let state2 = ServerRegistry::global()
1735 .get_or_spawn("127.0.0.1", port, None)
1736 .await
1737 .unwrap();
1738 assert!(
1739 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1740 "expected same dispatch table for same port"
1741 );
1742 }
1743
1744 #[tokio::test]
1745 async fn dispatch_handler_returns_404_for_unregistered_path() {
1746 let port = free_port();
1747 let state = ServerRegistry::global()
1748 .get_or_spawn("127.0.0.1", port, None)
1749 .await
1750 .unwrap();
1751 let app = Router::new().fallback(dispatch_handler).with_state(state);
1752 let response = tokio::time::timeout(
1753 Duration::from_secs(2),
1754 tower::ServiceExt::oneshot(
1755 app,
1756 axum::http::Request::builder()
1757 .uri("/nonexistent")
1758 .body(Body::empty())
1759 .unwrap(),
1760 ),
1761 )
1762 .await
1763 .unwrap()
1764 .unwrap();
1765 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1766 }
1767
1768 #[tokio::test]
1769 async fn client_mode_producer_connects_and_echoes() {
1770 let app = Router::new().route(
1771 "/echo",
1772 axum::routing::get(|ws: WebSocketUpgrade| async move {
1773 ws.on_upgrade(|mut socket: WebSocket| async move {
1774 while let Some(Ok(msg)) = socket.recv().await {
1775 match msg {
1776 WsMessage::Text(text) => {
1777 let _ = socket.send(WsMessage::Text(text)).await;
1778 }
1779 WsMessage::Binary(data) => {
1780 let _ = socket.send(WsMessage::Binary(data)).await;
1781 }
1782 WsMessage::Close(_) => break,
1783 _ => {}
1784 }
1785 }
1786 })
1787 }),
1788 );
1789 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1791 let port = listener.local_addr().unwrap().port();
1792 let server_task = tokio::spawn(async move {
1793 let _ = serve(listener, app).await;
1794 });
1795
1796 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1797 let producer = WsProducer::new(cfg.client_config());
1798
1799 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1800 tokio::time::sleep(Duration::from_millis(25)).await;
1801 let result =
1802 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1803 Ok(Ok(r)) => r,
1804 Ok(Err(_)) => panic!("producer call failed"),
1805 Err(_) => panic!("producer call timed out"),
1806 };
1807
1808 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1809
1810 server_task.abort();
1811 }
1812
1813 #[tokio::test]
1814 async fn max_connections_rejects_with_close_1013() {
1815 let port = free_port();
1816 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1817 let component_ctx = NoOpComponentContext;
1818 let endpoint = WsComponent::new()
1819 .create_endpoint(&uri, &component_ctx)
1820 .unwrap();
1821 let mut consumer = endpoint.create_consumer().unwrap();
1822 let (route_tx, _route_rx) = mpsc::channel(16);
1823 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1824 consumer.start(ctx).await.unwrap();
1825
1826 let url = format!("ws://127.0.0.1:{port}/limited");
1827 let (_client1, _) = loop {
1828 match connect_async(&url).await {
1829 Ok(ok) => break ok,
1830 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1831 }
1832 };
1833
1834 tokio::time::sleep(Duration::from_millis(100)).await;
1835
1836 let (mut client2, _) = connect_async(&url).await.unwrap();
1837
1838 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1839 loop {
1840 match client2.next().await {
1841 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1842 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1843 Some(Ok(ClientMessage::Text(_))) => continue,
1844 Some(Ok(_)) => continue,
1845 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1846 None => panic!("client2 closed without close frame"),
1847 }
1848 }
1849 })
1850 .await
1851 .unwrap();
1852
1853 assert_eq!(
1854 close_code,
1855 Some(CloseCode::from(1013u16)),
1856 "expected 1013 (Try Again Later) for max connections"
1857 );
1858
1859 consumer.stop().await.unwrap();
1860 }
1861
1862 #[tokio::test]
1863 async fn max_message_size_rejects_with_close_1009() {
1864 let port = free_port();
1865 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1866 let component_ctx = NoOpComponentContext;
1867 let endpoint = WsComponent::new()
1868 .create_endpoint(&uri, &component_ctx)
1869 .unwrap();
1870 let mut consumer = endpoint.create_consumer().unwrap();
1871 let (route_tx, _route_rx) = mpsc::channel(16);
1872 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1873 consumer.start(ctx).await.unwrap();
1874
1875 let url = format!("ws://127.0.0.1:{port}/sizelimit");
1876 let (mut client, _) = loop {
1877 match connect_async(&url).await {
1878 Ok(ok) => break ok,
1879 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1880 }
1881 };
1882
1883 let oversized = "x".repeat(100);
1884 client
1885 .send(ClientMessage::Text(oversized.into()))
1886 .await
1887 .unwrap();
1888
1889 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1890 loop {
1891 match client.next().await {
1892 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1893 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1894 Some(Ok(_)) => continue,
1895 Some(Err(e)) => panic!("ws receive failed: {e}"),
1896 None => panic!("websocket closed without close frame"),
1897 }
1898 }
1899 })
1900 .await
1901 .unwrap();
1902
1903 assert_eq!(
1904 close_code,
1905 Some(CloseCode::from(1009u16)),
1906 "expected 1009 (Message Too Big) for oversized message"
1907 );
1908
1909 consumer.stop().await.unwrap();
1910 }
1911
1912 #[tokio::test]
1913 async fn origin_rejection_returns_403() {
1914 let port = free_port();
1915 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1916 let component_ctx = NoOpComponentContext;
1917 let endpoint = WsComponent::new()
1918 .create_endpoint(&uri, &component_ctx)
1919 .unwrap();
1920 let mut consumer = endpoint.create_consumer().unwrap();
1921 let (route_tx, _route_rx) = mpsc::channel(16);
1922 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1923 consumer.start(ctx).await.unwrap();
1924
1925 let state = ServerRegistry::global()
1926 .get_or_spawn("127.0.0.1", port, None)
1927 .await
1928 .unwrap();
1929 let app = Router::new().fallback(dispatch_handler).with_state(state);
1930
1931 let response = tokio::time::timeout(
1932 Duration::from_secs(2),
1933 tower::ServiceExt::oneshot(
1934 app,
1935 axum::http::Request::builder()
1936 .uri("/origintest")
1937 .header("origin", "https://evil.com")
1938 .header("upgrade", "websocket")
1939 .header("connection", "Upgrade")
1940 .header("sec-websocket-version", "13")
1941 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1942 .body(Body::empty())
1943 .unwrap(),
1944 ),
1945 )
1946 .await
1947 .unwrap()
1948 .unwrap();
1949
1950 assert_eq!(
1951 response.status(),
1952 StatusCode::FORBIDDEN,
1953 "expected 403 for disallowed origin"
1954 );
1955
1956 consumer.stop().await.unwrap();
1957 }
1958
1959 #[tokio::test]
1960 async fn broadcast_sends_to_all_connected_clients() {
1961 let port = free_port();
1962 let uri = format!("ws://127.0.0.1:{port}/bc");
1963 let component_ctx = NoOpComponentContext;
1964 let endpoint = WsComponent::new()
1965 .create_endpoint(&uri, &component_ctx)
1966 .unwrap();
1967 let mut consumer = endpoint.create_consumer().unwrap();
1968 let producer = endpoint
1969 .create_producer(&ProducerContext::default())
1970 .unwrap();
1971
1972 let (route_tx, _route_rx) = mpsc::channel(16);
1973 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1974 consumer.start(ctx).await.unwrap();
1975
1976 let url = format!("ws://127.0.0.1:{port}/bc");
1977
1978 let (mut client1, _) = loop {
1979 match connect_async(&url).await {
1980 Ok(ok) => break ok,
1981 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1982 }
1983 };
1984
1985 let (mut client2, _) = connect_async(&url).await.unwrap();
1986
1987 tokio::time::sleep(Duration::from_millis(100)).await;
1988
1989 let mut response =
1990 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1991 response
1992 .input
1993 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1994 producer.oneshot(response).await.unwrap();
1995
1996 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1997 loop {
1998 match client1.next().await {
1999 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2000 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2001 _ => panic!("client1 unexpected message or close"),
2002 }
2003 }
2004 })
2005 .await
2006 .unwrap();
2007
2008 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
2009 loop {
2010 match client2.next().await {
2011 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2012 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2013 _ => panic!("client2 unexpected message or close"),
2014 }
2015 }
2016 })
2017 .await
2018 .unwrap();
2019
2020 assert_eq!(recv1, "broadcast-msg");
2021 assert_eq!(recv2, "broadcast-msg");
2022
2023 consumer.stop().await.unwrap();
2024 }
2025
2026 #[tokio::test]
2027 async fn concurrent_get_or_spawn_returns_same_state() {
2028 let port = free_port();
2029 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
2030 Arc::new(std::sync::Mutex::new(Vec::new()));
2031
2032 let mut handles = Vec::new();
2033 for _ in 0..4 {
2034 let results = results.clone();
2035 handles.push(tokio::spawn(async move {
2036 let state = ServerRegistry::global()
2037 .get_or_spawn("127.0.0.1", port, None)
2038 .await
2039 .unwrap();
2040 results.lock().unwrap().push(state);
2041 }));
2042 }
2043
2044 for h in handles {
2045 h.await.unwrap();
2046 }
2047
2048 let states = results.lock().unwrap();
2049 assert_eq!(states.len(), 4);
2050 for i in 1..states.len() {
2051 assert!(
2052 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
2053 "all concurrent callers should get the same dispatch table"
2054 );
2055 }
2056 }
2057
2058 #[tokio::test]
2059 async fn body_conversion_helpers_cover_text_and_binary_paths() {
2060 let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
2061 .await
2062 .unwrap();
2063 assert!(matches!(text_msg, WsMessage::Text(_)));
2064
2065 let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
2066 .await
2067 .unwrap();
2068 assert!(matches!(bin_msg, WsMessage::Binary(_)));
2069
2070 let client_text =
2071 body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
2072 .await
2073 .unwrap();
2074 assert!(matches!(client_text, ClientWsMessage::Text(_)));
2075
2076 let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
2077 .await
2078 .unwrap();
2079 assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
2080 }
2081
2082 #[tokio::test]
2083 async fn body_to_text_handles_empty_text_json_and_bytes() {
2084 assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2085 assert_eq!(
2086 body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2087 "hello"
2088 );
2089 assert_eq!(
2090 body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2091 .await
2092 .unwrap(),
2093 "{\"n\":1}"
2094 );
2095 assert_eq!(
2096 body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2097 .await
2098 .unwrap(),
2099 "hi"
2100 );
2101 }
2102
2103 #[test]
2104 fn try_send_with_backpressure_returns_false_when_channel_full() {
2105 let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2106 assert!(try_send_with_backpressure(
2107 &tx,
2108 WsMessage::Text("first".into()),
2109 "test"
2110 ));
2111 assert!(!try_send_with_backpressure(
2112 &tx,
2113 WsMessage::Text("second".into()),
2114 "test"
2115 ));
2116 }
2117
2118 #[test]
2119 fn map_connect_error_formats_connection_refused_and_generic_errors() {
2120 let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2121 let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2122 assert!(err.to_string().contains("WebSocket connection refused"));
2123
2124 let generic = map_connect_error(
2125 tungstenite::Error::Protocol(
2126 tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2127 ),
2128 "ws://localhost:2/y",
2129 );
2130 assert!(
2131 generic
2132 .to_string()
2133 .contains("WebSocket connection failed (ws://localhost:2/y)")
2134 );
2135 }
2136
2137 #[test]
2141 fn from_uri_rejects_max_connections_zero() {
2142 let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2143 assert!(result.is_err());
2144 let msg = result.unwrap_err().to_string();
2145 assert!(
2146 msg.contains("maxConnections must be >= 1"),
2147 "expected maxConnections validation error, got: {msg}"
2148 );
2149 }
2150
2151 #[test]
2153 fn from_uri_rejects_max_message_size_zero() {
2154 let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2155 assert!(result.is_err());
2156 let msg = result.unwrap_err().to_string();
2157 assert!(
2158 msg.contains("maxMessageSize must be > 0"),
2159 "expected maxMessageSize validation error, got: {msg}"
2160 );
2161 }
2162
2163 #[test]
2165 fn from_uri_rejects_empty_allow_origin() {
2166 let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2167 assert!(result.is_err());
2168 let msg = result.unwrap_err().to_string();
2169 assert!(
2170 msg.contains("allowOrigin must not be empty"),
2171 "expected allowOrigin validation error, got: {msg}"
2172 );
2173 }
2174
2175 #[tokio::test]
2177 async fn consumer_double_start_returns_error() {
2178 let port = free_port();
2179 let uri = format!("ws://127.0.0.1:{port}/doublestart");
2180 let component_ctx = NoOpComponentContext;
2181 let endpoint = WsComponent::new()
2182 .create_endpoint(&uri, &component_ctx)
2183 .unwrap();
2184
2185 let mut consumer = endpoint.create_consumer().unwrap();
2186 let (route_tx, _route_rx) = mpsc::channel(16);
2187 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2188
2189 consumer.start(ctx).await.unwrap();
2191
2192 let (route_tx2, _route_rx2) = mpsc::channel(16);
2194 let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2195 let result = consumer.start(ctx2).await;
2196 assert!(result.is_err());
2197 let msg = result.unwrap_err().to_string();
2198 assert!(
2199 msg.contains("already started"),
2200 "expected double-start error, got: {msg}"
2201 );
2202
2203 consumer.stop().await.unwrap();
2204 }
2205
2206 #[tokio::test]
2208 async fn registry_cleanup_on_consumer_stop() {
2209 let port = free_port();
2210 let uri = format!("ws://127.0.0.1:{port}/cleanup");
2211 let component_ctx = NoOpComponentContext;
2212 let endpoint = WsComponent::new()
2213 .create_endpoint(&uri, &component_ctx)
2214 .unwrap();
2215
2216 let mut consumer = endpoint.create_consumer().unwrap();
2217 let (route_tx, _route_rx) = mpsc::channel(16);
2218 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2219 consumer.start(ctx).await.unwrap();
2220
2221 let registries = global_registries();
2223 let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2224 assert!(
2225 registries.contains_key(&key),
2226 "registry should have entry after start"
2227 );
2228
2229 consumer.stop().await.unwrap();
2231
2232 assert!(
2234 !registries.contains_key(&key),
2235 "registry should be cleaned up after stop"
2236 );
2237
2238 let server_reg = ServerRegistry::global();
2241 let guard = server_reg.inner.lock().unwrap();
2242 assert!(
2243 !guard.contains_key(&port),
2244 "ServerRegistry should remove port entry after last consumer stops"
2245 );
2246 }
2247
2248 #[tokio::test]
2250 async fn producer_server_send_returns_error_when_all_dropped() {
2251 let port = free_port();
2252 let uri = format!("ws://127.0.0.1:{port}/backpressure");
2253 let component_ctx = NoOpComponentContext;
2254 let endpoint = WsComponent::new()
2255 .create_endpoint(&uri, &component_ctx)
2256 .unwrap();
2257
2258 let mut consumer = endpoint.create_consumer().unwrap();
2259 let producer = endpoint
2260 .create_producer(&ProducerContext::default())
2261 .unwrap();
2262
2263 let (route_tx, _route_rx) = mpsc::channel(1); let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2265 consumer.start(ctx).await.unwrap();
2266
2267 let url = format!("ws://127.0.0.1:{port}/backpressure");
2269 let (mut client, _) = loop {
2270 match connect_async(&url).await {
2271 Ok(ok) => break ok,
2272 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2273 }
2274 };
2275
2276 tokio::time::sleep(Duration::from_millis(50)).await;
2278
2279 let mut all_dropped = false;
2281 for _ in 0..100 {
2282 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2283 match producer.clone().oneshot(exchange).await {
2284 Ok(_) => {}
2285 Err(e) => {
2286 if e.to_string().contains("backpressure") {
2287 all_dropped = true;
2288 break;
2289 }
2290 }
2291 }
2292 }
2293
2294 assert!(
2296 all_dropped,
2297 "producer should return error when all messages are dropped due to backpressure"
2298 );
2299
2300 let _ = client.close(None).await;
2302 consumer.stop().await.unwrap();
2303 }
2304
2305 #[tokio::test]
2307 async fn server_responds_to_client_ping_with_pong() {
2308 let port = free_port();
2309 let uri = format!("ws://127.0.0.1:{port}/pingpong");
2310 let component_ctx = NoOpComponentContext;
2311 let endpoint = WsComponent::new()
2312 .create_endpoint(&uri, &component_ctx)
2313 .unwrap();
2314
2315 let mut consumer = endpoint.create_consumer().unwrap();
2316 let (route_tx, _route_rx) = mpsc::channel(16);
2317 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2318 consumer.start(ctx).await.unwrap();
2319
2320 let url = format!("ws://127.0.0.1:{port}/pingpong");
2321 let (mut client, _) = loop {
2322 match connect_async(&url).await {
2323 Ok(ok) => break ok,
2324 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2325 }
2326 };
2327
2328 client
2330 .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2331 .await
2332 .unwrap();
2333
2334 let pong = tokio::time::timeout(Duration::from_secs(2), async {
2336 loop {
2337 match client.next().await {
2338 Some(Ok(ClientMessage::Pong(data))) => break data,
2339 Some(Ok(ClientMessage::Ping(_))) => continue,
2340 Some(Ok(_)) => continue,
2341 Some(Err(e)) => panic!("ws receive failed: {e}"),
2342 None => panic!("websocket closed before pong"),
2343 }
2344 }
2345 })
2346 .await
2347 .unwrap();
2348
2349 assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2350
2351 consumer.stop().await.unwrap();
2352 }
2353
2354 #[tokio::test]
2356 async fn producer_retries_on_connection_refused() {
2357 let port = free_port();
2359 let cfg = WsEndpointConfig::from_uri(&format!(
2361 "ws://127.0.0.1:{port}/retry?reconnect=true&reconnectMaxAttempts=2&reconnectDelayMs=50"
2362 ))
2363 .unwrap();
2364 let producer = WsProducer::new(cfg.client_config());
2365
2366 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2367
2368 let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2370 assert!(
2371 result.is_ok(),
2372 "producer should complete (with error) within timeout"
2373 );
2374 let result = result.unwrap();
2375 assert!(
2376 result.is_err(),
2377 "producer should fail when nothing is listening"
2378 );
2379 let msg = result.unwrap_err().to_string();
2380 assert!(
2381 msg.contains("connection refused"),
2382 "expected connection refused error, got: {msg}"
2383 );
2384 }
2385
2386 #[tokio::test]
2388 async fn server_bind_error_is_reported() {
2389 let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2391 let port = _listener.local_addr().unwrap().port();
2392
2393 let uri = format!("ws://127.0.0.1:{port}/binderror");
2396 let component_ctx = NoOpComponentContext;
2397 let endpoint = WsComponent::new()
2398 .create_endpoint(&uri, &component_ctx)
2399 .unwrap();
2400
2401 let mut consumer = endpoint.create_consumer().unwrap();
2402 let (route_tx, _route_rx) = mpsc::channel(16);
2403 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2404
2405 let start_result = consumer.start(ctx).await;
2407 let _ = start_result;
2410
2411 consumer.stop().await.unwrap();
2412 }
2413
2414 #[test]
2415 fn ws_app_state_server_error_starts_false() {
2416 let state = WsAppState {
2417 dispatch: Arc::new(RwLock::new(HashMap::new())),
2418 path_configs: Arc::new(DashMap::new()),
2419 server_error: new_atomic_false(),
2420 };
2421 assert!(
2422 !state.server_error.load(Ordering::Relaxed),
2423 "server_error should start as false"
2424 );
2425 }
2426
2427 #[test]
2428 fn ws_app_state_server_error_can_be_set() {
2429 let state = WsAppState {
2430 dispatch: Arc::new(RwLock::new(HashMap::new())),
2431 path_configs: Arc::new(DashMap::new()),
2432 server_error: new_atomic_false(),
2433 };
2434 assert!(!state.server_error.load(Ordering::Relaxed));
2435 state.server_error.store(true, Ordering::Relaxed);
2436 assert!(state.server_error.load(Ordering::Relaxed));
2437 }
2438
2439 #[tokio::test]
2440 async fn consumer_stop_returns_error_when_server_had_errors() {
2441 let port = free_port();
2442 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2443 let mut consumer = WsConsumer::new(cfg.server_config());
2444 let (route_tx, _route_rx) = mpsc::channel(16);
2445 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2446 consumer.start(ctx).await.unwrap();
2447
2448 if let Some(ref state) = consumer.server_state {
2450 state.server_error.store(true, Ordering::Relaxed);
2451 }
2452
2453 let result = consumer.stop().await;
2454 assert!(
2455 result.is_err(),
2456 "stop should return error when server had errors"
2457 );
2458 let msg = result.unwrap_err().to_string();
2459 assert!(
2460 msg.contains("terminated with errors"),
2461 "expected server error message, got: {msg}"
2462 );
2463 }
2464
2465 #[tokio::test]
2466 async fn consumer_stop_succeeds_when_server_healthy() {
2467 let port = free_port();
2468 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2469 let mut consumer = WsConsumer::new(cfg.server_config());
2470 let (route_tx, _route_rx) = mpsc::channel(16);
2471 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2472 consumer.start(ctx).await.unwrap();
2473
2474 let result = consumer.stop().await;
2475 assert!(
2476 result.is_ok(),
2477 "stop should succeed when server is healthy: {:?}",
2478 result
2479 );
2480 }
2481
2482 #[test]
2486 fn endpoint_config_parses_subprotocols() {
2487 let cfg = WsEndpointConfig::from_uri(
2488 "ws://localhost:9001/chat?subprotocols=graphql-ws,graphql-transport-ws",
2489 )
2490 .unwrap();
2491 assert_eq!(cfg.subprotocols, vec!["graphql-ws", "graphql-transport-ws"]);
2492 }
2493
2494 #[test]
2495 fn endpoint_config_default_subprotocols_empty() {
2496 let cfg = WsEndpointConfig::default();
2497 assert!(cfg.subprotocols.is_empty());
2498 }
2499
2500 #[test]
2502 fn endpoint_config_parses_send_timeout() {
2503 let cfg =
2504 WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=5000").unwrap();
2505 assert_eq!(cfg.send_timeout, Duration::from_millis(5000));
2506 }
2507
2508 #[test]
2509 fn endpoint_config_default_send_timeout() {
2510 let cfg = WsEndpointConfig::default();
2511 assert_eq!(cfg.send_timeout, Duration::from_secs(30));
2512 }
2513
2514 #[test]
2515 fn endpoint_config_rejects_invalid_send_timeout() {
2516 let err =
2517 WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=abc").unwrap_err();
2518 assert!(err.to_string().contains("sendTimeoutMs"));
2519 }
2520
2521 #[test]
2523 fn endpoint_config_parses_binary_payload() {
2524 let cfg =
2525 WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=true").unwrap();
2526 assert!(cfg.binary_payload);
2527 }
2528
2529 #[test]
2530 fn endpoint_config_default_binary_payload_false() {
2531 let cfg = WsEndpointConfig::default();
2532 assert!(!cfg.binary_payload);
2533 }
2534
2535 #[test]
2536 fn endpoint_config_rejects_invalid_binary_payload() {
2537 let err =
2538 WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=yes").unwrap_err();
2539 assert!(err.to_string().contains("binaryPayload"));
2540 }
2541}