1pub mod bundle;
2pub mod config;
3
4pub use bundle::WsBundle;
5pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use async_trait::async_trait;
11use axum::body::Body;
12use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
13use axum::extract::{FromRequest, Request, State};
14use axum::http::{StatusCode, header};
15use axum::response::IntoResponse;
16use axum::{Router, serve};
17use camel_component_api::{
18 Body as CamelBody, BoxProcessor, CamelError, Exchange, Message as CamelMessage,
19};
20use camel_component_api::{
21 Component, ConcurrencyModel, Consumer, ConsumerContext, Endpoint, ExchangeEnvelope,
22 ProducerContext,
23};
24use dashmap::DashMap;
25use futures::{SinkExt, StreamExt};
26use std::future::Future;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29use tokio::sync::{OnceCell, RwLock, mpsc};
30use tokio::task::JoinHandle;
31use tokio_tungstenite::tungstenite;
32use tokio_tungstenite::tungstenite::client::IntoClientRequest;
33use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
34use tower::Service;
35
36#[derive(Clone)]
37struct WsPathConfig {
38 max_connections: u32,
39 max_message_size: u32,
40 heartbeat_interval: std::time::Duration,
41 idle_timeout: std::time::Duration,
42 allow_origin: String,
43}
44
45impl Default for WsPathConfig {
46 fn default() -> Self {
47 let cfg = WsEndpointConfig::default();
48 Self {
49 max_connections: cfg.max_connections,
50 max_message_size: cfg.max_message_size,
51 heartbeat_interval: cfg.heartbeat_interval,
52 idle_timeout: cfg.idle_timeout,
53 allow_origin: cfg.allow_origin,
54 }
55 }
56}
57
58#[derive(Clone)]
59struct WsTlsConfig {
60 cert_path: String,
61 key_path: String,
62}
63
64type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
65
66struct ServerHandle {
67 state: WsAppState,
68 is_tls: bool,
69 _task: JoinHandle<()>,
70}
71
72pub struct ServerRegistry {
73 inner: Mutex<HashMap<u16, Arc<OnceCell<ServerHandle>>>>,
74}
75
76impl ServerRegistry {
77 pub fn global() -> &'static Self {
78 static REG: OnceLock<ServerRegistry> = OnceLock::new();
79 REG.get_or_init(|| Self {
80 inner: Mutex::new(HashMap::new()),
81 })
82 }
83
84 pub(crate) async fn get_or_spawn(
85 &'static self,
86 host: &str,
87 port: u16,
88 tls_config: Option<WsTlsConfig>,
89 ) -> Result<WsAppState, CamelError> {
90 let wants_tls = tls_config.is_some();
91 let host_owned = host.to_string();
92
93 let cell = {
94 let mut guard = self.inner.lock().map_err(|_| {
95 CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
96 })?;
97 guard
98 .entry(port)
99 .or_insert_with(|| Arc::new(OnceCell::new()))
100 .clone()
101 };
102
103 let handle = cell
104 .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
105 .await?;
106
107 if wants_tls != handle.is_tls {
108 return Err(CamelError::EndpointCreationFailed(format!(
109 "Server on port {port} already running with different TLS mode"
110 )));
111 }
112
113 Ok(handle.state.clone())
114 }
115}
116
117async fn spawn_server(
118 host: &str,
119 port: u16,
120 tls_config: Option<WsTlsConfig>,
121) -> Result<ServerHandle, CamelError> {
122 let addr = format!("{host}:{port}");
123 let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
124 let path_configs = Arc::new(DashMap::new());
125 let state = WsAppState {
126 dispatch: Arc::clone(&dispatch),
127 path_configs: Arc::clone(&path_configs),
128 };
129 let app = Router::new()
130 .fallback(dispatch_handler)
131 .with_state(state.clone());
132
133 let (task, is_tls) = if let Some(ref tls) = tls_config {
134 let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
135 let parsed_addr = addr.parse().map_err(|e| {
136 CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
137 })?;
138 let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
139 let task = tokio::spawn(async move {
140 let _ = axum_server::bind_rustls(parsed_addr, tls_cfg)
141 .serve(app.into_make_service())
142 .await;
143 });
144 (task, true)
145 } else {
146 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
147 CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
148 })?;
149 let task = tokio::spawn(async move {
150 let _ = serve(listener, app).await;
151 });
152 (task, false)
153 };
154
155 Ok(ServerHandle {
156 state,
157 is_tls,
158 _task: task,
159 })
160}
161
162#[derive(Clone)]
163struct WsAppState {
164 dispatch: DispatchTable,
165 path_configs: Arc<DashMap<String, WsPathConfig>>,
166}
167
168pub struct WsConnectionRegistry {
169 connections: DashMap<String, mpsc::Sender<WsMessage>>,
170}
171
172static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
173 DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
174> = OnceLock::new();
175
176fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
177 GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
178}
179
180impl Default for WsConnectionRegistry {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186impl WsConnectionRegistry {
187 pub fn new() -> Self {
188 Self {
189 connections: DashMap::new(),
190 }
191 }
192
193 pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
194 self.connections.insert(key, tx);
195 }
196
197 pub fn remove(&self, key: &str) {
198 self.connections.remove(key);
199 }
200
201 pub fn len(&self) -> usize {
202 self.connections.len()
203 }
204
205 pub fn is_empty(&self) -> bool {
206 self.connections.is_empty()
207 }
208
209 pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
210 self.connections.iter().map(|e| e.value().clone()).collect()
211 }
212
213 pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
214 keys.iter()
215 .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
216 .collect()
217 }
218}
219
220async fn dispatch_handler(
221 State(state): State<WsAppState>,
222 req: Request<Body>,
223) -> impl IntoResponse {
224 let path = req.uri().path().to_string();
225 let origin = req
226 .headers()
227 .get(header::ORIGIN)
228 .and_then(|value| value.to_str().ok())
229 .map(str::to_string);
230 let remote_addr = req
231 .extensions()
232 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
233 .map(|ci| ci.0.to_string())
234 .unwrap_or_default();
235 let table = state.dispatch.read().await;
236 if !table.contains_key(&path) {
237 return (
238 StatusCode::NOT_FOUND,
239 "no ws endpoint registered for this path",
240 )
241 .into_response();
242 }
243 drop(table);
244
245 let path_config = state
246 .path_configs
247 .get(&path)
248 .map(|entry| entry.value().clone())
249 .unwrap_or_default();
250 if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
251 return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
252 }
253
254 let upgrade_headers: HashMap<String, String> = req
255 .headers()
256 .iter()
257 .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
258 .collect();
259
260 let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
261 Ok(ws) => ws,
262 Err(_) => {
263 return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
264 }
265 };
266
267 ws.on_upgrade(move |socket| ws_handler(socket, state, path, remote_addr, upgrade_headers))
268 .into_response()
269}
270
271#[allow(unused_variables)]
272async fn ws_handler(
273 socket: WebSocket,
274 state: WsAppState,
275 path: String,
276 remote_addr: String,
277 upgrade_headers: HashMap<String, String>,
278) {
279 let connection_key = uuid::Uuid::new_v4().to_string();
280 let path_config = state
281 .path_configs
282 .get(&path)
283 .map(|entry| entry.value().clone())
284 .unwrap_or_default();
285
286 let env_tx = {
287 let table = state.dispatch.read().await;
288 table.get(&path).cloned()
289 };
290 let Some(env_tx) = env_tx else {
291 return;
292 };
293
294 let (mut sink, mut stream) = socket.split();
295 let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
296
297 let registry = global_registries();
298 let mut registry_key = None;
299 for entry in registry.iter() {
300 if entry.key().2 == path {
301 entry.value().insert(connection_key.clone(), out_tx.clone());
302 registry_key = Some(entry.key().clone());
303 break;
304 }
305 }
306
307 let writer = tokio::spawn(async move {
308 while let Some(msg) = out_rx.recv().await {
309 let _ = sink.send(msg).await;
310 }
311 });
312
313 let mut over_limit = false;
314 if let Some(key) = ®istry_key
315 && let Some(entry) = registry.get(key)
316 && entry.len() > path_config.max_connections as usize
317 {
318 over_limit = true;
319 }
320 if over_limit {
321 try_send_with_backpressure(
322 &out_tx,
323 WsMessage::Close(Some(CloseFrame {
324 code: CloseCode::from(1013u16),
325 reason: "max connections exceeded".into(),
326 })),
327 "max-connections-close",
328 );
329 if let Some(key) = registry_key.clone()
330 && let Some(entry) = registry.get(&key)
331 {
332 entry.remove(&connection_key);
333 }
334 drop(out_tx);
335 let _ = writer.await;
336 return;
337 }
338
339 let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
340 let ping_tx = out_tx.clone();
341 let interval = path_config.heartbeat_interval;
342 Some(tokio::spawn(async move {
343 let mut ticker = tokio::time::interval(interval);
344 loop {
345 ticker.tick().await;
346 let _ = try_send_with_backpressure(
347 &ping_tx,
348 WsMessage::Ping(Vec::new().into()),
349 "heartbeat-ping",
350 );
351 }
352 }))
353 } else {
354 None
355 };
356
357 loop {
358 let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
359 match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
360 Ok(msg) => msg,
361 Err(_) => {
362 try_send_with_backpressure(
363 &out_tx,
364 WsMessage::Close(Some(CloseFrame {
365 code: CloseCode::from(1000u16),
366 reason: "idle timeout".into(),
367 })),
368 "idle-timeout-close",
369 );
370 break;
371 }
372 }
373 } else {
374 stream.next().await
375 };
376
377 let Some(msg) = next_msg else {
378 break;
379 };
380
381 match msg {
382 Ok(WsMessage::Text(text)) => {
383 if text.len() > path_config.max_message_size as usize {
384 try_send_with_backpressure(
385 &out_tx,
386 WsMessage::Close(Some(CloseFrame {
387 code: CloseCode::from(1009u16),
388 reason: "message too large".into(),
389 })),
390 "max-message-size-close-text",
391 );
392 break;
393 }
394
395 let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
396 message.set_header(
397 "CamelWsConnectionKey",
398 serde_json::Value::String(connection_key.clone()),
399 );
400 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
401 message.set_header(
402 "CamelWsRemoteAddress",
403 serde_json::Value::String(remote_addr.clone()),
404 );
405
406 #[allow(unused_mut)]
407 let mut exchange = Exchange::new(message);
408 #[cfg(feature = "otel")]
409 {
410 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
411 }
412 if env_tx
413 .send(ExchangeEnvelope {
414 exchange,
415 reply_tx: None,
416 })
417 .await
418 .is_err()
419 {
420 break;
421 }
422 }
423 Ok(WsMessage::Binary(data)) => {
424 if data.len() > path_config.max_message_size as usize {
425 try_send_with_backpressure(
426 &out_tx,
427 WsMessage::Close(Some(CloseFrame {
428 code: CloseCode::from(1009u16),
429 reason: "message too large".into(),
430 })),
431 "max-message-size-close-binary",
432 );
433 break;
434 }
435
436 let mut message = CamelMessage::new(CamelBody::Bytes(data));
437 message.set_header(
438 "CamelWsConnectionKey",
439 serde_json::Value::String(connection_key.clone()),
440 );
441 message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
442 message.set_header(
443 "CamelWsRemoteAddress",
444 serde_json::Value::String(remote_addr.clone()),
445 );
446
447 #[allow(unused_mut)]
448 let mut exchange = Exchange::new(message);
449 #[cfg(feature = "otel")]
450 {
451 camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
452 }
453 if env_tx
454 .send(ExchangeEnvelope {
455 exchange,
456 reply_tx: None,
457 })
458 .await
459 .is_err()
460 {
461 break;
462 }
463 }
464 Ok(WsMessage::Close(_)) | Err(_) => break,
465 _ => {}
466 }
467 }
468
469 if let Some(task) = heartbeat_task {
470 task.abort();
471 }
472
473 if let Some(key) = registry_key
474 && let Some(entry) = registry.get(&key)
475 {
476 entry.remove(&connection_key);
477 }
478 drop(out_tx);
479 let _ = writer.await;
480}
481
482pub struct WsComponent {
483 pub(crate) config: WsConfig,
484}
485
486impl WsComponent {
487 pub fn new() -> Self {
488 Self {
489 config: WsConfig::default(),
490 }
491 }
492
493 pub fn with_config(config: WsConfig) -> Self {
494 Self { config }
495 }
496}
497
498impl Default for WsComponent {
499 fn default() -> Self {
500 Self::new()
501 }
502}
503
504impl Component for WsComponent {
505 fn scheme(&self) -> &str {
506 "ws"
507 }
508
509 fn create_endpoint(
510 &self,
511 uri: &str,
512 _ctx: &dyn camel_component_api::ComponentContext,
513 ) -> Result<Box<dyn Endpoint>, CamelError> {
514 let mut cfg = WsEndpointConfig::from_uri(uri)?;
515 if let Some(v) = self.config.max_connections {
516 cfg.max_connections = v;
517 }
518 if let Some(v) = self.config.max_message_size {
519 cfg.max_message_size = v;
520 }
521 if let Some(v) = self.config.heartbeat_interval_ms {
522 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
523 }
524 if let Some(v) = self.config.idle_timeout_ms {
525 cfg.idle_timeout = std::time::Duration::from_millis(v);
526 }
527 if let Some(v) = self.config.connect_timeout_ms {
528 cfg.connect_timeout = std::time::Duration::from_millis(v);
529 }
530 if let Some(v) = self.config.response_timeout_ms {
531 cfg.response_timeout = std::time::Duration::from_millis(v);
532 }
533 Ok(Box::new(WsEndpoint {
534 uri: uri.to_string(),
535 cfg,
536 }))
537 }
538}
539
540pub struct WssComponent {
541 pub(crate) config: WsConfig,
542}
543
544impl WssComponent {
545 pub fn new() -> Self {
546 Self {
547 config: WsConfig::default(),
548 }
549 }
550
551 pub fn with_config(config: WsConfig) -> Self {
552 Self { config }
553 }
554}
555
556impl Default for WssComponent {
557 fn default() -> Self {
558 Self::new()
559 }
560}
561
562impl Component for WssComponent {
563 fn scheme(&self) -> &str {
564 "wss"
565 }
566
567 fn create_endpoint(
568 &self,
569 uri: &str,
570 _ctx: &dyn camel_component_api::ComponentContext,
571 ) -> Result<Box<dyn Endpoint>, CamelError> {
572 let mut cfg = WsEndpointConfig::from_uri(uri)?;
573 if let Some(v) = self.config.max_connections {
574 cfg.max_connections = v;
575 }
576 if let Some(v) = self.config.max_message_size {
577 cfg.max_message_size = v;
578 }
579 if let Some(v) = self.config.heartbeat_interval_ms {
580 cfg.heartbeat_interval = std::time::Duration::from_millis(v);
581 }
582 if let Some(v) = self.config.idle_timeout_ms {
583 cfg.idle_timeout = std::time::Duration::from_millis(v);
584 }
585 if let Some(v) = self.config.connect_timeout_ms {
586 cfg.connect_timeout = std::time::Duration::from_millis(v);
587 }
588 if let Some(v) = self.config.response_timeout_ms {
589 cfg.response_timeout = std::time::Duration::from_millis(v);
590 }
591 Ok(Box::new(WsEndpoint {
592 uri: uri.to_string(),
593 cfg,
594 }))
595 }
596}
597
598struct WsEndpoint {
599 uri: String,
600 cfg: WsEndpointConfig,
601}
602
603impl Endpoint for WsEndpoint {
604 fn uri(&self) -> &str {
605 &self.uri
606 }
607
608 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
609 Ok(Box::new(WsConsumer::new(self.cfg.server_config())))
610 }
611
612 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
613 Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
614 }
615}
616
617pub struct WsConsumer {
618 cfg: WsServerConfig,
619 registry: Arc<WsConnectionRegistry>,
620 server_state: Option<WsAppState>,
621 registry_key: Option<(String, u16, String)>,
622 forward_task: Option<JoinHandle<()>>,
623}
624
625impl WsConsumer {
626 pub fn new(cfg: WsServerConfig) -> Self {
627 Self {
628 cfg,
629 registry: Arc::new(WsConnectionRegistry::new()),
630 server_state: None,
631 registry_key: None,
632 forward_task: None,
633 }
634 }
635}
636
637#[async_trait]
638impl Consumer for WsConsumer {
639 async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
640 let tls_config = if self.cfg.inner.scheme == "wss" {
641 let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
642 CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
643 })?;
644 let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
645 CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
646 })?;
647 Some(WsTlsConfig {
648 cert_path,
649 key_path,
650 })
651 } else {
652 None
653 };
654
655 let state = ServerRegistry::global()
656 .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
657 .await?;
658
659 let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
660 {
661 let mut table = state.dispatch.write().await;
662 table.insert(self.cfg.inner.path.clone(), env_tx);
663 }
664
665 state.path_configs.insert(
666 self.cfg.inner.path.clone(),
667 WsPathConfig {
668 max_connections: self.cfg.inner.max_connections,
669 max_message_size: self.cfg.inner.max_message_size,
670 heartbeat_interval: self.cfg.inner.heartbeat_interval,
671 idle_timeout: self.cfg.inner.idle_timeout,
672 allow_origin: self.cfg.inner.allow_origin.clone(),
673 },
674 );
675
676 let registry_key = (
677 self.cfg.inner.canonical_host(),
678 self.cfg.inner.port,
679 self.cfg.inner.path.clone(),
680 );
681 global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
682
683 let sender = ctx.sender();
684 let forward_task = tokio::spawn(async move {
685 while let Some(envelope) = env_rx.recv().await {
686 if sender.send(envelope).await.is_err() {
687 break;
688 }
689 }
690 });
691
692 self.server_state = Some(state);
693 self.registry_key = Some(registry_key);
694 self.forward_task = Some(forward_task);
695 Ok(())
696 }
697
698 async fn stop(&mut self) -> Result<(), CamelError> {
699 let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
700 code: axum::extract::ws::CloseCode::from(1001u16),
701 reason: "consumer stopping".into(),
702 }));
703 for tx in self.registry.snapshot_senders() {
704 let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
705 }
706
707 if let Some(state) = self.server_state.take() {
708 let mut table = state.dispatch.write().await;
709 table.remove(&self.cfg.inner.path);
710 state.path_configs.remove(&self.cfg.inner.path);
711 }
712
713 if let Some(key) = self.registry_key.take() {
714 global_registries().remove(&key);
715 }
716
717 if let Some(task) = self.forward_task.take() {
718 task.abort();
719 }
720
721 Ok(())
722 }
723
724 fn concurrency_model(&self) -> ConcurrencyModel {
725 ConcurrencyModel::Concurrent {
726 max: Some(self.cfg.inner.max_connections as usize),
727 }
728 }
729}
730
731#[derive(Clone)]
732pub struct WsProducer {
733 cfg: WsClientConfig,
734}
735
736impl WsProducer {
737 pub fn new(cfg: WsClientConfig) -> Self {
738 Self { cfg }
739 }
740}
741
742impl Service<Exchange> for WsProducer {
743 type Response = Exchange;
744 type Error = CamelError;
745 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
746
747 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
748 Poll::Ready(Ok(()))
749 }
750
751 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
752 let cfg = self.cfg.clone();
753
754 Box::pin(async move {
755 let canonical_host = cfg.inner.canonical_host();
756 let key = (
757 canonical_host.clone(),
758 cfg.inner.port,
759 cfg.inner.path.clone(),
760 );
761
762 let send_to_all = exchange
763 .input
764 .header("CamelWsSendToAll")
765 .and_then(|v| v.as_bool())
766 .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
767 .unwrap_or(false);
768
769 let conn_keys_header = exchange
770 .input
771 .header("CamelWsConnectionKey")
772 .and_then(|v| v.as_str())
773 .map(str::to_string);
774
775 let local_exists = global_registries().contains_key(&key);
776 let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
777
778 let message_type = exchange
779 .input
780 .header("CamelWsMessageType")
781 .and_then(|v| v.as_str())
782 .unwrap_or("text")
783 .to_ascii_lowercase();
784
785 if server_send_mode {
786 let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
787 let Some(registry) = registry else {
788 return Err(CamelError::ProcessorError(format!(
789 "WebSocket local consumer not found for {}:{}{}",
790 canonical_host, cfg.inner.port, cfg.inner.path
791 )));
792 };
793
794 let out_msg = body_to_axum_ws_message(
795 std::mem::take(&mut exchange.input.body),
796 &message_type,
797 )
798 .await?;
799
800 let targets = if send_to_all {
801 registry.snapshot_senders()
802 } else if let Some(keys) = conn_keys_header {
803 let parsed: Vec<String> = keys
804 .split(',')
805 .map(str::trim)
806 .filter(|k| !k.is_empty())
807 .map(str::to_string)
808 .collect();
809 registry.get_senders_for_keys(&parsed)
810 } else {
811 registry.snapshot_senders()
812 };
813
814 for tx in targets {
815 let _ = try_send_with_backpressure(&tx, out_msg.clone(), "producer-send");
816 }
817
818 return Ok(exchange);
819 }
820
821 let url = format!(
822 "{}://{}:{}{}",
823 cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
824 );
825
826 #[allow(unused_mut)]
827 let mut request = url
828 .clone()
829 .into_client_request()
830 .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
831
832 #[cfg(feature = "otel")]
833 {
834 let mut otel_headers = HashMap::new();
835 camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
836 for (k, v) in otel_headers {
837 if let (Ok(name), Ok(val)) = (
838 http::header::HeaderName::from_bytes(k.as_bytes()),
839 http::header::HeaderValue::from_str(&v),
840 ) {
841 request.headers_mut().insert(name, val);
842 }
843 }
844 }
845
846 let connect_future = tokio_tungstenite::connect_async(request);
847 let (mut ws_stream, _) =
848 tokio::time::timeout(cfg.inner.connect_timeout, connect_future)
849 .await
850 .map_err(|_| {
851 CamelError::ProcessorError(format!(
852 "WebSocket connect timeout ({:?}) to {url}",
853 cfg.inner.connect_timeout
854 ))
855 })?
856 .map_err(|e| map_connect_error(e, &url))?;
857
858 let out_msg =
859 body_to_client_ws_message(std::mem::take(&mut exchange.input.body), &message_type)
860 .await?;
861
862 ws_stream
863 .send(out_msg)
864 .await
865 .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
866
867 let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
868 loop {
869 match ws_stream.next().await {
870 Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
871 continue;
872 }
873 other => break other,
874 }
875 }
876 })
877 .await
878 .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
879
880 match incoming {
881 Some(Ok(ClientWsMessage::Text(text))) => {
882 exchange.input.body = CamelBody::Text(text.to_string());
883 }
884 Some(Ok(ClientWsMessage::Binary(data))) => {
885 exchange.input.body = CamelBody::Bytes(data);
886 }
887 Some(Ok(ClientWsMessage::Close(frame))) => {
888 let normal = frame
889 .as_ref()
890 .map(|f| {
891 f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
892 || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
893 })
894 .unwrap_or(true);
895
896 if normal {
897 exchange.input.body = CamelBody::Empty;
898 } else {
899 let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
900 return Err(CamelError::ProcessorError(format!(
901 "WebSocket peer closed: code {code}"
902 )));
903 }
904 }
905 Some(Ok(_)) | None => {
906 exchange.input.body = CamelBody::Empty;
907 }
908 Some(Err(e)) => {
909 return Err(CamelError::ProcessorError(format!(
910 "WebSocket receive failed: {e}"
911 )));
912 }
913 }
914
915 let _ = ws_stream.close(None).await;
916 Ok(exchange)
917 })
918 }
919}
920
921async fn body_to_axum_ws_message(
922 body: CamelBody,
923 message_type: &str,
924) -> Result<WsMessage, CamelError> {
925 match message_type {
926 "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
927 _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
928 }
929}
930
931async fn body_to_client_ws_message(
932 body: CamelBody,
933 message_type: &str,
934) -> Result<ClientWsMessage, CamelError> {
935 match message_type {
936 "binary" => Ok(ClientWsMessage::Binary(
937 body.into_bytes(10 * 1024 * 1024).await?,
938 )),
939 _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
940 }
941}
942
943async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
944 Ok(match body {
945 CamelBody::Empty => String::new(),
946 CamelBody::Text(s) => s,
947 CamelBody::Xml(s) => s,
948 CamelBody::Json(v) => v.to_string(),
949 CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
950 CamelBody::Stream(stream) => {
951 let bytes = CamelBody::Stream(stream)
952 .into_bytes(10 * 1024 * 1024)
953 .await?;
954 String::from_utf8_lossy(&bytes).to_string()
955 }
956 })
957}
958
959fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
960 if allowed_origin == "*" {
961 return true;
962 }
963 request_origin.is_some_and(|origin| origin == allowed_origin)
964}
965
966fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
967 match tx.try_send(msg) {
968 Ok(()) => true,
969 Err(error) => {
970 tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
971 false
972 }
973 }
974}
975
976fn load_tls_config(
977 cert_path: &str,
978 key_path: &str,
979) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
980 use std::fs::File;
981 use std::io::BufReader;
982
983 let cert_file = File::open(cert_path)
984 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
985 let key_file = File::open(key_path)
986 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
987
988 let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
989 .collect::<Result<Vec<_>, _>>()
990 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
991
992 let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
993 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
994 .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
995
996 tokio_rustls::rustls::ServerConfig::builder()
997 .with_no_client_auth()
998 .with_single_cert(certs, key)
999 .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1000}
1001
1002fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1003 match err {
1004 tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1005 CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1006 }
1007 tungstenite::Error::Tls(_) => {
1008 CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1009 }
1010 other => {
1011 let msg = other.to_string();
1012 if msg.to_lowercase().contains("connection refused") {
1013 CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1014 } else if msg.to_lowercase().contains("tls") {
1015 CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1016 } else {
1017 CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1018 }
1019 }
1020 }
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025 use super::*;
1026 use camel_component_api::NoOpComponentContext;
1027 use std::time::Duration;
1028
1029 use tokio::sync::mpsc;
1030 use tokio_tungstenite::connect_async;
1031 use tokio_tungstenite::tungstenite::Message as ClientMessage;
1032 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1033 use tokio_util::sync::CancellationToken;
1034 use tower::ServiceExt;
1035
1036 fn free_port() -> u16 {
1037 std::net::TcpListener::bind("127.0.0.1:0")
1038 .unwrap()
1039 .local_addr()
1040 .unwrap()
1041 .port()
1042 }
1043
1044 #[test]
1045 fn ws_component_scheme_is_ws() {
1046 assert_eq!(WsComponent::new().scheme(), "ws");
1047 }
1048
1049 #[test]
1050 fn wss_component_scheme_is_wss() {
1051 assert_eq!(WssComponent::new().scheme(), "wss");
1052 }
1053
1054 #[test]
1055 fn endpoint_config_defaults_match_spec() {
1056 let cfg = WsEndpointConfig::default();
1057 assert_eq!(cfg.scheme, "ws");
1058 assert_eq!(cfg.host, "0.0.0.0");
1059 assert_eq!(cfg.port, 8080);
1060 assert_eq!(cfg.path, "/");
1061 assert_eq!(cfg.max_connections, 100);
1062 assert_eq!(cfg.max_message_size, 65536);
1063 assert!(!cfg.send_to_all);
1064 assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1065 assert_eq!(cfg.idle_timeout, Duration::ZERO);
1066 assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1067 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1068 assert_eq!(cfg.allow_origin, "*");
1069 assert_eq!(cfg.tls_cert, None);
1070 assert_eq!(cfg.tls_key, None);
1071 }
1072
1073 #[test]
1074 fn endpoint_config_parses_uri_params() {
1075 let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1076 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1077
1078 assert_eq!(cfg.scheme, "ws");
1079 assert_eq!(cfg.host, "localhost");
1080 assert_eq!(cfg.port, 9001);
1081 assert_eq!(cfg.path, "/chat");
1082 assert_eq!(cfg.max_connections, 42);
1083 assert_eq!(cfg.max_message_size, 1024);
1084 assert!(cfg.send_to_all);
1085 assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1086 assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1087 assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1088 assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1089 assert_eq!(cfg.allow_origin, "https://example.com");
1090 assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1091 assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1092 }
1093
1094 #[test]
1095 fn endpoint_config_override_chain_uri_overrides_defaults() {
1096 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1097 assert_eq!(cfg.max_connections, 7);
1098 assert_eq!(cfg.max_message_size, 65536);
1099 assert!(!cfg.send_to_all);
1100 assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1101 }
1102
1103 #[test]
1104 fn endpoint_trait_creates_consumer_and_producer() {
1105 let ctx = NoOpComponentContext;
1106 let endpoint = WsComponent::new()
1107 .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1108 .unwrap();
1109
1110 endpoint.create_consumer().unwrap();
1111 endpoint
1112 .create_producer(&ProducerContext::default())
1113 .unwrap();
1114 }
1115
1116 #[test]
1117 fn ws_consumer_concurrency_model_uses_max_connections() {
1118 let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1119 let consumer = WsConsumer::new(cfg.server_config());
1120 assert_eq!(
1121 consumer.concurrency_model(),
1122 ConcurrencyModel::Concurrent { max: Some(321) }
1123 );
1124 }
1125
1126 #[tokio::test]
1127 async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1128 let registry = WsConnectionRegistry::new();
1129 let (tx1, mut rx1) = mpsc::channel(8);
1130 let (tx2, mut rx2) = mpsc::channel(8);
1131
1132 registry.insert("k1".into(), tx1);
1133 registry.insert("k2".into(), tx2);
1134 assert_eq!(registry.len(), 2);
1135
1136 for tx in registry.snapshot_senders() {
1137 tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1138 }
1139
1140 assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1141 assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1142
1143 let target = registry.get_senders_for_keys(&["k1".to_string()]);
1144 assert_eq!(target.len(), 1);
1145 target[0]
1146 .send(WsMessage::Text("targeted".into()))
1147 .await
1148 .unwrap();
1149
1150 assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1151 assert!(
1152 tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1153 .await
1154 .is_err()
1155 );
1156
1157 registry.remove("k1");
1158 assert_eq!(registry.len(), 1);
1159 }
1160
1161 #[test]
1162 fn host_canonicalization_maps_local_hosts_to_loopback() {
1163 let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1164 .unwrap()
1165 .canonical_host();
1166 let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1167 .unwrap()
1168 .canonical_host();
1169 let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1170 .unwrap()
1171 .canonical_host();
1172
1173 assert_eq!(c1, "127.0.0.1");
1174 assert_eq!(c2, "127.0.0.1");
1175 assert_eq!(c3, "127.0.0.1");
1176 }
1177
1178 #[tokio::test]
1179 async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1180 let port = free_port();
1181 let uri = format!("ws://127.0.0.1:{port}/echo");
1182 let component_ctx = NoOpComponentContext;
1183 let endpoint = WsComponent::new()
1184 .create_endpoint(&uri, &component_ctx)
1185 .unwrap();
1186
1187 let mut consumer = endpoint.create_consumer().unwrap();
1188 let producer = endpoint
1189 .create_producer(&ProducerContext::default())
1190 .unwrap();
1191
1192 let (route_tx, mut route_rx) = mpsc::channel(16);
1193 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1194 consumer.start(ctx).await.unwrap();
1195
1196 let route_task = tokio::spawn(async move {
1197 if let Some(envelope) = route_rx.recv().await {
1198 let payload = envelope
1199 .exchange
1200 .input
1201 .body
1202 .as_text()
1203 .unwrap_or_default()
1204 .to_string();
1205 let key = envelope
1206 .exchange
1207 .input
1208 .header("CamelWsConnectionKey")
1209 .and_then(|v| v.as_str())
1210 .unwrap()
1211 .to_string();
1212
1213 let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1214 response
1215 .input
1216 .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1217 producer.oneshot(response).await.unwrap();
1218 }
1219 });
1220
1221 let url = format!("ws://127.0.0.1:{port}/echo");
1222 let (mut client, _) = loop {
1223 match connect_async(&url).await {
1224 Ok(ok) => break ok,
1225 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1226 }
1227 };
1228
1229 client
1230 .send(ClientMessage::Text("hello-ws".into()))
1231 .await
1232 .unwrap();
1233
1234 let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1235 loop {
1236 match client.next().await {
1237 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1238 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1239 Some(Ok(_)) => continue,
1240 Some(Err(e)) => panic!("ws receive failed: {e}"),
1241 None => panic!("websocket closed before echo"),
1242 }
1243 }
1244 })
1245 .await
1246 .unwrap();
1247
1248 assert_eq!(incoming, "hello-ws");
1249
1250 consumer.stop().await.unwrap();
1251 route_task.await.unwrap();
1252 }
1253
1254 #[tokio::test]
1255 async fn consumer_stop_sends_close_1001() {
1256 let port = free_port();
1257 let uri = format!("ws://127.0.0.1:{port}/shutdown");
1258 let component_ctx = NoOpComponentContext;
1259 let endpoint = WsComponent::new()
1260 .create_endpoint(&uri, &component_ctx)
1261 .unwrap();
1262
1263 let mut consumer = endpoint.create_consumer().unwrap();
1264 let (route_tx, _route_rx) = mpsc::channel(16);
1265 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1266 consumer.start(ctx).await.unwrap();
1267
1268 let url = format!("ws://127.0.0.1:{port}/shutdown");
1269 let (mut client, _) = loop {
1270 match connect_async(&url).await {
1271 Ok(ok) => break ok,
1272 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1273 }
1274 };
1275
1276 client
1277 .send(ClientMessage::Text("keepalive".into()))
1278 .await
1279 .unwrap();
1280
1281 consumer.stop().await.unwrap();
1282
1283 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1284 loop {
1285 match client.next().await {
1286 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1287 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1288 Some(Ok(_)) => continue,
1289 Some(Err(e)) => panic!("ws receive failed: {e}"),
1290 None => panic!("websocket closed without close frame"),
1291 }
1292 }
1293 })
1294 .await
1295 .unwrap();
1296
1297 assert_eq!(close_code, Some(CloseCode::Away));
1298 }
1299
1300 #[test]
1301 fn wildcard_origin_allows_anything() {
1302 assert!(is_origin_allowed("*", None));
1303 assert!(is_origin_allowed("*", Some("https://example.com")));
1304 }
1305
1306 #[test]
1307 fn exact_origin_requires_match() {
1308 assert!(is_origin_allowed(
1309 "https://example.com",
1310 Some("https://example.com")
1311 ));
1312 assert!(!is_origin_allowed(
1313 "https://example.com",
1314 Some("https://other.com")
1315 ));
1316 assert!(!is_origin_allowed("https://example.com", None));
1317 }
1318
1319 #[test]
1320 fn endpoint_config_rejects_invalid_scheme() {
1321 let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1322 assert!(result.is_err());
1323 let msg = result.unwrap_err().to_string();
1324 assert!(
1325 msg.contains("Invalid WebSocket scheme"),
1326 "expected scheme error, got: {msg}"
1327 );
1328 }
1329
1330 #[tokio::test]
1331 async fn wss_consumer_start_fails_without_tls_cert() {
1332 let port = free_port();
1333 let component_ctx = NoOpComponentContext;
1334 let endpoint = WssComponent::new()
1335 .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1336 .unwrap();
1337 let mut consumer = endpoint.create_consumer().unwrap();
1338 let (tx, _rx) = mpsc::channel(16);
1339 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1340 let result = consumer.start(ctx).await;
1341 assert!(result.is_err());
1342 let msg = result.unwrap_err().to_string();
1343 assert!(
1344 msg.contains("TLS cert path is required"),
1345 "expected TLS cert error, got: {msg}"
1346 );
1347 }
1348
1349 #[tokio::test]
1350 async fn wss_consumer_start_fails_with_nonexistent_cert() {
1351 let port = free_port();
1352 let component_ctx = NoOpComponentContext;
1353 let endpoint = WssComponent::new()
1354 .create_endpoint(&format!(
1355 "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1356 ), &component_ctx)
1357 .unwrap();
1358 let mut consumer = endpoint.create_consumer().unwrap();
1359 let (tx, _rx) = mpsc::channel(16);
1360 let ctx = ConsumerContext::new(tx, CancellationToken::new());
1361 let result = consumer.start(ctx).await;
1362 assert!(result.is_err());
1363 let msg = result.unwrap_err().to_string();
1364 assert!(
1365 msg.contains("TLS cert file error"),
1366 "expected cert file error, got: {msg}"
1367 );
1368 }
1369
1370 #[tokio::test]
1371 async fn server_registry_returns_same_state_for_same_port() {
1372 let port = free_port();
1373 let state1 = ServerRegistry::global()
1374 .get_or_spawn("127.0.0.1", port, None)
1375 .await
1376 .unwrap();
1377 let state2 = ServerRegistry::global()
1378 .get_or_spawn("127.0.0.1", port, None)
1379 .await
1380 .unwrap();
1381 assert!(
1382 Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1383 "expected same dispatch table for same port"
1384 );
1385 }
1386
1387 #[tokio::test]
1388 async fn dispatch_handler_returns_404_for_unregistered_path() {
1389 let port = free_port();
1390 let state = ServerRegistry::global()
1391 .get_or_spawn("127.0.0.1", port, None)
1392 .await
1393 .unwrap();
1394 let app = Router::new().fallback(dispatch_handler).with_state(state);
1395 let response = tokio::time::timeout(
1396 Duration::from_secs(2),
1397 tower::ServiceExt::oneshot(
1398 app,
1399 axum::http::Request::builder()
1400 .uri("/nonexistent")
1401 .body(Body::empty())
1402 .unwrap(),
1403 ),
1404 )
1405 .await
1406 .unwrap()
1407 .unwrap();
1408 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1409 }
1410
1411 #[tokio::test]
1412 async fn client_mode_producer_connects_and_echoes() {
1413 let port = free_port();
1414
1415 let app = Router::new().route(
1416 "/echo",
1417 axum::routing::get(|ws: WebSocketUpgrade| async move {
1418 ws.on_upgrade(|mut socket: WebSocket| async move {
1419 while let Some(Ok(msg)) = socket.recv().await {
1420 match msg {
1421 WsMessage::Text(text) => {
1422 let _ = socket.send(WsMessage::Text(text)).await;
1423 }
1424 WsMessage::Binary(data) => {
1425 let _ = socket.send(WsMessage::Binary(data)).await;
1426 }
1427 WsMessage::Close(_) => break,
1428 _ => {}
1429 }
1430 }
1431 })
1432 }),
1433 );
1434 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
1435 .await
1436 .unwrap();
1437 let server_task = tokio::spawn(async move {
1438 let _ = serve(listener, app).await;
1439 });
1440
1441 let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1442 let producer = WsProducer::new(cfg.client_config());
1443
1444 let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1445 tokio::time::sleep(Duration::from_millis(25)).await;
1446 let result =
1447 match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1448 Ok(Ok(r)) => r,
1449 Ok(Err(_)) => panic!("producer call failed"),
1450 Err(_) => panic!("producer call timed out"),
1451 };
1452
1453 assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1454
1455 server_task.abort();
1456 }
1457
1458 #[tokio::test]
1459 async fn max_connections_rejects_with_close_1013() {
1460 let port = free_port();
1461 let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1462 let component_ctx = NoOpComponentContext;
1463 let endpoint = WsComponent::new()
1464 .create_endpoint(&uri, &component_ctx)
1465 .unwrap();
1466 let mut consumer = endpoint.create_consumer().unwrap();
1467 let (route_tx, _route_rx) = mpsc::channel(16);
1468 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1469 consumer.start(ctx).await.unwrap();
1470
1471 let url = format!("ws://127.0.0.1:{port}/limited");
1472 let (_client1, _) = loop {
1473 match connect_async(&url).await {
1474 Ok(ok) => break ok,
1475 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1476 }
1477 };
1478
1479 tokio::time::sleep(Duration::from_millis(100)).await;
1480
1481 let (mut client2, _) = connect_async(&url).await.unwrap();
1482
1483 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1484 loop {
1485 match client2.next().await {
1486 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1487 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1488 Some(Ok(ClientMessage::Text(_))) => continue,
1489 Some(Ok(_)) => continue,
1490 Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1491 None => panic!("client2 closed without close frame"),
1492 }
1493 }
1494 })
1495 .await
1496 .unwrap();
1497
1498 assert_eq!(
1499 close_code,
1500 Some(CloseCode::from(1013u16)),
1501 "expected 1013 (Try Again Later) for max connections"
1502 );
1503
1504 consumer.stop().await.unwrap();
1505 }
1506
1507 #[tokio::test]
1508 async fn max_message_size_rejects_with_close_1009() {
1509 let port = free_port();
1510 let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
1511 let component_ctx = NoOpComponentContext;
1512 let endpoint = WsComponent::new()
1513 .create_endpoint(&uri, &component_ctx)
1514 .unwrap();
1515 let mut consumer = endpoint.create_consumer().unwrap();
1516 let (route_tx, _route_rx) = mpsc::channel(16);
1517 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1518 consumer.start(ctx).await.unwrap();
1519
1520 let url = format!("ws://127.0.0.1:{port}/sizelimit");
1521 let (mut client, _) = loop {
1522 match connect_async(&url).await {
1523 Ok(ok) => break ok,
1524 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1525 }
1526 };
1527
1528 let oversized = "x".repeat(100);
1529 client
1530 .send(ClientMessage::Text(oversized.into()))
1531 .await
1532 .unwrap();
1533
1534 let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1535 loop {
1536 match client.next().await {
1537 Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1538 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1539 Some(Ok(_)) => continue,
1540 Some(Err(e)) => panic!("ws receive failed: {e}"),
1541 None => panic!("websocket closed without close frame"),
1542 }
1543 }
1544 })
1545 .await
1546 .unwrap();
1547
1548 assert_eq!(
1549 close_code,
1550 Some(CloseCode::from(1009u16)),
1551 "expected 1009 (Message Too Big) for oversized message"
1552 );
1553
1554 consumer.stop().await.unwrap();
1555 }
1556
1557 #[tokio::test]
1558 async fn origin_rejection_returns_403() {
1559 let port = free_port();
1560 let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
1561 let component_ctx = NoOpComponentContext;
1562 let endpoint = WsComponent::new()
1563 .create_endpoint(&uri, &component_ctx)
1564 .unwrap();
1565 let mut consumer = endpoint.create_consumer().unwrap();
1566 let (route_tx, _route_rx) = mpsc::channel(16);
1567 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1568 consumer.start(ctx).await.unwrap();
1569
1570 let state = ServerRegistry::global()
1571 .get_or_spawn("127.0.0.1", port, None)
1572 .await
1573 .unwrap();
1574 let app = Router::new().fallback(dispatch_handler).with_state(state);
1575
1576 let response = tokio::time::timeout(
1577 Duration::from_secs(2),
1578 tower::ServiceExt::oneshot(
1579 app,
1580 axum::http::Request::builder()
1581 .uri("/origintest")
1582 .header("origin", "https://evil.com")
1583 .header("upgrade", "websocket")
1584 .header("connection", "Upgrade")
1585 .header("sec-websocket-version", "13")
1586 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1587 .body(Body::empty())
1588 .unwrap(),
1589 ),
1590 )
1591 .await
1592 .unwrap()
1593 .unwrap();
1594
1595 assert_eq!(
1596 response.status(),
1597 StatusCode::FORBIDDEN,
1598 "expected 403 for disallowed origin"
1599 );
1600
1601 consumer.stop().await.unwrap();
1602 }
1603
1604 #[tokio::test]
1605 async fn broadcast_sends_to_all_connected_clients() {
1606 let port = free_port();
1607 let uri = format!("ws://127.0.0.1:{port}/bc");
1608 let component_ctx = NoOpComponentContext;
1609 let endpoint = WsComponent::new()
1610 .create_endpoint(&uri, &component_ctx)
1611 .unwrap();
1612 let mut consumer = endpoint.create_consumer().unwrap();
1613 let producer = endpoint
1614 .create_producer(&ProducerContext::default())
1615 .unwrap();
1616
1617 let (route_tx, _route_rx) = mpsc::channel(16);
1618 let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1619 consumer.start(ctx).await.unwrap();
1620
1621 let url = format!("ws://127.0.0.1:{port}/bc");
1622
1623 let (mut client1, _) = loop {
1624 match connect_async(&url).await {
1625 Ok(ok) => break ok,
1626 Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1627 }
1628 };
1629
1630 let (mut client2, _) = connect_async(&url).await.unwrap();
1631
1632 tokio::time::sleep(Duration::from_millis(100)).await;
1633
1634 let mut response =
1635 Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
1636 response
1637 .input
1638 .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
1639 producer.oneshot(response).await.unwrap();
1640
1641 let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
1642 loop {
1643 match client1.next().await {
1644 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1645 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1646 _ => panic!("client1 unexpected message or close"),
1647 }
1648 }
1649 })
1650 .await
1651 .unwrap();
1652
1653 let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
1654 loop {
1655 match client2.next().await {
1656 Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1657 Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1658 _ => panic!("client2 unexpected message or close"),
1659 }
1660 }
1661 })
1662 .await
1663 .unwrap();
1664
1665 assert_eq!(recv1, "broadcast-msg");
1666 assert_eq!(recv2, "broadcast-msg");
1667
1668 consumer.stop().await.unwrap();
1669 }
1670
1671 #[tokio::test]
1672 async fn concurrent_get_or_spawn_returns_same_state() {
1673 let port = free_port();
1674 let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
1675 Arc::new(std::sync::Mutex::new(Vec::new()));
1676
1677 let mut handles = Vec::new();
1678 for _ in 0..4 {
1679 let results = results.clone();
1680 handles.push(tokio::spawn(async move {
1681 let state = ServerRegistry::global()
1682 .get_or_spawn("127.0.0.1", port, None)
1683 .await
1684 .unwrap();
1685 results.lock().unwrap().push(state);
1686 }));
1687 }
1688
1689 for h in handles {
1690 h.await.unwrap();
1691 }
1692
1693 let states = results.lock().unwrap();
1694 assert_eq!(states.len(), 4);
1695 for i in 1..states.len() {
1696 assert!(
1697 Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
1698 "all concurrent callers should get the same dispatch table"
1699 );
1700 }
1701 }
1702}