1pub mod bundle;
7pub mod config;
8pub mod health;
9
10pub use bundle::WsBundle;
11pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
12pub use health::WsHealthCheck;
13
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex, OnceLock};
16
17use async_trait::async_trait;
18use axum::body::Body;
19use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
20use axum::extract::{FromRequest, Request, State};
21use axum::http::{StatusCode, header};
22use axum::response::IntoResponse;
23use axum::{Router, serve};
24use camel_api::security_policy::AuthorizationDecision;
25use camel_component_api::{
26 Body as CamelBody, BoxProcessor, CamelError, Component, ConcurrencyModel, Consumer,
27 ConsumerContext, Endpoint, Exchange, ExchangeEnvelope, Message as CamelMessage,
28 NetworkRetryPolicy, ProducerContext, retry_async,
29};
30use dashmap::DashMap;
31use futures::{SinkExt, StreamExt};
32use std::future::Future;
33use std::pin::Pin;
34use std::task::{Context, Poll};
35use tokio::sync::{OnceCell, RwLock, mpsc};
36use tokio::task::JoinHandle;
37use tokio_tungstenite::tungstenite;
38use tokio_tungstenite::tungstenite::client::IntoClientRequest;
39use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
40use tower::Service;
41
42#[derive(Clone)]
43pub struct WsPathConfig {
44 pub max_connections: u32,
45 pub max_message_size: u32,
46 pub heartbeat_interval: std::time::Duration,
47 pub idle_timeout: std::time::Duration,
48 pub allow_origin: String,
49}
50
51impl Default for WsPathConfig {
52 fn default() -> Self {
53 let cfg = WsEndpointConfig::default();
54 Self {
55 max_connections: cfg.max_connections,
56 max_message_size: cfg.max_message_size,
57 heartbeat_interval: cfg.heartbeat_interval,
58 idle_timeout: cfg.idle_timeout,
59 allow_origin: cfg.allow_origin,
60 }
61 }
62}
63
64#[derive(Clone)]
65pub struct WsTlsConfig {
66 pub cert_path: String,
67 pub key_path: String,
68}
69
70type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
71
72struct ServerHandle {
73 state: WsAppState,
74 is_tls: bool,
75 _task: JoinHandle<()>,
76}
77
78struct ServerRegistryInner {
79 cell: Arc<OnceCell<ServerHandle>>,
80 ref_count: usize,
81}
82
83pub struct ServerRegistry {
84 inner: Mutex<HashMap<u16, ServerRegistryInner>>,
85}
86
87impl ServerRegistry {
88 pub fn global() -> &'static Self {
89 static REG: OnceLock<ServerRegistry> = OnceLock::new();
90 REG.get_or_init(|| Self {
91 inner: Mutex::new(HashMap::new()),
92 })
93 }
94
95 pub async fn get_or_spawn(
96 &'static self,
97 host: &str,
98 port: u16,
99 tls_config: Option<WsTlsConfig>,
100 ) -> Result<WsAppState, CamelError> {
101 let wants_tls = tls_config.is_some();
102 let host_owned = host.to_string();
103
104 let (cell, _is_new) = {
105 let mut guard = self.inner.lock().map_err(|_| {
106 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
107 })?;
108 let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
109 cell: Arc::new(OnceCell::new()),
110 ref_count: 0,
111 });
112 entry.ref_count += 1;
113 (entry.cell.clone(), entry.ref_count == 1)
114 };
115
116 let handle = cell
117 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
118 .await?;
119
120 if wants_tls != handle.is_tls {
121 let mut guard = self.inner.lock().map_err(|_| {
123 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
124 })?;
125 if let Some(entry) = guard.get_mut(&port) {
126 entry.ref_count -= 1;
127 if entry.ref_count == 0 {
128 guard.remove(&port);
129 }
130 }
131 return Err(CamelError::EndpointCreationFailed(format!(
132 "Server on port {port} already running with different TLS mode"
133 )));
134 }
135
136 Ok(handle.state.clone())
137 }
138
139 pub(crate) fn release(&self, port: u16) {
142 let mut guard = match self.inner.lock() {
143 Ok(g) => g,
144 Err(_) => return,
145 };
146 if let Some(entry) = guard.get_mut(&port) {
147 entry.ref_count = entry.ref_count.saturating_sub(1);
148 if entry.ref_count == 0 {
149 if let Some(handle) = entry.cell.get() {
151 handle._task.abort();
152 }
153 guard.remove(&port);
154 tracing::info!(port, "WebSocket server registry entry removed");
155 }
156 }
157 }
158}
159
160async fn spawn_server(
161 host: &str,
162 port: u16,
163 tls_config: Option<WsTlsConfig>,
164) -> Result<ServerHandle, CamelError> {
165 let host_owned = host.to_string();
166 let addr = format!("{host}:{port}");
167 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
168 let path_configs = Arc::new(DashMap::new());
169 let path_policies = Arc::new(DashMap::new());
170 let server_error = new_atomic_false();
171 let state = WsAppState {
172 dispatch: Arc::clone(&dispatch),
173 path_configs: Arc::clone(&path_configs),
174 path_policies: Arc::clone(&path_policies),
175 server_error: Arc::clone(&server_error),
176 };
177 let app = Router::new()
178 .fallback(dispatch_handler)
179 .with_state(state.clone());
180
181 let (task, is_tls) = if let Some(ref tls) = tls_config {
182 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
183 let parsed_addr = addr.parse().map_err(|e| {
184 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
185 })?;
186 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
187 let error_flag = Arc::clone(&server_error);
188 let task = tokio::spawn(async move {
189 if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
190 .serve(app.into_make_service())
191 .await
192 {
193 tracing::error!(
194 host = host_owned,
195 port = port,
196 error = %e,
197 "WebSocket server terminated with error"
198 );
199 error_flag.store(true, Ordering::Relaxed);
200 }
201 });
202 (task, true)
203 } else {
204 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
205 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
206 })?;
207 let error_flag = Arc::clone(&server_error);
208 let task = tokio::spawn(async move {
209 if let Err(e) = serve(listener, app).await {
210 tracing::error!(
211 host = host_owned,
212 port = port,
213 error = %e,
214 "WebSocket server terminated with error"
215 );
216 error_flag.store(true, Ordering::Relaxed);
217 }
218 });
219 (task, false)
220 };
221
222 tracing::info!(host, port, is_tls, "WebSocket server started");
223
224 Ok(ServerHandle {
225 state,
226 is_tls,
227 _task: task,
228 })
229}
230
231#[derive(Clone)]
232pub struct WsAppState {
233 pub dispatch: DispatchTable,
234 pub path_configs: Arc<DashMap<String, WsPathConfig>>,
235 pub path_policies: Arc<DashMap<String, camel_component_api::SecurityContext>>,
236 pub server_error: Arc<AtomicBool>,
237}
238
239pub struct WsConnectionRegistry {
240 connections: DashMap<String, mpsc::Sender<WsMessage>>,
241}
242
243static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
244 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
245> = OnceLock::new();
246
247fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
248 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
249}
250
251impl Default for WsConnectionRegistry {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257impl WsConnectionRegistry {
258 pub fn new() -> Self {
259 Self {
260 connections: DashMap::new(),
261 }
262 }
263
264 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
265 self.connections.insert(key, tx);
266 }
267
268 pub fn remove(&self, key: &str) {
269 self.connections.remove(key);
270 }
271
272 pub fn len(&self) -> usize {
273 self.connections.len()
274 }
275
276 pub fn is_empty(&self) -> bool {
277 self.connections.is_empty()
278 }
279
280 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
281 self.connections.iter().map(|e| e.value().clone()).collect()
282 }
283
284 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
285 keys.iter()
286 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
287 .collect()
288 }
289}
290
291pub async fn dispatch_handler(
292 State(state): State<WsAppState>,
293 req: Request<Body>,
294) -> impl IntoResponse {
295 let path = req.uri().path().to_string();
296 let origin = req
297 .headers()
298 .get(header::ORIGIN)
299 .and_then(|value| value.to_str().ok())
300 .map(str::to_string);
301 let remote_addr = req
302 .extensions()
303 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
304 .map(|ci| ci.0.to_string())
305 .unwrap_or_default();
306 let table = state.dispatch.read().await;
307 if !table.contains_key(&path) {
308 return (
309 StatusCode::NOT_FOUND,
310 "no ws endpoint registered for this path",
311 )
312 .into_response();
313 }
314 drop(table);
315
316 let path_config = state
317 .path_configs
318 .get(&path)
319 .map(|entry| entry.value().clone())
320 .unwrap_or_default();
321 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
322 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
323 }
324
325 let mut principal_opt: Option<camel_api::security_policy::Principal> = None;
326 if let Some(sec_ctx) = state.path_policies.get(&path) {
327 let extracted =
328 camel_auth::extract_token_multi(req.headers(), req.uri(), &sec_ctx.credential_sources);
329
330 match extracted {
331 Some(extracted) => {
332 if matches!(
333 extracted.source,
334 camel_auth::CredentialSource::QueryParam { .. }
335 ) {
336 let redacted =
337 camel_auth::redact_query_params(req.uri(), &["access_token", "token"]);
338 tracing::debug!(path = %redacted, "WS upgrade with query token (redacted)");
339 }
340 match sec_ctx
341 .authenticator
342 .authenticate_bearer(&extracted.token)
343 .await
344 {
345 Ok(principal) => {
346 let mut exchange = camel_api::Exchange::new(camel_api::Message::new(
347 camel_api::Body::Empty,
348 ));
349 camel_api::store_principal_properties(&mut exchange, &principal);
350 match sec_ctx.policy.evaluate(&mut exchange).await {
351 Ok(AuthorizationDecision::Granted { principal: _p }) => {
352 tracing::debug!(path = %path, subject = %principal.subject, "WS upgrade authorized");
353 principal_opt = Some(principal);
354 }
355 Ok(AuthorizationDecision::Denied { reason, .. }) => {
356 tracing::warn!(path = %path, reason = %reason, "WS upgrade denied");
357 return (StatusCode::FORBIDDEN, "Forbidden").into_response();
358 }
359 Err(e) => {
360 tracing::error!(path = %path, error = %e, "Policy evaluation error during WS upgrade");
361 return (
362 StatusCode::INTERNAL_SERVER_ERROR,
363 "Internal Server Error",
364 )
365 .into_response();
366 }
367 }
368 }
369 Err(e) => {
370 let (status, body) = match &e {
371 camel_api::CamelError::Unauthenticated(_) => {
372 (StatusCode::UNAUTHORIZED, "Unauthorized")
373 }
374 camel_api::CamelError::ProcessorError(msg)
375 if msg.contains("auth provider unavailable") =>
376 {
377 (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
378 }
379 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error"),
380 };
381 tracing::warn!(path = %path, error = %e, "WS upgrade authentication failed");
382 return (status, body).into_response();
383 }
384 }
385 }
386 None => {
387 tracing::warn!(path = %path, "WS upgrade rejected: no credential found in any source");
388 return (
389 StatusCode::UNAUTHORIZED,
390 [("WWW-Authenticate", "Bearer".to_string())],
391 "Unauthorized",
392 )
393 .into_response();
394 }
395 }
396 }
397
398 let upgrade_headers: HashMap<String, String> = req
399 .headers()
400 .iter()
401 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
402 .collect();
403
404 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
405 Ok(ws) => ws,
406 Err(_) => {
407 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
408 }
409 };
410
411 ws.on_upgrade(move |socket| {
412 ws_handler(
413 socket,
414 state,
415 path,
416 remote_addr,
417 upgrade_headers,
418 principal_opt,
419 )
420 })
421 .into_response()
422}
423
424#[allow(unused_variables)]
425async fn ws_handler(
426 socket: WebSocket,
427 state: WsAppState,
428 path: String,
429 remote_addr: String,
430 upgrade_headers: HashMap<String, String>,
431 principal: Option<camel_api::security_policy::Principal>,
432) {
433 let connection_key = uuid::Uuid::new_v4().to_string();
434 let path_config = state
435 .path_configs
436 .get(&path)
437 .map(|entry| entry.value().clone())
438 .unwrap_or_default();
439
440 let env_tx = {
441 let table = state.dispatch.read().await;
442 table.get(&path).cloned()
443 };
444 let Some(env_tx) = env_tx else {
445 return;
446 };
447
448 let (mut sink, mut stream) = socket.split();
449 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
450
451 let registry = global_registries();
452 let mut registry_key = None;
453 for entry in registry.iter() {
454 if entry.key().2 == path {
455 entry.value().insert(connection_key.clone(), out_tx.clone());
456 registry_key = Some(entry.key().clone());
457 break;
458 }
459 }
460
461 let conn_key_for_writer = connection_key.clone();
463 let path_for_writer = path.clone();
464
465 let writer = tokio::spawn(async move {
466 while let Some(msg) = out_rx.recv().await {
467 if let Err(e) = sink.send(msg).await {
468 tracing::warn!(
469 connection_key = conn_key_for_writer,
470 path = path_for_writer,
471 error = %e,
472 "WebSocket writer send error — closing connection"
473 );
474 break;
475 }
476 }
477 });
478
479 tracing::info!(
480 connection_key = connection_key,
481 path = path,
482 remote_addr = remote_addr,
483 "WebSocket connection opened"
484 );
485
486 let mut over_limit = false;
487 if let Some(key) = ®istry_key
488 && let Some(entry) = registry.get(key)
489 && entry.len() > path_config.max_connections as usize
490 {
491 over_limit = true;
492 }
493 if over_limit {
494 try_send_with_backpressure(
495 &out_tx,
496 WsMessage::Close(Some(CloseFrame {
497 code: CloseCode::from(1013u16),
498 reason: "max connections exceeded".into(),
499 })),
500 "max-connections-close",
501 );
502 if let Some(key) = registry_key.clone()
503 && let Some(entry) = registry.get(&key)
504 {
505 entry.remove(&connection_key);
506 }
507 drop(out_tx);
508 let _ = writer.await;
509 return;
510 }
511
512 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
513 let ping_tx = out_tx.clone();
514 let interval = path_config.heartbeat_interval;
515 Some(tokio::spawn(async move {
516 let mut ticker = tokio::time::interval(interval);
517 loop {
518 ticker.tick().await;
519 let _ = try_send_with_backpressure(
520 &ping_tx,
521 WsMessage::Ping(Vec::new().into()),
522 "heartbeat-ping",
523 );
524 }
525 }))
526 } else {
527 None
528 };
529
530 loop {
531 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
532 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
533 Ok(msg) => msg,
534 Err(_) => {
535 try_send_with_backpressure(
536 &out_tx,
537 WsMessage::Close(Some(CloseFrame {
538 code: CloseCode::from(1000u16),
539 reason: "idle timeout".into(),
540 })),
541 "idle-timeout-close",
542 );
543 break;
544 }
545 }
546 } else {
547 stream.next().await
548 };
549
550 let Some(msg) = next_msg else {
551 break;
552 };
553
554 match msg {
555 Ok(WsMessage::Ping(data)) => {
556 tracing::debug!(
557 connection_key = connection_key,
558 path = path,
559 "WebSocket ping received — sending pong"
560 );
561 let _ = try_send_with_backpressure(
562 &out_tx,
563 WsMessage::Pong(data),
564 "ping-pong-response",
565 );
566 }
567 Ok(WsMessage::Pong(_)) => {
568 tracing::debug!(
569 connection_key = connection_key,
570 path = path,
571 "WebSocket pong received"
572 );
573 }
574 Ok(WsMessage::Text(text)) => {
575 if text.len() > path_config.max_message_size as usize {
576 try_send_with_backpressure(
577 &out_tx,
578 WsMessage::Close(Some(CloseFrame {
579 code: CloseCode::from(1009u16),
580 reason: "message too large".into(),
581 })),
582 "max-message-size-close-text",
583 );
584 break;
585 }
586
587 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
588 message.set_header(
589 "CamelWsConnectionKey",
590 serde_json::Value::String(connection_key.clone()),
591 );
592 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
593 message.set_header(
594 "CamelWsRemoteAddress",
595 serde_json::Value::String(remote_addr.clone()),
596 );
597
598 #[allow(unused_mut)]
599 let mut exchange = Exchange::new(message);
600 if let Some(ref p) = principal {
601 camel_api::store_principal_properties(&mut exchange, p);
602 }
603 #[cfg(feature = "otel")]
604 {
605 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
606 }
607 if env_tx
608 .send(ExchangeEnvelope {
609 exchange,
610 reply_tx: None,
611 })
612 .await
613 .is_err()
614 {
615 break;
616 }
617 }
618 Ok(WsMessage::Binary(data)) => {
619 if data.len() > path_config.max_message_size as usize {
620 try_send_with_backpressure(
621 &out_tx,
622 WsMessage::Close(Some(CloseFrame {
623 code: CloseCode::from(1009u16),
624 reason: "message too large".into(),
625 })),
626 "max-message-size-close-binary",
627 );
628 break;
629 }
630
631 let mut message = CamelMessage::new(CamelBody::Bytes(data));
632 message.set_header(
633 "CamelWsConnectionKey",
634 serde_json::Value::String(connection_key.clone()),
635 );
636 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
637 message.set_header(
638 "CamelWsRemoteAddress",
639 serde_json::Value::String(remote_addr.clone()),
640 );
641
642 #[allow(unused_mut)]
643 let mut exchange = Exchange::new(message);
644 if let Some(ref p) = principal {
645 camel_api::store_principal_properties(&mut exchange, p);
646 }
647 #[cfg(feature = "otel")]
648 {
649 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
650 }
651 if env_tx
652 .send(ExchangeEnvelope {
653 exchange,
654 reply_tx: None,
655 })
656 .await
657 .is_err()
658 {
659 break;
660 }
661 }
662 Ok(WsMessage::Close(cf)) => {
663 let reason = cf
664 .as_ref()
665 .map(|f| f.reason.to_string())
666 .unwrap_or_default();
667 tracing::info!(
668 connection_key = connection_key,
669 path = path,
670 reason = reason,
671 "WebSocket connection closed by peer"
672 );
673 break;
674 }
675 Err(e) => {
676 tracing::warn!(
677 connection_key = connection_key,
678 path = path,
679 error = %e,
680 "WebSocket receive error"
681 );
682 break;
683 }
684 }
685 }
686
687 if let Some(task) = heartbeat_task {
688 task.abort();
689 }
690
691 if let Some(key) = registry_key
692 && let Some(entry) = registry.get(&key)
693 {
694 entry.remove(&connection_key);
695 }
696 drop(out_tx);
697 let _ = writer.await;
698
699 tracing::info!(
700 connection_key = connection_key,
701 path = path,
702 "WebSocket connection closed"
703 );
704}
705
706pub struct WsComponent {
707 pub(crate) config: WsConfig,
708}
709
710impl WsComponent {
711 pub fn new() -> Self {
712 Self {
713 config: WsConfig::default(),
714 }
715 }
716
717 pub fn with_config(config: WsConfig) -> Self {
718 Self { config }
719 }
720}
721
722impl Default for WsComponent {
723 fn default() -> Self {
724 Self::new()
725 }
726}
727
728impl Component for WsComponent {
729 fn scheme(&self) -> &str {
730 "ws"
731 }
732
733 fn create_endpoint(
734 &self,
735 uri: &str,
736 ctx: &dyn camel_component_api::ComponentContext,
737 ) -> Result<Box<dyn Endpoint>, CamelError> {
738 self.config.validate()?;
739 let mut cfg = WsEndpointConfig::from_uri(uri)?;
740 if let Some(v) = self.config.max_connections {
741 cfg.max_connections = v;
742 }
743 if let Some(v) = self.config.max_message_size {
744 cfg.max_message_size = v;
745 }
746 if let Some(v) = self.config.heartbeat_interval_ms {
747 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
748 }
749 if let Some(v) = self.config.idle_timeout_ms {
750 cfg.idle_timeout = std::time::Duration::from_millis(v);
751 }
752 if let Some(v) = self.config.connect_timeout_ms {
753 cfg.connect_timeout = std::time::Duration::from_millis(v);
754 }
755 if let Some(v) = self.config.response_timeout_ms {
756 cfg.response_timeout = std::time::Duration::from_millis(v);
757 }
758 if let Some(v) = self.config.send_timeout_ms {
759 cfg.send_timeout = std::time::Duration::from_millis(v);
760 }
761 if let Some(v) = self.config.binary_payload {
762 cfg.binary_payload = v;
763 }
764 if let Some(ref v) = self.config.subprotocols {
765 cfg.subprotocols = v.clone();
766 }
767 let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
768 ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
769 Ok(Box::new(WsEndpoint {
770 uri: uri.to_string(),
771 cfg,
772 }))
773 }
774}
775
776pub struct WssComponent {
777 pub(crate) config: WsConfig,
778}
779
780impl WssComponent {
781 pub fn new() -> Self {
782 Self {
783 config: WsConfig::default(),
784 }
785 }
786
787 pub fn with_config(config: WsConfig) -> Self {
788 Self { config }
789 }
790}
791
792impl Default for WssComponent {
793 fn default() -> Self {
794 Self::new()
795 }
796}
797
798impl Component for WssComponent {
799 fn scheme(&self) -> &str {
800 "wss"
801 }
802
803 fn create_endpoint(
804 &self,
805 uri: &str,
806 ctx: &dyn camel_component_api::ComponentContext,
807 ) -> Result<Box<dyn Endpoint>, CamelError> {
808 self.config.validate()?;
809 let mut cfg = WsEndpointConfig::from_uri(uri)?;
810 if let Some(v) = self.config.max_connections {
811 cfg.max_connections = v;
812 }
813 if let Some(v) = self.config.max_message_size {
814 cfg.max_message_size = v;
815 }
816 if let Some(v) = self.config.heartbeat_interval_ms {
817 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
818 }
819 if let Some(v) = self.config.idle_timeout_ms {
820 cfg.idle_timeout = std::time::Duration::from_millis(v);
821 }
822 if let Some(v) = self.config.connect_timeout_ms {
823 cfg.connect_timeout = std::time::Duration::from_millis(v);
824 }
825 if let Some(v) = self.config.response_timeout_ms {
826 cfg.response_timeout = std::time::Duration::from_millis(v);
827 }
828 if let Some(v) = self.config.send_timeout_ms {
829 cfg.send_timeout = std::time::Duration::from_millis(v);
830 }
831 if let Some(v) = self.config.binary_payload {
832 cfg.binary_payload = v;
833 }
834 if let Some(ref v) = self.config.subprotocols {
835 cfg.subprotocols = v.clone();
836 }
837 let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
838 ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
839 Ok(Box::new(WsEndpoint {
840 uri: uri.to_string(),
841 cfg,
842 }))
843 }
844}
845
846struct WsEndpoint {
847 uri: String,
848 cfg: WsEndpointConfig,
849}
850
851impl Endpoint for WsEndpoint {
852 fn uri(&self) -> &str {
853 &self.uri
854 }
855
856 fn create_consumer(
857 &self,
858 rt: Arc<dyn camel_component_api::RuntimeObservability>,
859 ) -> Result<Box<dyn Consumer>, CamelError> {
860 Ok(Box::new(WsConsumer::new(self.cfg.server_config(), rt)))
861 }
862
863 fn create_producer(
864 &self,
865 _rt: Arc<dyn camel_component_api::RuntimeObservability>,
866 _ctx: &ProducerContext,
867 ) -> Result<BoxProcessor, CamelError> {
868 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
869 }
870}
871
872pub struct WsConsumer {
873 cfg: WsServerConfig,
874 registry: Arc<WsConnectionRegistry>,
875 server_state: Option<WsAppState>,
876 registry_key: Option<(String, u16, String)>,
877 forward_task: Option<JoinHandle<Result<(), CamelError>>>,
878 security_ctx: Option<camel_component_api::SecurityContext>,
879 #[allow(dead_code)]
882 runtime: Arc<dyn camel_component_api::RuntimeObservability>,
883}
884
885impl WsConsumer {
886 pub fn new(
887 cfg: WsServerConfig,
888 runtime: Arc<dyn camel_component_api::RuntimeObservability>,
889 ) -> Self {
890 Self {
891 cfg,
892 registry: Arc::new(WsConnectionRegistry::new()),
893 server_state: None,
894 registry_key: None,
895 forward_task: None,
896 security_ctx: None,
897 runtime,
898 }
899 }
900}
901
902#[async_trait]
903impl Consumer for WsConsumer {
904 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
905 if self.server_state.is_some() {
907 return Err(CamelError::EndpointCreationFailed(
908 "WebSocket consumer already started".into(),
909 ));
910 }
911
912 tracing::info!(
913 host = self.cfg.inner.host,
914 port = self.cfg.inner.port,
915 path = self.cfg.inner.path,
916 scheme = self.cfg.inner.scheme,
917 "WebSocket consumer starting"
918 );
919
920 let tls_config = if self.cfg.inner.scheme == "wss" {
921 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
922 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
923 })?;
924 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
925 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
926 })?;
927 Some(WsTlsConfig {
928 cert_path,
929 key_path,
930 })
931 } else {
932 None
933 };
934
935 let state = ServerRegistry::global()
936 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
937 .await?;
938
939 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
940 {
941 let mut table = state.dispatch.write().await;
942 table.insert(self.cfg.inner.path.clone(), env_tx);
943 }
944
945 state.path_configs.insert(
946 self.cfg.inner.path.clone(),
947 WsPathConfig {
948 max_connections: self.cfg.inner.max_connections,
949 max_message_size: self.cfg.inner.max_message_size,
950 heartbeat_interval: self.cfg.inner.heartbeat_interval,
951 idle_timeout: self.cfg.inner.idle_timeout,
952 allow_origin: self.cfg.inner.allow_origin.clone(),
953 },
954 );
955
956 if let Some(ref sec_ctx) = self.security_ctx {
957 let path = self.cfg.inner.path.clone();
958 state.path_policies.insert(path, sec_ctx.clone());
959 }
960
961 let registry_key = (
962 self.cfg.inner.canonical_host(),
963 self.cfg.inner.port,
964 self.cfg.inner.path.clone(),
965 );
966 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
967
968 let sender = ctx.sender();
969 let forward_task: JoinHandle<Result<(), CamelError>> = tokio::spawn(async move {
970 while let Some(envelope) = env_rx.recv().await {
971 if sender.send(envelope).await.is_err() {
972 break;
973 }
974 }
975 Ok(())
976 });
977
978 self.server_state = Some(state);
979 self.registry_key = Some(registry_key);
980 self.forward_task = Some(forward_task);
981 Ok(())
982 }
983
984 async fn stop(&mut self) -> Result<(), CamelError> {
985 tracing::info!(
986 host = self.cfg.inner.host,
987 port = self.cfg.inner.port,
988 path = self.cfg.inner.path,
989 "WebSocket consumer stopping"
990 );
991
992 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
993 code: axum::extract::ws::CloseCode::from(1001u16),
994 reason: "consumer stopping".into(),
995 }));
996 for tx in self.registry.snapshot_senders() {
997 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
998 }
999
1000 let mut had_server_error = false;
1001
1002 if let Some(state) = self.server_state.take() {
1003 had_server_error = state.server_error.load(Ordering::Relaxed);
1004 state.path_policies.remove(&self.cfg.inner.path);
1005 let mut table = state.dispatch.write().await;
1006 table.remove(&self.cfg.inner.path);
1007 state.path_configs.remove(&self.cfg.inner.path);
1008 }
1009
1010 if let Some(key) = self.registry_key.take() {
1011 global_registries().remove(&key);
1012 ServerRegistry::global().release(key.1);
1013 }
1014
1015 if let Some(task) = self.forward_task.take() {
1016 task.abort();
1017 }
1018
1019 tracing::info!(
1020 host = self.cfg.inner.host,
1021 port = self.cfg.inner.port,
1022 path = self.cfg.inner.path,
1023 "WebSocket consumer stopped"
1024 );
1025
1026 if had_server_error {
1027 tracing::warn!(
1028 host = self.cfg.inner.host,
1029 port = self.cfg.inner.port,
1030 path = self.cfg.inner.path,
1031 "WebSocket server had errors during its lifetime"
1032 );
1033 return Err(CamelError::ProcessorError(
1034 "WebSocket server terminated with errors during its lifetime".into(),
1035 ));
1036 }
1037
1038 Ok(())
1039 }
1040
1041 fn concurrency_model(&self) -> ConcurrencyModel {
1042 ConcurrencyModel::Concurrent {
1043 max: Some(self.cfg.inner.max_connections as usize),
1044 }
1045 }
1046
1047 fn background_task_handle(&mut self) -> Option<JoinHandle<Result<(), CamelError>>> {
1048 self.forward_task.take()
1049 }
1050
1051 fn set_security_context(&mut self, ctx: camel_component_api::SecurityContext) {
1052 self.security_ctx = Some(ctx);
1053 }
1054}
1055
1056use std::sync::atomic::{AtomicBool, Ordering};
1057
1058fn new_atomic_false() -> Arc<AtomicBool> {
1059 Arc::new(AtomicBool::new(false))
1060}
1061
1062#[inline]
1067fn is_retryable_ws_error(err: &CamelError) -> bool {
1068 let s = err.to_string();
1069 s.contains("connection refused") || s.contains("timeout") || s.contains("connection failed")
1070}
1071
1072#[derive(Clone)]
1073pub struct WsProducer {
1074 cfg: WsClientConfig,
1075 backpressure_flag: Arc<AtomicBool>,
1078}
1079
1080impl WsProducer {
1081 pub fn new(cfg: WsClientConfig) -> Self {
1082 Self {
1083 cfg,
1084 backpressure_flag: Arc::new(AtomicBool::new(false)),
1085 }
1086 }
1087}
1088
1089impl Service<Exchange> for WsProducer {
1090 type Response = Exchange;
1091 type Error = CamelError;
1092 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
1093
1094 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
1095 if self.backpressure_flag.swap(false, Ordering::Relaxed) {
1097 return Poll::Ready(Err(CamelError::ProcessorError(
1098 "WebSocket producer backpressure: previous send was dropped due to full channel"
1099 .into(),
1100 )));
1101 }
1102 Poll::Ready(Ok(()))
1103 }
1104
1105 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
1106 let cfg = self.cfg.clone();
1107 let backpressure_flag = Arc::clone(&self.backpressure_flag);
1108
1109 Box::pin(async move {
1110 let canonical_host = cfg.inner.canonical_host();
1111 let key = (
1112 canonical_host.clone(),
1113 cfg.inner.port,
1114 cfg.inner.path.clone(),
1115 );
1116
1117 let send_to_all = exchange
1118 .input
1119 .header("CamelWsSendToAll")
1120 .and_then(|v| v.as_bool())
1121 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
1122 .unwrap_or(false);
1123
1124 let conn_keys_header = exchange
1125 .input
1126 .header("CamelWsConnectionKey")
1127 .and_then(|v| v.as_str())
1128 .map(str::to_string);
1129
1130 let local_exists = global_registries().contains_key(&key);
1131 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
1132
1133 let message_type = exchange
1134 .input
1135 .header("CamelWsMessageType")
1136 .and_then(|v| v.as_str())
1137 .unwrap_or("text")
1138 .to_ascii_lowercase();
1139
1140 if server_send_mode {
1141 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
1142 let Some(registry) = registry else {
1143 return Err(CamelError::ProcessorError(format!(
1144 "WebSocket local consumer not found for {}:{}{}",
1145 canonical_host, cfg.inner.port, cfg.inner.path
1146 )));
1147 };
1148
1149 let out_msg = body_to_axum_ws_message(
1150 std::mem::take(&mut exchange.input.body),
1151 &message_type,
1152 )
1153 .await?;
1154
1155 let targets = if send_to_all {
1156 registry.snapshot_senders()
1157 } else if let Some(keys) = conn_keys_header {
1158 let parsed: Vec<String> = keys
1159 .split(',')
1160 .map(str::trim)
1161 .filter(|k| !k.is_empty())
1162 .map(|k| k.to_string())
1163 .collect();
1164 registry.get_senders_for_keys(&parsed)
1165 } else {
1166 registry.snapshot_senders()
1167 };
1168
1169 let mut dropped = 0usize;
1170 for tx in &targets {
1171 if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1172 dropped += 1;
1173 }
1174 }
1175
1176 if dropped > 0 {
1177 tracing::warn!(
1178 host = canonical_host,
1179 port = cfg.inner.port,
1180 path = cfg.inner.path,
1181 dropped,
1182 total = targets.len(),
1183 "WebSocket producer dropped messages due to backpressure"
1184 );
1185 exchange.input.set_header(
1186 "CamelWsDeliveryDropped",
1187 serde_json::Value::Number(dropped.into()),
1188 );
1189 backpressure_flag.store(true, Ordering::Relaxed);
1191 if dropped == targets.len() {
1192 return Err(CamelError::ProcessorError(format!(
1193 "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1194 )));
1195 }
1196 }
1197
1198 tracing::debug!(
1199 host = canonical_host,
1200 port = cfg.inner.port,
1201 path = cfg.inner.path,
1202 targets = targets.len(),
1203 "WebSocket producer server-send complete"
1204 );
1205
1206 return Ok(exchange);
1207 }
1208
1209 let url = format!(
1210 "{}://{}:{}{}",
1211 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1212 );
1213
1214 tracing::debug!(url = url, "WebSocket producer connecting");
1215
1216 #[allow(unused_mut)]
1217 let mut request = url
1218 .clone()
1219 .into_client_request()
1220 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1221
1222 #[cfg(feature = "otel")]
1223 {
1224 let mut otel_headers = HashMap::new();
1225 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1226 for (k, v) in otel_headers {
1227 if let (Ok(name), Ok(val)) = (
1228 http::header::HeaderName::from_bytes(k.as_bytes()),
1229 http::header::HeaderValue::from_str(&v),
1230 ) {
1231 request.headers_mut().insert(name, val);
1232 }
1233 }
1234 }
1235
1236 if !cfg.inner.subprotocols.is_empty() {
1238 let proto_value = cfg.inner.subprotocols.join(", ");
1239 if let (Ok(name), Ok(val)) = (
1240 http::header::HeaderName::from_bytes(b"Sec-WebSocket-Protocol"),
1241 http::header::HeaderValue::from_str(&proto_value),
1242 ) {
1243 request.headers_mut().insert(name, val);
1244 }
1245 }
1246
1247 let effective_message_type = if cfg.inner.binary_payload {
1249 "binary"
1250 } else {
1251 &message_type
1252 };
1253
1254 let reconnect_policy = cfg.inner.reconnect_policy.clone();
1255 let mut ws_stream =
1256 connect_ws_with_retry(request, &url, cfg.inner.connect_timeout, &reconnect_policy)
1257 .await?;
1258
1259 let attempts = 0u32;
1266
1267 let out_msg = body_to_client_ws_message(
1268 std::mem::take(&mut exchange.input.body),
1269 effective_message_type,
1270 )
1271 .await?;
1272
1273 ws_stream
1274 .send(out_msg)
1275 .await
1276 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1277
1278 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1279 loop {
1280 match ws_stream.next().await {
1281 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1282 continue;
1283 }
1284 other => break other,
1285 }
1286 }
1287 })
1288 .await
1289 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1290
1291 match incoming {
1292 Some(Ok(ClientWsMessage::Text(text))) => {
1293 tracing::debug!(url = url, "WebSocket producer received text response");
1294 exchange.input.body = CamelBody::Text(text.to_string());
1295 }
1296 Some(Ok(ClientWsMessage::Binary(data))) => {
1297 tracing::debug!(url = url, "WebSocket producer received binary response");
1298 exchange.input.body = CamelBody::Bytes(data);
1299 }
1300 Some(Ok(ClientWsMessage::Close(frame))) => {
1301 let normal = frame
1302 .as_ref()
1303 .map(|f| {
1304 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1305 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1306 })
1307 .unwrap_or(true);
1308
1309 if normal {
1310 tracing::debug!(url = url, "WebSocket producer received normal close");
1311 exchange.input.body = CamelBody::Empty;
1312 } else if reconnect_policy.should_retry(attempts + 1) {
1313 let delay = reconnect_policy.delay_for(0); tracing::warn!(
1315 url = url,
1316 attempt = attempts + 1,
1317 delay_ms = delay.as_millis(),
1318 "WebSocket closed by peer — reconnecting"
1319 );
1320 tokio::time::sleep(delay).await;
1321 return Err(CamelError::ProcessorError(format!(
1322 "WebSocket reconnect required after close: code {}",
1323 frame.map(|f| u16::from(f.code)).unwrap_or_default()
1324 )));
1325 } else {
1326 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1327 return Err(CamelError::ProcessorError(format!(
1328 "WebSocket peer closed: code {code}"
1329 )));
1330 }
1331 }
1332 Some(Ok(_)) | None => {
1333 exchange.input.body = CamelBody::Empty;
1334 }
1335 Some(Err(e)) => {
1336 return Err(CamelError::ProcessorError(format!(
1337 "WebSocket receive failed: {e}"
1338 )));
1339 }
1340 }
1341
1342 let _ = ws_stream.close(None).await;
1343 tracing::debug!(url = url, "WebSocket producer connection closed");
1344 Ok(exchange)
1345 })
1346 }
1347}
1348
1349async fn body_to_axum_ws_message(
1350 body: CamelBody,
1351 message_type: &str,
1352) -> Result<WsMessage, CamelError> {
1353 match message_type {
1354 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1355 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1356 }
1357}
1358
1359async fn body_to_client_ws_message(
1360 body: CamelBody,
1361 message_type: &str,
1362) -> Result<ClientWsMessage, CamelError> {
1363 match message_type {
1364 "binary" => Ok(ClientWsMessage::Binary(
1365 body.into_bytes(10 * 1024 * 1024).await?,
1366 )),
1367 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1368 }
1369}
1370
1371async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1372 Ok(match body {
1373 CamelBody::Empty => String::new(),
1374 CamelBody::Text(s) => s,
1375 CamelBody::Xml(s) => s,
1376 CamelBody::Json(v) => v.to_string(),
1377 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1378 CamelBody::Stream(stream) => {
1379 let bytes = CamelBody::Stream(stream)
1380 .into_bytes(10 * 1024 * 1024)
1381 .await?;
1382 String::from_utf8_lossy(&bytes).to_string()
1383 }
1384 })
1385}
1386
1387fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1388 if allowed_origin == "*" {
1389 return true;
1390 }
1391 request_origin.is_some_and(|origin| origin == allowed_origin)
1392}
1393
1394fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1395 match tx.try_send(msg) {
1396 Ok(()) => true,
1397 Err(error) => {
1398 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1399 false
1400 }
1401 }
1402}
1403
1404fn load_tls_config(
1405 cert_path: &str,
1406 key_path: &str,
1407) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1408 use std::fs::File;
1409 use std::io::BufReader;
1410
1411 let cert_file = File::open(cert_path)
1412 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1413 let key_file = File::open(key_path)
1414 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1415
1416 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1417 .collect::<Result<Vec<_>, _>>()
1418 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1419
1420 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1421 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1422 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1423
1424 tokio_rustls::rustls::ServerConfig::builder()
1425 .with_no_client_auth()
1426 .with_single_cert(certs, key)
1427 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1428}
1429
1430fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1431 match err {
1432 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1433 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1434 }
1435 tungstenite::Error::Tls(_) => {
1436 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1437 }
1438 other => {
1439 let msg = other.to_string();
1440 if msg.to_lowercase().contains("connection refused") {
1441 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1442 } else if msg.to_lowercase().contains("tls") {
1443 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1444 } else {
1445 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1446 }
1447 }
1448 }
1449}
1450
1451async fn connect_ws_with_retry<R>(
1456 request: R,
1457 url: &str,
1458 connect_timeout: std::time::Duration,
1459 reconnect_policy: &NetworkRetryPolicy,
1460) -> Result<
1461 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
1462 CamelError,
1463>
1464where
1465 R: IntoClientRequest + Unpin + Clone,
1466{
1467 let url_owned = url.to_string();
1468 retry_async(
1469 reconnect_policy,
1470 Some("ws-producer"),
1471 || {
1472 let r = request.clone();
1473 let url = url_owned.clone();
1474 async move {
1475 match tokio::time::timeout(connect_timeout, tokio_tungstenite::connect_async(r))
1476 .await
1477 {
1478 Ok(Ok((stream, _))) => Ok(stream),
1479 Ok(Err(e)) => Err(map_connect_error(e, &url)),
1480 Err(_) => Err(CamelError::ProcessorError(format!(
1481 "WebSocket connect timeout ({connect_timeout:?}) to {url}"
1482 ))),
1483 }
1484 }
1485 },
1486 is_retryable_ws_error,
1487 )
1488 .await
1489}
1490
1491#[cfg(test)]
1492mod tests {
1493 use camel_component_api::test_support::PanicRuntimeObservability;
1494 fn test_rt() -> std::sync::Arc<dyn camel_component_api::RuntimeObservability> {
1495 std::sync::Arc::new(PanicRuntimeObservability)
1496 }
1497 fn rt() -> std::sync::Arc<dyn camel_component_api::RuntimeObservability> {
1498 std::sync::Arc::new(PanicRuntimeObservability)
1499 }
1500
1501 use super::*;
1502 use camel_component_api::NoOpComponentContext;
1503 use std::time::Duration;
1504
1505 use tokio::sync::mpsc;
1506 use tokio_tungstenite::connect_async;
1507 use tokio_tungstenite::tungstenite::Message as ClientMessage;
1508 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1509 use tokio_util::sync::CancellationToken;
1510 use tower::ServiceExt;
1511
1512 fn free_port() -> u16 {
1513 std::net::TcpListener::bind("127.0.0.1:0")
1514 .unwrap()
1515 .local_addr()
1516 .unwrap()
1517 .port()
1518 }
1519
1520 #[test]
1521 fn ws_component_scheme_is_ws() {
1522 assert_eq!(WsComponent::new().scheme(), "ws");
1523 }
1524
1525 #[test]
1526 fn wss_component_scheme_is_wss() {
1527 assert_eq!(WssComponent::new().scheme(), "wss");
1528 }
1529
1530 #[test]
1531 fn endpoint_config_defaults_match_spec() {
1532 let cfg = WsEndpointConfig::default();
1533 assert_eq!(cfg.scheme, "ws");
1534 assert_eq!(cfg.host, "0.0.0.0");
1535 assert_eq!(cfg.port, 8080);
1536 assert_eq!(cfg.path, "/");
1537 assert_eq!(cfg.max_connections, 100);
1538 assert_eq!(cfg.max_message_size, 65536);
1539 assert!(!cfg.send_to_all);
1540 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1541 assert_eq!(cfg.idle_timeout, Duration::ZERO);
1542 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1543 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1544 assert_eq!(cfg.allow_origin, "*");
1545 assert_eq!(cfg.tls_cert, None);
1546 assert_eq!(cfg.tls_key, None);
1547 assert!(cfg.reconnect);
1548 assert_eq!(cfg.reconnect_max_attempts, 5);
1549 assert_eq!(cfg.reconnect_delay_ms, 1000);
1550 assert_eq!(cfg.send_timeout, Duration::from_secs(30));
1551 assert!(!cfg.binary_payload);
1552 assert!(cfg.subprotocols.is_empty());
1553 }
1554
1555 #[test]
1556 fn endpoint_config_parses_uri_params() {
1557 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1558 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1559
1560 assert_eq!(cfg.scheme, "ws");
1561 assert_eq!(cfg.host, "localhost");
1562 assert_eq!(cfg.port, 9001);
1563 assert_eq!(cfg.path, "/chat");
1564 assert_eq!(cfg.max_connections, 42);
1565 assert_eq!(cfg.max_message_size, 1024);
1566 assert!(cfg.send_to_all);
1567 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1568 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1569 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1570 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1571 assert_eq!(cfg.allow_origin, "https://example.com");
1572 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1573 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1574 assert!(cfg.reconnect);
1575 assert_eq!(cfg.reconnect_max_attempts, 5);
1576 assert_eq!(cfg.reconnect_delay_ms, 1000);
1577 }
1578
1579 #[test]
1580 fn endpoint_config_parses_reconnect_uri_params() {
1581 let uri =
1582 "ws://localhost:9001/chat?reconnect=false&reconnectMaxAttempts=2&reconnectDelayMs=25";
1583 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1584 assert!(!cfg.reconnect);
1585 assert_eq!(cfg.reconnect_max_attempts, 2);
1586 assert_eq!(cfg.reconnect_delay_ms, 25);
1587 }
1588
1589 #[test]
1590 fn endpoint_config_override_chain_uri_overrides_defaults() {
1591 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1592 assert_eq!(cfg.max_connections, 7);
1593 assert_eq!(cfg.max_message_size, 65536);
1594 assert!(!cfg.send_to_all);
1595 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1596 }
1597
1598 #[test]
1599 fn endpoint_trait_creates_consumer_and_producer() {
1600 let ctx = NoOpComponentContext;
1601 let endpoint = WsComponent::new()
1602 .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1603 .unwrap();
1604
1605 endpoint.create_consumer(rt()).unwrap();
1606 endpoint
1607 .create_producer(rt(), &ProducerContext::default())
1608 .unwrap();
1609 }
1610
1611 #[test]
1612 fn ws_consumer_concurrency_model_uses_max_connections() {
1613 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1614 let consumer = WsConsumer::new(cfg.server_config(), test_rt());
1615 assert_eq!(
1616 consumer.concurrency_model(),
1617 ConcurrencyModel::Concurrent { max: Some(321) }
1618 );
1619 }
1620
1621 #[tokio::test]
1622 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1623 let registry = WsConnectionRegistry::new();
1624 let (tx1, mut rx1) = mpsc::channel(8);
1625 let (tx2, mut rx2) = mpsc::channel(8);
1626
1627 registry.insert("k1".into(), tx1);
1628 registry.insert("k2".into(), tx2);
1629 assert_eq!(registry.len(), 2);
1630
1631 for tx in registry.snapshot_senders() {
1632 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1633 }
1634
1635 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1636 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1637
1638 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1639 assert_eq!(target.len(), 1);
1640 target[0]
1641 .send(WsMessage::Text("targeted".into()))
1642 .await
1643 .unwrap();
1644
1645 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1646 assert!(
1647 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1648 .await
1649 .is_err()
1650 );
1651
1652 registry.remove("k1");
1653 assert_eq!(registry.len(), 1);
1654 }
1655
1656 #[test]
1657 fn host_canonicalization_maps_local_hosts_to_loopback() {
1658 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1659 .unwrap()
1660 .canonical_host();
1661 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1662 .unwrap()
1663 .canonical_host();
1664 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1665 .unwrap()
1666 .canonical_host();
1667
1668 assert_eq!(c1, "127.0.0.1");
1669 assert_eq!(c2, "127.0.0.1");
1670 assert_eq!(c3, "127.0.0.1");
1671 }
1672
1673 #[tokio::test]
1674 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1675 let port = free_port();
1676 let uri = format!("ws://127.0.0.1:{port}/echo");
1677 let component_ctx = NoOpComponentContext;
1678 let endpoint = WsComponent::new()
1679 .create_endpoint(&uri, &component_ctx)
1680 .unwrap();
1681
1682 let mut consumer = endpoint.create_consumer(rt()).unwrap();
1683 let producer = endpoint
1684 .create_producer(rt(), &ProducerContext::default())
1685 .unwrap();
1686
1687 let (route_tx, mut route_rx) = mpsc::channel(16);
1688 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1689 consumer.start(ctx).await.unwrap();
1690
1691 let route_task = tokio::spawn(async move {
1692 if let Some(envelope) = route_rx.recv().await {
1693 let payload = envelope
1694 .exchange
1695 .input
1696 .body
1697 .as_text()
1698 .unwrap_or_default()
1699 .to_string();
1700 let key = envelope
1701 .exchange
1702 .input
1703 .header("CamelWsConnectionKey")
1704 .and_then(|v| v.as_str())
1705 .unwrap()
1706 .to_string();
1707
1708 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1709 response
1710 .input
1711 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1712 producer.oneshot(response).await.unwrap();
1713 }
1714 });
1715
1716 let url = format!("ws://127.0.0.1:{port}/echo");
1717 let (mut client, _) = loop {
1718 match connect_async(&url).await {
1719 Ok(ok) => break ok,
1720 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1721 }
1722 };
1723
1724 client
1725 .send(ClientMessage::Text("hello-ws".into()))
1726 .await
1727 .unwrap();
1728
1729 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1730 loop {
1731 match client.next().await {
1732 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1733 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1734 Some(Ok(_)) => continue,
1735 Some(Err(e)) => panic!("ws receive failed: {e}"),
1736 None => panic!("websocket closed before echo"),
1737 }
1738 }
1739 })
1740 .await
1741 .unwrap();
1742
1743 assert_eq!(incoming, "hello-ws");
1744
1745 consumer.stop().await.unwrap();
1746 route_task.await.unwrap();
1747 }
1748
1749 #[tokio::test]
1750 async fn consumer_stop_sends_close_1001() {
1751 let port = free_port();
1752 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1753 let component_ctx = NoOpComponentContext;
1754 let endpoint = WsComponent::new()
1755 .create_endpoint(&uri, &component_ctx)
1756 .unwrap();
1757
1758 let mut consumer = endpoint.create_consumer(rt()).unwrap();
1759 let (route_tx, _route_rx) = mpsc::channel(16);
1760 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1761 consumer.start(ctx).await.unwrap();
1762
1763 let url = format!("ws://127.0.0.1:{port}/shutdown");
1764 let (mut client, _) = loop {
1765 match connect_async(&url).await {
1766 Ok(ok) => break ok,
1767 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1768 }
1769 };
1770
1771 client
1772 .send(ClientMessage::Text("keepalive".into()))
1773 .await
1774 .unwrap();
1775
1776 consumer.stop().await.unwrap();
1777
1778 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1779 loop {
1780 match client.next().await {
1781 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1782 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1783 Some(Ok(_)) => continue,
1784 Some(Err(e)) => panic!("ws receive failed: {e}"),
1785 None => panic!("websocket closed without close frame"),
1786 }
1787 }
1788 })
1789 .await
1790 .unwrap();
1791
1792 assert_eq!(close_code, Some(CloseCode::Away));
1793 }
1794
1795 #[test]
1796 fn wildcard_origin_allows_anything() {
1797 assert!(is_origin_allowed("*", None));
1798 assert!(is_origin_allowed("*", Some("https://example.com")));
1799 }
1800
1801 #[test]
1802 fn exact_origin_requires_match() {
1803 assert!(is_origin_allowed(
1804 "https://example.com",
1805 Some("https://example.com")
1806 ));
1807 assert!(!is_origin_allowed(
1808 "https://example.com",
1809 Some("https://other.com")
1810 ));
1811 assert!(!is_origin_allowed("https://example.com", None));
1812 }
1813
1814 #[test]
1815 fn endpoint_config_rejects_invalid_scheme() {
1816 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1817 assert!(result.is_err());
1818 let msg = result.unwrap_err().to_string();
1819 assert!(
1820 msg.contains("Invalid WebSocket scheme"),
1821 "expected scheme error, got: {msg}"
1822 );
1823 }
1824
1825 #[tokio::test]
1826 async fn wss_consumer_start_fails_without_tls_cert() {
1827 let port = free_port();
1828 let component_ctx = NoOpComponentContext;
1829 let endpoint = WssComponent::new()
1830 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1831 .unwrap();
1832 let mut consumer = endpoint.create_consumer(rt()).unwrap();
1833 let (tx, _rx) = mpsc::channel(16);
1834 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1835 let result = consumer.start(ctx).await;
1836 assert!(result.is_err());
1837 let msg = result.unwrap_err().to_string();
1838 assert!(
1839 msg.contains("TLS cert path is required"),
1840 "expected TLS cert error, got: {msg}"
1841 );
1842 }
1843
1844 #[tokio::test]
1845 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1846 let port = free_port();
1847 let component_ctx = NoOpComponentContext;
1848 let endpoint = WssComponent::new()
1849 .create_endpoint(&format!(
1850 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1851 ), &component_ctx)
1852 .unwrap();
1853 let mut consumer = endpoint.create_consumer(rt()).unwrap();
1854 let (tx, _rx) = mpsc::channel(16);
1855 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1856 let result = consumer.start(ctx).await;
1857 assert!(result.is_err());
1858 let msg = result.unwrap_err().to_string();
1859 assert!(
1860 msg.contains("TLS cert file error"),
1861 "expected cert file error, got: {msg}"
1862 );
1863 }
1864
1865 #[tokio::test]
1866 async fn server_registry_returns_same_state_for_same_port() {
1867 let port = free_port();
1868 let state1 = ServerRegistry::global()
1869 .get_or_spawn("127.0.0.1", port, None)
1870 .await
1871 .unwrap();
1872 let state2 = ServerRegistry::global()
1873 .get_or_spawn("127.0.0.1", port, None)
1874 .await
1875 .unwrap();
1876 assert!(
1877 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1878 "expected same dispatch table for same port"
1879 );
1880 }
1881
1882 #[tokio::test]
1883 async fn dispatch_handler_returns_404_for_unregistered_path() {
1884 let port = free_port();
1885 let state = ServerRegistry::global()
1886 .get_or_spawn("127.0.0.1", port, None)
1887 .await
1888 .unwrap();
1889 let app = Router::new().fallback(dispatch_handler).with_state(state);
1890 let response = tokio::time::timeout(
1891 Duration::from_secs(2),
1892 tower::ServiceExt::oneshot(
1893 app,
1894 axum::http::Request::builder()
1895 .uri("/nonexistent")
1896 .body(Body::empty())
1897 .unwrap(),
1898 ),
1899 )
1900 .await
1901 .unwrap()
1902 .unwrap();
1903 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1904 }
1905
1906 #[tokio::test]
1907 async fn client_mode_producer_connects_and_echoes() {
1908 let app = Router::new().route(
1909 "/echo",
1910 axum::routing::get(|ws: WebSocketUpgrade| async move {
1911 ws.on_upgrade(|mut socket: WebSocket| async move {
1912 while let Some(Ok(msg)) = socket.recv().await {
1913 match msg {
1914 WsMessage::Text(text) => {
1915 let _ = socket.send(WsMessage::Text(text)).await;
1916 }
1917 WsMessage::Binary(data) => {
1918 let _ = socket.send(WsMessage::Binary(data)).await;
1919 }
1920 WsMessage::Close(_) => break,
1921 _ => {}
1922 }
1923 }
1924 })
1925 }),
1926 );
1927 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1929 let port = listener.local_addr().unwrap().port();
1930 let server_task = tokio::spawn(async move {
1931 let _ = serve(listener, app).await;
1932 });
1933
1934 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1935 let producer = WsProducer::new(cfg.client_config());
1936
1937 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1938 tokio::time::sleep(Duration::from_millis(25)).await;
1939 let result =
1940 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1941 Ok(Ok(r)) => r,
1942 Ok(Err(_)) => panic!("producer call failed"),
1943 Err(_) => panic!("producer call timed out"),
1944 };
1945
1946 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1947
1948 server_task.abort();
1949 }
1950
1951 #[tokio::test]
1952 async fn max_connections_rejects_with_close_1013() {
1953 let port = free_port();
1954 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1955 let component_ctx = NoOpComponentContext;
1956 let endpoint = WsComponent::new()
1957 .create_endpoint(&uri, &component_ctx)
1958 .unwrap();
1959 let mut consumer = endpoint.create_consumer(rt()).unwrap();
1960 let (route_tx, _route_rx) = mpsc::channel(16);
1961 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1962 consumer.start(ctx).await.unwrap();
1963
1964 let url = format!("ws://127.0.0.1:{port}/limited");
1965 let (_client1, _) = loop {
1966 match connect_async(&url).await {
1967 Ok(ok) => break ok,
1968 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1969 }
1970 };
1971
1972 tokio::time::sleep(Duration::from_millis(100)).await;
1973
1974 let (mut client2, _) = connect_async(&url).await.unwrap();
1975
1976 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1977 loop {
1978 match client2.next().await {
1979 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1980 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1981 Some(Ok(ClientMessage::Text(_))) => continue,
1982 Some(Ok(_)) => continue,
1983 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1984 None => panic!("client2 closed without close frame"),
1985 }
1986 }
1987 })
1988 .await
1989 .unwrap();
1990
1991 assert_eq!(
1992 close_code,
1993 Some(CloseCode::from(1013u16)),
1994 "expected 1013 (Try Again Later) for max connections"
1995 );
1996
1997 consumer.stop().await.unwrap();
1998 }
1999
2000 #[tokio::test]
2001 async fn max_message_size_rejects_with_close_1009() {
2002 let port = free_port();
2003 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
2004 let component_ctx = NoOpComponentContext;
2005 let endpoint = WsComponent::new()
2006 .create_endpoint(&uri, &component_ctx)
2007 .unwrap();
2008 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2009 let (route_tx, _route_rx) = mpsc::channel(16);
2010 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2011 consumer.start(ctx).await.unwrap();
2012
2013 let url = format!("ws://127.0.0.1:{port}/sizelimit");
2014 let (mut client, _) = loop {
2015 match connect_async(&url).await {
2016 Ok(ok) => break ok,
2017 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2018 }
2019 };
2020
2021 let oversized = "x".repeat(100);
2022 client
2023 .send(ClientMessage::Text(oversized.into()))
2024 .await
2025 .unwrap();
2026
2027 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
2028 loop {
2029 match client.next().await {
2030 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
2031 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2032 Some(Ok(_)) => continue,
2033 Some(Err(e)) => panic!("ws receive failed: {e}"),
2034 None => panic!("websocket closed without close frame"),
2035 }
2036 }
2037 })
2038 .await
2039 .unwrap();
2040
2041 assert_eq!(
2042 close_code,
2043 Some(CloseCode::from(1009u16)),
2044 "expected 1009 (Message Too Big) for oversized message"
2045 );
2046
2047 consumer.stop().await.unwrap();
2048 }
2049
2050 #[tokio::test]
2051 async fn origin_rejection_returns_403() {
2052 let port = free_port();
2053 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
2054 let component_ctx = NoOpComponentContext;
2055 let endpoint = WsComponent::new()
2056 .create_endpoint(&uri, &component_ctx)
2057 .unwrap();
2058 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2059 let (route_tx, _route_rx) = mpsc::channel(16);
2060 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2061 consumer.start(ctx).await.unwrap();
2062
2063 let state = ServerRegistry::global()
2064 .get_or_spawn("127.0.0.1", port, None)
2065 .await
2066 .unwrap();
2067 let app = Router::new().fallback(dispatch_handler).with_state(state);
2068
2069 let response = tokio::time::timeout(
2070 Duration::from_secs(2),
2071 tower::ServiceExt::oneshot(
2072 app,
2073 axum::http::Request::builder()
2074 .uri("/origintest")
2075 .header("origin", "https://evil.com")
2076 .header("upgrade", "websocket")
2077 .header("connection", "Upgrade")
2078 .header("sec-websocket-version", "13")
2079 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
2080 .body(Body::empty())
2081 .unwrap(),
2082 ),
2083 )
2084 .await
2085 .unwrap()
2086 .unwrap();
2087
2088 assert_eq!(
2089 response.status(),
2090 StatusCode::FORBIDDEN,
2091 "expected 403 for disallowed origin"
2092 );
2093
2094 consumer.stop().await.unwrap();
2095 }
2096
2097 #[tokio::test]
2098 async fn broadcast_sends_to_all_connected_clients() {
2099 let port = free_port();
2100 let uri = format!("ws://127.0.0.1:{port}/bc");
2101 let component_ctx = NoOpComponentContext;
2102 let endpoint = WsComponent::new()
2103 .create_endpoint(&uri, &component_ctx)
2104 .unwrap();
2105 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2106 let producer = endpoint
2107 .create_producer(rt(), &ProducerContext::default())
2108 .unwrap();
2109
2110 let (route_tx, _route_rx) = mpsc::channel(16);
2111 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2112 consumer.start(ctx).await.unwrap();
2113
2114 let url = format!("ws://127.0.0.1:{port}/bc");
2115
2116 let (mut client1, _) = loop {
2117 match connect_async(&url).await {
2118 Ok(ok) => break ok,
2119 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2120 }
2121 };
2122
2123 let (mut client2, _) = connect_async(&url).await.unwrap();
2124
2125 tokio::time::sleep(Duration::from_millis(100)).await;
2126
2127 let mut response =
2128 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
2129 response
2130 .input
2131 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
2132 producer.oneshot(response).await.unwrap();
2133
2134 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
2135 loop {
2136 match client1.next().await {
2137 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2138 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2139 _ => panic!("client1 unexpected message or close"),
2140 }
2141 }
2142 })
2143 .await
2144 .unwrap();
2145
2146 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
2147 loop {
2148 match client2.next().await {
2149 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2150 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2151 _ => panic!("client2 unexpected message or close"),
2152 }
2153 }
2154 })
2155 .await
2156 .unwrap();
2157
2158 assert_eq!(recv1, "broadcast-msg");
2159 assert_eq!(recv2, "broadcast-msg");
2160
2161 consumer.stop().await.unwrap();
2162 }
2163
2164 #[tokio::test]
2165 async fn concurrent_get_or_spawn_returns_same_state() {
2166 let port = free_port();
2167 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
2168 Arc::new(std::sync::Mutex::new(Vec::new()));
2169
2170 let mut handles = Vec::new();
2171 for _ in 0..4 {
2172 let results = results.clone();
2173 handles.push(tokio::spawn(async move {
2174 let state = ServerRegistry::global()
2175 .get_or_spawn("127.0.0.1", port, None)
2176 .await
2177 .unwrap();
2178 results.lock().unwrap().push(state);
2179 }));
2180 }
2181
2182 for h in handles {
2183 h.await.unwrap();
2184 }
2185
2186 let states = results.lock().unwrap();
2187 assert_eq!(states.len(), 4);
2188 for i in 1..states.len() {
2189 assert!(
2190 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
2191 "all concurrent callers should get the same dispatch table"
2192 );
2193 }
2194 }
2195
2196 #[tokio::test]
2197 async fn body_conversion_helpers_cover_text_and_binary_paths() {
2198 let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
2199 .await
2200 .unwrap();
2201 assert!(matches!(text_msg, WsMessage::Text(_)));
2202
2203 let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
2204 .await
2205 .unwrap();
2206 assert!(matches!(bin_msg, WsMessage::Binary(_)));
2207
2208 let client_text =
2209 body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
2210 .await
2211 .unwrap();
2212 assert!(matches!(client_text, ClientWsMessage::Text(_)));
2213
2214 let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
2215 .await
2216 .unwrap();
2217 assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
2218 }
2219
2220 #[tokio::test]
2221 async fn body_to_text_handles_empty_text_json_and_bytes() {
2222 assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2223 assert_eq!(
2224 body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2225 "hello"
2226 );
2227 assert_eq!(
2228 body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2229 .await
2230 .unwrap(),
2231 "{\"n\":1}"
2232 );
2233 assert_eq!(
2234 body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2235 .await
2236 .unwrap(),
2237 "hi"
2238 );
2239 }
2240
2241 #[test]
2242 fn try_send_with_backpressure_returns_false_when_channel_full() {
2243 let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2244 assert!(try_send_with_backpressure(
2245 &tx,
2246 WsMessage::Text("first".into()),
2247 "test"
2248 ));
2249 assert!(!try_send_with_backpressure(
2250 &tx,
2251 WsMessage::Text("second".into()),
2252 "test"
2253 ));
2254 }
2255
2256 #[test]
2257 fn map_connect_error_formats_connection_refused_and_generic_errors() {
2258 let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2259 let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2260 assert!(err.to_string().contains("WebSocket connection refused"));
2261
2262 let generic = map_connect_error(
2263 tungstenite::Error::Protocol(
2264 tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2265 ),
2266 "ws://localhost:2/y",
2267 );
2268 assert!(
2269 generic
2270 .to_string()
2271 .contains("WebSocket connection failed (ws://localhost:2/y)")
2272 );
2273 }
2274
2275 #[test]
2279 fn from_uri_rejects_max_connections_zero() {
2280 let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2281 assert!(result.is_err());
2282 let msg = result.unwrap_err().to_string();
2283 assert!(
2284 msg.contains("maxConnections must be >= 1"),
2285 "expected maxConnections validation error, got: {msg}"
2286 );
2287 }
2288
2289 #[test]
2291 fn from_uri_rejects_max_message_size_zero() {
2292 let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2293 assert!(result.is_err());
2294 let msg = result.unwrap_err().to_string();
2295 assert!(
2296 msg.contains("maxMessageSize must be > 0"),
2297 "expected maxMessageSize validation error, got: {msg}"
2298 );
2299 }
2300
2301 #[test]
2303 fn from_uri_rejects_empty_allow_origin() {
2304 let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2305 assert!(result.is_err());
2306 let msg = result.unwrap_err().to_string();
2307 assert!(
2308 msg.contains("allowOrigin must not be empty"),
2309 "expected allowOrigin validation error, got: {msg}"
2310 );
2311 }
2312
2313 #[tokio::test]
2315 async fn consumer_double_start_returns_error() {
2316 let port = free_port();
2317 let uri = format!("ws://127.0.0.1:{port}/doublestart");
2318 let component_ctx = NoOpComponentContext;
2319 let endpoint = WsComponent::new()
2320 .create_endpoint(&uri, &component_ctx)
2321 .unwrap();
2322
2323 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2324 let (route_tx, _route_rx) = mpsc::channel(16);
2325 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2326
2327 consumer.start(ctx).await.unwrap();
2329
2330 let (route_tx2, _route_rx2) = mpsc::channel(16);
2332 let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2333 let result = consumer.start(ctx2).await;
2334 assert!(result.is_err());
2335 let msg = result.unwrap_err().to_string();
2336 assert!(
2337 msg.contains("already started"),
2338 "expected double-start error, got: {msg}"
2339 );
2340
2341 consumer.stop().await.unwrap();
2342 }
2343
2344 #[tokio::test]
2346 async fn registry_cleanup_on_consumer_stop() {
2347 let port = free_port();
2348 let uri = format!("ws://127.0.0.1:{port}/cleanup");
2349 let component_ctx = NoOpComponentContext;
2350 let endpoint = WsComponent::new()
2351 .create_endpoint(&uri, &component_ctx)
2352 .unwrap();
2353
2354 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2355 let (route_tx, _route_rx) = mpsc::channel(16);
2356 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2357 consumer.start(ctx).await.unwrap();
2358
2359 let registries = global_registries();
2361 let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2362 assert!(
2363 registries.contains_key(&key),
2364 "registry should have entry after start"
2365 );
2366
2367 consumer.stop().await.unwrap();
2369
2370 assert!(
2372 !registries.contains_key(&key),
2373 "registry should be cleaned up after stop"
2374 );
2375
2376 let server_reg = ServerRegistry::global();
2379 let guard = server_reg.inner.lock().unwrap();
2380 assert!(
2381 !guard.contains_key(&port),
2382 "ServerRegistry should remove port entry after last consumer stops"
2383 );
2384 }
2385
2386 #[tokio::test]
2388 async fn producer_server_send_returns_error_when_all_dropped() {
2389 let port = free_port();
2390 let uri = format!("ws://127.0.0.1:{port}/backpressure");
2391 let component_ctx = NoOpComponentContext;
2392 let endpoint = WsComponent::new()
2393 .create_endpoint(&uri, &component_ctx)
2394 .unwrap();
2395
2396 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2397 let producer = endpoint
2398 .create_producer(rt(), &ProducerContext::default())
2399 .unwrap();
2400
2401 let (route_tx, _route_rx) = mpsc::channel(1); let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2403 consumer.start(ctx).await.unwrap();
2404
2405 let url = format!("ws://127.0.0.1:{port}/backpressure");
2407 let (mut client, _) = loop {
2408 match connect_async(&url).await {
2409 Ok(ok) => break ok,
2410 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2411 }
2412 };
2413
2414 tokio::time::sleep(Duration::from_millis(50)).await;
2416
2417 let mut all_dropped = false;
2419 for _ in 0..100 {
2420 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2421 match producer.clone().oneshot(exchange).await {
2422 Ok(_) => {}
2423 Err(e) => {
2424 if e.to_string().contains("backpressure") {
2425 all_dropped = true;
2426 break;
2427 }
2428 }
2429 }
2430 }
2431
2432 assert!(
2434 all_dropped,
2435 "producer should return error when all messages are dropped due to backpressure"
2436 );
2437
2438 let _ = client.close(None).await;
2440 consumer.stop().await.unwrap();
2441 }
2442
2443 #[tokio::test]
2445 async fn server_responds_to_client_ping_with_pong() {
2446 let port = free_port();
2447 let uri = format!("ws://127.0.0.1:{port}/pingpong");
2448 let component_ctx = NoOpComponentContext;
2449 let endpoint = WsComponent::new()
2450 .create_endpoint(&uri, &component_ctx)
2451 .unwrap();
2452
2453 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2454 let (route_tx, _route_rx) = mpsc::channel(16);
2455 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2456 consumer.start(ctx).await.unwrap();
2457
2458 let url = format!("ws://127.0.0.1:{port}/pingpong");
2459 let (mut client, _) = loop {
2460 match connect_async(&url).await {
2461 Ok(ok) => break ok,
2462 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2463 }
2464 };
2465
2466 client
2468 .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2469 .await
2470 .unwrap();
2471
2472 let pong = tokio::time::timeout(Duration::from_secs(2), async {
2474 loop {
2475 match client.next().await {
2476 Some(Ok(ClientMessage::Pong(data))) => break data,
2477 Some(Ok(ClientMessage::Ping(_))) => continue,
2478 Some(Ok(_)) => continue,
2479 Some(Err(e)) => panic!("ws receive failed: {e}"),
2480 None => panic!("websocket closed before pong"),
2481 }
2482 }
2483 })
2484 .await
2485 .unwrap();
2486
2487 assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2488
2489 consumer.stop().await.unwrap();
2490 }
2491
2492 #[tokio::test]
2494 async fn producer_retries_on_connection_refused() {
2495 let port = free_port();
2497 let cfg = WsEndpointConfig::from_uri(&format!(
2499 "ws://127.0.0.1:{port}/retry?reconnect=true&reconnectMaxAttempts=2&reconnectDelayMs=50"
2500 ))
2501 .unwrap();
2502 let producer = WsProducer::new(cfg.client_config());
2503
2504 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2505
2506 let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2508 assert!(
2509 result.is_ok(),
2510 "producer should complete (with error) within timeout"
2511 );
2512 let result = result.unwrap();
2513 assert!(
2514 result.is_err(),
2515 "producer should fail when nothing is listening"
2516 );
2517 let msg = result.unwrap_err().to_string();
2518 assert!(
2519 msg.contains("connection refused"),
2520 "expected connection refused error, got: {msg}"
2521 );
2522 }
2523
2524 #[tokio::test]
2526 async fn server_bind_error_is_reported() {
2527 let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2529 let port = _listener.local_addr().unwrap().port();
2530
2531 let uri = format!("ws://127.0.0.1:{port}/binderror");
2534 let component_ctx = NoOpComponentContext;
2535 let endpoint = WsComponent::new()
2536 .create_endpoint(&uri, &component_ctx)
2537 .unwrap();
2538
2539 let mut consumer = endpoint.create_consumer(rt()).unwrap();
2540 let (route_tx, _route_rx) = mpsc::channel(16);
2541 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2542
2543 let start_result = consumer.start(ctx).await;
2545 let _ = start_result;
2548
2549 consumer.stop().await.unwrap();
2550 }
2551
2552 #[test]
2553 fn ws_app_state_server_error_starts_false() {
2554 let state = WsAppState {
2555 dispatch: Arc::new(RwLock::new(HashMap::new())),
2556 path_configs: Arc::new(DashMap::new()),
2557 path_policies: Arc::new(DashMap::new()),
2558 server_error: new_atomic_false(),
2559 };
2560 assert!(
2561 !state.server_error.load(Ordering::Relaxed),
2562 "server_error should start as false"
2563 );
2564 }
2565
2566 #[test]
2567 fn ws_app_state_server_error_can_be_set() {
2568 let state = WsAppState {
2569 dispatch: Arc::new(RwLock::new(HashMap::new())),
2570 path_configs: Arc::new(DashMap::new()),
2571 path_policies: Arc::new(DashMap::new()),
2572 server_error: new_atomic_false(),
2573 };
2574 assert!(!state.server_error.load(Ordering::Relaxed));
2575 state.server_error.store(true, Ordering::Relaxed);
2576 assert!(state.server_error.load(Ordering::Relaxed));
2577 }
2578
2579 #[tokio::test]
2580 async fn consumer_stop_returns_error_when_server_had_errors() {
2581 let port = free_port();
2582 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2583 let mut consumer = WsConsumer::new(cfg.server_config(), test_rt());
2584 let (route_tx, _route_rx) = mpsc::channel(16);
2585 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2586 consumer.start(ctx).await.unwrap();
2587
2588 if let Some(ref state) = consumer.server_state {
2590 state.server_error.store(true, Ordering::Relaxed);
2591 }
2592
2593 let result = consumer.stop().await;
2594 assert!(
2595 result.is_err(),
2596 "stop should return error when server had errors"
2597 );
2598 let msg = result.unwrap_err().to_string();
2599 assert!(
2600 msg.contains("terminated with errors"),
2601 "expected server error message, got: {msg}"
2602 );
2603 }
2604
2605 #[tokio::test]
2606 async fn consumer_stop_succeeds_when_server_healthy() {
2607 let port = free_port();
2608 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2609 let mut consumer = WsConsumer::new(cfg.server_config(), test_rt());
2610 let (route_tx, _route_rx) = mpsc::channel(16);
2611 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2612 consumer.start(ctx).await.unwrap();
2613
2614 let result = consumer.stop().await;
2615 assert!(
2616 result.is_ok(),
2617 "stop should succeed when server is healthy: {:?}",
2618 result
2619 );
2620 }
2621
2622 #[test]
2626 fn endpoint_config_parses_subprotocols() {
2627 let cfg = WsEndpointConfig::from_uri(
2628 "ws://localhost:9001/chat?subprotocols=graphql-ws,graphql-transport-ws",
2629 )
2630 .unwrap();
2631 assert_eq!(cfg.subprotocols, vec!["graphql-ws", "graphql-transport-ws"]);
2632 }
2633
2634 #[test]
2635 fn endpoint_config_default_subprotocols_empty() {
2636 let cfg = WsEndpointConfig::default();
2637 assert!(cfg.subprotocols.is_empty());
2638 }
2639
2640 #[test]
2642 fn endpoint_config_parses_send_timeout() {
2643 let cfg =
2644 WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=5000").unwrap();
2645 assert_eq!(cfg.send_timeout, Duration::from_millis(5000));
2646 }
2647
2648 #[test]
2649 fn endpoint_config_default_send_timeout() {
2650 let cfg = WsEndpointConfig::default();
2651 assert_eq!(cfg.send_timeout, Duration::from_secs(30));
2652 }
2653
2654 #[test]
2655 fn endpoint_config_rejects_invalid_send_timeout() {
2656 let err =
2657 WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=abc").unwrap_err();
2658 assert!(err.to_string().contains("sendTimeoutMs"));
2659 }
2660
2661 #[test]
2663 fn endpoint_config_parses_binary_payload() {
2664 let cfg =
2665 WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=true").unwrap();
2666 assert!(cfg.binary_payload);
2667 }
2668
2669 #[test]
2670 fn endpoint_config_default_binary_payload_false() {
2671 let cfg = WsEndpointConfig::default();
2672 assert!(!cfg.binary_payload);
2673 }
2674
2675 #[test]
2676 fn endpoint_config_rejects_invalid_binary_payload() {
2677 let err =
2678 WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=yes").unwrap_err();
2679 assert!(err.to_string().contains("binaryPayload"));
2680 }
2681
2682 #[tokio::test]
2686 async fn retry_loop_invokes_operation_exactly_max_attempts_times() {
2687 use camel_component_api::NetworkRetryPolicy;
2688 use std::sync::Arc;
2689 use std::sync::atomic::{AtomicU32, Ordering};
2690
2691 let policy = NetworkRetryPolicy {
2692 max_attempts: 3,
2693 initial_delay: Duration::from_millis(1),
2694 max_delay: Duration::from_millis(1),
2695 multiplier: 1.0,
2696 ..NetworkRetryPolicy::default()
2697 };
2698
2699 let calls = Arc::new(AtomicU32::new(0));
2700 let calls_clone = Arc::clone(&calls);
2701 let mut attempts: u32 = 0;
2702
2703 let _result: Result<(), ()> = loop {
2704 calls_clone.fetch_add(1, Ordering::SeqCst);
2705 let op_result: Result<(), ()> = Err(());
2706 match op_result {
2707 Ok(_) => unreachable!(),
2708 Err(_) if policy.should_retry(attempts + 1) => {
2709 let delay = policy.delay_for(attempts);
2710 tokio::time::sleep(delay).await;
2711 attempts += 1;
2712 continue;
2713 }
2714 Err(_) => break Err(()),
2715 }
2716 };
2717
2718 assert_eq!(
2719 calls.load(Ordering::SeqCst),
2720 3,
2721 "max_attempts=3 must yield exactly 3 invocations"
2722 );
2723 }
2724
2725 #[tokio::test]
2728 async fn retry_loop_with_max_attempts_1_invokes_operation_once() {
2729 use camel_component_api::NetworkRetryPolicy;
2730 use std::sync::Arc;
2731 use std::sync::atomic::{AtomicU32, Ordering};
2732
2733 let policy = NetworkRetryPolicy {
2734 max_attempts: 1,
2735 initial_delay: Duration::from_millis(1),
2736 max_delay: Duration::from_millis(1),
2737 multiplier: 1.0,
2738 ..NetworkRetryPolicy::default()
2739 };
2740
2741 let calls = Arc::new(AtomicU32::new(0));
2742 let calls_clone = Arc::clone(&calls);
2743 let mut attempts: u32 = 0;
2744
2745 let _result: Result<(), ()> = loop {
2746 calls_clone.fetch_add(1, Ordering::SeqCst);
2747 let op_result: Result<(), ()> = Err(());
2748 match op_result {
2749 Ok(_) => unreachable!(),
2750 Err(_) if policy.should_retry(attempts + 1) => {
2751 let delay = policy.delay_for(attempts);
2752 tokio::time::sleep(delay).await;
2753 attempts += 1;
2754 continue;
2755 }
2756 Err(_) => break Err(()),
2757 }
2758 };
2759
2760 assert_eq!(
2761 calls.load(Ordering::SeqCst),
2762 1,
2763 "max_attempts=1 must yield exactly 1 invocation"
2764 );
2765 }
2766
2767 use std::fmt::Write as _;
2770 use std::sync::{Arc, Mutex};
2771 use tracing::Subscriber;
2772 use tracing_subscriber::Layer;
2773 use tracing_subscriber::layer::SubscriberExt;
2774
2775 struct CollectingLayer {
2776 events: Arc<Mutex<Vec<String>>>,
2777 }
2778
2779 impl<S: Subscriber> Layer<S> for CollectingLayer {
2780 fn on_event(
2781 &self,
2782 event: &tracing::Event<'_>,
2783 _ctx: tracing_subscriber::layer::Context<'_, S>,
2784 ) {
2785 let mut buf = String::new();
2786 let mut visitor = CollectingVisitor { fields: &mut buf };
2787 event.record(&mut visitor);
2788 if let Ok(mut events) = self.events.lock() {
2789 events.push(buf);
2790 }
2791 }
2792 }
2793
2794 struct CollectingVisitor<'a> {
2795 fields: &'a mut String,
2796 }
2797
2798 impl CollectingVisitor<'_> {
2799 fn record_field(&mut self, name: &str, value: &str) {
2800 write!(self.fields, " {name}={value}").ok();
2801 }
2802 }
2803
2804 impl tracing::field::Visit for CollectingVisitor<'_> {
2805 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
2806 self.record_field(field.name(), value);
2807 }
2808 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
2809 self.record_field(field.name(), &format!("{value:?}"));
2810 }
2811 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
2812 self.record_field(field.name(), &value.to_string());
2813 }
2814 }
2815
2816 #[tokio::test]
2822 async fn ws_producer_retry_log_emits_component_ws_producer() {
2823 let events = Arc::new(Mutex::new(Vec::new()));
2824 let layer = CollectingLayer {
2825 events: events.clone(),
2826 };
2827 let subscriber = tracing_subscriber::registry().with(layer);
2828 let _guard = tracing::subscriber::set_default(subscriber);
2829
2830 let policy = NetworkRetryPolicy {
2831 max_attempts: 2,
2832 initial_delay: Duration::from_millis(1),
2833 max_delay: Duration::from_millis(5),
2834 ..NetworkRetryPolicy::default()
2835 };
2836
2837 let request = "ws://127.0.0.1:1".into_client_request().unwrap();
2841
2842 let result: Result<_, CamelError> = connect_ws_with_retry(
2843 request,
2844 "ws://127.0.0.1:1",
2845 Duration::from_millis(100),
2846 &policy,
2847 )
2848 .await;
2849
2850 assert!(result.is_err(), "expected exhausted-retries error");
2851 let captured = events.lock().unwrap();
2852 assert!(
2853 !captured.is_empty(),
2854 "expected at least one retry log event, got none"
2855 );
2856 let first = &captured[0];
2857 assert!(
2858 first.contains("component=ws-producer"),
2859 "rc-1nm regression: expected 'component=ws-producer' in WS retry log, got: {first}"
2860 );
2861 }
2862}