1pub mod bundle;
7pub mod config;
8
9pub use bundle::WsBundle;
10pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
11
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex, OnceLock};
14
15use async_trait::async_trait;
16use axum::body::Body;
17use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
18use axum::extract::{FromRequest, Request, State};
19use axum::http::{StatusCode, header};
20use axum::response::IntoResponse;
21use axum::{Router, serve};
22use camel_component_api::{
23 Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
24};
25use camel_component_api::{
26 Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
27 ProducerContext,
28};
29use dashmap::DashMap;
30use futures::{SinkExt, StreamExt};
31use std::future::Future;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use tokio::sync::{OnceCell, RwLock, mpsc};
35use tokio::task::JoinHandle;
36use tokio_tungstenite::tungstenite;
37use tokio_tungstenite::tungstenite::client::IntoClientRequest;
38use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
39use tower::Service;
40
41#[derive(Clone)]
42struct WsPathConfig {
43 max_connections: u32,
44 max_message_size: u32,
45 heartbeat_interval: std::time::Duration,
46 idle_timeout: std::time::Duration,
47 allow_origin: String,
48}
49
50impl Default for WsPathConfig {
51 fn default() -> Self {
52 let cfg = WsEndpointConfig::default();
53 Self {
54 max_connections: cfg.max_connections,
55 max_message_size: cfg.max_message_size,
56 heartbeat_interval: cfg.heartbeat_interval,
57 idle_timeout: cfg.idle_timeout,
58 allow_origin: cfg.allow_origin,
59 }
60 }
61}
62
63#[derive(Clone)]
64struct WsTlsConfig {
65 cert_path: String,
66 key_path: String,
67}
68
69type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
70
71struct ServerHandle {
72 state: WsAppState,
73 is_tls: bool,
74 _task: JoinHandle<()>,
75}
76
77struct ServerRegistryInner {
78 cell: Arc<OnceCell<ServerHandle>>,
79 ref_count: usize,
80}
81
82pub struct ServerRegistry {
83 inner: Mutex<HashMap<u16, ServerRegistryInner>>,
84}
85
86impl ServerRegistry {
87 pub fn global() -> &'static Self {
88 static REG: OnceLock<ServerRegistry> = OnceLock::new();
89 REG.get_or_init(|| Self {
90 inner: Mutex::new(HashMap::new()),
91 })
92 }
93
94 pub(crate) async fn get_or_spawn(
95 &'static self,
96 host: &str,
97 port: u16,
98 tls_config: Option<WsTlsConfig>,
99 ) -> Result<WsAppState, CamelError> {
100 let wants_tls = tls_config.is_some();
101 let host_owned = host.to_string();
102
103 let (cell, _is_new) = {
104 let mut guard = self.inner.lock().map_err(|_| {
105 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
106 })?;
107 let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
108 cell: Arc::new(OnceCell::new()),
109 ref_count: 0,
110 });
111 entry.ref_count += 1;
112 (entry.cell.clone(), entry.ref_count == 1)
113 };
114
115 let handle = cell
116 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
117 .await?;
118
119 if wants_tls != handle.is_tls {
120 let mut guard = self.inner.lock().map_err(|_| {
122 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
123 })?;
124 if let Some(entry) = guard.get_mut(&port) {
125 entry.ref_count -= 1;
126 if entry.ref_count == 0 {
127 guard.remove(&port);
128 }
129 }
130 return Err(CamelError::EndpointCreationFailed(format!(
131 "Server on port {port} already running with different TLS mode"
132 )));
133 }
134
135 Ok(handle.state.clone())
136 }
137
138 pub(crate) fn release(&self, port: u16) {
141 let mut guard = match self.inner.lock() {
142 Ok(g) => g,
143 Err(_) => return,
144 };
145 if let Some(entry) = guard.get_mut(&port) {
146 entry.ref_count = entry.ref_count.saturating_sub(1);
147 if entry.ref_count == 0 {
148 if let Some(handle) = entry.cell.get() {
150 handle._task.abort();
151 }
152 guard.remove(&port);
153 tracing::info!(port, "WebSocket server registry entry removed");
154 }
155 }
156 }
157}
158
159async fn spawn_server(
160 host: &str,
161 port: u16,
162 tls_config: Option<WsTlsConfig>,
163) -> Result<ServerHandle, CamelError> {
164 let host_owned = host.to_string();
165 let addr = format!("{host}:{port}");
166 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
167 let path_configs = Arc::new(DashMap::new());
168 let server_error = new_atomic_false();
169 let state = WsAppState {
170 dispatch: Arc::clone(&dispatch),
171 path_configs: Arc::clone(&path_configs),
172 server_error: Arc::clone(&server_error),
173 };
174 let app = Router::new()
175 .fallback(dispatch_handler)
176 .with_state(state.clone());
177
178 let (task, is_tls) = if let Some(ref tls) = tls_config {
179 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
180 let parsed_addr = addr.parse().map_err(|e| {
181 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
182 })?;
183 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
184 let error_flag = Arc::clone(&server_error);
185 let task = tokio::spawn(async move {
186 if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
187 .serve(app.into_make_service())
188 .await
189 {
190 tracing::error!(
191 host = host_owned,
192 port = port,
193 error = %e,
194 "WebSocket server terminated with error"
195 );
196 error_flag.store(true, Ordering::Relaxed);
197 }
198 });
199 (task, true)
200 } else {
201 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
202 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
203 })?;
204 let error_flag = Arc::clone(&server_error);
205 let task = tokio::spawn(async move {
206 if let Err(e) = serve(listener, app).await {
207 tracing::error!(
208 host = host_owned,
209 port = port,
210 error = %e,
211 "WebSocket server terminated with error"
212 );
213 error_flag.store(true, Ordering::Relaxed);
214 }
215 });
216 (task, false)
217 };
218
219 tracing::info!(host, port, is_tls, "WebSocket server started");
220
221 Ok(ServerHandle {
222 state,
223 is_tls,
224 _task: task,
225 })
226}
227
228#[derive(Clone)]
229struct WsAppState {
230 dispatch: DispatchTable,
231 path_configs: Arc<DashMap<String, WsPathConfig>>,
232 server_error: Arc<AtomicBool>,
233}
234
235pub struct WsConnectionRegistry {
236 connections: DashMap<String, mpsc::Sender<WsMessage>>,
237}
238
239static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
240 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
241> = OnceLock::new();
242
243fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
244 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
245}
246
247impl Default for WsConnectionRegistry {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253impl WsConnectionRegistry {
254 pub fn new() -> Self {
255 Self {
256 connections: DashMap::new(),
257 }
258 }
259
260 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
261 self.connections.insert(key, tx);
262 }
263
264 pub fn remove(&self, key: &str) {
265 self.connections.remove(key);
266 }
267
268 pub fn len(&self) -> usize {
269 self.connections.len()
270 }
271
272 pub fn is_empty(&self) -> bool {
273 self.connections.is_empty()
274 }
275
276 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
277 self.connections.iter().map(|e| e.value().clone()).collect()
278 }
279
280 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
281 keys.iter()
282 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
283 .collect()
284 }
285}
286
287async fn dispatch_handler(
288 State(state): State<WsAppState>,
289 req: Request<Body>,
290) -> impl IntoResponse {
291 let path = req.uri().path().to_string();
292 let origin = req
293 .headers()
294 .get(header::ORIGIN)
295 .and_then(|value| value.to_str().ok())
296 .map(str::to_string);
297 let remote_addr = req
298 .extensions()
299 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
300 .map(|ci| ci.0.to_string())
301 .unwrap_or_default();
302 let table = state.dispatch.read().await;
303 if !table.contains_key(&path) {
304 return (
305 StatusCode::NOT_FOUND,
306 "no ws endpoint registered for this path",
307 )
308 .into_response();
309 }
310 drop(table);
311
312 let path_config = state
313 .path_configs
314 .get(&path)
315 .map(|entry| entry.value().clone())
316 .unwrap_or_default();
317 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
318 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
319 }
320
321 let upgrade_headers: HashMap<String, String> = req
322 .headers()
323 .iter()
324 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
325 .collect();
326
327 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
328 Ok(ws) => ws,
329 Err(_) => {
330 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
331 }
332 };
333
334 ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
335 .into_response()
336}
337
338#[allow(unused_variables)]
339async fn ws_handler(
340 socket: WebSocket,
341 state: WsAppState,
342 path: String,
343 remote_addr: String,
344 upgrade_headers: HashMap<String, String>,
345) {
346 let connection_key = uuid::Uuid::new_v4().to_string();
347 let path_config = state
348 .path_configs
349 .get(&path)
350 .map(|entry| entry.value().clone())
351 .unwrap_or_default();
352
353 let env_tx = {
354 let table = state.dispatch.read().await;
355 table.get(&path).cloned()
356 };
357 let Some(env_tx) = env_tx else {
358 return;
359 };
360
361 let (mut sink, mut stream) = socket.split();
362 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
363
364 let registry = global_registries();
365 let mut registry_key = None;
366 for entry in registry.iter() {
367 if entry.key().2 == path {
368 entry.value().insert(connection_key.clone(), out_tx.clone());
369 registry_key = Some(entry.key().clone());
370 break;
371 }
372 }
373
374 let conn_key_for_writer = connection_key.clone();
376 let path_for_writer = path.clone();
377
378 let writer = tokio::spawn(async move {
379 while let Some(msg) = out_rx.recv().await {
380 if let Err(e) = sink.send(msg).await {
381 tracing::warn!(
382 connection_key = conn_key_for_writer,
383 path = path_for_writer,
384 error = %e,
385 "WebSocket writer send error — closing connection"
386 );
387 break;
388 }
389 }
390 });
391
392 tracing::info!(
393 connection_key = connection_key,
394 path = path,
395 remote_addr = remote_addr,
396 "WebSocket connection opened"
397 );
398
399 let mut over_limit = false;
400 if let Some(key) = ®istry_key
401 && let Some(entry) = registry.get(key)
402 && entry.len() > path_config.max_connections as usize
403 {
404 over_limit = true;
405 }
406 if over_limit {
407 try_send_with_backpressure(
408 &out_tx,
409 WsMessage::Close(Some(CloseFrame {
410 code: CloseCode::from(1013u16),
411 reason: "max connections exceeded".into(),
412 })),
413 "max-connections-close",
414 );
415 if let Some(key) = registry_key.clone()
416 && let Some(entry) = registry.get(&key)
417 {
418 entry.remove(&connection_key);
419 }
420 drop(out_tx);
421 let _ = writer.await;
422 return;
423 }
424
425 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
426 let ping_tx = out_tx.clone();
427 let interval = path_config.heartbeat_interval;
428 Some(tokio::spawn(async move {
429 let mut ticker = tokio::time::interval(interval);
430 loop {
431 ticker.tick().await;
432 let _ = try_send_with_backpressure(
433 &ping_tx,
434 WsMessage::Ping(Vec::new().into()),
435 "heartbeat-ping",
436 );
437 }
438 }))
439 } else {
440 None
441 };
442
443 loop {
444 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
445 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
446 Ok(msg) => msg,
447 Err(_) => {
448 try_send_with_backpressure(
449 &out_tx,
450 WsMessage::Close(Some(CloseFrame {
451 code: CloseCode::from(1000u16),
452 reason: "idle timeout".into(),
453 })),
454 "idle-timeout-close",
455 );
456 break;
457 }
458 }
459 } else {
460 stream.next().await
461 };
462
463 let Some(msg) = next_msg else {
464 break;
465 };
466
467 match msg {
468 Ok(WsMessage::Ping(data)) => {
469 tracing::debug!(
470 connection_key = connection_key,
471 path = path,
472 "WebSocket ping received — sending pong"
473 );
474 let _ = try_send_with_backpressure(
475 &out_tx,
476 WsMessage::Pong(data),
477 "ping-pong-response",
478 );
479 }
480 Ok(WsMessage::Pong(_)) => {
481 tracing::debug!(
482 connection_key = connection_key,
483 path = path,
484 "WebSocket pong received"
485 );
486 }
487 Ok(WsMessage::Text(text)) => {
488 if text.len() > path_config.max_message_size as usize {
489 try_send_with_backpressure(
490 &out_tx,
491 WsMessage::Close(Some(CloseFrame {
492 code: CloseCode::from(1009u16),
493 reason: "message too large".into(),
494 })),
495 "max-message-size-close-text",
496 );
497 break;
498 }
499
500 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
501 message.set_header(
502 "CamelWsConnectionKey",
503 serde_json::Value::String(connection_key.clone()),
504 );
505 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
506 message.set_header(
507 "CamelWsRemoteAddress",
508 serde_json::Value::String(remote_addr.clone()),
509 );
510
511 #[allow(unused_mut)]
512 let mut exchange = Exchange::new(message);
513 #[cfg(feature = "otel")]
514 {
515 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
516 }
517 if env_tx
518 .send(ExchangeEnvelope {
519 exchange,
520 reply_tx: None,
521 })
522 .await
523 .is_err()
524 {
525 break;
526 }
527 }
528 Ok(WsMessage::Binary(data)) => {
529 if data.len() > path_config.max_message_size as usize {
530 try_send_with_backpressure(
531 &out_tx,
532 WsMessage::Close(Some(CloseFrame {
533 code: CloseCode::from(1009u16),
534 reason: "message too large".into(),
535 })),
536 "max-message-size-close-binary",
537 );
538 break;
539 }
540
541 let mut message = CamelMessage::new(CamelBody::Bytes(data));
542 message.set_header(
543 "CamelWsConnectionKey",
544 serde_json::Value::String(connection_key.clone()),
545 );
546 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
547 message.set_header(
548 "CamelWsRemoteAddress",
549 serde_json::Value::String(remote_addr.clone()),
550 );
551
552 #[allow(unused_mut)]
553 let mut exchange = Exchange::new(message);
554 #[cfg(feature = "otel")]
555 {
556 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
557 }
558 if env_tx
559 .send(ExchangeEnvelope {
560 exchange,
561 reply_tx: None,
562 })
563 .await
564 .is_err()
565 {
566 break;
567 }
568 }
569 Ok(WsMessage::Close(cf)) => {
570 let reason = cf
571 .as_ref()
572 .map(|f| f.reason.to_string())
573 .unwrap_or_default();
574 tracing::info!(
575 connection_key = connection_key,
576 path = path,
577 reason = reason,
578 "WebSocket connection closed by peer"
579 );
580 break;
581 }
582 Err(e) => {
583 tracing::warn!(
584 connection_key = connection_key,
585 path = path,
586 error = %e,
587 "WebSocket receive error"
588 );
589 break;
590 }
591 }
592 }
593
594 if let Some(task) = heartbeat_task {
595 task.abort();
596 }
597
598 if let Some(key) = registry_key
599 && let Some(entry) = registry.get(&key)
600 {
601 entry.remove(&connection_key);
602 }
603 drop(out_tx);
604 let _ = writer.await;
605
606 tracing::info!(
607 connection_key = connection_key,
608 path = path,
609 "WebSocket connection closed"
610 );
611}
612
613pub struct WsComponent {
614 pub(crate) config: WsConfig,
615}
616
617impl WsComponent {
618 pub fn new() -> Self {
619 Self {
620 config: WsConfig::default(),
621 }
622 }
623
624 pub fn with_config(config: WsConfig) -> Self {
625 Self { config }
626 }
627}
628
629impl Default for WsComponent {
630 fn default() -> Self {
631 Self::new()
632 }
633}
634
635impl Component for WsComponent {
636 fn scheme(&self) -> &str {
637 "ws"
638 }
639
640 fn create_endpoint(
641 &self,
642 uri: &str,
643 _ctx: &dyn camel_component_api::ComponentContext,
644 ) -> Result<Box<dyn Endpoint>, CamelError> {
645 let mut cfg = WsEndpointConfig::from_uri(uri)?;
646 if let Some(v) = self.config.max_connections {
647 cfg.max_connections = v;
648 }
649 if let Some(v) = self.config.max_message_size {
650 cfg.max_message_size = v;
651 }
652 if let Some(v) = self.config.heartbeat_interval_ms {
653 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
654 }
655 if let Some(v) = self.config.idle_timeout_ms {
656 cfg.idle_timeout = std::time::Duration::from_millis(v);
657 }
658 if let Some(v) = self.config.connect_timeout_ms {
659 cfg.connect_timeout = std::time::Duration::from_millis(v);
660 }
661 if let Some(v) = self.config.response_timeout_ms {
662 cfg.response_timeout = std::time::Duration::from_millis(v);
663 }
664 Ok(Box::new(WsEndpoint {
665 uri: uri.to_string(),
666 cfg,
667 }))
668 }
669}
670
671pub struct WssComponent {
672 pub(crate) config: WsConfig,
673}
674
675impl WssComponent {
676 pub fn new() -> Self {
677 Self {
678 config: WsConfig::default(),
679 }
680 }
681
682 pub fn with_config(config: WsConfig) -> Self {
683 Self { config }
684 }
685}
686
687impl Default for WssComponent {
688 fn default() -> Self {
689 Self::new()
690 }
691}
692
693impl Component for WssComponent {
694 fn scheme(&self) -> &str {
695 "wss"
696 }
697
698 fn create_endpoint(
699 &self,
700 uri: &str,
701 _ctx: &dyn camel_component_api::ComponentContext,
702 ) -> Result<Box<dyn Endpoint>, CamelError> {
703 let mut cfg = WsEndpointConfig::from_uri(uri)?;
704 if let Some(v) = self.config.max_connections {
705 cfg.max_connections = v;
706 }
707 if let Some(v) = self.config.max_message_size {
708 cfg.max_message_size = v;
709 }
710 if let Some(v) = self.config.heartbeat_interval_ms {
711 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
712 }
713 if let Some(v) = self.config.idle_timeout_ms {
714 cfg.idle_timeout = std::time::Duration::from_millis(v);
715 }
716 if let Some(v) = self.config.connect_timeout_ms {
717 cfg.connect_timeout = std::time::Duration::from_millis(v);
718 }
719 if let Some(v) = self.config.response_timeout_ms {
720 cfg.response_timeout = std::time::Duration::from_millis(v);
721 }
722 Ok(Box::new(WsEndpoint {
723 uri: uri.to_string(),
724 cfg,
725 }))
726 }
727}
728
729struct WsEndpoint {
730 uri: String,
731 cfg: WsEndpointConfig,
732}
733
734impl Endpoint for WsEndpoint {
735 fn uri(&self) -> &str {
736 &self.uri
737 }
738
739 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
740 Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
741 }
742
743 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
744 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
745 }
746}
747
748pub struct WsConsumer {
749 cfg: WsServerConfig,
750 registry: Arc<WsConnectionRegistry>,
751 server_state: Option<WsAppState>,
752 registry_key: Option<(String, u16, String)>,
753 forward_task: Option<JoinHandle<()>>,
754}
755
756impl WsConsumer {
757 pub fn new(cfg: WsServerConfig) -> Self {
758 Self {
759 cfg,
760 registry: Arc::new(WsConnectionRegistry::new()),
761 server_state: None,
762 registry_key: None,
763 forward_task: None,
764 }
765 }
766}
767
768#[async_trait]
769impl Consumer for WsConsumer {
770 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
771 if self.server_state.is_some() {
773 return Err(CamelError::EndpointCreationFailed(
774 "WebSocket consumer already started".into(),
775 ));
776 }
777
778 tracing::info!(
779 host = self.cfg.inner.host,
780 port = self.cfg.inner.port,
781 path = self.cfg.inner.path,
782 scheme = self.cfg.inner.scheme,
783 "WebSocket consumer starting"
784 );
785
786 let tls_config = if self.cfg.inner.scheme == "wss" {
787 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
788 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
789 })?;
790 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
791 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
792 })?;
793 Some(WsTlsConfig {
794 cert_path,
795 key_path,
796 })
797 } else {
798 None
799 };
800
801 let state = ServerRegistry::global()
802 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
803 .await?;
804
805 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
806 {
807 let mut table = state.dispatch.write().await;
808 table.insert(self.cfg.inner.path.clone(), env_tx);
809 }
810
811 state.path_configs.insert(
812 self.cfg.inner.path.clone(),
813 WsPathConfig {
814 max_connections: self.cfg.inner.max_connections,
815 max_message_size: self.cfg.inner.max_message_size,
816 heartbeat_interval: self.cfg.inner.heartbeat_interval,
817 idle_timeout: self.cfg.inner.idle_timeout,
818 allow_origin: self.cfg.inner.allow_origin.clone(),
819 },
820 );
821
822 let registry_key = (
823 self.cfg.inner.canonical_host(),
824 self.cfg.inner.port,
825 self.cfg.inner.path.clone(),
826 );
827 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
828
829 let sender = ctx.sender();
830 let forward_task = tokio::spawn(async move {
831 while let Some(envelope) = env_rx.recv().await {
832 if sender.send(envelope).await.is_err() {
833 break;
834 }
835 }
836 });
837
838 self.server_state = Some(state);
839 self.registry_key = Some(registry_key);
840 self.forward_task = Some(forward_task);
841 Ok(())
842 }
843
844 async fn stop(&mut self) -> Result<(), CamelError> {
845 tracing::info!(
846 host = self.cfg.inner.host,
847 port = self.cfg.inner.port,
848 path = self.cfg.inner.path,
849 "WebSocket consumer stopping"
850 );
851
852 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
853 code: axum::extract::ws::CloseCode::from(1001u16),
854 reason: "consumer stopping".into(),
855 }));
856 for tx in self.registry.snapshot_senders() {
857 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
858 }
859
860 let mut had_server_error = false;
861
862 if let Some(state) = self.server_state.take() {
863 had_server_error = state.server_error.load(Ordering::Relaxed);
864 let mut table = state.dispatch.write().await;
865 table.remove(&self.cfg.inner.path);
866 state.path_configs.remove(&self.cfg.inner.path);
867 }
868
869 if let Some(key) = self.registry_key.take() {
870 global_registries().remove(&key);
871 ServerRegistry::global().release(key.1);
872 }
873
874 if let Some(task) = self.forward_task.take() {
875 task.abort();
876 }
877
878 tracing::info!(
879 host = self.cfg.inner.host,
880 port = self.cfg.inner.port,
881 path = self.cfg.inner.path,
882 "WebSocket consumer stopped"
883 );
884
885 if had_server_error {
886 tracing::warn!(
887 host = self.cfg.inner.host,
888 port = self.cfg.inner.port,
889 path = self.cfg.inner.path,
890 "WebSocket server had errors during its lifetime"
891 );
892 return Err(CamelError::ProcessorError(
893 "WebSocket server terminated with errors during its lifetime".into(),
894 ));
895 }
896
897 Ok(())
898 }
899
900 fn concurrency_model(&self) -> ConcurrencyModel {
901 ConcurrencyModel::Concurrent {
902 max: Some(self.cfg.inner.max_connections as usize),
903 }
904 }
905}
906
907use std::sync::atomic::{AtomicBool, Ordering};
908
909fn new_atomic_false() -> Arc<AtomicBool> {
910 Arc::new(AtomicBool::new(false))
911}
912
913#[derive(Clone)]
914pub struct WsProducer {
915 cfg: WsClientConfig,
916 backpressure_flag: Arc<AtomicBool>,
919}
920
921impl WsProducer {
922 pub fn new(cfg: WsClientConfig) -> Self {
923 Self {
924 cfg,
925 backpressure_flag: Arc::new(AtomicBool::new(false)),
926 }
927 }
928}
929
930impl Service<Exchange> for WsProducer {
931 type Response = Exchange;
932 type Error = CamelError;
933 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
934
935 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
936 if self.backpressure_flag.swap(false, Ordering::Relaxed) {
938 return Poll::Ready(Err(CamelError::ProcessorError(
939 "WebSocket producer backpressure: previous send was dropped due to full channel"
940 .into(),
941 )));
942 }
943 Poll::Ready(Ok(()))
944 }
945
946 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
947 let cfg = self.cfg.clone();
948 let backpressure_flag = Arc::clone(&self.backpressure_flag);
949
950 Box::pin(async move {
951 let canonical_host = cfg.inner.canonical_host();
952 let key = (
953 canonical_host.clone(),
954 cfg.inner.port,
955 cfg.inner.path.clone(),
956 );
957
958 let send_to_all = exchange
959 .input
960 .header("CamelWsSendToAll")
961 .and_then(|v| v.as_bool())
962 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
963 .unwrap_or(false);
964
965 let conn_keys_header = exchange
966 .input
967 .header("CamelWsConnectionKey")
968 .and_then(|v| v.as_str())
969 .map(str::to_string);
970
971 let local_exists = global_registries().contains_key(&key);
972 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
973
974 let message_type = exchange
975 .input
976 .header("CamelWsMessageType")
977 .and_then(|v| v.as_str())
978 .unwrap_or("text")
979 .to_ascii_lowercase();
980
981 if server_send_mode {
982 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
983 let Some(registry) = registry else {
984 return Err(CamelError::ProcessorError(format!(
985 "WebSocket local consumer not found for {}:{}{}",
986 canonical_host, cfg.inner.port, cfg.inner.path
987 )));
988 };
989
990 let out_msg = body_to_axum_ws_message(
991 std::mem::take(&mut exchange.input.body),
992 &message_type,
993 )
994 .await?;
995
996 let targets = if send_to_all {
997 registry.snapshot_senders()
998 } else if let Some(keys) = conn_keys_header {
999 let parsed: Vec<String> = keys
1000 .split(',')
1001 .map(str::trim)
1002 .filter(|k| !k.is_empty())
1003 .map(|k| k.to_string())
1004 .collect();
1005 registry.get_senders_for_keys(&parsed)
1006 } else {
1007 registry.snapshot_senders()
1008 };
1009
1010 let mut dropped = 0usize;
1011 for tx in &targets {
1012 if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1013 dropped += 1;
1014 }
1015 }
1016
1017 if dropped > 0 {
1018 tracing::warn!(
1019 host = canonical_host,
1020 port = cfg.inner.port,
1021 path = cfg.inner.path,
1022 dropped,
1023 total = targets.len(),
1024 "WebSocket producer dropped messages due to backpressure"
1025 );
1026 exchange.input.set_header(
1027 "CamelWsDeliveryDropped",
1028 serde_json::Value::Number(dropped.into()),
1029 );
1030 backpressure_flag.store(true, Ordering::Relaxed);
1032 if dropped == targets.len() {
1033 return Err(CamelError::ProcessorError(format!(
1034 "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1035 )));
1036 }
1037 }
1038
1039 tracing::debug!(
1040 host = canonical_host,
1041 port = cfg.inner.port,
1042 path = cfg.inner.path,
1043 targets = targets.len(),
1044 "WebSocket producer server-send complete"
1045 );
1046
1047 return Ok(exchange);
1048 }
1049
1050 let url = format!(
1051 "{}://{}:{}{}",
1052 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1053 );
1054
1055 tracing::debug!(url = url, "WebSocket producer connecting");
1056
1057 #[allow(unused_mut)]
1058 let mut request = url
1059 .clone()
1060 .into_client_request()
1061 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1062
1063 #[cfg(feature = "otel")]
1064 {
1065 let mut otel_headers = HashMap::new();
1066 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1067 for (k, v) in otel_headers {
1068 if let (Ok(name), Ok(val)) = (
1069 http::header::HeaderName::from_bytes(k.as_bytes()),
1070 http::header::HeaderValue::from_str(&v),
1071 ) {
1072 request.headers_mut().insert(name, val);
1073 }
1074 }
1075 }
1076
1077 let max_retries = 3usize;
1079 let mut retries_left = max_retries;
1080 let mut last_err: Option<CamelError> = None;
1081 let mut ws_stream = loop {
1082 let connect_future = tokio_tungstenite::connect_async(request.clone());
1083 match tokio::time::timeout(cfg.inner.connect_timeout, connect_future).await {
1084 Ok(Ok((stream, _))) => break stream,
1085 Ok(Err(e)) => {
1086 let err = map_connect_error(e, &url);
1087 let is_transient = err.to_string().contains("connection refused")
1089 || err.to_string().contains("timeout");
1090 if retries_left > 0 && is_transient {
1091 tracing::warn!(
1092 url = url,
1093 error = %err,
1094 retries_left,
1095 "WebSocket connect failed — retrying"
1096 );
1097 last_err = Some(err);
1098 retries_left -= 1;
1099 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1100 continue;
1101 }
1102 return Err(err);
1103 }
1104 Err(_) => {
1105 let err = CamelError::ProcessorError(format!(
1106 "WebSocket connect timeout ({:?}) to {url}",
1107 cfg.inner.connect_timeout
1108 ));
1109 if retries_left > 0 {
1110 tracing::warn!(
1111 url = url,
1112 retries_left,
1113 "WebSocket connect timeout — retrying"
1114 );
1115 last_err = Some(err);
1116 retries_left -= 1;
1117 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1118 continue;
1119 }
1120 return Err(err);
1121 }
1122 }
1123 };
1124 if let Some(ref _err) = last_err {
1125 tracing::info!(url = url, "WebSocket producer connected after retry");
1126 }
1127
1128 let out_msg =
1129 body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
1130 .await?;
1131
1132 ws_stream
1133 .send(out_msg)
1134 .await
1135 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1136
1137 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1138 loop {
1139 match ws_stream.next().await {
1140 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1141 continue;
1142 }
1143 other => break other,
1144 }
1145 }
1146 })
1147 .await
1148 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1149
1150 match incoming {
1151 Some(Ok(ClientWsMessage::Text(text))) => {
1152 tracing::debug!(url = url, "WebSocket producer received text response");
1153 exchange.input.body = CamelBody::Text(text.to_string());
1154 }
1155 Some(Ok(ClientWsMessage::Binary(data))) => {
1156 tracing::debug!(url = url, "WebSocket producer received binary response");
1157 exchange.input.body = CamelBody::Bytes(data);
1158 }
1159 Some(Ok(ClientWsMessage::Close(frame))) => {
1160 let normal = frame
1161 .as_ref()
1162 .map(|f| {
1163 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1164 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1165 })
1166 .unwrap_or(true);
1167
1168 if normal {
1169 tracing::debug!(url = url, "WebSocket producer received normal close");
1170 exchange.input.body = CamelBody::Empty;
1171 } else {
1172 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1173 return Err(CamelError::ProcessorError(format!(
1174 "WebSocket peer closed: code {code}"
1175 )));
1176 }
1177 }
1178 Some(Ok(_)) | None => {
1179 exchange.input.body = CamelBody::Empty;
1180 }
1181 Some(Err(e)) => {
1182 return Err(CamelError::ProcessorError(format!(
1183 "WebSocket receive failed: {e}"
1184 )));
1185 }
1186 }
1187
1188 let _ = ws_stream.close(None).await;
1189 tracing::debug!(url = url, "WebSocket producer connection closed");
1190 Ok(exchange)
1191 })
1192 }
1193}
1194
1195async fn body_to_axum_ws_message(
1196 body: CamelBody,
1197 message_type: &str,
1198) -> Result<WsMessage, CamelError> {
1199 match message_type {
1200 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1201 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1202 }
1203}
1204
1205async fn body_to_client_ws_message(
1206 body: CamelBody,
1207 message_type: &str,
1208) -> Result<ClientWsMessage, CamelError> {
1209 match message_type {
1210 "binary" => Ok(ClientWsMessage::Binary(
1211 body.into_bytes(10 * 1024 * 1024).await?,
1212 )),
1213 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1214 }
1215}
1216
1217async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1218 Ok(match body {
1219 CamelBody::Empty => String::new(),
1220 CamelBody::Text(s) => s,
1221 CamelBody::Xml(s) => s,
1222 CamelBody::Json(v) => v.to_string(),
1223 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1224 CamelBody::Stream(stream) => {
1225 let bytes = CamelBody::Stream(stream)
1226 .into_bytes(10 * 1024 * 1024)
1227 .await?;
1228 String::from_utf8_lossy(&bytes).to_string()
1229 }
1230 })
1231}
1232
1233fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1234 if allowed_origin == "*" {
1235 return true;
1236 }
1237 request_origin.is_some_and(|origin| origin == allowed_origin)
1238}
1239
1240fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1241 match tx.try_send(msg) {
1242 Ok(()) => true,
1243 Err(error) => {
1244 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1245 false
1246 }
1247 }
1248}
1249
1250fn load_tls_config(
1251 cert_path: &str,
1252 key_path: &str,
1253) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1254 use std::fs::File;
1255 use std::io::BufReader;
1256
1257 let cert_file = File::open(cert_path)
1258 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1259 let key_file = File::open(key_path)
1260 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1261
1262 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1263 .collect::<Result<Vec<_>, _>>()
1264 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1265
1266 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1267 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1268 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1269
1270 tokio_rustls::rustls::ServerConfig::builder()
1271 .with_no_client_auth()
1272 .with_single_cert(certs, key)
1273 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1274}
1275
1276fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1277 match err {
1278 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1279 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1280 }
1281 tungstenite::Error::Tls(_) => {
1282 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1283 }
1284 other => {
1285 let msg = other.to_string();
1286 if msg.to_lowercase().contains("connection refused") {
1287 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1288 } else if msg.to_lowercase().contains("tls") {
1289 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1290 } else {
1291 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1292 }
1293 }
1294 }
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299 use super::*;
1300 use camel_component_api::NoOpComponentContext;
1301 use std::time::Duration;
1302
1303 use tokio::sync::mpsc;
1304 use tokio_tungstenite::connect_async;
1305 use tokio_tungstenite::tungstenite::Message as ClientMessage;
1306 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1307 use tokio_util::sync::CancellationToken;
1308 use tower::ServiceExt;
1309
1310 fn free_port() -> u16 {
1311 std::net::TcpListener::bind("127.0.0.1:0")
1312 .unwrap()
1313 .local_addr()
1314 .unwrap()
1315 .port()
1316 }
1317
1318 #[test]
1319 fn ws_component_scheme_is_ws() {
1320 assert_eq!(WsComponent::new().scheme(), "ws");
1321 }
1322
1323 #[test]
1324 fn wss_component_scheme_is_wss() {
1325 assert_eq!(WssComponent::new().scheme(), "wss");
1326 }
1327
1328 #[test]
1329 fn endpoint_config_defaults_match_spec() {
1330 let cfg = WsEndpointConfig::default();
1331 assert_eq!(cfg.scheme, "ws");
1332 assert_eq!(cfg.host, "0.0.0.0");
1333 assert_eq!(cfg.port, 8080);
1334 assert_eq!(cfg.path, "/");
1335 assert_eq!(cfg.max_connections, 100);
1336 assert_eq!(cfg.max_message_size, 65536);
1337 assert!(!cfg.send_to_all);
1338 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1339 assert_eq!(cfg.idle_timeout, Duration::ZERO);
1340 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1341 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1342 assert_eq!(cfg.allow_origin, "*");
1343 assert_eq!(cfg.tls_cert, None);
1344 assert_eq!(cfg.tls_key, None);
1345 }
1346
1347 #[test]
1348 fn endpoint_config_parses_uri_params() {
1349 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1350 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1351
1352 assert_eq!(cfg.scheme, "ws");
1353 assert_eq!(cfg.host, "localhost");
1354 assert_eq!(cfg.port, 9001);
1355 assert_eq!(cfg.path, "/chat");
1356 assert_eq!(cfg.max_connections, 42);
1357 assert_eq!(cfg.max_message_size, 1024);
1358 assert!(cfg.send_to_all);
1359 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1360 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1361 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1362 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1363 assert_eq!(cfg.allow_origin, "https://example.com");
1364 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1365 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1366 }
1367
1368 #[test]
1369 fn endpoint_config_override_chain_uri_overrides_defaults() {
1370 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1371 assert_eq!(cfg.max_connections, 7);
1372 assert_eq!(cfg.max_message_size, 65536);
1373 assert!(!cfg.send_to_all);
1374 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1375 }
1376
1377 #[test]
1378 fn endpoint_trait_creates_consumer_and_producer() {
1379 let ctx = NoOpComponentContext;
1380 let endpoint = WsComponent::new()
1381 .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1382 .unwrap();
1383
1384 endpoint.create_consumer().unwrap();
1385 endpoint
1386 .create_producer(&ProducerContext::default())
1387 .unwrap();
1388 }
1389
1390 #[test]
1391 fn ws_consumer_concurrency_model_uses_max_connections() {
1392 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1393 let consumer = WsConsumer::new(cfg.server_config());
1394 assert_eq!(
1395 consumer.concurrency_model(),
1396 ConcurrencyModel::Concurrent { max: Some(321) }
1397 );
1398 }
1399
1400 #[tokio::test]
1401 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1402 let registry = WsConnectionRegistry::new();
1403 let (tx1, mut rx1) = mpsc::channel(8);
1404 let (tx2, mut rx2) = mpsc::channel(8);
1405
1406 registry.insert("k1".into(), tx1);
1407 registry.insert("k2".into(), tx2);
1408 assert_eq!(registry.len(), 2);
1409
1410 for tx in registry.snapshot_senders() {
1411 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1412 }
1413
1414 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1415 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1416
1417 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1418 assert_eq!(target.len(), 1);
1419 target[0]
1420 .send(WsMessage::Text("targeted".into()))
1421 .await
1422 .unwrap();
1423
1424 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1425 assert!(
1426 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1427 .await
1428 .is_err()
1429 );
1430
1431 registry.remove("k1");
1432 assert_eq!(registry.len(), 1);
1433 }
1434
1435 #[test]
1436 fn host_canonicalization_maps_local_hosts_to_loopback() {
1437 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1438 .unwrap()
1439 .canonical_host();
1440 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1441 .unwrap()
1442 .canonical_host();
1443 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1444 .unwrap()
1445 .canonical_host();
1446
1447 assert_eq!(c1, "127.0.0.1");
1448 assert_eq!(c2, "127.0.0.1");
1449 assert_eq!(c3, "127.0.0.1");
1450 }
1451
1452 #[tokio::test]
1453 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1454 let port = free_port();
1455 let uri = format!("ws://127.0.0.1:{port}/echo");
1456 let component_ctx = NoOpComponentContext;
1457 let endpoint = WsComponent::new()
1458 .create_endpoint(&uri, &component_ctx)
1459 .unwrap();
1460
1461 let mut consumer = endpoint.create_consumer().unwrap();
1462 let producer = endpoint
1463 .create_producer(&ProducerContext::default())
1464 .unwrap();
1465
1466 let (route_tx, mut route_rx) = mpsc::channel(16);
1467 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1468 consumer.start(ctx).await.unwrap();
1469
1470 let route_task = tokio::spawn(async move {
1471 if let Some(envelope) = route_rx.recv().await {
1472 let payload = envelope
1473 .exchange
1474 .input
1475 .body
1476 .as_text()
1477 .unwrap_or_default()
1478 .to_string();
1479 let key = envelope
1480 .exchange
1481 .input
1482 .header("CamelWsConnectionKey")
1483 .and_then(|v| v.as_str())
1484 .unwrap()
1485 .to_string();
1486
1487 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1488 response
1489 .input
1490 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1491 producer.oneshot(response).await.unwrap();
1492 }
1493 });
1494
1495 let url = format!("ws://127.0.0.1:{port}/echo");
1496 let (mut client, _) = loop {
1497 match connect_async(&url).await {
1498 Ok(ok) => break ok,
1499 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1500 }
1501 };
1502
1503 client
1504 .send(ClientMessage::Text("hello-ws".into()))
1505 .await
1506 .unwrap();
1507
1508 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1509 loop {
1510 match client.next().await {
1511 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1512 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1513 Some(Ok(_)) => continue,
1514 Some(Err(e)) => panic!("ws receive failed: {e}"),
1515 None => panic!("websocket closed before echo"),
1516 }
1517 }
1518 })
1519 .await
1520 .unwrap();
1521
1522 assert_eq!(incoming, "hello-ws");
1523
1524 consumer.stop().await.unwrap();
1525 route_task.await.unwrap();
1526 }
1527
1528 #[tokio::test]
1529 async fn consumer_stop_sends_close_1001() {
1530 let port = free_port();
1531 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1532 let component_ctx = NoOpComponentContext;
1533 let endpoint = WsComponent::new()
1534 .create_endpoint(&uri, &component_ctx)
1535 .unwrap();
1536
1537 let mut consumer = endpoint.create_consumer().unwrap();
1538 let (route_tx, _route_rx) = mpsc::channel(16);
1539 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1540 consumer.start(ctx).await.unwrap();
1541
1542 let url = format!("ws://127.0.0.1:{port}/shutdown");
1543 let (mut client, _) = loop {
1544 match connect_async(&url).await {
1545 Ok(ok) => break ok,
1546 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1547 }
1548 };
1549
1550 client
1551 .send(ClientMessage::Text("keepalive".into()))
1552 .await
1553 .unwrap();
1554
1555 consumer.stop().await.unwrap();
1556
1557 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1558 loop {
1559 match client.next().await {
1560 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1561 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1562 Some(Ok(_)) => continue,
1563 Some(Err(e)) => panic!("ws receive failed: {e}"),
1564 None => panic!("websocket closed without close frame"),
1565 }
1566 }
1567 })
1568 .await
1569 .unwrap();
1570
1571 assert_eq!(close_code, Some(CloseCode::Away));
1572 }
1573
1574 #[test]
1575 fn wildcard_origin_allows_anything() {
1576 assert!(is_origin_allowed("*", None));
1577 assert!(is_origin_allowed("*", Some("https://example.com")));
1578 }
1579
1580 #[test]
1581 fn exact_origin_requires_match() {
1582 assert!(is_origin_allowed(
1583 "https://example.com",
1584 Some("https://example.com")
1585 ));
1586 assert!(!is_origin_allowed(
1587 "https://example.com",
1588 Some("https://other.com")
1589 ));
1590 assert!(!is_origin_allowed("https://example.com", None));
1591 }
1592
1593 #[test]
1594 fn endpoint_config_rejects_invalid_scheme() {
1595 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1596 assert!(result.is_err());
1597 let msg = result.unwrap_err().to_string();
1598 assert!(
1599 msg.contains("Invalid WebSocket scheme"),
1600 "expected scheme error, got: {msg}"
1601 );
1602 }
1603
1604 #[tokio::test]
1605 async fn wss_consumer_start_fails_without_tls_cert() {
1606 let port = free_port();
1607 let component_ctx = NoOpComponentContext;
1608 let endpoint = WssComponent::new()
1609 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1610 .unwrap();
1611 let mut consumer = endpoint.create_consumer().unwrap();
1612 let (tx, _rx) = mpsc::channel(16);
1613 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1614 let result = consumer.start(ctx).await;
1615 assert!(result.is_err());
1616 let msg = result.unwrap_err().to_string();
1617 assert!(
1618 msg.contains("TLS cert path is required"),
1619 "expected TLS cert error, got: {msg}"
1620 );
1621 }
1622
1623 #[tokio::test]
1624 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1625 let port = free_port();
1626 let component_ctx = NoOpComponentContext;
1627 let endpoint = WssComponent::new()
1628 .create_endpoint(&format!(
1629 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1630 ), &component_ctx)
1631 .unwrap();
1632 let mut consumer = endpoint.create_consumer().unwrap();
1633 let (tx, _rx) = mpsc::channel(16);
1634 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1635 let result = consumer.start(ctx).await;
1636 assert!(result.is_err());
1637 let msg = result.unwrap_err().to_string();
1638 assert!(
1639 msg.contains("TLS cert file error"),
1640 "expected cert file error, got: {msg}"
1641 );
1642 }
1643
1644 #[tokio::test]
1645 async fn server_registry_returns_same_state_for_same_port() {
1646 let port = free_port();
1647 let state1 = ServerRegistry::global()
1648 .get_or_spawn("127.0.0.1", port, None)
1649 .await
1650 .unwrap();
1651 let state2 = ServerRegistry::global()
1652 .get_or_spawn("127.0.0.1", port, None)
1653 .await
1654 .unwrap();
1655 assert!(
1656 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1657 "expected same dispatch table for same port"
1658 );
1659 }
1660
1661 #[tokio::test]
1662 async fn dispatch_handler_returns_404_for_unregistered_path() {
1663 let port = free_port();
1664 let state = ServerRegistry::global()
1665 .get_or_spawn("127.0.0.1", port, None)
1666 .await
1667 .unwrap();
1668 let app = Router::new().fallback(dispatch_handler).with_state(state);
1669 let response = tokio::time::timeout(
1670 Duration::from_secs(2),
1671 tower::ServiceExt::oneshot(
1672 app,
1673 axum::http::Request::builder()
1674 .uri("/nonexistent")
1675 .body(Body::empty())
1676 .unwrap(),
1677 ),
1678 )
1679 .await
1680 .unwrap()
1681 .unwrap();
1682 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1683 }
1684
1685 #[tokio::test]
1686 async fn client_mode_producer_connects_and_echoes() {
1687 let port = free_port();
1688
1689 let app = Router::new().route(
1690 "/echo",
1691 axum::routing::get(|ws: WebSocketUpgrade| async move {
1692 ws.on_upgrade(|mut socket: WebSocket| async move {
1693 while let Some(Ok(msg)) = socket.recv().await {
1694 match msg {
1695 WsMessage::Text(text) => {
1696 let _ = socket.send(WsMessage::Text(text)).await;
1697 }
1698 WsMessage::Binary(data) => {
1699 let _ = socket.send(WsMessage::Binary(data)).await;
1700 }
1701 WsMessage::Close(_) => break,
1702 _ => {}
1703 }
1704 }
1705 })
1706 }),
1707 );
1708 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1709 .await
1710 .unwrap();
1711 let server_task = tokio::spawn(async move {
1712 let _ = serve(listener, app).await;
1713 });
1714
1715 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1716 let producer = WsProducer::new(cfg.client_config());
1717
1718 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1719 tokio::time::sleep(Duration::from_millis(25)).await;
1720 let result =
1721 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1722 Ok(Ok(r)) => r,
1723 Ok(Err(_)) => panic!("producer call failed"),
1724 Err(_) => panic!("producer call timed out"),
1725 };
1726
1727 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1728
1729 server_task.abort();
1730 }
1731
1732 #[tokio::test]
1733 async fn max_connections_rejects_with_close_1013() {
1734 let port = free_port();
1735 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1736 let component_ctx = NoOpComponentContext;
1737 let endpoint = WsComponent::new()
1738 .create_endpoint(&uri, &component_ctx)
1739 .unwrap();
1740 let mut consumer = endpoint.create_consumer().unwrap();
1741 let (route_tx, _route_rx) = mpsc::channel(16);
1742 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1743 consumer.start(ctx).await.unwrap();
1744
1745 let url = format!("ws://127.0.0.1:{port}/limited");
1746 let (_client1, _) = loop {
1747 match connect_async(&url).await {
1748 Ok(ok) => break ok,
1749 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1750 }
1751 };
1752
1753 tokio::time::sleep(Duration::from_millis(100)).await;
1754
1755 let (mut client2, _) = connect_async(&url).await.unwrap();
1756
1757 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1758 loop {
1759 match client2.next().await {
1760 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1761 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1762 Some(Ok(ClientMessage::Text(_))) => continue,
1763 Some(Ok(_)) => continue,
1764 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1765 None => panic!("client2 closed without close frame"),
1766 }
1767 }
1768 })
1769 .await
1770 .unwrap();
1771
1772 assert_eq!(
1773 close_code,
1774 Some(CloseCode::from(1013u16)),
1775 "expected 1013 (Try Again Later) for max connections"
1776 );
1777
1778 consumer.stop().await.unwrap();
1779 }
1780
1781 #[tokio::test]
1782 async fn max_message_size_rejects_with_close_1009() {
1783 let port = free_port();
1784 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1785 let component_ctx = NoOpComponentContext;
1786 let endpoint = WsComponent::new()
1787 .create_endpoint(&uri, &component_ctx)
1788 .unwrap();
1789 let mut consumer = endpoint.create_consumer().unwrap();
1790 let (route_tx, _route_rx) = mpsc::channel(16);
1791 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1792 consumer.start(ctx).await.unwrap();
1793
1794 let url = format!("ws://127.0.0.1:{port}/sizelimit");
1795 let (mut client, _) = loop {
1796 match connect_async(&url).await {
1797 Ok(ok) => break ok,
1798 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1799 }
1800 };
1801
1802 let oversized = "x".repeat(100);
1803 client
1804 .send(ClientMessage::Text(oversized.into()))
1805 .await
1806 .unwrap();
1807
1808 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1809 loop {
1810 match client.next().await {
1811 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1812 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1813 Some(Ok(_)) => continue,
1814 Some(Err(e)) => panic!("ws receive failed: {e}"),
1815 None => panic!("websocket closed without close frame"),
1816 }
1817 }
1818 })
1819 .await
1820 .unwrap();
1821
1822 assert_eq!(
1823 close_code,
1824 Some(CloseCode::from(1009u16)),
1825 "expected 1009 (Message Too Big) for oversized message"
1826 );
1827
1828 consumer.stop().await.unwrap();
1829 }
1830
1831 #[tokio::test]
1832 async fn origin_rejection_returns_403() {
1833 let port = free_port();
1834 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1835 let component_ctx = NoOpComponentContext;
1836 let endpoint = WsComponent::new()
1837 .create_endpoint(&uri, &component_ctx)
1838 .unwrap();
1839 let mut consumer = endpoint.create_consumer().unwrap();
1840 let (route_tx, _route_rx) = mpsc::channel(16);
1841 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1842 consumer.start(ctx).await.unwrap();
1843
1844 let state = ServerRegistry::global()
1845 .get_or_spawn("127.0.0.1", port, None)
1846 .await
1847 .unwrap();
1848 let app = Router::new().fallback(dispatch_handler).with_state(state);
1849
1850 let response = tokio::time::timeout(
1851 Duration::from_secs(2),
1852 tower::ServiceExt::oneshot(
1853 app,
1854 axum::http::Request::builder()
1855 .uri("/origintest")
1856 .header("origin", "https://evil.com")
1857 .header("upgrade", "websocket")
1858 .header("connection", "Upgrade")
1859 .header("sec-websocket-version", "13")
1860 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1861 .body(Body::empty())
1862 .unwrap(),
1863 ),
1864 )
1865 .await
1866 .unwrap()
1867 .unwrap();
1868
1869 assert_eq!(
1870 response.status(),
1871 StatusCode::FORBIDDEN,
1872 "expected 403 for disallowed origin"
1873 );
1874
1875 consumer.stop().await.unwrap();
1876 }
1877
1878 #[tokio::test]
1879 async fn broadcast_sends_to_all_connected_clients() {
1880 let port = free_port();
1881 let uri = format!("ws://127.0.0.1:{port}/bc");
1882 let component_ctx = NoOpComponentContext;
1883 let endpoint = WsComponent::new()
1884 .create_endpoint(&uri, &component_ctx)
1885 .unwrap();
1886 let mut consumer = endpoint.create_consumer().unwrap();
1887 let producer = endpoint
1888 .create_producer(&ProducerContext::default())
1889 .unwrap();
1890
1891 let (route_tx, _route_rx) = mpsc::channel(16);
1892 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1893 consumer.start(ctx).await.unwrap();
1894
1895 let url = format!("ws://127.0.0.1:{port}/bc");
1896
1897 let (mut client1, _) = loop {
1898 match connect_async(&url).await {
1899 Ok(ok) => break ok,
1900 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1901 }
1902 };
1903
1904 let (mut client2, _) = connect_async(&url).await.unwrap();
1905
1906 tokio::time::sleep(Duration::from_millis(100)).await;
1907
1908 let mut response =
1909 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1910 response
1911 .input
1912 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1913 producer.oneshot(response).await.unwrap();
1914
1915 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1916 loop {
1917 match client1.next().await {
1918 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1919 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1920 _ => panic!("client1 unexpected message or close"),
1921 }
1922 }
1923 })
1924 .await
1925 .unwrap();
1926
1927 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1928 loop {
1929 match client2.next().await {
1930 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1931 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1932 _ => panic!("client2 unexpected message or close"),
1933 }
1934 }
1935 })
1936 .await
1937 .unwrap();
1938
1939 assert_eq!(recv1, "broadcast-msg");
1940 assert_eq!(recv2, "broadcast-msg");
1941
1942 consumer.stop().await.unwrap();
1943 }
1944
1945 #[tokio::test]
1946 async fn concurrent_get_or_spawn_returns_same_state() {
1947 let port = free_port();
1948 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1949 Arc::new(std::sync::Mutex::new(Vec::new()));
1950
1951 let mut handles = Vec::new();
1952 for _ in 0..4 {
1953 let results = results.clone();
1954 handles.push(tokio::spawn(async move {
1955 let state = ServerRegistry::global()
1956 .get_or_spawn("127.0.0.1", port, None)
1957 .await
1958 .unwrap();
1959 results.lock().unwrap().push(state);
1960 }));
1961 }
1962
1963 for h in handles {
1964 h.await.unwrap();
1965 }
1966
1967 let states = results.lock().unwrap();
1968 assert_eq!(states.len(), 4);
1969 for i in 1..states.len() {
1970 assert!(
1971 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1972 "all concurrent callers should get the same dispatch table"
1973 );
1974 }
1975 }
1976
1977 #[tokio::test]
1978 async fn body_conversion_helpers_cover_text_and_binary_paths() {
1979 let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
1980 .await
1981 .unwrap();
1982 assert!(matches!(text_msg, WsMessage::Text(_)));
1983
1984 let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
1985 .await
1986 .unwrap();
1987 assert!(matches!(bin_msg, WsMessage::Binary(_)));
1988
1989 let client_text =
1990 body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
1991 .await
1992 .unwrap();
1993 assert!(matches!(client_text, ClientWsMessage::Text(_)));
1994
1995 let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
1996 .await
1997 .unwrap();
1998 assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
1999 }
2000
2001 #[tokio::test]
2002 async fn body_to_text_handles_empty_text_json_and_bytes() {
2003 assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2004 assert_eq!(
2005 body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2006 "hello"
2007 );
2008 assert_eq!(
2009 body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2010 .await
2011 .unwrap(),
2012 "{\"n\":1}"
2013 );
2014 assert_eq!(
2015 body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2016 .await
2017 .unwrap(),
2018 "hi"
2019 );
2020 }
2021
2022 #[test]
2023 fn try_send_with_backpressure_returns_false_when_channel_full() {
2024 let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2025 assert!(try_send_with_backpressure(
2026 &tx,
2027 WsMessage::Text("first".into()),
2028 "test"
2029 ));
2030 assert!(!try_send_with_backpressure(
2031 &tx,
2032 WsMessage::Text("second".into()),
2033 "test"
2034 ));
2035 }
2036
2037 #[test]
2038 fn map_connect_error_formats_connection_refused_and_generic_errors() {
2039 let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2040 let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2041 assert!(err.to_string().contains("WebSocket connection refused"));
2042
2043 let generic = map_connect_error(
2044 tungstenite::Error::Protocol(
2045 tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2046 ),
2047 "ws://localhost:2/y",
2048 );
2049 assert!(
2050 generic
2051 .to_string()
2052 .contains("WebSocket connection failed (ws://localhost:2/y)")
2053 );
2054 }
2055
2056 #[test]
2060 fn from_uri_rejects_max_connections_zero() {
2061 let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2062 assert!(result.is_err());
2063 let msg = result.unwrap_err().to_string();
2064 assert!(
2065 msg.contains("maxConnections must be >= 1"),
2066 "expected maxConnections validation error, got: {msg}"
2067 );
2068 }
2069
2070 #[test]
2072 fn from_uri_rejects_max_message_size_zero() {
2073 let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2074 assert!(result.is_err());
2075 let msg = result.unwrap_err().to_string();
2076 assert!(
2077 msg.contains("maxMessageSize must be > 0"),
2078 "expected maxMessageSize validation error, got: {msg}"
2079 );
2080 }
2081
2082 #[test]
2084 fn from_uri_rejects_empty_allow_origin() {
2085 let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2086 assert!(result.is_err());
2087 let msg = result.unwrap_err().to_string();
2088 assert!(
2089 msg.contains("allowOrigin must not be empty"),
2090 "expected allowOrigin validation error, got: {msg}"
2091 );
2092 }
2093
2094 #[tokio::test]
2096 async fn consumer_double_start_returns_error() {
2097 let port = free_port();
2098 let uri = format!("ws://127.0.0.1:{port}/doublestart");
2099 let component_ctx = NoOpComponentContext;
2100 let endpoint = WsComponent::new()
2101 .create_endpoint(&uri, &component_ctx)
2102 .unwrap();
2103
2104 let mut consumer = endpoint.create_consumer().unwrap();
2105 let (route_tx, _route_rx) = mpsc::channel(16);
2106 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2107
2108 consumer.start(ctx).await.unwrap();
2110
2111 let (route_tx2, _route_rx2) = mpsc::channel(16);
2113 let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2114 let result = consumer.start(ctx2).await;
2115 assert!(result.is_err());
2116 let msg = result.unwrap_err().to_string();
2117 assert!(
2118 msg.contains("already started"),
2119 "expected double-start error, got: {msg}"
2120 );
2121
2122 consumer.stop().await.unwrap();
2123 }
2124
2125 #[tokio::test]
2127 async fn registry_cleanup_on_consumer_stop() {
2128 let port = free_port();
2129 let uri = format!("ws://127.0.0.1:{port}/cleanup");
2130 let component_ctx = NoOpComponentContext;
2131 let endpoint = WsComponent::new()
2132 .create_endpoint(&uri, &component_ctx)
2133 .unwrap();
2134
2135 let mut consumer = endpoint.create_consumer().unwrap();
2136 let (route_tx, _route_rx) = mpsc::channel(16);
2137 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2138 consumer.start(ctx).await.unwrap();
2139
2140 let registries = global_registries();
2142 let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2143 assert!(
2144 registries.contains_key(&key),
2145 "registry should have entry after start"
2146 );
2147
2148 consumer.stop().await.unwrap();
2150
2151 assert!(
2153 !registries.contains_key(&key),
2154 "registry should be cleaned up after stop"
2155 );
2156
2157 let server_reg = ServerRegistry::global();
2160 let guard = server_reg.inner.lock().unwrap();
2161 assert!(
2162 !guard.contains_key(&port),
2163 "ServerRegistry should remove port entry after last consumer stops"
2164 );
2165 }
2166
2167 #[tokio::test]
2169 async fn producer_server_send_returns_error_when_all_dropped() {
2170 let port = free_port();
2171 let uri = format!("ws://127.0.0.1:{port}/backpressure");
2172 let component_ctx = NoOpComponentContext;
2173 let endpoint = WsComponent::new()
2174 .create_endpoint(&uri, &component_ctx)
2175 .unwrap();
2176
2177 let mut consumer = endpoint.create_consumer().unwrap();
2178 let producer = endpoint
2179 .create_producer(&ProducerContext::default())
2180 .unwrap();
2181
2182 let (route_tx, _route_rx) = mpsc::channel(1); let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2184 consumer.start(ctx).await.unwrap();
2185
2186 let url = format!("ws://127.0.0.1:{port}/backpressure");
2188 let (mut client, _) = loop {
2189 match connect_async(&url).await {
2190 Ok(ok) => break ok,
2191 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2192 }
2193 };
2194
2195 tokio::time::sleep(Duration::from_millis(50)).await;
2197
2198 let mut all_dropped = false;
2200 for _ in 0..100 {
2201 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2202 match producer.clone().oneshot(exchange).await {
2203 Ok(_) => {}
2204 Err(e) => {
2205 if e.to_string().contains("backpressure") {
2206 all_dropped = true;
2207 break;
2208 }
2209 }
2210 }
2211 }
2212
2213 assert!(
2215 all_dropped,
2216 "producer should return error when all messages are dropped due to backpressure"
2217 );
2218
2219 let _ = client.close(None).await;
2221 consumer.stop().await.unwrap();
2222 }
2223
2224 #[tokio::test]
2226 async fn server_responds_to_client_ping_with_pong() {
2227 let port = free_port();
2228 let uri = format!("ws://127.0.0.1:{port}/pingpong");
2229 let component_ctx = NoOpComponentContext;
2230 let endpoint = WsComponent::new()
2231 .create_endpoint(&uri, &component_ctx)
2232 .unwrap();
2233
2234 let mut consumer = endpoint.create_consumer().unwrap();
2235 let (route_tx, _route_rx) = mpsc::channel(16);
2236 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2237 consumer.start(ctx).await.unwrap();
2238
2239 let url = format!("ws://127.0.0.1:{port}/pingpong");
2240 let (mut client, _) = loop {
2241 match connect_async(&url).await {
2242 Ok(ok) => break ok,
2243 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2244 }
2245 };
2246
2247 client
2249 .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2250 .await
2251 .unwrap();
2252
2253 let pong = tokio::time::timeout(Duration::from_secs(2), async {
2255 loop {
2256 match client.next().await {
2257 Some(Ok(ClientMessage::Pong(data))) => break data,
2258 Some(Ok(ClientMessage::Ping(_))) => continue,
2259 Some(Ok(_)) => continue,
2260 Some(Err(e)) => panic!("ws receive failed: {e}"),
2261 None => panic!("websocket closed before pong"),
2262 }
2263 }
2264 })
2265 .await
2266 .unwrap();
2267
2268 assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2269
2270 consumer.stop().await.unwrap();
2271 }
2272
2273 #[tokio::test]
2275 async fn producer_retries_on_connection_refused() {
2276 let port = free_port();
2278 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/retry")).unwrap();
2280 let producer = WsProducer::new(cfg.client_config());
2281
2282 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2283
2284 let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2286 assert!(
2287 result.is_ok(),
2288 "producer should complete (with error) within timeout"
2289 );
2290 let result = result.unwrap();
2291 assert!(
2292 result.is_err(),
2293 "producer should fail when nothing is listening"
2294 );
2295 let msg = result.unwrap_err().to_string();
2296 assert!(
2297 msg.contains("connection refused"),
2298 "expected connection refused error, got: {msg}"
2299 );
2300 }
2301
2302 #[tokio::test]
2304 async fn server_bind_error_is_reported() {
2305 let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2307 let port = _listener.local_addr().unwrap().port();
2308
2309 let uri = format!("ws://127.0.0.1:{port}/binderror");
2312 let component_ctx = NoOpComponentContext;
2313 let endpoint = WsComponent::new()
2314 .create_endpoint(&uri, &component_ctx)
2315 .unwrap();
2316
2317 let mut consumer = endpoint.create_consumer().unwrap();
2318 let (route_tx, _route_rx) = mpsc::channel(16);
2319 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2320
2321 let start_result = consumer.start(ctx).await;
2323 let _ = start_result;
2326
2327 consumer.stop().await.unwrap();
2328 }
2329
2330 #[test]
2331 fn ws_app_state_server_error_starts_false() {
2332 let state = WsAppState {
2333 dispatch: Arc::new(RwLock::new(HashMap::new())),
2334 path_configs: Arc::new(DashMap::new()),
2335 server_error: new_atomic_false(),
2336 };
2337 assert!(
2338 !state.server_error.load(Ordering::Relaxed),
2339 "server_error should start as false"
2340 );
2341 }
2342
2343 #[test]
2344 fn ws_app_state_server_error_can_be_set() {
2345 let state = WsAppState {
2346 dispatch: Arc::new(RwLock::new(HashMap::new())),
2347 path_configs: Arc::new(DashMap::new()),
2348 server_error: new_atomic_false(),
2349 };
2350 assert!(!state.server_error.load(Ordering::Relaxed));
2351 state.server_error.store(true, Ordering::Relaxed);
2352 assert!(state.server_error.load(Ordering::Relaxed));
2353 }
2354
2355 #[tokio::test]
2356 async fn consumer_stop_returns_error_when_server_had_errors() {
2357 let port = free_port();
2358 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2359 let mut consumer = WsConsumer::new(cfg.server_config());
2360 let (route_tx, _route_rx) = mpsc::channel(16);
2361 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2362 consumer.start(ctx).await.unwrap();
2363
2364 if let Some(ref state) = consumer.server_state {
2366 state.server_error.store(true, Ordering::Relaxed);
2367 }
2368
2369 let result = consumer.stop().await;
2370 assert!(
2371 result.is_err(),
2372 "stop should return error when server had errors"
2373 );
2374 let msg = result.unwrap_err().to_string();
2375 assert!(
2376 msg.contains("terminated with errors"),
2377 "expected server error message, got: {msg}"
2378 );
2379 }
2380
2381 #[tokio::test]
2382 async fn consumer_stop_succeeds_when_server_healthy() {
2383 let port = free_port();
2384 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2385 let mut consumer = WsConsumer::new(cfg.server_config());
2386 let (route_tx, _route_rx) = mpsc::channel(16);
2387 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2388 consumer.start(ctx).await.unwrap();
2389
2390 let result = consumer.stop().await;
2391 assert!(
2392 result.is_ok(),
2393 "stop should succeed when server is healthy: {:?}",
2394 result
2395 );
2396 }
2397}