1pub mod bundle;
7pub mod config;
8pub mod health;
9
10pub use bundle::WsBundle;
11pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
12pub use health::WsHealthCheck;
13
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex, OnceLock};
16
17use async_trait::async_trait;
18use axum::body::Body;
19use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
20use axum::extract::{FromRequest, Request, State};
21use axum::http::{StatusCode, header};
22use axum::response::IntoResponse;
23use axum::{Router, serve};
24use camel_api::security_policy::AuthorizationDecision;
25use camel_api::{BackoffConfig, BackoffState};
26use camel_component_api::{
27 Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
28};
29use camel_component_api::{
30 Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
31 ProducerContext,
32};
33use dashmap::DashMap;
34use futures::{SinkExt, StreamExt};
35use std::future::Future;
36use std::pin::Pin;
37use std::task::{Context, Poll};
38use tokio::sync::{OnceCell, RwLock, mpsc};
39use tokio::task::JoinHandle;
40use tokio_tungstenite::tungstenite;
41use tokio_tungstenite::tungstenite::client::IntoClientRequest;
42use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
43use tower::Service;
44
45#[derive(Clone)]
46pub struct WsPathConfig {
47 pub max_connections: u32,
48 pub max_message_size: u32,
49 pub heartbeat_interval: std::time::Duration,
50 pub idle_timeout: std::time::Duration,
51 pub allow_origin: String,
52}
53
54impl Default for WsPathConfig {
55 fn default() -> Self {
56 let cfg = WsEndpointConfig::default();
57 Self {
58 max_connections: cfg.max_connections,
59 max_message_size: cfg.max_message_size,
60 heartbeat_interval: cfg.heartbeat_interval,
61 idle_timeout: cfg.idle_timeout,
62 allow_origin: cfg.allow_origin,
63 }
64 }
65}
66
67#[derive(Clone)]
68pub struct WsTlsConfig {
69 pub cert_path: String,
70 pub key_path: String,
71}
72
73type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
74
75struct ServerHandle {
76 state: WsAppState,
77 is_tls: bool,
78 _task: JoinHandle<()>,
79}
80
81struct ServerRegistryInner {
82 cell: Arc<OnceCell<ServerHandle>>,
83 ref_count: usize,
84}
85
86pub struct ServerRegistry {
87 inner: Mutex<HashMap<u16, ServerRegistryInner>>,
88}
89
90impl ServerRegistry {
91 pub fn global() -> &'static Self {
92 static REG: OnceLock<ServerRegistry> = OnceLock::new();
93 REG.get_or_init(|| Self {
94 inner: Mutex::new(HashMap::new()),
95 })
96 }
97
98 pub async fn get_or_spawn(
99 &'static self,
100 host: &str,
101 port: u16,
102 tls_config: Option<WsTlsConfig>,
103 ) -> Result<WsAppState, CamelError> {
104 let wants_tls = tls_config.is_some();
105 let host_owned = host.to_string();
106
107 let (cell, _is_new) = {
108 let mut guard = self.inner.lock().map_err(|_| {
109 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
110 })?;
111 let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
112 cell: Arc::new(OnceCell::new()),
113 ref_count: 0,
114 });
115 entry.ref_count += 1;
116 (entry.cell.clone(), entry.ref_count == 1)
117 };
118
119 let handle = cell
120 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
121 .await?;
122
123 if wants_tls != handle.is_tls {
124 let mut guard = self.inner.lock().map_err(|_| {
126 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
127 })?;
128 if let Some(entry) = guard.get_mut(&port) {
129 entry.ref_count -= 1;
130 if entry.ref_count == 0 {
131 guard.remove(&port);
132 }
133 }
134 return Err(CamelError::EndpointCreationFailed(format!(
135 "Server on port {port} already running with different TLS mode"
136 )));
137 }
138
139 Ok(handle.state.clone())
140 }
141
142 pub(crate) fn release(&self, port: u16) {
145 let mut guard = match self.inner.lock() {
146 Ok(g) => g,
147 Err(_) => return,
148 };
149 if let Some(entry) = guard.get_mut(&port) {
150 entry.ref_count = entry.ref_count.saturating_sub(1);
151 if entry.ref_count == 0 {
152 if let Some(handle) = entry.cell.get() {
154 handle._task.abort();
155 }
156 guard.remove(&port);
157 tracing::info!(port, "WebSocket server registry entry removed");
158 }
159 }
160 }
161}
162
163async fn spawn_server(
164 host: &str,
165 port: u16,
166 tls_config: Option<WsTlsConfig>,
167) -> Result<ServerHandle, CamelError> {
168 let host_owned = host.to_string();
169 let addr = format!("{host}:{port}");
170 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
171 let path_configs = Arc::new(DashMap::new());
172 let path_policies = Arc::new(DashMap::new());
173 let server_error = new_atomic_false();
174 let state = WsAppState {
175 dispatch: Arc::clone(&dispatch),
176 path_configs: Arc::clone(&path_configs),
177 path_policies: Arc::clone(&path_policies),
178 server_error: Arc::clone(&server_error),
179 };
180 let app = Router::new()
181 .fallback(dispatch_handler)
182 .with_state(state.clone());
183
184 let (task, is_tls) = if let Some(ref tls) = tls_config {
185 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
186 let parsed_addr = addr.parse().map_err(|e| {
187 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
188 })?;
189 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
190 let error_flag = Arc::clone(&server_error);
191 let task = tokio::spawn(async move {
192 if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
193 .serve(app.into_make_service())
194 .await
195 {
196 tracing::error!(
197 host = host_owned,
198 port = port,
199 error = %e,
200 "WebSocket server terminated with error"
201 );
202 error_flag.store(true, Ordering::Relaxed);
203 }
204 });
205 (task, true)
206 } else {
207 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
208 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
209 })?;
210 let error_flag = Arc::clone(&server_error);
211 let task = tokio::spawn(async move {
212 if let Err(e) = serve(listener, app).await {
213 tracing::error!(
214 host = host_owned,
215 port = port,
216 error = %e,
217 "WebSocket server terminated with error"
218 );
219 error_flag.store(true, Ordering::Relaxed);
220 }
221 });
222 (task, false)
223 };
224
225 tracing::info!(host, port, is_tls, "WebSocket server started");
226
227 Ok(ServerHandle {
228 state,
229 is_tls,
230 _task: task,
231 })
232}
233
234#[derive(Clone)]
235pub struct WsAppState {
236 pub dispatch: DispatchTable,
237 pub path_configs: Arc<DashMap<String, WsPathConfig>>,
238 pub path_policies: Arc<DashMap<String, camel_component_api::SecurityContext>>,
239 pub server_error: Arc<AtomicBool>,
240}
241
242pub struct WsConnectionRegistry {
243 connections: DashMap<String, mpsc::Sender<WsMessage>>,
244}
245
246static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
247 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
248> = OnceLock::new();
249
250fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
251 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
252}
253
254impl Default for WsConnectionRegistry {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260impl WsConnectionRegistry {
261 pub fn new() -> Self {
262 Self {
263 connections: DashMap::new(),
264 }
265 }
266
267 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
268 self.connections.insert(key, tx);
269 }
270
271 pub fn remove(&self, key: &str) {
272 self.connections.remove(key);
273 }
274
275 pub fn len(&self) -> usize {
276 self.connections.len()
277 }
278
279 pub fn is_empty(&self) -> bool {
280 self.connections.is_empty()
281 }
282
283 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
284 self.connections.iter().map(|e| e.value().clone()).collect()
285 }
286
287 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
288 keys.iter()
289 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
290 .collect()
291 }
292}
293
294pub async fn dispatch_handler(
295 State(state): State<WsAppState>,
296 req: Request<Body>,
297) -> impl IntoResponse {
298 let path = req.uri().path().to_string();
299 let origin = req
300 .headers()
301 .get(header::ORIGIN)
302 .and_then(|value| value.to_str().ok())
303 .map(str::to_string);
304 let remote_addr = req
305 .extensions()
306 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
307 .map(|ci| ci.0.to_string())
308 .unwrap_or_default();
309 let table = state.dispatch.read().await;
310 if !table.contains_key(&path) {
311 return (
312 StatusCode::NOT_FOUND,
313 "no ws endpoint registered for this path",
314 )
315 .into_response();
316 }
317 drop(table);
318
319 let path_config = state
320 .path_configs
321 .get(&path)
322 .map(|entry| entry.value().clone())
323 .unwrap_or_default();
324 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
325 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
326 }
327
328 let mut principal_opt: Option<camel_api::security_policy::Principal> = None;
329 if let Some(sec_ctx) = state.path_policies.get(&path) {
330 let extracted =
331 camel_auth::extract_token_multi(req.headers(), req.uri(), &sec_ctx.credential_sources);
332
333 match extracted {
334 Some(extracted) => {
335 if matches!(
336 extracted.source,
337 camel_auth::CredentialSource::QueryParam { .. }
338 ) {
339 let redacted =
340 camel_auth::redact_query_params(req.uri(), &["access_token", "token"]);
341 tracing::debug!(path = %redacted, "WS upgrade with query token (redacted)");
342 }
343 match sec_ctx
344 .authenticator
345 .authenticate_bearer(&extracted.token)
346 .await
347 {
348 Ok(principal) => {
349 let mut exchange = camel_api::Exchange::new(camel_api::Message::new(
350 camel_api::Body::Empty,
351 ));
352 camel_api::store_principal_properties(&mut exchange, &principal);
353 match sec_ctx.policy.evaluate(&mut exchange).await {
354 Ok(AuthorizationDecision::Granted { principal: _p }) => {
355 tracing::debug!(path = %path, subject = %principal.subject, "WS upgrade authorized");
356 principal_opt = Some(principal);
357 }
358 Ok(AuthorizationDecision::Denied { reason, .. }) => {
359 tracing::warn!(path = %path, reason = %reason, "WS upgrade denied");
360 return (StatusCode::FORBIDDEN, "Forbidden").into_response();
361 }
362 Err(e) => {
363 tracing::error!(path = %path, error = %e, "Policy evaluation error during WS upgrade");
364 return (
365 StatusCode::INTERNAL_SERVER_ERROR,
366 "Internal Server Error",
367 )
368 .into_response();
369 }
370 }
371 }
372 Err(e) => {
373 let (status, body) = match &e {
374 camel_api::CamelError::Unauthenticated(_) => {
375 (StatusCode::UNAUTHORIZED, "Unauthorized")
376 }
377 camel_api::CamelError::ProcessorError(msg)
378 if msg.contains("auth provider unavailable") =>
379 {
380 (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
381 }
382 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error"),
383 };
384 tracing::warn!(path = %path, error = %e, "WS upgrade authentication failed");
385 return (status, body).into_response();
386 }
387 }
388 }
389 None => {
390 tracing::warn!(path = %path, "WS upgrade rejected: no credential found in any source");
391 return (
392 StatusCode::UNAUTHORIZED,
393 [("WWW-Authenticate", "Bearer".to_string())],
394 "Unauthorized",
395 )
396 .into_response();
397 }
398 }
399 }
400
401 let upgrade_headers: HashMap<String, String> = req
402 .headers()
403 .iter()
404 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
405 .collect();
406
407 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
408 Ok(ws) => ws,
409 Err(_) => {
410 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
411 }
412 };
413
414 ws.on_upgrade(move |socket| {
415 ws_handler(
416 socket,
417 state,
418 path,
419 remote_addr,
420 upgrade_headers,
421 principal_opt,
422 )
423 })
424 .into_response()
425}
426
427#[allow(unused_variables)]
428async fn ws_handler(
429 socket: WebSocket,
430 state: WsAppState,
431 path: String,
432 remote_addr: String,
433 upgrade_headers: HashMap<String, String>,
434 principal: Option<camel_api::security_policy::Principal>,
435) {
436 let connection_key = uuid::Uuid::new_v4().to_string();
437 let path_config = state
438 .path_configs
439 .get(&path)
440 .map(|entry| entry.value().clone())
441 .unwrap_or_default();
442
443 let env_tx = {
444 let table = state.dispatch.read().await;
445 table.get(&path).cloned()
446 };
447 let Some(env_tx) = env_tx else {
448 return;
449 };
450
451 let (mut sink, mut stream) = socket.split();
452 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
453
454 let registry = global_registries();
455 let mut registry_key = None;
456 for entry in registry.iter() {
457 if entry.key().2 == path {
458 entry.value().insert(connection_key.clone(), out_tx.clone());
459 registry_key = Some(entry.key().clone());
460 break;
461 }
462 }
463
464 let conn_key_for_writer = connection_key.clone();
466 let path_for_writer = path.clone();
467
468 let writer = tokio::spawn(async move {
469 while let Some(msg) = out_rx.recv().await {
470 if let Err(e) = sink.send(msg).await {
471 tracing::warn!(
472 connection_key = conn_key_for_writer,
473 path = path_for_writer,
474 error = %e,
475 "WebSocket writer send error — closing connection"
476 );
477 break;
478 }
479 }
480 });
481
482 tracing::info!(
483 connection_key = connection_key,
484 path = path,
485 remote_addr = remote_addr,
486 "WebSocket connection opened"
487 );
488
489 let mut over_limit = false;
490 if let Some(key) = ®istry_key
491 && let Some(entry) = registry.get(key)
492 && entry.len() > path_config.max_connections as usize
493 {
494 over_limit = true;
495 }
496 if over_limit {
497 try_send_with_backpressure(
498 &out_tx,
499 WsMessage::Close(Some(CloseFrame {
500 code: CloseCode::from(1013u16),
501 reason: "max connections exceeded".into(),
502 })),
503 "max-connections-close",
504 );
505 if let Some(key) = registry_key.clone()
506 && let Some(entry) = registry.get(&key)
507 {
508 entry.remove(&connection_key);
509 }
510 drop(out_tx);
511 let _ = writer.await;
512 return;
513 }
514
515 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
516 let ping_tx = out_tx.clone();
517 let interval = path_config.heartbeat_interval;
518 Some(tokio::spawn(async move {
519 let mut ticker = tokio::time::interval(interval);
520 loop {
521 ticker.tick().await;
522 let _ = try_send_with_backpressure(
523 &ping_tx,
524 WsMessage::Ping(Vec::new().into()),
525 "heartbeat-ping",
526 );
527 }
528 }))
529 } else {
530 None
531 };
532
533 loop {
534 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
535 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
536 Ok(msg) => msg,
537 Err(_) => {
538 try_send_with_backpressure(
539 &out_tx,
540 WsMessage::Close(Some(CloseFrame {
541 code: CloseCode::from(1000u16),
542 reason: "idle timeout".into(),
543 })),
544 "idle-timeout-close",
545 );
546 break;
547 }
548 }
549 } else {
550 stream.next().await
551 };
552
553 let Some(msg) = next_msg else {
554 break;
555 };
556
557 match msg {
558 Ok(WsMessage::Ping(data)) => {
559 tracing::debug!(
560 connection_key = connection_key,
561 path = path,
562 "WebSocket ping received — sending pong"
563 );
564 let _ = try_send_with_backpressure(
565 &out_tx,
566 WsMessage::Pong(data),
567 "ping-pong-response",
568 );
569 }
570 Ok(WsMessage::Pong(_)) => {
571 tracing::debug!(
572 connection_key = connection_key,
573 path = path,
574 "WebSocket pong received"
575 );
576 }
577 Ok(WsMessage::Text(text)) => {
578 if text.len() > path_config.max_message_size as usize {
579 try_send_with_backpressure(
580 &out_tx,
581 WsMessage::Close(Some(CloseFrame {
582 code: CloseCode::from(1009u16),
583 reason: "message too large".into(),
584 })),
585 "max-message-size-close-text",
586 );
587 break;
588 }
589
590 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
591 message.set_header(
592 "CamelWsConnectionKey",
593 serde_json::Value::String(connection_key.clone()),
594 );
595 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
596 message.set_header(
597 "CamelWsRemoteAddress",
598 serde_json::Value::String(remote_addr.clone()),
599 );
600
601 #[allow(unused_mut)]
602 let mut exchange = Exchange::new(message);
603 if let Some(ref p) = principal {
604 camel_api::store_principal_properties(&mut exchange, p);
605 }
606 #[cfg(feature = "otel")]
607 {
608 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
609 }
610 if env_tx
611 .send(ExchangeEnvelope {
612 exchange,
613 reply_tx: None,
614 })
615 .await
616 .is_err()
617 {
618 break;
619 }
620 }
621 Ok(WsMessage::Binary(data)) => {
622 if data.len() > path_config.max_message_size as usize {
623 try_send_with_backpressure(
624 &out_tx,
625 WsMessage::Close(Some(CloseFrame {
626 code: CloseCode::from(1009u16),
627 reason: "message too large".into(),
628 })),
629 "max-message-size-close-binary",
630 );
631 break;
632 }
633
634 let mut message = CamelMessage::new(CamelBody::Bytes(data));
635 message.set_header(
636 "CamelWsConnectionKey",
637 serde_json::Value::String(connection_key.clone()),
638 );
639 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
640 message.set_header(
641 "CamelWsRemoteAddress",
642 serde_json::Value::String(remote_addr.clone()),
643 );
644
645 #[allow(unused_mut)]
646 let mut exchange = Exchange::new(message);
647 if let Some(ref p) = principal {
648 camel_api::store_principal_properties(&mut exchange, p);
649 }
650 #[cfg(feature = "otel")]
651 {
652 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
653 }
654 if env_tx
655 .send(ExchangeEnvelope {
656 exchange,
657 reply_tx: None,
658 })
659 .await
660 .is_err()
661 {
662 break;
663 }
664 }
665 Ok(WsMessage::Close(cf)) => {
666 let reason = cf
667 .as_ref()
668 .map(|f| f.reason.to_string())
669 .unwrap_or_default();
670 tracing::info!(
671 connection_key = connection_key,
672 path = path,
673 reason = reason,
674 "WebSocket connection closed by peer"
675 );
676 break;
677 }
678 Err(e) => {
679 tracing::warn!(
680 connection_key = connection_key,
681 path = path,
682 error = %e,
683 "WebSocket receive error"
684 );
685 break;
686 }
687 }
688 }
689
690 if let Some(task) = heartbeat_task {
691 task.abort();
692 }
693
694 if let Some(key) = registry_key
695 && let Some(entry) = registry.get(&key)
696 {
697 entry.remove(&connection_key);
698 }
699 drop(out_tx);
700 let _ = writer.await;
701
702 tracing::info!(
703 connection_key = connection_key,
704 path = path,
705 "WebSocket connection closed"
706 );
707}
708
709pub struct WsComponent {
710 pub(crate) config: WsConfig,
711}
712
713impl WsComponent {
714 pub fn new() -> Self {
715 Self {
716 config: WsConfig::default(),
717 }
718 }
719
720 pub fn with_config(config: WsConfig) -> Self {
721 Self { config }
722 }
723}
724
725impl Default for WsComponent {
726 fn default() -> Self {
727 Self::new()
728 }
729}
730
731impl Component for WsComponent {
732 fn scheme(&self) -> &str {
733 "ws"
734 }
735
736 fn create_endpoint(
737 &self,
738 uri: &str,
739 ctx: &dyn camel_component_api::ComponentContext,
740 ) -> Result<Box<dyn Endpoint>, CamelError> {
741 self.config.validate()?;
742 let mut cfg = WsEndpointConfig::from_uri(uri)?;
743 if let Some(v) = self.config.max_connections {
744 cfg.max_connections = v;
745 }
746 if let Some(v) = self.config.max_message_size {
747 cfg.max_message_size = v;
748 }
749 if let Some(v) = self.config.heartbeat_interval_ms {
750 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
751 }
752 if let Some(v) = self.config.idle_timeout_ms {
753 cfg.idle_timeout = std::time::Duration::from_millis(v);
754 }
755 if let Some(v) = self.config.connect_timeout_ms {
756 cfg.connect_timeout = std::time::Duration::from_millis(v);
757 }
758 if let Some(v) = self.config.response_timeout_ms {
759 cfg.response_timeout = std::time::Duration::from_millis(v);
760 }
761 if let Some(v) = self.config.send_timeout_ms {
762 cfg.send_timeout = std::time::Duration::from_millis(v);
763 }
764 if let Some(v) = self.config.binary_payload {
765 cfg.binary_payload = v;
766 }
767 if let Some(ref v) = self.config.subprotocols {
768 cfg.subprotocols = v.clone();
769 }
770 let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
771 ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
772 Ok(Box::new(WsEndpoint {
773 uri: uri.to_string(),
774 cfg,
775 }))
776 }
777}
778
779pub struct WssComponent {
780 pub(crate) config: WsConfig,
781}
782
783impl WssComponent {
784 pub fn new() -> Self {
785 Self {
786 config: WsConfig::default(),
787 }
788 }
789
790 pub fn with_config(config: WsConfig) -> Self {
791 Self { config }
792 }
793}
794
795impl Default for WssComponent {
796 fn default() -> Self {
797 Self::new()
798 }
799}
800
801impl Component for WssComponent {
802 fn scheme(&self) -> &str {
803 "wss"
804 }
805
806 fn create_endpoint(
807 &self,
808 uri: &str,
809 ctx: &dyn camel_component_api::ComponentContext,
810 ) -> Result<Box<dyn Endpoint>, CamelError> {
811 self.config.validate()?;
812 let mut cfg = WsEndpointConfig::from_uri(uri)?;
813 if let Some(v) = self.config.max_connections {
814 cfg.max_connections = v;
815 }
816 if let Some(v) = self.config.max_message_size {
817 cfg.max_message_size = v;
818 }
819 if let Some(v) = self.config.heartbeat_interval_ms {
820 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
821 }
822 if let Some(v) = self.config.idle_timeout_ms {
823 cfg.idle_timeout = std::time::Duration::from_millis(v);
824 }
825 if let Some(v) = self.config.connect_timeout_ms {
826 cfg.connect_timeout = std::time::Duration::from_millis(v);
827 }
828 if let Some(v) = self.config.response_timeout_ms {
829 cfg.response_timeout = std::time::Duration::from_millis(v);
830 }
831 if let Some(v) = self.config.send_timeout_ms {
832 cfg.send_timeout = std::time::Duration::from_millis(v);
833 }
834 if let Some(v) = self.config.binary_payload {
835 cfg.binary_payload = v;
836 }
837 if let Some(ref v) = self.config.subprotocols {
838 cfg.subprotocols = v.clone();
839 }
840 let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
841 ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
842 Ok(Box::new(WsEndpoint {
843 uri: uri.to_string(),
844 cfg,
845 }))
846 }
847}
848
849struct WsEndpoint {
850 uri: String,
851 cfg: WsEndpointConfig,
852}
853
854impl Endpoint for WsEndpoint {
855 fn uri(&self) -> &str {
856 &self.uri
857 }
858
859 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
860 Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
861 }
862
863 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
864 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
865 }
866}
867
868pub struct WsConsumer {
869 cfg: WsServerConfig,
870 registry: Arc<WsConnectionRegistry>,
871 server_state: Option<WsAppState>,
872 registry_key: Option<(String, u16, String)>,
873 forward_task: Option<JoinHandle<Result<(), CamelError>>>,
874 security_ctx: Option<camel_component_api::SecurityContext>,
875}
876
877impl WsConsumer {
878 pub fn new(cfg: WsServerConfig) -> Self {
879 Self {
880 cfg,
881 registry: Arc::new(WsConnectionRegistry::new()),
882 server_state: None,
883 registry_key: None,
884 forward_task: None,
885 security_ctx: None,
886 }
887 }
888}
889
890#[async_trait]
891impl Consumer for WsConsumer {
892 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
893 if self.server_state.is_some() {
895 return Err(CamelError::EndpointCreationFailed(
896 "WebSocket consumer already started".into(),
897 ));
898 }
899
900 tracing::info!(
901 host = self.cfg.inner.host,
902 port = self.cfg.inner.port,
903 path = self.cfg.inner.path,
904 scheme = self.cfg.inner.scheme,
905 "WebSocket consumer starting"
906 );
907
908 let tls_config = if self.cfg.inner.scheme == "wss" {
909 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
910 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
911 })?;
912 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
913 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
914 })?;
915 Some(WsTlsConfig {
916 cert_path,
917 key_path,
918 })
919 } else {
920 None
921 };
922
923 let state = ServerRegistry::global()
924 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
925 .await?;
926
927 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
928 {
929 let mut table = state.dispatch.write().await;
930 table.insert(self.cfg.inner.path.clone(), env_tx);
931 }
932
933 state.path_configs.insert(
934 self.cfg.inner.path.clone(),
935 WsPathConfig {
936 max_connections: self.cfg.inner.max_connections,
937 max_message_size: self.cfg.inner.max_message_size,
938 heartbeat_interval: self.cfg.inner.heartbeat_interval,
939 idle_timeout: self.cfg.inner.idle_timeout,
940 allow_origin: self.cfg.inner.allow_origin.clone(),
941 },
942 );
943
944 if let Some(ref sec_ctx) = self.security_ctx {
945 let path = self.cfg.inner.path.clone();
946 state.path_policies.insert(path, sec_ctx.clone());
947 }
948
949 let registry_key = (
950 self.cfg.inner.canonical_host(),
951 self.cfg.inner.port,
952 self.cfg.inner.path.clone(),
953 );
954 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
955
956 let sender = ctx.sender();
957 let forward_task: JoinHandle<Result<(), CamelError>> = tokio::spawn(async move {
958 while let Some(envelope) = env_rx.recv().await {
959 if sender.send(envelope).await.is_err() {
960 break;
961 }
962 }
963 Ok(())
964 });
965
966 self.server_state = Some(state);
967 self.registry_key = Some(registry_key);
968 self.forward_task = Some(forward_task);
969 Ok(())
970 }
971
972 async fn stop(&mut self) -> Result<(), CamelError> {
973 tracing::info!(
974 host = self.cfg.inner.host,
975 port = self.cfg.inner.port,
976 path = self.cfg.inner.path,
977 "WebSocket consumer stopping"
978 );
979
980 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
981 code: axum::extract::ws::CloseCode::from(1001u16),
982 reason: "consumer stopping".into(),
983 }));
984 for tx in self.registry.snapshot_senders() {
985 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
986 }
987
988 let mut had_server_error = false;
989
990 if let Some(state) = self.server_state.take() {
991 had_server_error = state.server_error.load(Ordering::Relaxed);
992 state.path_policies.remove(&self.cfg.inner.path);
993 let mut table = state.dispatch.write().await;
994 table.remove(&self.cfg.inner.path);
995 state.path_configs.remove(&self.cfg.inner.path);
996 }
997
998 if let Some(key) = self.registry_key.take() {
999 global_registries().remove(&key);
1000 ServerRegistry::global().release(key.1);
1001 }
1002
1003 if let Some(task) = self.forward_task.take() {
1004 task.abort();
1005 }
1006
1007 tracing::info!(
1008 host = self.cfg.inner.host,
1009 port = self.cfg.inner.port,
1010 path = self.cfg.inner.path,
1011 "WebSocket consumer stopped"
1012 );
1013
1014 if had_server_error {
1015 tracing::warn!(
1016 host = self.cfg.inner.host,
1017 port = self.cfg.inner.port,
1018 path = self.cfg.inner.path,
1019 "WebSocket server had errors during its lifetime"
1020 );
1021 return Err(CamelError::ProcessorError(
1022 "WebSocket server terminated with errors during its lifetime".into(),
1023 ));
1024 }
1025
1026 Ok(())
1027 }
1028
1029 fn concurrency_model(&self) -> ConcurrencyModel {
1030 ConcurrencyModel::Concurrent {
1031 max: Some(self.cfg.inner.max_connections as usize),
1032 }
1033 }
1034
1035 fn background_task_handle(&mut self) -> Option<JoinHandle<Result<(), CamelError>>> {
1036 self.forward_task.take()
1037 }
1038
1039 fn set_security_context(&mut self, ctx: camel_component_api::SecurityContext) {
1040 self.security_ctx = Some(ctx);
1041 }
1042}
1043
1044use std::sync::atomic::{AtomicBool, Ordering};
1045
1046fn new_atomic_false() -> Arc<AtomicBool> {
1047 Arc::new(AtomicBool::new(false))
1048}
1049
1050#[derive(Clone)]
1051pub struct WsProducer {
1052 cfg: WsClientConfig,
1053 backpressure_flag: Arc<AtomicBool>,
1056}
1057
1058impl WsProducer {
1059 pub fn new(cfg: WsClientConfig) -> Self {
1060 Self {
1061 cfg,
1062 backpressure_flag: Arc::new(AtomicBool::new(false)),
1063 }
1064 }
1065}
1066
1067impl Service<Exchange> for WsProducer {
1068 type Response = Exchange;
1069 type Error = CamelError;
1070 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
1071
1072 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
1073 if self.backpressure_flag.swap(false, Ordering::Relaxed) {
1075 return Poll::Ready(Err(CamelError::ProcessorError(
1076 "WebSocket producer backpressure: previous send was dropped due to full channel"
1077 .into(),
1078 )));
1079 }
1080 Poll::Ready(Ok(()))
1081 }
1082
1083 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
1084 let cfg = self.cfg.clone();
1085 let backpressure_flag = Arc::clone(&self.backpressure_flag);
1086
1087 Box::pin(async move {
1088 let canonical_host = cfg.inner.canonical_host();
1089 let key = (
1090 canonical_host.clone(),
1091 cfg.inner.port,
1092 cfg.inner.path.clone(),
1093 );
1094
1095 let send_to_all = exchange
1096 .input
1097 .header("CamelWsSendToAll")
1098 .and_then(|v| v.as_bool())
1099 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
1100 .unwrap_or(false);
1101
1102 let conn_keys_header = exchange
1103 .input
1104 .header("CamelWsConnectionKey")
1105 .and_then(|v| v.as_str())
1106 .map(str::to_string);
1107
1108 let local_exists = global_registries().contains_key(&key);
1109 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
1110
1111 let message_type = exchange
1112 .input
1113 .header("CamelWsMessageType")
1114 .and_then(|v| v.as_str())
1115 .unwrap_or("text")
1116 .to_ascii_lowercase();
1117
1118 if server_send_mode {
1119 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
1120 let Some(registry) = registry else {
1121 return Err(CamelError::ProcessorError(format!(
1122 "WebSocket local consumer not found for {}:{}{}",
1123 canonical_host, cfg.inner.port, cfg.inner.path
1124 )));
1125 };
1126
1127 let out_msg = body_to_axum_ws_message(
1128 std::mem::take(&mut exchange.input.body),
1129 &message_type,
1130 )
1131 .await?;
1132
1133 let targets = if send_to_all {
1134 registry.snapshot_senders()
1135 } else if let Some(keys) = conn_keys_header {
1136 let parsed: Vec<String> = keys
1137 .split(',')
1138 .map(str::trim)
1139 .filter(|k| !k.is_empty())
1140 .map(|k| k.to_string())
1141 .collect();
1142 registry.get_senders_for_keys(&parsed)
1143 } else {
1144 registry.snapshot_senders()
1145 };
1146
1147 let mut dropped = 0usize;
1148 for tx in &targets {
1149 if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1150 dropped += 1;
1151 }
1152 }
1153
1154 if dropped > 0 {
1155 tracing::warn!(
1156 host = canonical_host,
1157 port = cfg.inner.port,
1158 path = cfg.inner.path,
1159 dropped,
1160 total = targets.len(),
1161 "WebSocket producer dropped messages due to backpressure"
1162 );
1163 exchange.input.set_header(
1164 "CamelWsDeliveryDropped",
1165 serde_json::Value::Number(dropped.into()),
1166 );
1167 backpressure_flag.store(true, Ordering::Relaxed);
1169 if dropped == targets.len() {
1170 return Err(CamelError::ProcessorError(format!(
1171 "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1172 )));
1173 }
1174 }
1175
1176 tracing::debug!(
1177 host = canonical_host,
1178 port = cfg.inner.port,
1179 path = cfg.inner.path,
1180 targets = targets.len(),
1181 "WebSocket producer server-send complete"
1182 );
1183
1184 return Ok(exchange);
1185 }
1186
1187 let url = format!(
1188 "{}://{}:{}{}",
1189 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1190 );
1191
1192 tracing::debug!(url = url, "WebSocket producer connecting");
1193
1194 #[allow(unused_mut)]
1195 let mut request = url
1196 .clone()
1197 .into_client_request()
1198 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1199
1200 #[cfg(feature = "otel")]
1201 {
1202 let mut otel_headers = HashMap::new();
1203 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1204 for (k, v) in otel_headers {
1205 if let (Ok(name), Ok(val)) = (
1206 http::header::HeaderName::from_bytes(k.as_bytes()),
1207 http::header::HeaderValue::from_str(&v),
1208 ) {
1209 request.headers_mut().insert(name, val);
1210 }
1211 }
1212 }
1213
1214 if !cfg.inner.subprotocols.is_empty() {
1216 let proto_value = cfg.inner.subprotocols.join(", ");
1217 if let (Ok(name), Ok(val)) = (
1218 http::header::HeaderName::from_bytes(b"Sec-WebSocket-Protocol"),
1219 http::header::HeaderValue::from_str(&proto_value),
1220 ) {
1221 request.headers_mut().insert(name, val);
1222 }
1223 }
1224
1225 let effective_message_type = if cfg.inner.binary_payload {
1227 "binary"
1228 } else {
1229 &message_type
1230 };
1231
1232 let max_retries = if cfg.inner.reconnect {
1233 cfg.inner.reconnect_max_attempts as usize
1234 } else {
1235 0
1236 };
1237 let mut backoff = BackoffState::new(BackoffConfig {
1238 initial_delay: std::time::Duration::from_millis(cfg.inner.reconnect_delay_ms),
1239 multiplier: 2.0,
1240 max_delay: std::time::Duration::from_secs(30),
1241 });
1242 let mut attempts = 0usize;
1243 let mut ws_stream = loop {
1244 let connect_future = tokio_tungstenite::connect_async(request.clone());
1245 match tokio::time::timeout(cfg.inner.connect_timeout, connect_future).await {
1246 Ok(Ok((stream, _))) => break stream,
1247 Ok(Err(e)) => {
1248 let err = map_connect_error(e, &url);
1249 let is_transient = err.to_string().contains("connection refused")
1250 || err.to_string().contains("timeout")
1251 || err.to_string().contains("connection failed");
1252 if is_transient && attempts < max_retries {
1253 attempts += 1;
1254 tracing::warn!(
1255 url = url,
1256 error = %err,
1257 attempt = attempts,
1258 max_retries,
1259 "WebSocket connect failed — retrying"
1260 );
1261 tokio::time::sleep(backoff.next_delay()).await;
1262 continue;
1263 }
1264 return Err(err);
1265 }
1266 Err(_) => {
1267 let err = CamelError::ProcessorError(format!(
1268 "WebSocket connect timeout ({:?}) to {url}",
1269 cfg.inner.connect_timeout
1270 ));
1271 if attempts < max_retries {
1272 attempts += 1;
1273 tracing::warn!(
1274 url = url,
1275 attempt = attempts,
1276 max_retries,
1277 "WebSocket connect timeout — retrying"
1278 );
1279 tokio::time::sleep(backoff.next_delay()).await;
1280 continue;
1281 }
1282 return Err(err);
1283 }
1284 }
1285 };
1286 if attempts > 0 {
1287 tracing::info!(url = url, "WebSocket producer connected after retry");
1288 }
1289
1290 let out_msg = body_to_client_ws_message(
1291 std::mem::take(&mut exchange.input.body),
1292 effective_message_type,
1293 )
1294 .await?;
1295
1296 ws_stream
1297 .send(out_msg)
1298 .await
1299 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1300
1301 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1302 loop {
1303 match ws_stream.next().await {
1304 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1305 continue;
1306 }
1307 other => break other,
1308 }
1309 }
1310 })
1311 .await
1312 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1313
1314 match incoming {
1315 Some(Ok(ClientWsMessage::Text(text))) => {
1316 tracing::debug!(url = url, "WebSocket producer received text response");
1317 exchange.input.body = CamelBody::Text(text.to_string());
1318 }
1319 Some(Ok(ClientWsMessage::Binary(data))) => {
1320 tracing::debug!(url = url, "WebSocket producer received binary response");
1321 exchange.input.body = CamelBody::Bytes(data);
1322 }
1323 Some(Ok(ClientWsMessage::Close(frame))) => {
1324 let normal = frame
1325 .as_ref()
1326 .map(|f| {
1327 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1328 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1329 })
1330 .unwrap_or(true);
1331
1332 if normal {
1333 tracing::debug!(url = url, "WebSocket producer received normal close");
1334 exchange.input.body = CamelBody::Empty;
1335 } else if cfg.inner.reconnect && attempts < max_retries {
1336 backoff.reset();
1337 attempts += 1;
1338 tracing::warn!(
1339 url = url,
1340 attempt = attempts,
1341 max_retries,
1342 "WebSocket closed by peer — reconnecting"
1343 );
1344 tokio::time::sleep(backoff.next_delay()).await;
1345 return Err(CamelError::ProcessorError(format!(
1346 "WebSocket reconnect required after close: code {}",
1347 frame.map(|f| u16::from(f.code)).unwrap_or_default()
1348 )));
1349 } else {
1350 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1351 return Err(CamelError::ProcessorError(format!(
1352 "WebSocket peer closed: code {code}"
1353 )));
1354 }
1355 }
1356 Some(Ok(_)) | None => {
1357 exchange.input.body = CamelBody::Empty;
1358 }
1359 Some(Err(e)) => {
1360 return Err(CamelError::ProcessorError(format!(
1361 "WebSocket receive failed: {e}"
1362 )));
1363 }
1364 }
1365
1366 let _ = ws_stream.close(None).await;
1367 tracing::debug!(url = url, "WebSocket producer connection closed");
1368 Ok(exchange)
1369 })
1370 }
1371}
1372
1373async fn body_to_axum_ws_message(
1374 body: CamelBody,
1375 message_type: &str,
1376) -> Result<WsMessage, CamelError> {
1377 match message_type {
1378 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1379 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1380 }
1381}
1382
1383async fn body_to_client_ws_message(
1384 body: CamelBody,
1385 message_type: &str,
1386) -> Result<ClientWsMessage, CamelError> {
1387 match message_type {
1388 "binary" => Ok(ClientWsMessage::Binary(
1389 body.into_bytes(10 * 1024 * 1024).await?,
1390 )),
1391 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1392 }
1393}
1394
1395async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1396 Ok(match body {
1397 CamelBody::Empty => String::new(),
1398 CamelBody::Text(s) => s,
1399 CamelBody::Xml(s) => s,
1400 CamelBody::Json(v) => v.to_string(),
1401 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1402 CamelBody::Stream(stream) => {
1403 let bytes = CamelBody::Stream(stream)
1404 .into_bytes(10 * 1024 * 1024)
1405 .await?;
1406 String::from_utf8_lossy(&bytes).to_string()
1407 }
1408 })
1409}
1410
1411fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1412 if allowed_origin == "*" {
1413 return true;
1414 }
1415 request_origin.is_some_and(|origin| origin == allowed_origin)
1416}
1417
1418fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1419 match tx.try_send(msg) {
1420 Ok(()) => true,
1421 Err(error) => {
1422 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1423 false
1424 }
1425 }
1426}
1427
1428fn load_tls_config(
1429 cert_path: &str,
1430 key_path: &str,
1431) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1432 use std::fs::File;
1433 use std::io::BufReader;
1434
1435 let cert_file = File::open(cert_path)
1436 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1437 let key_file = File::open(key_path)
1438 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1439
1440 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1441 .collect::<Result<Vec<_>, _>>()
1442 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1443
1444 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1445 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1446 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1447
1448 tokio_rustls::rustls::ServerConfig::builder()
1449 .with_no_client_auth()
1450 .with_single_cert(certs, key)
1451 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1452}
1453
1454fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1455 match err {
1456 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1457 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1458 }
1459 tungstenite::Error::Tls(_) => {
1460 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1461 }
1462 other => {
1463 let msg = other.to_string();
1464 if msg.to_lowercase().contains("connection refused") {
1465 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1466 } else if msg.to_lowercase().contains("tls") {
1467 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1468 } else {
1469 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1470 }
1471 }
1472 }
1473}
1474
1475#[cfg(test)]
1476mod tests {
1477 use super::*;
1478 use camel_component_api::NoOpComponentContext;
1479 use std::time::Duration;
1480
1481 use tokio::sync::mpsc;
1482 use tokio_tungstenite::connect_async;
1483 use tokio_tungstenite::tungstenite::Message as ClientMessage;
1484 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1485 use tokio_util::sync::CancellationToken;
1486 use tower::ServiceExt;
1487
1488 fn free_port() -> u16 {
1489 std::net::TcpListener::bind("127.0.0.1:0")
1490 .unwrap()
1491 .local_addr()
1492 .unwrap()
1493 .port()
1494 }
1495
1496 #[test]
1497 fn ws_component_scheme_is_ws() {
1498 assert_eq!(WsComponent::new().scheme(), "ws");
1499 }
1500
1501 #[test]
1502 fn wss_component_scheme_is_wss() {
1503 assert_eq!(WssComponent::new().scheme(), "wss");
1504 }
1505
1506 #[test]
1507 fn endpoint_config_defaults_match_spec() {
1508 let cfg = WsEndpointConfig::default();
1509 assert_eq!(cfg.scheme, "ws");
1510 assert_eq!(cfg.host, "0.0.0.0");
1511 assert_eq!(cfg.port, 8080);
1512 assert_eq!(cfg.path, "/");
1513 assert_eq!(cfg.max_connections, 100);
1514 assert_eq!(cfg.max_message_size, 65536);
1515 assert!(!cfg.send_to_all);
1516 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1517 assert_eq!(cfg.idle_timeout, Duration::ZERO);
1518 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1519 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1520 assert_eq!(cfg.allow_origin, "*");
1521 assert_eq!(cfg.tls_cert, None);
1522 assert_eq!(cfg.tls_key, None);
1523 assert!(cfg.reconnect);
1524 assert_eq!(cfg.reconnect_max_attempts, 5);
1525 assert_eq!(cfg.reconnect_delay_ms, 1000);
1526 assert_eq!(cfg.send_timeout, Duration::from_secs(30));
1527 assert!(!cfg.binary_payload);
1528 assert!(cfg.subprotocols.is_empty());
1529 }
1530
1531 #[test]
1532 fn endpoint_config_parses_uri_params() {
1533 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1534 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1535
1536 assert_eq!(cfg.scheme, "ws");
1537 assert_eq!(cfg.host, "localhost");
1538 assert_eq!(cfg.port, 9001);
1539 assert_eq!(cfg.path, "/chat");
1540 assert_eq!(cfg.max_connections, 42);
1541 assert_eq!(cfg.max_message_size, 1024);
1542 assert!(cfg.send_to_all);
1543 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1544 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1545 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1546 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1547 assert_eq!(cfg.allow_origin, "https://example.com");
1548 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1549 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1550 assert!(cfg.reconnect);
1551 assert_eq!(cfg.reconnect_max_attempts, 5);
1552 assert_eq!(cfg.reconnect_delay_ms, 1000);
1553 }
1554
1555 #[test]
1556 fn endpoint_config_parses_reconnect_uri_params() {
1557 let uri =
1558 "ws://localhost:9001/chat?reconnect=false&reconnectMaxAttempts=2&reconnectDelayMs=25";
1559 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1560 assert!(!cfg.reconnect);
1561 assert_eq!(cfg.reconnect_max_attempts, 2);
1562 assert_eq!(cfg.reconnect_delay_ms, 25);
1563 }
1564
1565 #[test]
1566 fn endpoint_config_override_chain_uri_overrides_defaults() {
1567 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1568 assert_eq!(cfg.max_connections, 7);
1569 assert_eq!(cfg.max_message_size, 65536);
1570 assert!(!cfg.send_to_all);
1571 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1572 }
1573
1574 #[test]
1575 fn endpoint_trait_creates_consumer_and_producer() {
1576 let ctx = NoOpComponentContext;
1577 let endpoint = WsComponent::new()
1578 .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1579 .unwrap();
1580
1581 endpoint.create_consumer().unwrap();
1582 endpoint
1583 .create_producer(&ProducerContext::default())
1584 .unwrap();
1585 }
1586
1587 #[test]
1588 fn ws_consumer_concurrency_model_uses_max_connections() {
1589 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1590 let consumer = WsConsumer::new(cfg.server_config());
1591 assert_eq!(
1592 consumer.concurrency_model(),
1593 ConcurrencyModel::Concurrent { max: Some(321) }
1594 );
1595 }
1596
1597 #[tokio::test]
1598 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1599 let registry = WsConnectionRegistry::new();
1600 let (tx1, mut rx1) = mpsc::channel(8);
1601 let (tx2, mut rx2) = mpsc::channel(8);
1602
1603 registry.insert("k1".into(), tx1);
1604 registry.insert("k2".into(), tx2);
1605 assert_eq!(registry.len(), 2);
1606
1607 for tx in registry.snapshot_senders() {
1608 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1609 }
1610
1611 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1612 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1613
1614 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1615 assert_eq!(target.len(), 1);
1616 target[0]
1617 .send(WsMessage::Text("targeted".into()))
1618 .await
1619 .unwrap();
1620
1621 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1622 assert!(
1623 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1624 .await
1625 .is_err()
1626 );
1627
1628 registry.remove("k1");
1629 assert_eq!(registry.len(), 1);
1630 }
1631
1632 #[test]
1633 fn host_canonicalization_maps_local_hosts_to_loopback() {
1634 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1635 .unwrap()
1636 .canonical_host();
1637 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1638 .unwrap()
1639 .canonical_host();
1640 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1641 .unwrap()
1642 .canonical_host();
1643
1644 assert_eq!(c1, "127.0.0.1");
1645 assert_eq!(c2, "127.0.0.1");
1646 assert_eq!(c3, "127.0.0.1");
1647 }
1648
1649 #[tokio::test]
1650 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1651 let port = free_port();
1652 let uri = format!("ws://127.0.0.1:{port}/echo");
1653 let component_ctx = NoOpComponentContext;
1654 let endpoint = WsComponent::new()
1655 .create_endpoint(&uri, &component_ctx)
1656 .unwrap();
1657
1658 let mut consumer = endpoint.create_consumer().unwrap();
1659 let producer = endpoint
1660 .create_producer(&ProducerContext::default())
1661 .unwrap();
1662
1663 let (route_tx, mut route_rx) = mpsc::channel(16);
1664 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1665 consumer.start(ctx).await.unwrap();
1666
1667 let route_task = tokio::spawn(async move {
1668 if let Some(envelope) = route_rx.recv().await {
1669 let payload = envelope
1670 .exchange
1671 .input
1672 .body
1673 .as_text()
1674 .unwrap_or_default()
1675 .to_string();
1676 let key = envelope
1677 .exchange
1678 .input
1679 .header("CamelWsConnectionKey")
1680 .and_then(|v| v.as_str())
1681 .unwrap()
1682 .to_string();
1683
1684 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1685 response
1686 .input
1687 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1688 producer.oneshot(response).await.unwrap();
1689 }
1690 });
1691
1692 let url = format!("ws://127.0.0.1:{port}/echo");
1693 let (mut client, _) = loop {
1694 match connect_async(&url).await {
1695 Ok(ok) => break ok,
1696 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1697 }
1698 };
1699
1700 client
1701 .send(ClientMessage::Text("hello-ws".into()))
1702 .await
1703 .unwrap();
1704
1705 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1706 loop {
1707 match client.next().await {
1708 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1709 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1710 Some(Ok(_)) => continue,
1711 Some(Err(e)) => panic!("ws receive failed: {e}"),
1712 None => panic!("websocket closed before echo"),
1713 }
1714 }
1715 })
1716 .await
1717 .unwrap();
1718
1719 assert_eq!(incoming, "hello-ws");
1720
1721 consumer.stop().await.unwrap();
1722 route_task.await.unwrap();
1723 }
1724
1725 #[tokio::test]
1726 async fn consumer_stop_sends_close_1001() {
1727 let port = free_port();
1728 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1729 let component_ctx = NoOpComponentContext;
1730 let endpoint = WsComponent::new()
1731 .create_endpoint(&uri, &component_ctx)
1732 .unwrap();
1733
1734 let mut consumer = endpoint.create_consumer().unwrap();
1735 let (route_tx, _route_rx) = mpsc::channel(16);
1736 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1737 consumer.start(ctx).await.unwrap();
1738
1739 let url = format!("ws://127.0.0.1:{port}/shutdown");
1740 let (mut client, _) = loop {
1741 match connect_async(&url).await {
1742 Ok(ok) => break ok,
1743 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1744 }
1745 };
1746
1747 client
1748 .send(ClientMessage::Text("keepalive".into()))
1749 .await
1750 .unwrap();
1751
1752 consumer.stop().await.unwrap();
1753
1754 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1755 loop {
1756 match client.next().await {
1757 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1758 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1759 Some(Ok(_)) => continue,
1760 Some(Err(e)) => panic!("ws receive failed: {e}"),
1761 None => panic!("websocket closed without close frame"),
1762 }
1763 }
1764 })
1765 .await
1766 .unwrap();
1767
1768 assert_eq!(close_code, Some(CloseCode::Away));
1769 }
1770
1771 #[test]
1772 fn wildcard_origin_allows_anything() {
1773 assert!(is_origin_allowed("*", None));
1774 assert!(is_origin_allowed("*", Some("https://example.com")));
1775 }
1776
1777 #[test]
1778 fn exact_origin_requires_match() {
1779 assert!(is_origin_allowed(
1780 "https://example.com",
1781 Some("https://example.com")
1782 ));
1783 assert!(!is_origin_allowed(
1784 "https://example.com",
1785 Some("https://other.com")
1786 ));
1787 assert!(!is_origin_allowed("https://example.com", None));
1788 }
1789
1790 #[test]
1791 fn endpoint_config_rejects_invalid_scheme() {
1792 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1793 assert!(result.is_err());
1794 let msg = result.unwrap_err().to_string();
1795 assert!(
1796 msg.contains("Invalid WebSocket scheme"),
1797 "expected scheme error, got: {msg}"
1798 );
1799 }
1800
1801 #[tokio::test]
1802 async fn wss_consumer_start_fails_without_tls_cert() {
1803 let port = free_port();
1804 let component_ctx = NoOpComponentContext;
1805 let endpoint = WssComponent::new()
1806 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1807 .unwrap();
1808 let mut consumer = endpoint.create_consumer().unwrap();
1809 let (tx, _rx) = mpsc::channel(16);
1810 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1811 let result = consumer.start(ctx).await;
1812 assert!(result.is_err());
1813 let msg = result.unwrap_err().to_string();
1814 assert!(
1815 msg.contains("TLS cert path is required"),
1816 "expected TLS cert error, got: {msg}"
1817 );
1818 }
1819
1820 #[tokio::test]
1821 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1822 let port = free_port();
1823 let component_ctx = NoOpComponentContext;
1824 let endpoint = WssComponent::new()
1825 .create_endpoint(&format!(
1826 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1827 ), &component_ctx)
1828 .unwrap();
1829 let mut consumer = endpoint.create_consumer().unwrap();
1830 let (tx, _rx) = mpsc::channel(16);
1831 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1832 let result = consumer.start(ctx).await;
1833 assert!(result.is_err());
1834 let msg = result.unwrap_err().to_string();
1835 assert!(
1836 msg.contains("TLS cert file error"),
1837 "expected cert file error, got: {msg}"
1838 );
1839 }
1840
1841 #[tokio::test]
1842 async fn server_registry_returns_same_state_for_same_port() {
1843 let port = free_port();
1844 let state1 = ServerRegistry::global()
1845 .get_or_spawn("127.0.0.1", port, None)
1846 .await
1847 .unwrap();
1848 let state2 = ServerRegistry::global()
1849 .get_or_spawn("127.0.0.1", port, None)
1850 .await
1851 .unwrap();
1852 assert!(
1853 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1854 "expected same dispatch table for same port"
1855 );
1856 }
1857
1858 #[tokio::test]
1859 async fn dispatch_handler_returns_404_for_unregistered_path() {
1860 let port = free_port();
1861 let state = ServerRegistry::global()
1862 .get_or_spawn("127.0.0.1", port, None)
1863 .await
1864 .unwrap();
1865 let app = Router::new().fallback(dispatch_handler).with_state(state);
1866 let response = tokio::time::timeout(
1867 Duration::from_secs(2),
1868 tower::ServiceExt::oneshot(
1869 app,
1870 axum::http::Request::builder()
1871 .uri("/nonexistent")
1872 .body(Body::empty())
1873 .unwrap(),
1874 ),
1875 )
1876 .await
1877 .unwrap()
1878 .unwrap();
1879 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1880 }
1881
1882 #[tokio::test]
1883 async fn client_mode_producer_connects_and_echoes() {
1884 let app = Router::new().route(
1885 "/echo",
1886 axum::routing::get(|ws: WebSocketUpgrade| async move {
1887 ws.on_upgrade(|mut socket: WebSocket| async move {
1888 while let Some(Ok(msg)) = socket.recv().await {
1889 match msg {
1890 WsMessage::Text(text) => {
1891 let _ = socket.send(WsMessage::Text(text)).await;
1892 }
1893 WsMessage::Binary(data) => {
1894 let _ = socket.send(WsMessage::Binary(data)).await;
1895 }
1896 WsMessage::Close(_) => break,
1897 _ => {}
1898 }
1899 }
1900 })
1901 }),
1902 );
1903 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1905 let port = listener.local_addr().unwrap().port();
1906 let server_task = tokio::spawn(async move {
1907 let _ = serve(listener, app).await;
1908 });
1909
1910 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1911 let producer = WsProducer::new(cfg.client_config());
1912
1913 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1914 tokio::time::sleep(Duration::from_millis(25)).await;
1915 let result =
1916 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1917 Ok(Ok(r)) => r,
1918 Ok(Err(_)) => panic!("producer call failed"),
1919 Err(_) => panic!("producer call timed out"),
1920 };
1921
1922 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1923
1924 server_task.abort();
1925 }
1926
1927 #[tokio::test]
1928 async fn max_connections_rejects_with_close_1013() {
1929 let port = free_port();
1930 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1931 let component_ctx = NoOpComponentContext;
1932 let endpoint = WsComponent::new()
1933 .create_endpoint(&uri, &component_ctx)
1934 .unwrap();
1935 let mut consumer = endpoint.create_consumer().unwrap();
1936 let (route_tx, _route_rx) = mpsc::channel(16);
1937 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1938 consumer.start(ctx).await.unwrap();
1939
1940 let url = format!("ws://127.0.0.1:{port}/limited");
1941 let (_client1, _) = loop {
1942 match connect_async(&url).await {
1943 Ok(ok) => break ok,
1944 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1945 }
1946 };
1947
1948 tokio::time::sleep(Duration::from_millis(100)).await;
1949
1950 let (mut client2, _) = connect_async(&url).await.unwrap();
1951
1952 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1953 loop {
1954 match client2.next().await {
1955 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1956 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1957 Some(Ok(ClientMessage::Text(_))) => continue,
1958 Some(Ok(_)) => continue,
1959 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1960 None => panic!("client2 closed without close frame"),
1961 }
1962 }
1963 })
1964 .await
1965 .unwrap();
1966
1967 assert_eq!(
1968 close_code,
1969 Some(CloseCode::from(1013u16)),
1970 "expected 1013 (Try Again Later) for max connections"
1971 );
1972
1973 consumer.stop().await.unwrap();
1974 }
1975
1976 #[tokio::test]
1977 async fn max_message_size_rejects_with_close_1009() {
1978 let port = free_port();
1979 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1980 let component_ctx = NoOpComponentContext;
1981 let endpoint = WsComponent::new()
1982 .create_endpoint(&uri, &component_ctx)
1983 .unwrap();
1984 let mut consumer = endpoint.create_consumer().unwrap();
1985 let (route_tx, _route_rx) = mpsc::channel(16);
1986 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1987 consumer.start(ctx).await.unwrap();
1988
1989 let url = format!("ws://127.0.0.1:{port}/sizelimit");
1990 let (mut client, _) = loop {
1991 match connect_async(&url).await {
1992 Ok(ok) => break ok,
1993 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1994 }
1995 };
1996
1997 let oversized = "x".repeat(100);
1998 client
1999 .send(ClientMessage::Text(oversized.into()))
2000 .await
2001 .unwrap();
2002
2003 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
2004 loop {
2005 match client.next().await {
2006 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
2007 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2008 Some(Ok(_)) => continue,
2009 Some(Err(e)) => panic!("ws receive failed: {e}"),
2010 None => panic!("websocket closed without close frame"),
2011 }
2012 }
2013 })
2014 .await
2015 .unwrap();
2016
2017 assert_eq!(
2018 close_code,
2019 Some(CloseCode::from(1009u16)),
2020 "expected 1009 (Message Too Big) for oversized message"
2021 );
2022
2023 consumer.stop().await.unwrap();
2024 }
2025
2026 #[tokio::test]
2027 async fn origin_rejection_returns_403() {
2028 let port = free_port();
2029 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
2030 let component_ctx = NoOpComponentContext;
2031 let endpoint = WsComponent::new()
2032 .create_endpoint(&uri, &component_ctx)
2033 .unwrap();
2034 let mut consumer = endpoint.create_consumer().unwrap();
2035 let (route_tx, _route_rx) = mpsc::channel(16);
2036 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2037 consumer.start(ctx).await.unwrap();
2038
2039 let state = ServerRegistry::global()
2040 .get_or_spawn("127.0.0.1", port, None)
2041 .await
2042 .unwrap();
2043 let app = Router::new().fallback(dispatch_handler).with_state(state);
2044
2045 let response = tokio::time::timeout(
2046 Duration::from_secs(2),
2047 tower::ServiceExt::oneshot(
2048 app,
2049 axum::http::Request::builder()
2050 .uri("/origintest")
2051 .header("origin", "https://evil.com")
2052 .header("upgrade", "websocket")
2053 .header("connection", "Upgrade")
2054 .header("sec-websocket-version", "13")
2055 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
2056 .body(Body::empty())
2057 .unwrap(),
2058 ),
2059 )
2060 .await
2061 .unwrap()
2062 .unwrap();
2063
2064 assert_eq!(
2065 response.status(),
2066 StatusCode::FORBIDDEN,
2067 "expected 403 for disallowed origin"
2068 );
2069
2070 consumer.stop().await.unwrap();
2071 }
2072
2073 #[tokio::test]
2074 async fn broadcast_sends_to_all_connected_clients() {
2075 let port = free_port();
2076 let uri = format!("ws://127.0.0.1:{port}/bc");
2077 let component_ctx = NoOpComponentContext;
2078 let endpoint = WsComponent::new()
2079 .create_endpoint(&uri, &component_ctx)
2080 .unwrap();
2081 let mut consumer = endpoint.create_consumer().unwrap();
2082 let producer = endpoint
2083 .create_producer(&ProducerContext::default())
2084 .unwrap();
2085
2086 let (route_tx, _route_rx) = mpsc::channel(16);
2087 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2088 consumer.start(ctx).await.unwrap();
2089
2090 let url = format!("ws://127.0.0.1:{port}/bc");
2091
2092 let (mut client1, _) = loop {
2093 match connect_async(&url).await {
2094 Ok(ok) => break ok,
2095 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2096 }
2097 };
2098
2099 let (mut client2, _) = connect_async(&url).await.unwrap();
2100
2101 tokio::time::sleep(Duration::from_millis(100)).await;
2102
2103 let mut response =
2104 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
2105 response
2106 .input
2107 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
2108 producer.oneshot(response).await.unwrap();
2109
2110 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
2111 loop {
2112 match client1.next().await {
2113 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2114 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2115 _ => panic!("client1 unexpected message or close"),
2116 }
2117 }
2118 })
2119 .await
2120 .unwrap();
2121
2122 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
2123 loop {
2124 match client2.next().await {
2125 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2126 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2127 _ => panic!("client2 unexpected message or close"),
2128 }
2129 }
2130 })
2131 .await
2132 .unwrap();
2133
2134 assert_eq!(recv1, "broadcast-msg");
2135 assert_eq!(recv2, "broadcast-msg");
2136
2137 consumer.stop().await.unwrap();
2138 }
2139
2140 #[tokio::test]
2141 async fn concurrent_get_or_spawn_returns_same_state() {
2142 let port = free_port();
2143 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
2144 Arc::new(std::sync::Mutex::new(Vec::new()));
2145
2146 let mut handles = Vec::new();
2147 for _ in 0..4 {
2148 let results = results.clone();
2149 handles.push(tokio::spawn(async move {
2150 let state = ServerRegistry::global()
2151 .get_or_spawn("127.0.0.1", port, None)
2152 .await
2153 .unwrap();
2154 results.lock().unwrap().push(state);
2155 }));
2156 }
2157
2158 for h in handles {
2159 h.await.unwrap();
2160 }
2161
2162 let states = results.lock().unwrap();
2163 assert_eq!(states.len(), 4);
2164 for i in 1..states.len() {
2165 assert!(
2166 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
2167 "all concurrent callers should get the same dispatch table"
2168 );
2169 }
2170 }
2171
2172 #[tokio::test]
2173 async fn body_conversion_helpers_cover_text_and_binary_paths() {
2174 let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
2175 .await
2176 .unwrap();
2177 assert!(matches!(text_msg, WsMessage::Text(_)));
2178
2179 let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
2180 .await
2181 .unwrap();
2182 assert!(matches!(bin_msg, WsMessage::Binary(_)));
2183
2184 let client_text =
2185 body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
2186 .await
2187 .unwrap();
2188 assert!(matches!(client_text, ClientWsMessage::Text(_)));
2189
2190 let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
2191 .await
2192 .unwrap();
2193 assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
2194 }
2195
2196 #[tokio::test]
2197 async fn body_to_text_handles_empty_text_json_and_bytes() {
2198 assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2199 assert_eq!(
2200 body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2201 "hello"
2202 );
2203 assert_eq!(
2204 body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2205 .await
2206 .unwrap(),
2207 "{\"n\":1}"
2208 );
2209 assert_eq!(
2210 body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2211 .await
2212 .unwrap(),
2213 "hi"
2214 );
2215 }
2216
2217 #[test]
2218 fn try_send_with_backpressure_returns_false_when_channel_full() {
2219 let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2220 assert!(try_send_with_backpressure(
2221 &tx,
2222 WsMessage::Text("first".into()),
2223 "test"
2224 ));
2225 assert!(!try_send_with_backpressure(
2226 &tx,
2227 WsMessage::Text("second".into()),
2228 "test"
2229 ));
2230 }
2231
2232 #[test]
2233 fn map_connect_error_formats_connection_refused_and_generic_errors() {
2234 let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2235 let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2236 assert!(err.to_string().contains("WebSocket connection refused"));
2237
2238 let generic = map_connect_error(
2239 tungstenite::Error::Protocol(
2240 tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2241 ),
2242 "ws://localhost:2/y",
2243 );
2244 assert!(
2245 generic
2246 .to_string()
2247 .contains("WebSocket connection failed (ws://localhost:2/y)")
2248 );
2249 }
2250
2251 #[test]
2255 fn from_uri_rejects_max_connections_zero() {
2256 let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2257 assert!(result.is_err());
2258 let msg = result.unwrap_err().to_string();
2259 assert!(
2260 msg.contains("maxConnections must be >= 1"),
2261 "expected maxConnections validation error, got: {msg}"
2262 );
2263 }
2264
2265 #[test]
2267 fn from_uri_rejects_max_message_size_zero() {
2268 let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2269 assert!(result.is_err());
2270 let msg = result.unwrap_err().to_string();
2271 assert!(
2272 msg.contains("maxMessageSize must be > 0"),
2273 "expected maxMessageSize validation error, got: {msg}"
2274 );
2275 }
2276
2277 #[test]
2279 fn from_uri_rejects_empty_allow_origin() {
2280 let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2281 assert!(result.is_err());
2282 let msg = result.unwrap_err().to_string();
2283 assert!(
2284 msg.contains("allowOrigin must not be empty"),
2285 "expected allowOrigin validation error, got: {msg}"
2286 );
2287 }
2288
2289 #[tokio::test]
2291 async fn consumer_double_start_returns_error() {
2292 let port = free_port();
2293 let uri = format!("ws://127.0.0.1:{port}/doublestart");
2294 let component_ctx = NoOpComponentContext;
2295 let endpoint = WsComponent::new()
2296 .create_endpoint(&uri, &component_ctx)
2297 .unwrap();
2298
2299 let mut consumer = endpoint.create_consumer().unwrap();
2300 let (route_tx, _route_rx) = mpsc::channel(16);
2301 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2302
2303 consumer.start(ctx).await.unwrap();
2305
2306 let (route_tx2, _route_rx2) = mpsc::channel(16);
2308 let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2309 let result = consumer.start(ctx2).await;
2310 assert!(result.is_err());
2311 let msg = result.unwrap_err().to_string();
2312 assert!(
2313 msg.contains("already started"),
2314 "expected double-start error, got: {msg}"
2315 );
2316
2317 consumer.stop().await.unwrap();
2318 }
2319
2320 #[tokio::test]
2322 async fn registry_cleanup_on_consumer_stop() {
2323 let port = free_port();
2324 let uri = format!("ws://127.0.0.1:{port}/cleanup");
2325 let component_ctx = NoOpComponentContext;
2326 let endpoint = WsComponent::new()
2327 .create_endpoint(&uri, &component_ctx)
2328 .unwrap();
2329
2330 let mut consumer = endpoint.create_consumer().unwrap();
2331 let (route_tx, _route_rx) = mpsc::channel(16);
2332 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2333 consumer.start(ctx).await.unwrap();
2334
2335 let registries = global_registries();
2337 let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2338 assert!(
2339 registries.contains_key(&key),
2340 "registry should have entry after start"
2341 );
2342
2343 consumer.stop().await.unwrap();
2345
2346 assert!(
2348 !registries.contains_key(&key),
2349 "registry should be cleaned up after stop"
2350 );
2351
2352 let server_reg = ServerRegistry::global();
2355 let guard = server_reg.inner.lock().unwrap();
2356 assert!(
2357 !guard.contains_key(&port),
2358 "ServerRegistry should remove port entry after last consumer stops"
2359 );
2360 }
2361
2362 #[tokio::test]
2364 async fn producer_server_send_returns_error_when_all_dropped() {
2365 let port = free_port();
2366 let uri = format!("ws://127.0.0.1:{port}/backpressure");
2367 let component_ctx = NoOpComponentContext;
2368 let endpoint = WsComponent::new()
2369 .create_endpoint(&uri, &component_ctx)
2370 .unwrap();
2371
2372 let mut consumer = endpoint.create_consumer().unwrap();
2373 let producer = endpoint
2374 .create_producer(&ProducerContext::default())
2375 .unwrap();
2376
2377 let (route_tx, _route_rx) = mpsc::channel(1); let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2379 consumer.start(ctx).await.unwrap();
2380
2381 let url = format!("ws://127.0.0.1:{port}/backpressure");
2383 let (mut client, _) = loop {
2384 match connect_async(&url).await {
2385 Ok(ok) => break ok,
2386 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2387 }
2388 };
2389
2390 tokio::time::sleep(Duration::from_millis(50)).await;
2392
2393 let mut all_dropped = false;
2395 for _ in 0..100 {
2396 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2397 match producer.clone().oneshot(exchange).await {
2398 Ok(_) => {}
2399 Err(e) => {
2400 if e.to_string().contains("backpressure") {
2401 all_dropped = true;
2402 break;
2403 }
2404 }
2405 }
2406 }
2407
2408 assert!(
2410 all_dropped,
2411 "producer should return error when all messages are dropped due to backpressure"
2412 );
2413
2414 let _ = client.close(None).await;
2416 consumer.stop().await.unwrap();
2417 }
2418
2419 #[tokio::test]
2421 async fn server_responds_to_client_ping_with_pong() {
2422 let port = free_port();
2423 let uri = format!("ws://127.0.0.1:{port}/pingpong");
2424 let component_ctx = NoOpComponentContext;
2425 let endpoint = WsComponent::new()
2426 .create_endpoint(&uri, &component_ctx)
2427 .unwrap();
2428
2429 let mut consumer = endpoint.create_consumer().unwrap();
2430 let (route_tx, _route_rx) = mpsc::channel(16);
2431 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2432 consumer.start(ctx).await.unwrap();
2433
2434 let url = format!("ws://127.0.0.1:{port}/pingpong");
2435 let (mut client, _) = loop {
2436 match connect_async(&url).await {
2437 Ok(ok) => break ok,
2438 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2439 }
2440 };
2441
2442 client
2444 .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2445 .await
2446 .unwrap();
2447
2448 let pong = tokio::time::timeout(Duration::from_secs(2), async {
2450 loop {
2451 match client.next().await {
2452 Some(Ok(ClientMessage::Pong(data))) => break data,
2453 Some(Ok(ClientMessage::Ping(_))) => continue,
2454 Some(Ok(_)) => continue,
2455 Some(Err(e)) => panic!("ws receive failed: {e}"),
2456 None => panic!("websocket closed before pong"),
2457 }
2458 }
2459 })
2460 .await
2461 .unwrap();
2462
2463 assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2464
2465 consumer.stop().await.unwrap();
2466 }
2467
2468 #[tokio::test]
2470 async fn producer_retries_on_connection_refused() {
2471 let port = free_port();
2473 let cfg = WsEndpointConfig::from_uri(&format!(
2475 "ws://127.0.0.1:{port}/retry?reconnect=true&reconnectMaxAttempts=2&reconnectDelayMs=50"
2476 ))
2477 .unwrap();
2478 let producer = WsProducer::new(cfg.client_config());
2479
2480 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2481
2482 let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2484 assert!(
2485 result.is_ok(),
2486 "producer should complete (with error) within timeout"
2487 );
2488 let result = result.unwrap();
2489 assert!(
2490 result.is_err(),
2491 "producer should fail when nothing is listening"
2492 );
2493 let msg = result.unwrap_err().to_string();
2494 assert!(
2495 msg.contains("connection refused"),
2496 "expected connection refused error, got: {msg}"
2497 );
2498 }
2499
2500 #[tokio::test]
2502 async fn server_bind_error_is_reported() {
2503 let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2505 let port = _listener.local_addr().unwrap().port();
2506
2507 let uri = format!("ws://127.0.0.1:{port}/binderror");
2510 let component_ctx = NoOpComponentContext;
2511 let endpoint = WsComponent::new()
2512 .create_endpoint(&uri, &component_ctx)
2513 .unwrap();
2514
2515 let mut consumer = endpoint.create_consumer().unwrap();
2516 let (route_tx, _route_rx) = mpsc::channel(16);
2517 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2518
2519 let start_result = consumer.start(ctx).await;
2521 let _ = start_result;
2524
2525 consumer.stop().await.unwrap();
2526 }
2527
2528 #[test]
2529 fn ws_app_state_server_error_starts_false() {
2530 let state = WsAppState {
2531 dispatch: Arc::new(RwLock::new(HashMap::new())),
2532 path_configs: Arc::new(DashMap::new()),
2533 path_policies: Arc::new(DashMap::new()),
2534 server_error: new_atomic_false(),
2535 };
2536 assert!(
2537 !state.server_error.load(Ordering::Relaxed),
2538 "server_error should start as false"
2539 );
2540 }
2541
2542 #[test]
2543 fn ws_app_state_server_error_can_be_set() {
2544 let state = WsAppState {
2545 dispatch: Arc::new(RwLock::new(HashMap::new())),
2546 path_configs: Arc::new(DashMap::new()),
2547 path_policies: Arc::new(DashMap::new()),
2548 server_error: new_atomic_false(),
2549 };
2550 assert!(!state.server_error.load(Ordering::Relaxed));
2551 state.server_error.store(true, Ordering::Relaxed);
2552 assert!(state.server_error.load(Ordering::Relaxed));
2553 }
2554
2555 #[tokio::test]
2556 async fn consumer_stop_returns_error_when_server_had_errors() {
2557 let port = free_port();
2558 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2559 let mut consumer = WsConsumer::new(cfg.server_config());
2560 let (route_tx, _route_rx) = mpsc::channel(16);
2561 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2562 consumer.start(ctx).await.unwrap();
2563
2564 if let Some(ref state) = consumer.server_state {
2566 state.server_error.store(true, Ordering::Relaxed);
2567 }
2568
2569 let result = consumer.stop().await;
2570 assert!(
2571 result.is_err(),
2572 "stop should return error when server had errors"
2573 );
2574 let msg = result.unwrap_err().to_string();
2575 assert!(
2576 msg.contains("terminated with errors"),
2577 "expected server error message, got: {msg}"
2578 );
2579 }
2580
2581 #[tokio::test]
2582 async fn consumer_stop_succeeds_when_server_healthy() {
2583 let port = free_port();
2584 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2585 let mut consumer = WsConsumer::new(cfg.server_config());
2586 let (route_tx, _route_rx) = mpsc::channel(16);
2587 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2588 consumer.start(ctx).await.unwrap();
2589
2590 let result = consumer.stop().await;
2591 assert!(
2592 result.is_ok(),
2593 "stop should succeed when server is healthy: {:?}",
2594 result
2595 );
2596 }
2597
2598 #[test]
2602 fn endpoint_config_parses_subprotocols() {
2603 let cfg = WsEndpointConfig::from_uri(
2604 "ws://localhost:9001/chat?subprotocols=graphql-ws,graphql-transport-ws",
2605 )
2606 .unwrap();
2607 assert_eq!(cfg.subprotocols, vec!["graphql-ws", "graphql-transport-ws"]);
2608 }
2609
2610 #[test]
2611 fn endpoint_config_default_subprotocols_empty() {
2612 let cfg = WsEndpointConfig::default();
2613 assert!(cfg.subprotocols.is_empty());
2614 }
2615
2616 #[test]
2618 fn endpoint_config_parses_send_timeout() {
2619 let cfg =
2620 WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=5000").unwrap();
2621 assert_eq!(cfg.send_timeout, Duration::from_millis(5000));
2622 }
2623
2624 #[test]
2625 fn endpoint_config_default_send_timeout() {
2626 let cfg = WsEndpointConfig::default();
2627 assert_eq!(cfg.send_timeout, Duration::from_secs(30));
2628 }
2629
2630 #[test]
2631 fn endpoint_config_rejects_invalid_send_timeout() {
2632 let err =
2633 WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=abc").unwrap_err();
2634 assert!(err.to_string().contains("sendTimeoutMs"));
2635 }
2636
2637 #[test]
2639 fn endpoint_config_parses_binary_payload() {
2640 let cfg =
2641 WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=true").unwrap();
2642 assert!(cfg.binary_payload);
2643 }
2644
2645 #[test]
2646 fn endpoint_config_default_binary_payload_false() {
2647 let cfg = WsEndpointConfig::default();
2648 assert!(!cfg.binary_payload);
2649 }
2650
2651 #[test]
2652 fn endpoint_config_rejects_invalid_binary_payload() {
2653 let err =
2654 WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=yes").unwrap_err();
2655 assert!(err.to_string().contains("binaryPayload"));
2656 }
2657}