Skip to main content

camel_component_ws/
lib.rs

1//! WebSocket component for rust-camel — Axum-based WebSocket server and Tokio-tungstenite client for bidirectional messaging.
2//!
3//! Main types: `WsComponent`, `WsBundle`, `WsConfig`, `WsServerConfig`, `WsClientConfig`, `WsEndpointConfig`.
4//! Main modules: `bundle`, `config`, `health`.
5
6pub mod bundle;
7pub mod config;
8pub mod health;
9
10pub use bundle::WsBundle;
11pub use config::{WsClientConfig, WsConfig, WsEndpointConfig, WsServerConfig};
12pub use health::WsHealthCheck;
13
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex, OnceLock};
16
17use async_trait::async_trait;
18use axum::body::Body;
19use axum::extract::ws::{CloseCode, CloseFrame, Message as WsMessage, WebSocket, WebSocketUpgrade};
20use axum::extract::{FromRequest, Request, State};
21use axum::http::{StatusCode, header};
22use axum::response::IntoResponse;
23use axum::{Router, serve};
24use camel_api::security_policy::AuthorizationDecision;
25use camel_component_api::{
26    Body as CamelBody, BoxProcessor, CamelError, Component, ConcurrencyModel, Consumer,
27    ConsumerContext, Endpoint, Exchange, ExchangeEnvelope, Message as CamelMessage,
28    NetworkRetryPolicy, ProducerContext, retry_async,
29};
30use dashmap::DashMap;
31use futures::{SinkExt, StreamExt};
32use std::future::Future;
33use std::pin::Pin;
34use std::task::{Context, Poll};
35use tokio::sync::{OnceCell, RwLock, mpsc};
36use tokio::task::JoinHandle;
37use tokio_tungstenite::tungstenite;
38use tokio_tungstenite::tungstenite::client::IntoClientRequest;
39use tokio_tungstenite::tungstenite::protocol::Message as ClientWsMessage;
40use tower::Service;
41
42#[derive(Clone)]
43pub struct WsPathConfig {
44    pub max_connections: u32,
45    pub max_message_size: u32,
46    pub heartbeat_interval: std::time::Duration,
47    pub idle_timeout: std::time::Duration,
48    pub allow_origin: String,
49}
50
51impl Default for WsPathConfig {
52    fn default() -> Self {
53        let cfg = WsEndpointConfig::default();
54        Self {
55            max_connections: cfg.max_connections,
56            max_message_size: cfg.max_message_size,
57            heartbeat_interval: cfg.heartbeat_interval,
58            idle_timeout: cfg.idle_timeout,
59            allow_origin: cfg.allow_origin,
60        }
61    }
62}
63
64#[derive(Clone)]
65pub struct WsTlsConfig {
66    pub cert_path: String,
67    pub key_path: String,
68}
69
70type DispatchTable = Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>>;
71
72struct ServerHandle {
73    state: WsAppState,
74    is_tls: bool,
75    _task: JoinHandle<()>,
76}
77
78struct ServerRegistryInner {
79    cell: Arc<OnceCell<ServerHandle>>,
80    ref_count: usize,
81}
82
83pub struct ServerRegistry {
84    inner: Mutex<HashMap<u16, ServerRegistryInner>>,
85}
86
87impl ServerRegistry {
88    pub fn global() -> &'static Self {
89        static REG: OnceLock<ServerRegistry> = OnceLock::new();
90        REG.get_or_init(|| Self {
91            inner: Mutex::new(HashMap::new()),
92        })
93    }
94
95    pub async fn get_or_spawn(
96        &'static self,
97        host: &str,
98        port: u16,
99        tls_config: Option<WsTlsConfig>,
100    ) -> Result<WsAppState, CamelError> {
101        let wants_tls = tls_config.is_some();
102        let host_owned = host.to_string();
103
104        let (cell, _is_new) = {
105            let mut guard = self.inner.lock().map_err(|_| {
106                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
107            })?;
108            let entry = guard.entry(port).or_insert_with(|| ServerRegistryInner {
109                cell: Arc::new(OnceCell::new()),
110                ref_count: 0,
111            });
112            entry.ref_count += 1;
113            (entry.cell.clone(), entry.ref_count == 1)
114        };
115
116        let handle = cell
117            .get_or_try_init(|| async { spawn_server(&host_owned, port, tls_config).await })
118            .await?;
119
120        if wants_tls != handle.is_tls {
121            // Decrement ref count since we're rejecting this caller
122            let mut guard = self.inner.lock().map_err(|_| {
123                CamelError::EndpointCreationFailed("ServerRegistry lock poisoned".into())
124            })?;
125            if let Some(entry) = guard.get_mut(&port) {
126                entry.ref_count -= 1;
127                if entry.ref_count == 0 {
128                    guard.remove(&port);
129                }
130            }
131            return Err(CamelError::EndpointCreationFailed(format!(
132                "Server on port {port} already running with different TLS mode"
133            )));
134        }
135
136        Ok(handle.state.clone())
137    }
138
139    /// Release a reference to the server on the given port.
140    /// When the last reference is released, the server task is aborted and the entry removed. (WS-005)
141    pub(crate) fn release(&self, port: u16) {
142        let mut guard = match self.inner.lock() {
143            Ok(g) => g,
144            Err(_) => return,
145        };
146        if let Some(entry) = guard.get_mut(&port) {
147            entry.ref_count = entry.ref_count.saturating_sub(1);
148            if entry.ref_count == 0 {
149                // Abort the server task if it exists (WS-001, WS-005)
150                if let Some(handle) = entry.cell.get() {
151                    handle._task.abort();
152                }
153                guard.remove(&port);
154                tracing::info!(port, "WebSocket server registry entry removed");
155            }
156        }
157    }
158}
159
160async fn spawn_server(
161    host: &str,
162    port: u16,
163    tls_config: Option<WsTlsConfig>,
164) -> Result<ServerHandle, CamelError> {
165    let host_owned = host.to_string();
166    let addr = format!("{host}:{port}");
167    let dispatch: DispatchTable = Arc::new(RwLock::new(HashMap::new()));
168    let path_configs = Arc::new(DashMap::new());
169    let path_policies = Arc::new(DashMap::new());
170    let server_error = new_atomic_false();
171    let state = WsAppState {
172        dispatch: Arc::clone(&dispatch),
173        path_configs: Arc::clone(&path_configs),
174        path_policies: Arc::clone(&path_policies),
175        server_error: Arc::clone(&server_error),
176    };
177    let app = Router::new()
178        .fallback(dispatch_handler)
179        .with_state(state.clone());
180
181    let (task, is_tls) = if let Some(ref tls) = tls_config {
182        let rustls = load_tls_config(&tls.cert_path, &tls.key_path)?;
183        let parsed_addr = addr.parse().map_err(|e| {
184            CamelError::EndpointCreationFailed(format!("Invalid listen address {addr}: {e}"))
185        })?;
186        let tls_cfg = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(rustls));
187        let error_flag = Arc::clone(&server_error);
188        let task = tokio::spawn(async move {
189            if let Err(e) = axum_server::bind_rustls(parsed_addr, tls_cfg)
190                .serve(app.into_make_service())
191                .await
192            {
193                tracing::error!(
194                    host = host_owned,
195                    port = port,
196                    error = %e,
197                    "WebSocket server terminated with error"
198                );
199                error_flag.store(true, Ordering::Relaxed);
200            }
201        });
202        (task, true)
203    } else {
204        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
205            CamelError::EndpointCreationFailed(format!("Failed to bind {addr}: {e}"))
206        })?;
207        let error_flag = Arc::clone(&server_error);
208        let task = tokio::spawn(async move {
209            if let Err(e) = serve(listener, app).await {
210                tracing::error!(
211                    host = host_owned,
212                    port = port,
213                    error = %e,
214                    "WebSocket server terminated with error"
215                );
216                error_flag.store(true, Ordering::Relaxed);
217            }
218        });
219        (task, false)
220    };
221
222    tracing::info!(host, port, is_tls, "WebSocket server started");
223
224    Ok(ServerHandle {
225        state,
226        is_tls,
227        _task: task,
228    })
229}
230
231#[derive(Clone)]
232pub struct WsAppState {
233    pub dispatch: DispatchTable,
234    pub path_configs: Arc<DashMap<String, WsPathConfig>>,
235    pub path_policies: Arc<DashMap<String, camel_component_api::SecurityContext>>,
236    pub server_error: Arc<AtomicBool>,
237}
238
239pub struct WsConnectionRegistry {
240    connections: DashMap<String, mpsc::Sender<WsMessage>>,
241}
242
243static GLOBAL_CONNECTION_REGISTRIES: OnceLock<
244    DashMap<(String, u16, String), Arc<WsConnectionRegistry>>,
245> = OnceLock::new();
246
247fn global_registries() -> &'static DashMap<(String, u16, String), Arc<WsConnectionRegistry>> {
248    GLOBAL_CONNECTION_REGISTRIES.get_or_init(DashMap::new)
249}
250
251impl Default for WsConnectionRegistry {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257impl WsConnectionRegistry {
258    pub fn new() -> Self {
259        Self {
260            connections: DashMap::new(),
261        }
262    }
263
264    pub fn insert(&self, key: String, tx: mpsc::Sender<WsMessage>) {
265        self.connections.insert(key, tx);
266    }
267
268    pub fn remove(&self, key: &str) {
269        self.connections.remove(key);
270    }
271
272    pub fn len(&self) -> usize {
273        self.connections.len()
274    }
275
276    pub fn is_empty(&self) -> bool {
277        self.connections.is_empty()
278    }
279
280    pub fn snapshot_senders(&self) -> Vec<mpsc::Sender<WsMessage>> {
281        self.connections.iter().map(|e| e.value().clone()).collect()
282    }
283
284    pub fn get_senders_for_keys(&self, keys: &[String]) -> Vec<mpsc::Sender<WsMessage>> {
285        keys.iter()
286            .filter_map(|k| self.connections.get(k).map(|e| e.value().clone()))
287            .collect()
288    }
289}
290
291pub async fn dispatch_handler(
292    State(state): State<WsAppState>,
293    req: Request<Body>,
294) -> impl IntoResponse {
295    let path = req.uri().path().to_string();
296    let origin = req
297        .headers()
298        .get(header::ORIGIN)
299        .and_then(|value| value.to_str().ok())
300        .map(str::to_string);
301    let remote_addr = req
302        .extensions()
303        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
304        .map(|ci| ci.0.to_string())
305        .unwrap_or_default();
306    let table = state.dispatch.read().await;
307    if !table.contains_key(&path) {
308        return (
309            StatusCode::NOT_FOUND,
310            "no ws endpoint registered for this path",
311        )
312            .into_response();
313    }
314    drop(table);
315
316    let path_config = state
317        .path_configs
318        .get(&path)
319        .map(|entry| entry.value().clone())
320        .unwrap_or_default();
321    if !is_origin_allowed(&path_config.allow_origin, origin.as_deref()) {
322        return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
323    }
324
325    let mut principal_opt: Option<camel_api::security_policy::Principal> = None;
326    if let Some(sec_ctx) = state.path_policies.get(&path) {
327        let extracted =
328            camel_auth::extract_token_multi(req.headers(), req.uri(), &sec_ctx.credential_sources);
329
330        match extracted {
331            Some(extracted) => {
332                if matches!(
333                    extracted.source,
334                    camel_auth::CredentialSource::QueryParam { .. }
335                ) {
336                    let redacted =
337                        camel_auth::redact_query_params(req.uri(), &["access_token", "token"]);
338                    tracing::debug!(path = %redacted, "WS upgrade with query token (redacted)");
339                }
340                match sec_ctx
341                    .authenticator
342                    .authenticate_bearer(&extracted.token)
343                    .await
344                {
345                    Ok(principal) => {
346                        let mut exchange = camel_api::Exchange::new(camel_api::Message::new(
347                            camel_api::Body::Empty,
348                        ));
349                        camel_api::store_principal_properties(&mut exchange, &principal);
350                        match sec_ctx.policy.evaluate(&mut exchange).await {
351                            Ok(AuthorizationDecision::Granted { principal: _p }) => {
352                                tracing::debug!(path = %path, subject = %principal.subject, "WS upgrade authorized");
353                                principal_opt = Some(principal);
354                            }
355                            Ok(AuthorizationDecision::Denied { reason, .. }) => {
356                                tracing::warn!(path = %path, reason = %reason, "WS upgrade denied");
357                                return (StatusCode::FORBIDDEN, "Forbidden").into_response();
358                            }
359                            Err(e) => {
360                                tracing::error!(path = %path, error = %e, "Policy evaluation error during WS upgrade");
361                                return (
362                                    StatusCode::INTERNAL_SERVER_ERROR,
363                                    "Internal Server Error",
364                                )
365                                    .into_response();
366                            }
367                        }
368                    }
369                    Err(e) => {
370                        let (status, body) = match &e {
371                            camel_api::CamelError::Unauthenticated(_) => {
372                                (StatusCode::UNAUTHORIZED, "Unauthorized")
373                            }
374                            camel_api::CamelError::ProcessorError(msg)
375                                if msg.contains("auth provider unavailable") =>
376                            {
377                                (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
378                            }
379                            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error"),
380                        };
381                        tracing::warn!(path = %path, error = %e, "WS upgrade authentication failed");
382                        return (status, body).into_response();
383                    }
384                }
385            }
386            None => {
387                tracing::warn!(path = %path, "WS upgrade rejected: no credential found in any source");
388                return (
389                    StatusCode::UNAUTHORIZED,
390                    [("WWW-Authenticate", "Bearer".to_string())],
391                    "Unauthorized",
392                )
393                    .into_response();
394            }
395        }
396    }
397
398    let upgrade_headers: HashMap<String, String> = req
399        .headers()
400        .iter()
401        .filter_map(|(k, v)| Some((k.as_str().to_lowercase(), v.to_str().ok()?.to_string())))
402        .collect();
403
404    let ws: WebSocketUpgrade = match WebSocketUpgrade::from_request(req, &()).await {
405        Ok(ws) => ws,
406        Err(_) => {
407            return (StatusCode::BAD_REQUEST, "not a websocket request").into_response();
408        }
409    };
410
411    ws.on_upgrade(move |socket| {
412        ws_handler(
413            socket,
414            state,
415            path,
416            remote_addr,
417            upgrade_headers,
418            principal_opt,
419        )
420    })
421    .into_response()
422}
423
424#[allow(unused_variables)]
425async fn ws_handler(
426    socket: WebSocket,
427    state: WsAppState,
428    path: String,
429    remote_addr: String,
430    upgrade_headers: HashMap<String, String>,
431    principal: Option<camel_api::security_policy::Principal>,
432) {
433    let connection_key = uuid::Uuid::new_v4().to_string();
434    let path_config = state
435        .path_configs
436        .get(&path)
437        .map(|entry| entry.value().clone())
438        .unwrap_or_default();
439
440    let env_tx = {
441        let table = state.dispatch.read().await;
442        table.get(&path).cloned()
443    };
444    let Some(env_tx) = env_tx else {
445        return;
446    };
447
448    let (mut sink, mut stream) = socket.split();
449    let (out_tx, mut out_rx) = mpsc::channel::<WsMessage>(32);
450
451    let registry = global_registries();
452    let mut registry_key = None;
453    for entry in registry.iter() {
454        if entry.key().2 == path {
455            entry.value().insert(connection_key.clone(), out_tx.clone());
456            registry_key = Some(entry.key().clone());
457            break;
458        }
459    }
460
461    // Clone for writer closure and subsequent tracing (WS-009)
462    let conn_key_for_writer = connection_key.clone();
463    let path_for_writer = path.clone();
464
465    let writer = tokio::spawn(async move {
466        while let Some(msg) = out_rx.recv().await {
467            if let Err(e) = sink.send(msg).await {
468                tracing::warn!(
469                    connection_key = conn_key_for_writer,
470                    path = path_for_writer,
471                    error = %e,
472                    "WebSocket writer send error — closing connection"
473                );
474                break;
475            }
476        }
477    });
478
479    tracing::info!(
480        connection_key = connection_key,
481        path = path,
482        remote_addr = remote_addr,
483        "WebSocket connection opened"
484    );
485
486    let mut over_limit = false;
487    if let Some(key) = &registry_key
488        && let Some(entry) = registry.get(key)
489        && entry.len() > path_config.max_connections as usize
490    {
491        over_limit = true;
492    }
493    if over_limit {
494        try_send_with_backpressure(
495            &out_tx,
496            WsMessage::Close(Some(CloseFrame {
497                code: CloseCode::from(1013u16),
498                reason: "max connections exceeded".into(),
499            })),
500            "max-connections-close",
501        );
502        if let Some(key) = registry_key.clone()
503            && let Some(entry) = registry.get(&key)
504        {
505            entry.remove(&connection_key);
506        }
507        drop(out_tx);
508        let _ = writer.await;
509        return;
510    }
511
512    let heartbeat_task = if path_config.heartbeat_interval > std::time::Duration::ZERO {
513        let ping_tx = out_tx.clone();
514        let interval = path_config.heartbeat_interval;
515        Some(tokio::spawn(async move {
516            let mut ticker = tokio::time::interval(interval);
517            loop {
518                ticker.tick().await;
519                let _ = try_send_with_backpressure(
520                    &ping_tx,
521                    WsMessage::Ping(Vec::new().into()),
522                    "heartbeat-ping",
523                );
524            }
525        }))
526    } else {
527        None
528    };
529
530    loop {
531        let next_msg = if path_config.idle_timeout > std::time::Duration::ZERO {
532            match tokio::time::timeout(path_config.idle_timeout, stream.next()).await {
533                Ok(msg) => msg,
534                Err(_) => {
535                    try_send_with_backpressure(
536                        &out_tx,
537                        WsMessage::Close(Some(CloseFrame {
538                            code: CloseCode::from(1000u16),
539                            reason: "idle timeout".into(),
540                        })),
541                        "idle-timeout-close",
542                    );
543                    break;
544                }
545            }
546        } else {
547            stream.next().await
548        };
549
550        let Some(msg) = next_msg else {
551            break;
552        };
553
554        match msg {
555            Ok(WsMessage::Ping(data)) => {
556                tracing::debug!(
557                    connection_key = connection_key,
558                    path = path,
559                    "WebSocket ping received — sending pong"
560                );
561                let _ = try_send_with_backpressure(
562                    &out_tx,
563                    WsMessage::Pong(data),
564                    "ping-pong-response",
565                );
566            }
567            Ok(WsMessage::Pong(_)) => {
568                tracing::debug!(
569                    connection_key = connection_key,
570                    path = path,
571                    "WebSocket pong received"
572                );
573            }
574            Ok(WsMessage::Text(text)) => {
575                if text.len() > path_config.max_message_size as usize {
576                    try_send_with_backpressure(
577                        &out_tx,
578                        WsMessage::Close(Some(CloseFrame {
579                            code: CloseCode::from(1009u16),
580                            reason: "message too large".into(),
581                        })),
582                        "max-message-size-close-text",
583                    );
584                    break;
585                }
586
587                let mut message = CamelMessage::new(CamelBody::Text(text.to_string()));
588                message.set_header(
589                    "CamelWsConnectionKey",
590                    serde_json::Value::String(connection_key.clone()),
591                );
592                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
593                message.set_header(
594                    "CamelWsRemoteAddress",
595                    serde_json::Value::String(remote_addr.clone()),
596                );
597
598                #[allow(unused_mut)]
599                let mut exchange = Exchange::new(message);
600                if let Some(ref p) = principal {
601                    camel_api::store_principal_properties(&mut exchange, p);
602                }
603                #[cfg(feature = "otel")]
604                {
605                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
606                }
607                if env_tx
608                    .send(ExchangeEnvelope {
609                        exchange,
610                        reply_tx: None,
611                    })
612                    .await
613                    .is_err()
614                {
615                    break;
616                }
617            }
618            Ok(WsMessage::Binary(data)) => {
619                if data.len() > path_config.max_message_size as usize {
620                    try_send_with_backpressure(
621                        &out_tx,
622                        WsMessage::Close(Some(CloseFrame {
623                            code: CloseCode::from(1009u16),
624                            reason: "message too large".into(),
625                        })),
626                        "max-message-size-close-binary",
627                    );
628                    break;
629                }
630
631                let mut message = CamelMessage::new(CamelBody::Bytes(data));
632                message.set_header(
633                    "CamelWsConnectionKey",
634                    serde_json::Value::String(connection_key.clone()),
635                );
636                message.set_header("CamelWsPath", serde_json::Value::String(path.clone()));
637                message.set_header(
638                    "CamelWsRemoteAddress",
639                    serde_json::Value::String(remote_addr.clone()),
640                );
641
642                #[allow(unused_mut)]
643                let mut exchange = Exchange::new(message);
644                if let Some(ref p) = principal {
645                    camel_api::store_principal_properties(&mut exchange, p);
646                }
647                #[cfg(feature = "otel")]
648                {
649                    camel_otel::extract_into_exchange(&mut exchange, &upgrade_headers);
650                }
651                if env_tx
652                    .send(ExchangeEnvelope {
653                        exchange,
654                        reply_tx: None,
655                    })
656                    .await
657                    .is_err()
658                {
659                    break;
660                }
661            }
662            Ok(WsMessage::Close(cf)) => {
663                let reason = cf
664                    .as_ref()
665                    .map(|f| f.reason.to_string())
666                    .unwrap_or_default();
667                tracing::info!(
668                    connection_key = connection_key,
669                    path = path,
670                    reason = reason,
671                    "WebSocket connection closed by peer"
672                );
673                break;
674            }
675            Err(e) => {
676                tracing::warn!(
677                    connection_key = connection_key,
678                    path = path,
679                    error = %e,
680                    "WebSocket receive error"
681                );
682                break;
683            }
684        }
685    }
686
687    if let Some(task) = heartbeat_task {
688        task.abort();
689    }
690
691    if let Some(key) = registry_key
692        && let Some(entry) = registry.get(&key)
693    {
694        entry.remove(&connection_key);
695    }
696    drop(out_tx);
697    let _ = writer.await;
698
699    tracing::info!(
700        connection_key = connection_key,
701        path = path,
702        "WebSocket connection closed"
703    );
704}
705
706pub struct WsComponent {
707    pub(crate) config: WsConfig,
708}
709
710impl WsComponent {
711    pub fn new() -> Self {
712        Self {
713            config: WsConfig::default(),
714        }
715    }
716
717    pub fn with_config(config: WsConfig) -> Self {
718        Self { config }
719    }
720}
721
722impl Default for WsComponent {
723    fn default() -> Self {
724        Self::new()
725    }
726}
727
728impl Component for WsComponent {
729    fn scheme(&self) -> &str {
730        "ws"
731    }
732
733    fn create_endpoint(
734        &self,
735        uri: &str,
736        ctx: &dyn camel_component_api::ComponentContext,
737    ) -> Result<Box<dyn Endpoint>, CamelError> {
738        self.config.validate()?;
739        let mut cfg = WsEndpointConfig::from_uri(uri)?;
740        if let Some(v) = self.config.max_connections {
741            cfg.max_connections = v;
742        }
743        if let Some(v) = self.config.max_message_size {
744            cfg.max_message_size = v;
745        }
746        if let Some(v) = self.config.heartbeat_interval_ms {
747            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
748        }
749        if let Some(v) = self.config.idle_timeout_ms {
750            cfg.idle_timeout = std::time::Duration::from_millis(v);
751        }
752        if let Some(v) = self.config.connect_timeout_ms {
753            cfg.connect_timeout = std::time::Duration::from_millis(v);
754        }
755        if let Some(v) = self.config.response_timeout_ms {
756            cfg.response_timeout = std::time::Duration::from_millis(v);
757        }
758        if let Some(v) = self.config.send_timeout_ms {
759            cfg.send_timeout = std::time::Duration::from_millis(v);
760        }
761        if let Some(v) = self.config.binary_payload {
762            cfg.binary_payload = v;
763        }
764        if let Some(ref v) = self.config.subprotocols {
765            cfg.subprotocols = v.clone();
766        }
767        let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
768        ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
769        Ok(Box::new(WsEndpoint {
770            uri: uri.to_string(),
771            cfg,
772        }))
773    }
774}
775
776pub struct WssComponent {
777    pub(crate) config: WsConfig,
778}
779
780impl WssComponent {
781    pub fn new() -> Self {
782        Self {
783            config: WsConfig::default(),
784        }
785    }
786
787    pub fn with_config(config: WsConfig) -> Self {
788        Self { config }
789    }
790}
791
792impl Default for WssComponent {
793    fn default() -> Self {
794        Self::new()
795    }
796}
797
798impl Component for WssComponent {
799    fn scheme(&self) -> &str {
800        "wss"
801    }
802
803    fn create_endpoint(
804        &self,
805        uri: &str,
806        ctx: &dyn camel_component_api::ComponentContext,
807    ) -> Result<Box<dyn Endpoint>, CamelError> {
808        self.config.validate()?;
809        let mut cfg = WsEndpointConfig::from_uri(uri)?;
810        if let Some(v) = self.config.max_connections {
811            cfg.max_connections = v;
812        }
813        if let Some(v) = self.config.max_message_size {
814            cfg.max_message_size = v;
815        }
816        if let Some(v) = self.config.heartbeat_interval_ms {
817            cfg.heartbeat_interval = std::time::Duration::from_millis(v);
818        }
819        if let Some(v) = self.config.idle_timeout_ms {
820            cfg.idle_timeout = std::time::Duration::from_millis(v);
821        }
822        if let Some(v) = self.config.connect_timeout_ms {
823            cfg.connect_timeout = std::time::Duration::from_millis(v);
824        }
825        if let Some(v) = self.config.response_timeout_ms {
826            cfg.response_timeout = std::time::Duration::from_millis(v);
827        }
828        if let Some(v) = self.config.send_timeout_ms {
829            cfg.send_timeout = std::time::Duration::from_millis(v);
830        }
831        if let Some(v) = self.config.binary_payload {
832            cfg.binary_payload = v;
833        }
834        if let Some(ref v) = self.config.subprotocols {
835            cfg.subprotocols = v.clone();
836        }
837        let health_check = WsHealthCheck::new(cfg.host.clone(), cfg.port);
838        ctx.register_current_route_health_check(std::sync::Arc::new(health_check));
839        Ok(Box::new(WsEndpoint {
840            uri: uri.to_string(),
841            cfg,
842        }))
843    }
844}
845
846struct WsEndpoint {
847    uri: String,
848    cfg: WsEndpointConfig,
849}
850
851impl Endpoint for WsEndpoint {
852    fn uri(&self) -> &str {
853        &self.uri
854    }
855
856    fn create_consumer(
857        &self,
858        rt: Arc<dyn camel_component_api::RuntimeObservability>,
859    ) -> Result<Box<dyn Consumer>, CamelError> {
860        Ok(Box::new(WsConsumer::new(self.cfg.server_config(), rt)))
861    }
862
863    fn create_producer(
864        &self,
865        _rt: Arc<dyn camel_component_api::RuntimeObservability>,
866        _ctx: &ProducerContext,
867    ) -> Result<BoxProcessor, CamelError> {
868        Ok(BoxProcessor::new(WsProducer::new(self.cfg.client_config())))
869    }
870}
871
872pub struct WsConsumer {
873    cfg: WsServerConfig,
874    registry: Arc<WsConnectionRegistry>,
875    server_state: Option<WsAppState>,
876    registry_key: Option<(String, u16, String)>,
877    forward_task: Option<JoinHandle<Result<(), CamelError>>>,
878    security_ctx: Option<camel_component_api::SecurityContext>,
879    /// Phase B will use this for `rt.metrics().increment_errors(...)` and
880    /// `rt.health().force_unhealthy_for_route(...)` calls per ADR-0012.
881    #[allow(dead_code)]
882    runtime: Arc<dyn camel_component_api::RuntimeObservability>,
883}
884
885impl WsConsumer {
886    pub fn new(
887        cfg: WsServerConfig,
888        runtime: Arc<dyn camel_component_api::RuntimeObservability>,
889    ) -> Self {
890        Self {
891            cfg,
892            registry: Arc::new(WsConnectionRegistry::new()),
893            server_state: None,
894            registry_key: None,
895            forward_task: None,
896            security_ctx: None,
897            runtime,
898        }
899    }
900}
901
902#[async_trait]
903impl Consumer for WsConsumer {
904    async fn start(&mut self, ctx: ConsumerContext) -> Result<(), CamelError> {
905        // Reject double-start (WS-006)
906        if self.server_state.is_some() {
907            return Err(CamelError::EndpointCreationFailed(
908                "WebSocket consumer already started".into(),
909            ));
910        }
911
912        tracing::info!(
913            host = self.cfg.inner.host,
914            port = self.cfg.inner.port,
915            path = self.cfg.inner.path,
916            scheme = self.cfg.inner.scheme,
917            "WebSocket consumer starting"
918        );
919
920        let tls_config = if self.cfg.inner.scheme == "wss" {
921            let cert_path = self.cfg.inner.tls_cert.clone().ok_or_else(|| {
922                CamelError::EndpointCreationFailed("TLS cert path is required for wss".into())
923            })?;
924            let key_path = self.cfg.inner.tls_key.clone().ok_or_else(|| {
925                CamelError::EndpointCreationFailed("TLS key path is required for wss".into())
926            })?;
927            Some(WsTlsConfig {
928                cert_path,
929                key_path,
930            })
931        } else {
932            None
933        };
934
935        let state = ServerRegistry::global()
936            .get_or_spawn(&self.cfg.inner.host, self.cfg.inner.port, tls_config)
937            .await?;
938
939        let (env_tx, mut env_rx) = mpsc::channel::<ExchangeEnvelope>(64);
940        {
941            let mut table = state.dispatch.write().await;
942            table.insert(self.cfg.inner.path.clone(), env_tx);
943        }
944
945        state.path_configs.insert(
946            self.cfg.inner.path.clone(),
947            WsPathConfig {
948                max_connections: self.cfg.inner.max_connections,
949                max_message_size: self.cfg.inner.max_message_size,
950                heartbeat_interval: self.cfg.inner.heartbeat_interval,
951                idle_timeout: self.cfg.inner.idle_timeout,
952                allow_origin: self.cfg.inner.allow_origin.clone(),
953            },
954        );
955
956        if let Some(ref sec_ctx) = self.security_ctx {
957            let path = self.cfg.inner.path.clone();
958            state.path_policies.insert(path, sec_ctx.clone());
959        }
960
961        let registry_key = (
962            self.cfg.inner.canonical_host(),
963            self.cfg.inner.port,
964            self.cfg.inner.path.clone(),
965        );
966        global_registries().insert(registry_key.clone(), Arc::clone(&self.registry));
967
968        let sender = ctx.sender();
969        let forward_task: JoinHandle<Result<(), CamelError>> = tokio::spawn(async move {
970            while let Some(envelope) = env_rx.recv().await {
971                if sender.send(envelope).await.is_err() {
972                    break;
973                }
974            }
975            Ok(())
976        });
977
978        self.server_state = Some(state);
979        self.registry_key = Some(registry_key);
980        self.forward_task = Some(forward_task);
981        Ok(())
982    }
983
984    async fn stop(&mut self) -> Result<(), CamelError> {
985        tracing::info!(
986            host = self.cfg.inner.host,
987            port = self.cfg.inner.port,
988            path = self.cfg.inner.path,
989            "WebSocket consumer stopping"
990        );
991
992        let close_msg = WsMessage::Close(Some(axum::extract::ws::CloseFrame {
993            code: axum::extract::ws::CloseCode::from(1001u16),
994            reason: "consumer stopping".into(),
995        }));
996        for tx in self.registry.snapshot_senders() {
997            let _ = try_send_with_backpressure(&tx, close_msg.clone(), "consumer-stop-close");
998        }
999
1000        let mut had_server_error = false;
1001
1002        if let Some(state) = self.server_state.take() {
1003            had_server_error = state.server_error.load(Ordering::Relaxed);
1004            state.path_policies.remove(&self.cfg.inner.path);
1005            let mut table = state.dispatch.write().await;
1006            table.remove(&self.cfg.inner.path);
1007            state.path_configs.remove(&self.cfg.inner.path);
1008        }
1009
1010        if let Some(key) = self.registry_key.take() {
1011            global_registries().remove(&key);
1012            ServerRegistry::global().release(key.1);
1013        }
1014
1015        if let Some(task) = self.forward_task.take() {
1016            task.abort();
1017        }
1018
1019        tracing::info!(
1020            host = self.cfg.inner.host,
1021            port = self.cfg.inner.port,
1022            path = self.cfg.inner.path,
1023            "WebSocket consumer stopped"
1024        );
1025
1026        if had_server_error {
1027            tracing::warn!(
1028                host = self.cfg.inner.host,
1029                port = self.cfg.inner.port,
1030                path = self.cfg.inner.path,
1031                "WebSocket server had errors during its lifetime"
1032            );
1033            return Err(CamelError::ProcessorError(
1034                "WebSocket server terminated with errors during its lifetime".into(),
1035            ));
1036        }
1037
1038        Ok(())
1039    }
1040
1041    fn concurrency_model(&self) -> ConcurrencyModel {
1042        ConcurrencyModel::Concurrent {
1043            max: Some(self.cfg.inner.max_connections as usize),
1044        }
1045    }
1046
1047    fn background_task_handle(&mut self) -> Option<JoinHandle<Result<(), CamelError>>> {
1048        self.forward_task.take()
1049    }
1050
1051    fn set_security_context(&mut self, ctx: camel_component_api::SecurityContext) {
1052        self.security_ctx = Some(ctx);
1053    }
1054}
1055
1056use std::sync::atomic::{AtomicBool, Ordering};
1057
1058fn new_atomic_false() -> Arc<AtomicBool> {
1059    Arc::new(AtomicBool::new(false))
1060}
1061
1062/// Classify a WebSocket error as retryable (transient network failure).
1063///
1064/// Retryable: connection refused, timeout, connection failed.
1065/// Permanent: anything else (protocol errors, auth failures, etc.).
1066#[inline]
1067fn is_retryable_ws_error(err: &CamelError) -> bool {
1068    let s = err.to_string();
1069    s.contains("connection refused") || s.contains("timeout") || s.contains("connection failed")
1070}
1071
1072#[derive(Clone)]
1073pub struct WsProducer {
1074    cfg: WsClientConfig,
1075    /// Shared flag set by the async future when server-send hits backpressure,
1076    /// so that the next `poll_ready` call can return an error. (WS-003)
1077    backpressure_flag: Arc<AtomicBool>,
1078}
1079
1080impl WsProducer {
1081    pub fn new(cfg: WsClientConfig) -> Self {
1082        Self {
1083            cfg,
1084            backpressure_flag: Arc::new(AtomicBool::new(false)),
1085        }
1086    }
1087}
1088
1089impl Service<Exchange> for WsProducer {
1090    type Response = Exchange;
1091    type Error = CamelError;
1092    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
1093
1094    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
1095        // Return error if last server-send hit backpressure (WS-003)
1096        if self.backpressure_flag.swap(false, Ordering::Relaxed) {
1097            return Poll::Ready(Err(CamelError::ProcessorError(
1098                "WebSocket producer backpressure: previous send was dropped due to full channel"
1099                    .into(),
1100            )));
1101        }
1102        Poll::Ready(Ok(()))
1103    }
1104
1105    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
1106        let cfg = self.cfg.clone();
1107        let backpressure_flag = Arc::clone(&self.backpressure_flag);
1108
1109        Box::pin(async move {
1110            let canonical_host = cfg.inner.canonical_host();
1111            let key = (
1112                canonical_host.clone(),
1113                cfg.inner.port,
1114                cfg.inner.path.clone(),
1115            );
1116
1117            let send_to_all = exchange
1118                .input
1119                .header("CamelWsSendToAll")
1120                .and_then(|v| v.as_bool())
1121                .or_else(|| exchange.input.header("sendToAll").and_then(|v| v.as_bool()))
1122                .unwrap_or(false);
1123
1124            let conn_keys_header = exchange
1125                .input
1126                .header("CamelWsConnectionKey")
1127                .and_then(|v| v.as_str())
1128                .map(str::to_string);
1129
1130            let local_exists = global_registries().contains_key(&key);
1131            let server_send_mode = send_to_all || conn_keys_header.is_some() || local_exists;
1132
1133            let message_type = exchange
1134                .input
1135                .header("CamelWsMessageType")
1136                .and_then(|v| v.as_str())
1137                .unwrap_or("text")
1138                .to_ascii_lowercase();
1139
1140            if server_send_mode {
1141                let registry = global_registries().get(&key).map(|e| Arc::clone(e.value()));
1142                let Some(registry) = registry else {
1143                    return Err(CamelError::ProcessorError(format!(
1144                        "WebSocket local consumer not found for {}:{}{}",
1145                        canonical_host, cfg.inner.port, cfg.inner.path
1146                    )));
1147                };
1148
1149                let out_msg = body_to_axum_ws_message(
1150                    std::mem::take(&mut exchange.input.body),
1151                    &message_type,
1152                )
1153                .await?;
1154
1155                let targets = if send_to_all {
1156                    registry.snapshot_senders()
1157                } else if let Some(keys) = conn_keys_header {
1158                    let parsed: Vec<String> = keys
1159                        .split(',')
1160                        .map(str::trim)
1161                        .filter(|k| !k.is_empty())
1162                        .map(|k| k.to_string())
1163                        .collect();
1164                    registry.get_senders_for_keys(&parsed)
1165                } else {
1166                    registry.snapshot_senders()
1167                };
1168
1169                let mut dropped = 0usize;
1170                for tx in &targets {
1171                    if !try_send_with_backpressure(tx, out_msg.clone(), "producer-send") {
1172                        dropped += 1;
1173                    }
1174                }
1175
1176                if dropped > 0 {
1177                    tracing::warn!(
1178                        host = canonical_host,
1179                        port = cfg.inner.port,
1180                        path = cfg.inner.path,
1181                        dropped,
1182                        total = targets.len(),
1183                        "WebSocket producer dropped messages due to backpressure"
1184                    );
1185                    exchange.input.set_header(
1186                        "CamelWsDeliveryDropped",
1187                        serde_json::Value::Number(dropped.into()),
1188                    );
1189                    // Signal backpressure for next poll_ready call (WS-003)
1190                    backpressure_flag.store(true, Ordering::Relaxed);
1191                    if dropped == targets.len() {
1192                        return Err(CamelError::ProcessorError(format!(
1193                            "WebSocket producer: all {dropped} message(s) dropped due to backpressure"
1194                        )));
1195                    }
1196                }
1197
1198                tracing::debug!(
1199                    host = canonical_host,
1200                    port = cfg.inner.port,
1201                    path = cfg.inner.path,
1202                    targets = targets.len(),
1203                    "WebSocket producer server-send complete"
1204                );
1205
1206                return Ok(exchange);
1207            }
1208
1209            let url = format!(
1210                "{}://{}:{}{}",
1211                cfg.inner.scheme, cfg.inner.host, cfg.inner.port, cfg.inner.path
1212            );
1213
1214            tracing::debug!(url = url, "WebSocket producer connecting");
1215
1216            #[allow(unused_mut)]
1217            let mut request = url
1218                .clone()
1219                .into_client_request()
1220                .map_err(|e| CamelError::ProcessorError(format!("WebSocket request error: {e}")))?;
1221
1222            #[cfg(feature = "otel")]
1223            {
1224                let mut otel_headers = HashMap::new();
1225                camel_otel::inject_from_exchange(&exchange, &mut otel_headers);
1226                for (k, v) in otel_headers {
1227                    if let (Ok(name), Ok(val)) = (
1228                        http::header::HeaderName::from_bytes(k.as_bytes()),
1229                        http::header::HeaderValue::from_str(&v),
1230                    ) {
1231                        request.headers_mut().insert(name, val);
1232                    }
1233                }
1234            }
1235
1236            // Add Sec-WebSocket-Protocol header if subprotocols configured (WS-007)
1237            if !cfg.inner.subprotocols.is_empty() {
1238                let proto_value = cfg.inner.subprotocols.join(", ");
1239                if let (Ok(name), Ok(val)) = (
1240                    http::header::HeaderName::from_bytes(b"Sec-WebSocket-Protocol"),
1241                    http::header::HeaderValue::from_str(&proto_value),
1242                ) {
1243                    request.headers_mut().insert(name, val);
1244                }
1245            }
1246
1247            // Determine message type: respect binary_payload config (WS-018)
1248            let effective_message_type = if cfg.inner.binary_payload {
1249                "binary"
1250            } else {
1251                &message_type
1252            };
1253
1254            let reconnect_policy = cfg.inner.reconnect_policy.clone();
1255            let mut ws_stream =
1256                connect_ws_with_retry(request, &url, cfg.inner.connect_timeout, &reconnect_policy)
1257                    .await?;
1258
1259            // Close/reconnect path: rate-limited bail. On close frame, sleep
1260            // delay_for(0) and return Err to signal the outer route to re-invoke
1261            // the producer. The attempts counter below bounds how many times
1262            // we'll signal reconnect before terminating. Independent counter —
1263            // OLD code shared a counter with the connect loop above; this is a
1264            // behavior change (cleaner separation of concerns).
1265            let attempts = 0u32;
1266
1267            let out_msg = body_to_client_ws_message(
1268                std::mem::take(&mut exchange.input.body),
1269                effective_message_type,
1270            )
1271            .await?;
1272
1273            ws_stream
1274                .send(out_msg)
1275                .await
1276                .map_err(|e| CamelError::ProcessorError(format!("WebSocket send failed: {e}")))?;
1277
1278            let incoming = tokio::time::timeout(cfg.inner.response_timeout, async {
1279                loop {
1280                    match ws_stream.next().await {
1281                        Some(Ok(ClientWsMessage::Ping(_))) | Some(Ok(ClientWsMessage::Pong(_))) => {
1282                            continue;
1283                        }
1284                        other => break other,
1285                    }
1286                }
1287            })
1288            .await
1289            .map_err(|_| CamelError::ProcessorError("WebSocket response timeout".into()))?;
1290
1291            match incoming {
1292                Some(Ok(ClientWsMessage::Text(text))) => {
1293                    tracing::debug!(url = url, "WebSocket producer received text response");
1294                    exchange.input.body = CamelBody::Text(text.to_string());
1295                }
1296                Some(Ok(ClientWsMessage::Binary(data))) => {
1297                    tracing::debug!(url = url, "WebSocket producer received binary response");
1298                    exchange.input.body = CamelBody::Bytes(data);
1299                }
1300                Some(Ok(ClientWsMessage::Close(frame))) => {
1301                    let normal = frame
1302                        .as_ref()
1303                        .map(|f| {
1304                            f.code == tungstenite::protocol::frame::coding::CloseCode::Normal
1305                                || f.code == tungstenite::protocol::frame::coding::CloseCode::Away
1306                        })
1307                        .unwrap_or(true);
1308
1309                    if normal {
1310                        tracing::debug!(url = url, "WebSocket producer received normal close");
1311                        exchange.input.body = CamelBody::Empty;
1312                    } else if reconnect_policy.should_retry(attempts + 1) {
1313                        let delay = reconnect_policy.delay_for(0); // fresh delay on close
1314                        tracing::warn!(
1315                            url = url,
1316                            attempt = attempts + 1,
1317                            delay_ms = delay.as_millis(),
1318                            "WebSocket closed by peer — reconnecting"
1319                        );
1320                        tokio::time::sleep(delay).await;
1321                        return Err(CamelError::ProcessorError(format!(
1322                            "WebSocket reconnect required after close: code {}",
1323                            frame.map(|f| u16::from(f.code)).unwrap_or_default()
1324                        )));
1325                    } else {
1326                        let code = frame.map(|f| u16::from(f.code)).unwrap_or_default();
1327                        return Err(CamelError::ProcessorError(format!(
1328                            "WebSocket peer closed: code {code}"
1329                        )));
1330                    }
1331                }
1332                Some(Ok(_)) | None => {
1333                    exchange.input.body = CamelBody::Empty;
1334                }
1335                Some(Err(e)) => {
1336                    return Err(CamelError::ProcessorError(format!(
1337                        "WebSocket receive failed: {e}"
1338                    )));
1339                }
1340            }
1341
1342            let _ = ws_stream.close(None).await;
1343            tracing::debug!(url = url, "WebSocket producer connection closed");
1344            Ok(exchange)
1345        })
1346    }
1347}
1348
1349async fn body_to_axum_ws_message(
1350    body: CamelBody,
1351    message_type: &str,
1352) -> Result<WsMessage, CamelError> {
1353    match message_type {
1354        "binary" => Ok(WsMessage::Binary(body.into_bytes(10 * 1024 * 1024).await?)),
1355        _ => Ok(WsMessage::Text(body_to_text(body).await?.into())),
1356    }
1357}
1358
1359async fn body_to_client_ws_message(
1360    body: CamelBody,
1361    message_type: &str,
1362) -> Result<ClientWsMessage, CamelError> {
1363    match message_type {
1364        "binary" => Ok(ClientWsMessage::Binary(
1365            body.into_bytes(10 * 1024 * 1024).await?,
1366        )),
1367        _ => Ok(ClientWsMessage::Text(body_to_text(body).await?.into())),
1368    }
1369}
1370
1371async fn body_to_text(body: CamelBody) -> Result<String, CamelError> {
1372    Ok(match body {
1373        CamelBody::Empty => String::new(),
1374        CamelBody::Text(s) => s,
1375        CamelBody::Xml(s) => s,
1376        CamelBody::Json(v) => v.to_string(),
1377        CamelBody::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
1378        CamelBody::Stream(stream) => {
1379            let bytes = CamelBody::Stream(stream)
1380                .into_bytes(10 * 1024 * 1024)
1381                .await?;
1382            String::from_utf8_lossy(&bytes).to_string()
1383        }
1384    })
1385}
1386
1387fn is_origin_allowed(allowed_origin: &str, request_origin: Option<&str>) -> bool {
1388    if allowed_origin == "*" {
1389        return true;
1390    }
1391    request_origin.is_some_and(|origin| origin == allowed_origin)
1392}
1393
1394fn try_send_with_backpressure(tx: &mpsc::Sender<WsMessage>, msg: WsMessage, context: &str) -> bool {
1395    match tx.try_send(msg) {
1396        Ok(()) => true,
1397        Err(error) => {
1398            tracing::warn!(%context, %error, "dropping websocket outbound message due to backpressure");
1399            false
1400        }
1401    }
1402}
1403
1404fn load_tls_config(
1405    cert_path: &str,
1406    key_path: &str,
1407) -> Result<tokio_rustls::rustls::ServerConfig, CamelError> {
1408    use std::fs::File;
1409    use std::io::BufReader;
1410
1411    let cert_file = File::open(cert_path)
1412        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert file error: {e}")))?;
1413    let key_file = File::open(key_path)
1414        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key file error: {e}")))?;
1415
1416    let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
1417        .collect::<Result<Vec<_>, _>>()
1418        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS cert parse error: {e}")))?;
1419
1420    let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
1421        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS key parse error: {e}")))?
1422        .ok_or_else(|| CamelError::EndpointCreationFailed("TLS: no private key found".into()))?;
1423
1424    tokio_rustls::rustls::ServerConfig::builder()
1425        .with_no_client_auth()
1426        .with_single_cert(certs, key)
1427        .map_err(|e| CamelError::EndpointCreationFailed(format!("TLS config error: {e}")))
1428}
1429
1430fn map_connect_error(err: tungstenite::Error, url: &str) -> CamelError {
1431    match err {
1432        tungstenite::Error::Io(ioe) if ioe.kind() == std::io::ErrorKind::ConnectionRefused => {
1433            CamelError::ProcessorError(format!("WebSocket connection refused: {ioe}"))
1434        }
1435        tungstenite::Error::Tls(_) => {
1436            CamelError::ProcessorError("WebSocket TLS handshake failed: handshake error".into())
1437        }
1438        other => {
1439            let msg = other.to_string();
1440            if msg.to_lowercase().contains("connection refused") {
1441                CamelError::ProcessorError(format!("WebSocket connection refused: {msg}"))
1442            } else if msg.to_lowercase().contains("tls") {
1443                CamelError::ProcessorError(format!("WebSocket TLS handshake failed: {msg}"))
1444            } else {
1445                CamelError::ProcessorError(format!("WebSocket connection failed ({url}): {msg}"))
1446            }
1447        }
1448    }
1449}
1450
1451/// Connect to a WebSocket server with retry logic using the configured
1452/// [`NetworkRetryPolicy`]. Extracted for testability so the regression test
1453/// (rc-1nm) can drive the real production connect path rather than a
1454/// synthetic fake.
1455async fn connect_ws_with_retry<R>(
1456    request: R,
1457    url: &str,
1458    connect_timeout: std::time::Duration,
1459    reconnect_policy: &NetworkRetryPolicy,
1460) -> Result<
1461    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
1462    CamelError,
1463>
1464where
1465    R: IntoClientRequest + Unpin + Clone,
1466{
1467    let url_owned = url.to_string();
1468    retry_async(
1469        reconnect_policy,
1470        Some("ws-producer"),
1471        || {
1472            let r = request.clone();
1473            let url = url_owned.clone();
1474            async move {
1475                match tokio::time::timeout(connect_timeout, tokio_tungstenite::connect_async(r))
1476                    .await
1477                {
1478                    Ok(Ok((stream, _))) => Ok(stream),
1479                    Ok(Err(e)) => Err(map_connect_error(e, &url)),
1480                    Err(_) => Err(CamelError::ProcessorError(format!(
1481                        "WebSocket connect timeout ({connect_timeout:?}) to {url}"
1482                    ))),
1483                }
1484            }
1485        },
1486        is_retryable_ws_error,
1487    )
1488    .await
1489}
1490
1491#[cfg(test)]
1492mod tests {
1493    use camel_component_api::test_support::PanicRuntimeObservability;
1494    fn test_rt() -> std::sync::Arc<dyn camel_component_api::RuntimeObservability> {
1495        std::sync::Arc::new(PanicRuntimeObservability)
1496    }
1497    fn rt() -> std::sync::Arc<dyn camel_component_api::RuntimeObservability> {
1498        std::sync::Arc::new(PanicRuntimeObservability)
1499    }
1500
1501    use super::*;
1502    use camel_component_api::NoOpComponentContext;
1503    use std::time::Duration;
1504
1505    use tokio::sync::mpsc;
1506    use tokio_tungstenite::connect_async;
1507    use tokio_tungstenite::tungstenite::Message as ClientMessage;
1508    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
1509    use tokio_util::sync::CancellationToken;
1510    use tower::ServiceExt;
1511
1512    fn free_port() -> u16 {
1513        std::net::TcpListener::bind("127.0.0.1:0")
1514            .unwrap()
1515            .local_addr()
1516            .unwrap()
1517            .port()
1518    }
1519
1520    #[test]
1521    fn ws_component_scheme_is_ws() {
1522        assert_eq!(WsComponent::new().scheme(), "ws");
1523    }
1524
1525    #[test]
1526    fn wss_component_scheme_is_wss() {
1527        assert_eq!(WssComponent::new().scheme(), "wss");
1528    }
1529
1530    #[test]
1531    fn endpoint_config_defaults_match_spec() {
1532        let cfg = WsEndpointConfig::default();
1533        assert_eq!(cfg.scheme, "ws");
1534        assert_eq!(cfg.host, "0.0.0.0");
1535        assert_eq!(cfg.port, 8080);
1536        assert_eq!(cfg.path, "/");
1537        assert_eq!(cfg.max_connections, 100);
1538        assert_eq!(cfg.max_message_size, 65536);
1539        assert!(!cfg.send_to_all);
1540        assert_eq!(cfg.heartbeat_interval, Duration::ZERO);
1541        assert_eq!(cfg.idle_timeout, Duration::ZERO);
1542        assert_eq!(cfg.connect_timeout, Duration::from_secs(10));
1543        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1544        assert_eq!(cfg.allow_origin, "*");
1545        assert_eq!(cfg.tls_cert, None);
1546        assert_eq!(cfg.tls_key, None);
1547        assert!(cfg.reconnect);
1548        assert_eq!(cfg.reconnect_max_attempts, 5);
1549        assert_eq!(cfg.reconnect_delay_ms, 1000);
1550        assert_eq!(cfg.send_timeout, Duration::from_secs(30));
1551        assert!(!cfg.binary_payload);
1552        assert!(cfg.subprotocols.is_empty());
1553    }
1554
1555    #[test]
1556    fn endpoint_config_parses_uri_params() {
1557        let uri = "ws://localhost:9001/chat?maxConnections=42&maxMessageSize=1024&sendToAll=true&heartbeatIntervalMs=1500&idleTimeoutMs=2500&connectTimeoutMs=3500&responseTimeoutMs=4500&allowOrigin=https://example.com&tlsCert=/tmp/cert.pem&tlsKey=/tmp/key.pem";
1558        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1559
1560        assert_eq!(cfg.scheme, "ws");
1561        assert_eq!(cfg.host, "localhost");
1562        assert_eq!(cfg.port, 9001);
1563        assert_eq!(cfg.path, "/chat");
1564        assert_eq!(cfg.max_connections, 42);
1565        assert_eq!(cfg.max_message_size, 1024);
1566        assert!(cfg.send_to_all);
1567        assert_eq!(cfg.heartbeat_interval, Duration::from_millis(1500));
1568        assert_eq!(cfg.idle_timeout, Duration::from_millis(2500));
1569        assert_eq!(cfg.connect_timeout, Duration::from_millis(3500));
1570        assert_eq!(cfg.response_timeout, Duration::from_millis(4500));
1571        assert_eq!(cfg.allow_origin, "https://example.com");
1572        assert_eq!(cfg.tls_cert.as_deref(), Some("/tmp/cert.pem"));
1573        assert_eq!(cfg.tls_key.as_deref(), Some("/tmp/key.pem"));
1574        assert!(cfg.reconnect);
1575        assert_eq!(cfg.reconnect_max_attempts, 5);
1576        assert_eq!(cfg.reconnect_delay_ms, 1000);
1577    }
1578
1579    #[test]
1580    fn endpoint_config_parses_reconnect_uri_params() {
1581        let uri =
1582            "ws://localhost:9001/chat?reconnect=false&reconnectMaxAttempts=2&reconnectDelayMs=25";
1583        let cfg = WsEndpointConfig::from_uri(uri).unwrap();
1584        assert!(!cfg.reconnect);
1585        assert_eq!(cfg.reconnect_max_attempts, 2);
1586        assert_eq!(cfg.reconnect_delay_ms, 25);
1587    }
1588
1589    #[test]
1590    fn endpoint_config_override_chain_uri_overrides_defaults() {
1591        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:8089/echo?maxConnections=7").unwrap();
1592        assert_eq!(cfg.max_connections, 7);
1593        assert_eq!(cfg.max_message_size, 65536);
1594        assert!(!cfg.send_to_all);
1595        assert_eq!(cfg.response_timeout, Duration::from_secs(30));
1596    }
1597
1598    #[test]
1599    fn endpoint_trait_creates_consumer_and_producer() {
1600        let ctx = NoOpComponentContext;
1601        let endpoint = WsComponent::new()
1602            .create_endpoint("ws://127.0.0.1:9010/trait", &ctx)
1603            .unwrap();
1604
1605        endpoint.create_consumer(rt()).unwrap();
1606        endpoint
1607            .create_producer(rt(), &ProducerContext::default())
1608            .unwrap();
1609    }
1610
1611    #[test]
1612    fn ws_consumer_concurrency_model_uses_max_connections() {
1613        let cfg = WsEndpointConfig::from_uri("ws://127.0.0.1:9011/cm?maxConnections=321").unwrap();
1614        let consumer = WsConsumer::new(cfg.server_config(), test_rt());
1615        assert_eq!(
1616            consumer.concurrency_model(),
1617            ConcurrencyModel::Concurrent { max: Some(321) }
1618        );
1619    }
1620
1621    #[tokio::test]
1622    async fn connection_registry_add_remove_broadcast_and_targeted_send() {
1623        let registry = WsConnectionRegistry::new();
1624        let (tx1, mut rx1) = mpsc::channel(8);
1625        let (tx2, mut rx2) = mpsc::channel(8);
1626
1627        registry.insert("k1".into(), tx1);
1628        registry.insert("k2".into(), tx2);
1629        assert_eq!(registry.len(), 2);
1630
1631        for tx in registry.snapshot_senders() {
1632            tx.send(WsMessage::Text("broadcast".into())).await.unwrap();
1633        }
1634
1635        assert_eq!(rx1.recv().await, Some(WsMessage::Text("broadcast".into())));
1636        assert_eq!(rx2.recv().await, Some(WsMessage::Text("broadcast".into())));
1637
1638        let target = registry.get_senders_for_keys(&["k1".to_string()]);
1639        assert_eq!(target.len(), 1);
1640        target[0]
1641            .send(WsMessage::Text("targeted".into()))
1642            .await
1643            .unwrap();
1644
1645        assert_eq!(rx1.recv().await, Some(WsMessage::Text("targeted".into())));
1646        assert!(
1647            tokio::time::timeout(Duration::from_millis(50), rx2.recv())
1648                .await
1649                .is_err()
1650        );
1651
1652        registry.remove("k1");
1653        assert_eq!(registry.len(), 1);
1654    }
1655
1656    #[test]
1657    fn host_canonicalization_maps_local_hosts_to_loopback() {
1658        let c1 = WsEndpointConfig::from_uri("ws://0.0.0.0:9100/a")
1659            .unwrap()
1660            .canonical_host();
1661        let c2 = WsEndpointConfig::from_uri("ws://localhost:9101/b")
1662            .unwrap()
1663            .canonical_host();
1664        let c3 = WsEndpointConfig::from_uri("ws://127.0.0.1:9102/c")
1665            .unwrap()
1666            .canonical_host();
1667
1668        assert_eq!(c1, "127.0.0.1");
1669        assert_eq!(c2, "127.0.0.1");
1670        assert_eq!(c3, "127.0.0.1");
1671    }
1672
1673    #[tokio::test]
1674    async fn echo_flow_round_trips_message_through_consumer_and_producer() {
1675        let port = free_port();
1676        let uri = format!("ws://127.0.0.1:{port}/echo");
1677        let component_ctx = NoOpComponentContext;
1678        let endpoint = WsComponent::new()
1679            .create_endpoint(&uri, &component_ctx)
1680            .unwrap();
1681
1682        let mut consumer = endpoint.create_consumer(rt()).unwrap();
1683        let producer = endpoint
1684            .create_producer(rt(), &ProducerContext::default())
1685            .unwrap();
1686
1687        let (route_tx, mut route_rx) = mpsc::channel(16);
1688        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1689        consumer.start(ctx).await.unwrap();
1690
1691        let route_task = tokio::spawn(async move {
1692            if let Some(envelope) = route_rx.recv().await {
1693                let payload = envelope
1694                    .exchange
1695                    .input
1696                    .body
1697                    .as_text()
1698                    .unwrap_or_default()
1699                    .to_string();
1700                let key = envelope
1701                    .exchange
1702                    .input
1703                    .header("CamelWsConnectionKey")
1704                    .and_then(|v| v.as_str())
1705                    .unwrap()
1706                    .to_string();
1707
1708                let mut response = Exchange::new(CamelMessage::new(CamelBody::Text(payload)));
1709                response
1710                    .input
1711                    .set_header("CamelWsConnectionKey", serde_json::Value::String(key));
1712                producer.oneshot(response).await.unwrap();
1713            }
1714        });
1715
1716        let url = format!("ws://127.0.0.1:{port}/echo");
1717        let (mut client, _) = loop {
1718            match connect_async(&url).await {
1719                Ok(ok) => break ok,
1720                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1721            }
1722        };
1723
1724        client
1725            .send(ClientMessage::Text("hello-ws".into()))
1726            .await
1727            .unwrap();
1728
1729        let incoming = tokio::time::timeout(Duration::from_secs(2), async {
1730            loop {
1731                match client.next().await {
1732                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
1733                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1734                    Some(Ok(_)) => continue,
1735                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1736                    None => panic!("websocket closed before echo"),
1737                }
1738            }
1739        })
1740        .await
1741        .unwrap();
1742
1743        assert_eq!(incoming, "hello-ws");
1744
1745        consumer.stop().await.unwrap();
1746        route_task.await.unwrap();
1747    }
1748
1749    #[tokio::test]
1750    async fn consumer_stop_sends_close_1001() {
1751        let port = free_port();
1752        let uri = format!("ws://127.0.0.1:{port}/shutdown");
1753        let component_ctx = NoOpComponentContext;
1754        let endpoint = WsComponent::new()
1755            .create_endpoint(&uri, &component_ctx)
1756            .unwrap();
1757
1758        let mut consumer = endpoint.create_consumer(rt()).unwrap();
1759        let (route_tx, _route_rx) = mpsc::channel(16);
1760        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1761        consumer.start(ctx).await.unwrap();
1762
1763        let url = format!("ws://127.0.0.1:{port}/shutdown");
1764        let (mut client, _) = loop {
1765            match connect_async(&url).await {
1766                Ok(ok) => break ok,
1767                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1768            }
1769        };
1770
1771        client
1772            .send(ClientMessage::Text("keepalive".into()))
1773            .await
1774            .unwrap();
1775
1776        consumer.stop().await.unwrap();
1777
1778        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1779            loop {
1780                match client.next().await {
1781                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1782                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1783                    Some(Ok(_)) => continue,
1784                    Some(Err(e)) => panic!("ws receive failed: {e}"),
1785                    None => panic!("websocket closed without close frame"),
1786                }
1787            }
1788        })
1789        .await
1790        .unwrap();
1791
1792        assert_eq!(close_code, Some(CloseCode::Away));
1793    }
1794
1795    #[test]
1796    fn wildcard_origin_allows_anything() {
1797        assert!(is_origin_allowed("*", None));
1798        assert!(is_origin_allowed("*", Some("https://example.com")));
1799    }
1800
1801    #[test]
1802    fn exact_origin_requires_match() {
1803        assert!(is_origin_allowed(
1804            "https://example.com",
1805            Some("https://example.com")
1806        ));
1807        assert!(!is_origin_allowed(
1808            "https://example.com",
1809            Some("https://other.com")
1810        ));
1811        assert!(!is_origin_allowed("https://example.com", None));
1812    }
1813
1814    #[test]
1815    fn endpoint_config_rejects_invalid_scheme() {
1816        let result = WsEndpointConfig::from_uri("http://localhost:9000/path");
1817        assert!(result.is_err());
1818        let msg = result.unwrap_err().to_string();
1819        assert!(
1820            msg.contains("Invalid WebSocket scheme"),
1821            "expected scheme error, got: {msg}"
1822        );
1823    }
1824
1825    #[tokio::test]
1826    async fn wss_consumer_start_fails_without_tls_cert() {
1827        let port = free_port();
1828        let component_ctx = NoOpComponentContext;
1829        let endpoint = WssComponent::new()
1830            .create_endpoint(&format!("wss://127.0.0.1:{port}/secure"), &component_ctx)
1831            .unwrap();
1832        let mut consumer = endpoint.create_consumer(rt()).unwrap();
1833        let (tx, _rx) = mpsc::channel(16);
1834        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1835        let result = consumer.start(ctx).await;
1836        assert!(result.is_err());
1837        let msg = result.unwrap_err().to_string();
1838        assert!(
1839            msg.contains("TLS cert path is required"),
1840            "expected TLS cert error, got: {msg}"
1841        );
1842    }
1843
1844    #[tokio::test]
1845    async fn wss_consumer_start_fails_with_nonexistent_cert() {
1846        let port = free_port();
1847        let component_ctx = NoOpComponentContext;
1848        let endpoint = WssComponent::new()
1849            .create_endpoint(&format!(
1850                "wss://127.0.0.1:{port}/secure?tlsCert=/nonexistent/cert.pem&tlsKey=/nonexistent/key.pem"
1851            ), &component_ctx)
1852            .unwrap();
1853        let mut consumer = endpoint.create_consumer(rt()).unwrap();
1854        let (tx, _rx) = mpsc::channel(16);
1855        let ctx = ConsumerContext::new(tx, CancellationToken::new());
1856        let result = consumer.start(ctx).await;
1857        assert!(result.is_err());
1858        let msg = result.unwrap_err().to_string();
1859        assert!(
1860            msg.contains("TLS cert file error"),
1861            "expected cert file error, got: {msg}"
1862        );
1863    }
1864
1865    #[tokio::test]
1866    async fn server_registry_returns_same_state_for_same_port() {
1867        let port = free_port();
1868        let state1 = ServerRegistry::global()
1869            .get_or_spawn("127.0.0.1", port, None)
1870            .await
1871            .unwrap();
1872        let state2 = ServerRegistry::global()
1873            .get_or_spawn("127.0.0.1", port, None)
1874            .await
1875            .unwrap();
1876        assert!(
1877            Arc::ptr_eq(&state1.dispatch, &state2.dispatch),
1878            "expected same dispatch table for same port"
1879        );
1880    }
1881
1882    #[tokio::test]
1883    async fn dispatch_handler_returns_404_for_unregistered_path() {
1884        let port = free_port();
1885        let state = ServerRegistry::global()
1886            .get_or_spawn("127.0.0.1", port, None)
1887            .await
1888            .unwrap();
1889        let app = Router::new().fallback(dispatch_handler).with_state(state);
1890        let response = tokio::time::timeout(
1891            Duration::from_secs(2),
1892            tower::ServiceExt::oneshot(
1893                app,
1894                axum::http::Request::builder()
1895                    .uri("/nonexistent")
1896                    .body(Body::empty())
1897                    .unwrap(),
1898            ),
1899        )
1900        .await
1901        .unwrap()
1902        .unwrap();
1903        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1904    }
1905
1906    #[tokio::test]
1907    async fn client_mode_producer_connects_and_echoes() {
1908        let app = Router::new().route(
1909            "/echo",
1910            axum::routing::get(|ws: WebSocketUpgrade| async move {
1911                ws.on_upgrade(|mut socket: WebSocket| async move {
1912                    while let Some(Ok(msg)) = socket.recv().await {
1913                        match msg {
1914                            WsMessage::Text(text) => {
1915                                let _ = socket.send(WsMessage::Text(text)).await;
1916                            }
1917                            WsMessage::Binary(data) => {
1918                                let _ = socket.send(WsMessage::Binary(data)).await;
1919                            }
1920                            WsMessage::Close(_) => break,
1921                            _ => {}
1922                        }
1923                    }
1924                })
1925            }),
1926        );
1927        // Bind to port 0 directly to avoid TOCTOU race with free_port() + re-bind
1928        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1929        let port = listener.local_addr().unwrap().port();
1930        let server_task = tokio::spawn(async move {
1931            let _ = serve(listener, app).await;
1932        });
1933
1934        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/echo")).unwrap();
1935        let producer = WsProducer::new(cfg.client_config());
1936
1937        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello-client".into())));
1938        tokio::time::sleep(Duration::from_millis(25)).await;
1939        let result =
1940            match tokio::time::timeout(Duration::from_secs(3), producer.oneshot(exchange)).await {
1941                Ok(Ok(r)) => r,
1942                Ok(Err(_)) => panic!("producer call failed"),
1943                Err(_) => panic!("producer call timed out"),
1944            };
1945
1946        assert_eq!(result.input.body.as_text().unwrap(), "hello-client");
1947
1948        server_task.abort();
1949    }
1950
1951    #[tokio::test]
1952    async fn max_connections_rejects_with_close_1013() {
1953        let port = free_port();
1954        let uri = format!("ws://127.0.0.1:{port}/limited?maxConnections=1");
1955        let component_ctx = NoOpComponentContext;
1956        let endpoint = WsComponent::new()
1957            .create_endpoint(&uri, &component_ctx)
1958            .unwrap();
1959        let mut consumer = endpoint.create_consumer(rt()).unwrap();
1960        let (route_tx, _route_rx) = mpsc::channel(16);
1961        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
1962        consumer.start(ctx).await.unwrap();
1963
1964        let url = format!("ws://127.0.0.1:{port}/limited");
1965        let (_client1, _) = loop {
1966            match connect_async(&url).await {
1967                Ok(ok) => break ok,
1968                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
1969            }
1970        };
1971
1972        tokio::time::sleep(Duration::from_millis(100)).await;
1973
1974        let (mut client2, _) = connect_async(&url).await.unwrap();
1975
1976        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
1977            loop {
1978                match client2.next().await {
1979                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
1980                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
1981                    Some(Ok(ClientMessage::Text(_))) => continue,
1982                    Some(Ok(_)) => continue,
1983                    Some(Err(e)) => panic!("client2 ws receive failed: {e}"),
1984                    None => panic!("client2 closed without close frame"),
1985                }
1986            }
1987        })
1988        .await
1989        .unwrap();
1990
1991        assert_eq!(
1992            close_code,
1993            Some(CloseCode::from(1013u16)),
1994            "expected 1013 (Try Again Later) for max connections"
1995        );
1996
1997        consumer.stop().await.unwrap();
1998    }
1999
2000    #[tokio::test]
2001    async fn max_message_size_rejects_with_close_1009() {
2002        let port = free_port();
2003        let uri = format!("ws://127.0.0.1:{port}/sizelimit?maxMessageSize=10");
2004        let component_ctx = NoOpComponentContext;
2005        let endpoint = WsComponent::new()
2006            .create_endpoint(&uri, &component_ctx)
2007            .unwrap();
2008        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2009        let (route_tx, _route_rx) = mpsc::channel(16);
2010        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2011        consumer.start(ctx).await.unwrap();
2012
2013        let url = format!("ws://127.0.0.1:{port}/sizelimit");
2014        let (mut client, _) = loop {
2015            match connect_async(&url).await {
2016                Ok(ok) => break ok,
2017                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2018            }
2019        };
2020
2021        let oversized = "x".repeat(100);
2022        client
2023            .send(ClientMessage::Text(oversized.into()))
2024            .await
2025            .unwrap();
2026
2027        let close_code = tokio::time::timeout(Duration::from_secs(2), async {
2028            loop {
2029                match client.next().await {
2030                    Some(Ok(ClientMessage::Close(frame))) => break frame.map(|f| f.code),
2031                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2032                    Some(Ok(_)) => continue,
2033                    Some(Err(e)) => panic!("ws receive failed: {e}"),
2034                    None => panic!("websocket closed without close frame"),
2035                }
2036            }
2037        })
2038        .await
2039        .unwrap();
2040
2041        assert_eq!(
2042            close_code,
2043            Some(CloseCode::from(1009u16)),
2044            "expected 1009 (Message Too Big) for oversized message"
2045        );
2046
2047        consumer.stop().await.unwrap();
2048    }
2049
2050    #[tokio::test]
2051    async fn origin_rejection_returns_403() {
2052        let port = free_port();
2053        let uri = format!("ws://127.0.0.1:{port}/origintest?allowOrigin=https://allowed.com");
2054        let component_ctx = NoOpComponentContext;
2055        let endpoint = WsComponent::new()
2056            .create_endpoint(&uri, &component_ctx)
2057            .unwrap();
2058        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2059        let (route_tx, _route_rx) = mpsc::channel(16);
2060        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2061        consumer.start(ctx).await.unwrap();
2062
2063        let state = ServerRegistry::global()
2064            .get_or_spawn("127.0.0.1", port, None)
2065            .await
2066            .unwrap();
2067        let app = Router::new().fallback(dispatch_handler).with_state(state);
2068
2069        let response = tokio::time::timeout(
2070            Duration::from_secs(2),
2071            tower::ServiceExt::oneshot(
2072                app,
2073                axum::http::Request::builder()
2074                    .uri("/origintest")
2075                    .header("origin", "https://evil.com")
2076                    .header("upgrade", "websocket")
2077                    .header("connection", "Upgrade")
2078                    .header("sec-websocket-version", "13")
2079                    .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
2080                    .body(Body::empty())
2081                    .unwrap(),
2082            ),
2083        )
2084        .await
2085        .unwrap()
2086        .unwrap();
2087
2088        assert_eq!(
2089            response.status(),
2090            StatusCode::FORBIDDEN,
2091            "expected 403 for disallowed origin"
2092        );
2093
2094        consumer.stop().await.unwrap();
2095    }
2096
2097    #[tokio::test]
2098    async fn broadcast_sends_to_all_connected_clients() {
2099        let port = free_port();
2100        let uri = format!("ws://127.0.0.1:{port}/bc");
2101        let component_ctx = NoOpComponentContext;
2102        let endpoint = WsComponent::new()
2103            .create_endpoint(&uri, &component_ctx)
2104            .unwrap();
2105        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2106        let producer = endpoint
2107            .create_producer(rt(), &ProducerContext::default())
2108            .unwrap();
2109
2110        let (route_tx, _route_rx) = mpsc::channel(16);
2111        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2112        consumer.start(ctx).await.unwrap();
2113
2114        let url = format!("ws://127.0.0.1:{port}/bc");
2115
2116        let (mut client1, _) = loop {
2117            match connect_async(&url).await {
2118                Ok(ok) => break ok,
2119                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2120            }
2121        };
2122
2123        let (mut client2, _) = connect_async(&url).await.unwrap();
2124
2125        tokio::time::sleep(Duration::from_millis(100)).await;
2126
2127        let mut response =
2128            Exchange::new(CamelMessage::new(CamelBody::Text("broadcast-msg".into())));
2129        response
2130            .input
2131            .set_header("CamelWsSendToAll", serde_json::Value::Bool(true));
2132        producer.oneshot(response).await.unwrap();
2133
2134        let recv1 = tokio::time::timeout(Duration::from_secs(2), async {
2135            loop {
2136                match client1.next().await {
2137                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2138                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2139                    _ => panic!("client1 unexpected message or close"),
2140                }
2141            }
2142        })
2143        .await
2144        .unwrap();
2145
2146        let recv2 = tokio::time::timeout(Duration::from_secs(2), async {
2147            loop {
2148                match client2.next().await {
2149                    Some(Ok(ClientMessage::Text(txt))) => break txt.to_string(),
2150                    Some(Ok(ClientMessage::Ping(_))) | Some(Ok(ClientMessage::Pong(_))) => continue,
2151                    _ => panic!("client2 unexpected message or close"),
2152                }
2153            }
2154        })
2155        .await
2156        .unwrap();
2157
2158        assert_eq!(recv1, "broadcast-msg");
2159        assert_eq!(recv2, "broadcast-msg");
2160
2161        consumer.stop().await.unwrap();
2162    }
2163
2164    #[tokio::test]
2165    async fn concurrent_get_or_spawn_returns_same_state() {
2166        let port = free_port();
2167        let results: Arc<std::sync::Mutex<Vec<WsAppState>>> =
2168            Arc::new(std::sync::Mutex::new(Vec::new()));
2169
2170        let mut handles = Vec::new();
2171        for _ in 0..4 {
2172            let results = results.clone();
2173            handles.push(tokio::spawn(async move {
2174                let state = ServerRegistry::global()
2175                    .get_or_spawn("127.0.0.1", port, None)
2176                    .await
2177                    .unwrap();
2178                results.lock().unwrap().push(state);
2179            }));
2180        }
2181
2182        for h in handles {
2183            h.await.unwrap();
2184        }
2185
2186        let states = results.lock().unwrap();
2187        assert_eq!(states.len(), 4);
2188        for i in 1..states.len() {
2189            assert!(
2190                Arc::ptr_eq(&states[0].dispatch, &states[i].dispatch),
2191                "all concurrent callers should get the same dispatch table"
2192            );
2193        }
2194    }
2195
2196    #[tokio::test]
2197    async fn body_conversion_helpers_cover_text_and_binary_paths() {
2198        let text_msg = body_to_axum_ws_message(CamelBody::Text("abc".into()), "text")
2199            .await
2200            .unwrap();
2201        assert!(matches!(text_msg, WsMessage::Text(_)));
2202
2203        let bin_msg = body_to_axum_ws_message(CamelBody::Bytes(vec![1, 2, 3].into()), "binary")
2204            .await
2205            .unwrap();
2206        assert!(matches!(bin_msg, WsMessage::Binary(_)));
2207
2208        let client_text =
2209            body_to_client_ws_message(CamelBody::Json(serde_json::json!({"k":"v"})), "text")
2210                .await
2211                .unwrap();
2212        assert!(matches!(client_text, ClientWsMessage::Text(_)));
2213
2214        let client_bin = body_to_client_ws_message(CamelBody::Bytes(vec![7, 8].into()), "binary")
2215            .await
2216            .unwrap();
2217        assert!(matches!(client_bin, ClientWsMessage::Binary(_)));
2218    }
2219
2220    #[tokio::test]
2221    async fn body_to_text_handles_empty_text_json_and_bytes() {
2222        assert_eq!(body_to_text(CamelBody::Empty).await.unwrap(), "");
2223        assert_eq!(
2224            body_to_text(CamelBody::Text("hello".into())).await.unwrap(),
2225            "hello"
2226        );
2227        assert_eq!(
2228            body_to_text(CamelBody::Json(serde_json::json!({"n":1})))
2229                .await
2230                .unwrap(),
2231            "{\"n\":1}"
2232        );
2233        assert_eq!(
2234            body_to_text(CamelBody::Bytes(b"hi".to_vec().into()))
2235                .await
2236                .unwrap(),
2237            "hi"
2238        );
2239    }
2240
2241    #[test]
2242    fn try_send_with_backpressure_returns_false_when_channel_full() {
2243        let (tx, _rx) = mpsc::channel::<WsMessage>(1);
2244        assert!(try_send_with_backpressure(
2245            &tx,
2246            WsMessage::Text("first".into()),
2247            "test"
2248        ));
2249        assert!(!try_send_with_backpressure(
2250            &tx,
2251            WsMessage::Text("second".into()),
2252            "test"
2253        ));
2254    }
2255
2256    #[test]
2257    fn map_connect_error_formats_connection_refused_and_generic_errors() {
2258        let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
2259        let err = map_connect_error(tungstenite::Error::Io(refused), "ws://localhost:1/x");
2260        assert!(err.to_string().contains("WebSocket connection refused"));
2261
2262        let generic = map_connect_error(
2263            tungstenite::Error::Protocol(
2264                tokio_tungstenite::tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
2265            ),
2266            "ws://localhost:2/y",
2267        );
2268        assert!(
2269            generic
2270                .to_string()
2271                .contains("WebSocket connection failed (ws://localhost:2/y)")
2272        );
2273    }
2274
2275    // === Phase B Finding Tests ===
2276
2277    // WS-015: maxConnections=0 must be rejected
2278    #[test]
2279    fn from_uri_rejects_max_connections_zero() {
2280        let result = WsEndpointConfig::from_uri("ws://localhost:9200/test?maxConnections=0");
2281        assert!(result.is_err());
2282        let msg = result.unwrap_err().to_string();
2283        assert!(
2284            msg.contains("maxConnections must be >= 1"),
2285            "expected maxConnections validation error, got: {msg}"
2286        );
2287    }
2288
2289    // WS-019: maxMessageSize=0 must be rejected
2290    #[test]
2291    fn from_uri_rejects_max_message_size_zero() {
2292        let result = WsEndpointConfig::from_uri("ws://localhost:9201/test?maxMessageSize=0");
2293        assert!(result.is_err());
2294        let msg = result.unwrap_err().to_string();
2295        assert!(
2296            msg.contains("maxMessageSize must be > 0"),
2297            "expected maxMessageSize validation error, got: {msg}"
2298        );
2299    }
2300
2301    // WS-020: allowOrigin="" must be rejected
2302    #[test]
2303    fn from_uri_rejects_empty_allow_origin() {
2304        let result = WsEndpointConfig::from_uri("ws://localhost:9202/test?allowOrigin=");
2305        assert!(result.is_err());
2306        let msg = result.unwrap_err().to_string();
2307        assert!(
2308            msg.contains("allowOrigin must not be empty"),
2309            "expected allowOrigin validation error, got: {msg}"
2310        );
2311    }
2312
2313    // WS-006: Double-start must be rejected
2314    #[tokio::test]
2315    async fn consumer_double_start_returns_error() {
2316        let port = free_port();
2317        let uri = format!("ws://127.0.0.1:{port}/doublestart");
2318        let component_ctx = NoOpComponentContext;
2319        let endpoint = WsComponent::new()
2320            .create_endpoint(&uri, &component_ctx)
2321            .unwrap();
2322
2323        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2324        let (route_tx, _route_rx) = mpsc::channel(16);
2325        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2326
2327        // First start should succeed
2328        consumer.start(ctx).await.unwrap();
2329
2330        // Second start should fail
2331        let (route_tx2, _route_rx2) = mpsc::channel(16);
2332        let ctx2 = ConsumerContext::new(route_tx2, CancellationToken::new());
2333        let result = consumer.start(ctx2).await;
2334        assert!(result.is_err());
2335        let msg = result.unwrap_err().to_string();
2336        assert!(
2337            msg.contains("already started"),
2338            "expected double-start error, got: {msg}"
2339        );
2340
2341        consumer.stop().await.unwrap();
2342    }
2343
2344    // WS-005: Registry cleanup on stop + port reuse
2345    #[tokio::test]
2346    async fn registry_cleanup_on_consumer_stop() {
2347        let port = free_port();
2348        let uri = format!("ws://127.0.0.1:{port}/cleanup");
2349        let component_ctx = NoOpComponentContext;
2350        let endpoint = WsComponent::new()
2351            .create_endpoint(&uri, &component_ctx)
2352            .unwrap();
2353
2354        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2355        let (route_tx, _route_rx) = mpsc::channel(16);
2356        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2357        consumer.start(ctx).await.unwrap();
2358
2359        // Verify registry entry exists
2360        let registries = global_registries();
2361        let key = ("127.0.0.1".to_string(), port, "/cleanup".to_string());
2362        assert!(
2363            registries.contains_key(&key),
2364            "registry should have entry after start"
2365        );
2366
2367        // Stop consumer
2368        consumer.stop().await.unwrap();
2369
2370        // Verify registry entry is removed
2371        assert!(
2372            !registries.contains_key(&key),
2373            "registry should be cleaned up after stop"
2374        );
2375
2376        // Verify server registry reference is released (port can be reused)
2377        // The ServerRegistry should have removed the entry
2378        let server_reg = ServerRegistry::global();
2379        let guard = server_reg.inner.lock().unwrap();
2380        assert!(
2381            !guard.contains_key(&port),
2382            "ServerRegistry should remove port entry after last consumer stops"
2383        );
2384    }
2385
2386    // WS-003 + WS-004: poll_ready backpressure and server-send error handling
2387    #[tokio::test]
2388    async fn producer_server_send_returns_error_when_all_dropped() {
2389        let port = free_port();
2390        let uri = format!("ws://127.0.0.1:{port}/backpressure");
2391        let component_ctx = NoOpComponentContext;
2392        let endpoint = WsComponent::new()
2393            .create_endpoint(&uri, &component_ctx)
2394            .unwrap();
2395
2396        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2397        let producer = endpoint
2398            .create_producer(rt(), &ProducerContext::default())
2399            .unwrap();
2400
2401        let (route_tx, _route_rx) = mpsc::channel(1); // Tiny channel to force backpressure
2402        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2403        consumer.start(ctx).await.unwrap();
2404
2405        // Connect a client so the registry has an entry
2406        let url = format!("ws://127.0.0.1:{port}/backpressure");
2407        let (mut client, _) = loop {
2408            match connect_async(&url).await {
2409                Ok(ok) => break ok,
2410                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2411            }
2412        };
2413
2414        // Don't consume messages — let the channel fill up
2415        tokio::time::sleep(Duration::from_millis(50)).await;
2416
2417        // Flood the channel to trigger backpressure
2418        let mut all_dropped = false;
2419        for _ in 0..100 {
2420            let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("flood".into())));
2421            match producer.clone().oneshot(exchange).await {
2422                Ok(_) => {}
2423                Err(e) => {
2424                    if e.to_string().contains("backpressure") {
2425                        all_dropped = true;
2426                        break;
2427                    }
2428                }
2429            }
2430        }
2431
2432        // The producer should eventually return a backpressure error
2433        assert!(
2434            all_dropped,
2435            "producer should return error when all messages are dropped due to backpressure"
2436        );
2437
2438        // Clean up
2439        let _ = client.close(None).await;
2440        consumer.stop().await.unwrap();
2441    }
2442
2443    // WS-012: Ping/pong round-trip in server mode
2444    #[tokio::test]
2445    async fn server_responds_to_client_ping_with_pong() {
2446        let port = free_port();
2447        let uri = format!("ws://127.0.0.1:{port}/pingpong");
2448        let component_ctx = NoOpComponentContext;
2449        let endpoint = WsComponent::new()
2450            .create_endpoint(&uri, &component_ctx)
2451            .unwrap();
2452
2453        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2454        let (route_tx, _route_rx) = mpsc::channel(16);
2455        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2456        consumer.start(ctx).await.unwrap();
2457
2458        let url = format!("ws://127.0.0.1:{port}/pingpong");
2459        let (mut client, _) = loop {
2460            match connect_async(&url).await {
2461                Ok(ok) => break ok,
2462                Err(_) => tokio::time::sleep(Duration::from_millis(25)).await,
2463            }
2464        };
2465
2466        // Send a ping
2467        client
2468            .send(ClientMessage::Ping(vec![1, 2, 3].into()))
2469            .await
2470            .unwrap();
2471
2472        // Expect a pong with the same payload
2473        let pong = tokio::time::timeout(Duration::from_secs(2), async {
2474            loop {
2475                match client.next().await {
2476                    Some(Ok(ClientMessage::Pong(data))) => break data,
2477                    Some(Ok(ClientMessage::Ping(_))) => continue,
2478                    Some(Ok(_)) => continue,
2479                    Some(Err(e)) => panic!("ws receive failed: {e}"),
2480                    None => panic!("websocket closed before pong"),
2481                }
2482            }
2483        })
2484        .await
2485        .unwrap();
2486
2487        assert_eq!(pong, vec![1, 2, 3], "pong should echo ping payload");
2488
2489        consumer.stop().await.unwrap();
2490    }
2491
2492    // WS-008: Client-side retry on transient connect failures
2493    #[tokio::test]
2494    async fn producer_retries_on_connection_refused() {
2495        // Use a port that nothing is listening on
2496        let port = free_port();
2497        // Ensure nothing is on this port
2498        let cfg = WsEndpointConfig::from_uri(&format!(
2499            "ws://127.0.0.1:{port}/retry?reconnect=true&reconnectMaxAttempts=2&reconnectDelayMs=50"
2500        ))
2501        .unwrap();
2502        let producer = WsProducer::new(cfg.client_config());
2503
2504        let exchange = Exchange::new(CamelMessage::new(CamelBody::Text("hello".into())));
2505
2506        // Should fail after retries (nothing listening)
2507        let result = tokio::time::timeout(Duration::from_secs(5), producer.oneshot(exchange)).await;
2508        assert!(
2509            result.is_ok(),
2510            "producer should complete (with error) within timeout"
2511        );
2512        let result = result.unwrap();
2513        assert!(
2514            result.is_err(),
2515            "producer should fail when nothing is listening"
2516        );
2517        let msg = result.unwrap_err().to_string();
2518        assert!(
2519            msg.contains("connection refused"),
2520            "expected connection refused error, got: {msg}"
2521        );
2522    }
2523
2524    // WS-001: Server bind error is visible (fake server-start error test)
2525    #[tokio::test]
2526    async fn server_bind_error_is_reported() {
2527        // Bind a port manually to cause a conflict
2528        let _listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
2529        let port = _listener.local_addr().unwrap().port();
2530
2531        // Try to start a consumer on the same port — should succeed since axum binds lazily
2532        // The actual bind error happens when the server task runs
2533        let uri = format!("ws://127.0.0.1:{port}/binderror");
2534        let component_ctx = NoOpComponentContext;
2535        let endpoint = WsComponent::new()
2536            .create_endpoint(&uri, &component_ctx)
2537            .unwrap();
2538
2539        let mut consumer = endpoint.create_consumer(rt()).unwrap();
2540        let (route_tx, _route_rx) = mpsc::channel(16);
2541        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2542
2543        // Start should succeed (server spawns, but bind may fail)
2544        let start_result = consumer.start(ctx).await;
2545        // The server may or may not have bound yet — this test verifies no panic
2546        // The actual error is logged by the server task
2547        let _ = start_result;
2548
2549        consumer.stop().await.unwrap();
2550    }
2551
2552    #[test]
2553    fn ws_app_state_server_error_starts_false() {
2554        let state = WsAppState {
2555            dispatch: Arc::new(RwLock::new(HashMap::new())),
2556            path_configs: Arc::new(DashMap::new()),
2557            path_policies: Arc::new(DashMap::new()),
2558            server_error: new_atomic_false(),
2559        };
2560        assert!(
2561            !state.server_error.load(Ordering::Relaxed),
2562            "server_error should start as false"
2563        );
2564    }
2565
2566    #[test]
2567    fn ws_app_state_server_error_can_be_set() {
2568        let state = WsAppState {
2569            dispatch: Arc::new(RwLock::new(HashMap::new())),
2570            path_configs: Arc::new(DashMap::new()),
2571            path_policies: Arc::new(DashMap::new()),
2572            server_error: new_atomic_false(),
2573        };
2574        assert!(!state.server_error.load(Ordering::Relaxed));
2575        state.server_error.store(true, Ordering::Relaxed);
2576        assert!(state.server_error.load(Ordering::Relaxed));
2577    }
2578
2579    #[tokio::test]
2580    async fn consumer_stop_returns_error_when_server_had_errors() {
2581        let port = free_port();
2582        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/errorflag")).unwrap();
2583        let mut consumer = WsConsumer::new(cfg.server_config(), test_rt());
2584        let (route_tx, _route_rx) = mpsc::channel(16);
2585        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2586        consumer.start(ctx).await.unwrap();
2587
2588        // Simulate server error by setting the flag directly
2589        if let Some(ref state) = consumer.server_state {
2590            state.server_error.store(true, Ordering::Relaxed);
2591        }
2592
2593        let result = consumer.stop().await;
2594        assert!(
2595            result.is_err(),
2596            "stop should return error when server had errors"
2597        );
2598        let msg = result.unwrap_err().to_string();
2599        assert!(
2600            msg.contains("terminated with errors"),
2601            "expected server error message, got: {msg}"
2602        );
2603    }
2604
2605    #[tokio::test]
2606    async fn consumer_stop_succeeds_when_server_healthy() {
2607        let port = free_port();
2608        let cfg = WsEndpointConfig::from_uri(&format!("ws://127.0.0.1:{port}/healthy")).unwrap();
2609        let mut consumer = WsConsumer::new(cfg.server_config(), test_rt());
2610        let (route_tx, _route_rx) = mpsc::channel(16);
2611        let ctx = ConsumerContext::new(route_tx, CancellationToken::new());
2612        consumer.start(ctx).await.unwrap();
2613
2614        let result = consumer.stop().await;
2615        assert!(
2616            result.is_ok(),
2617            "stop should succeed when server is healthy: {:?}",
2618            result
2619        );
2620    }
2621
2622    // === H-10 Finding Tests ===
2623
2624    // WS-007: subprotocol negotiation support
2625    #[test]
2626    fn endpoint_config_parses_subprotocols() {
2627        let cfg = WsEndpointConfig::from_uri(
2628            "ws://localhost:9001/chat?subprotocols=graphql-ws,graphql-transport-ws",
2629        )
2630        .unwrap();
2631        assert_eq!(cfg.subprotocols, vec!["graphql-ws", "graphql-transport-ws"]);
2632    }
2633
2634    #[test]
2635    fn endpoint_config_default_subprotocols_empty() {
2636        let cfg = WsEndpointConfig::default();
2637        assert!(cfg.subprotocols.is_empty());
2638    }
2639
2640    // WS-017: sendTimeoutMs URI option
2641    #[test]
2642    fn endpoint_config_parses_send_timeout() {
2643        let cfg =
2644            WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=5000").unwrap();
2645        assert_eq!(cfg.send_timeout, Duration::from_millis(5000));
2646    }
2647
2648    #[test]
2649    fn endpoint_config_default_send_timeout() {
2650        let cfg = WsEndpointConfig::default();
2651        assert_eq!(cfg.send_timeout, Duration::from_secs(30));
2652    }
2653
2654    #[test]
2655    fn endpoint_config_rejects_invalid_send_timeout() {
2656        let err =
2657            WsEndpointConfig::from_uri("ws://localhost:9001/chat?sendTimeoutMs=abc").unwrap_err();
2658        assert!(err.to_string().contains("sendTimeoutMs"));
2659    }
2660
2661    // WS-018: binaryPayload URI option
2662    #[test]
2663    fn endpoint_config_parses_binary_payload() {
2664        let cfg =
2665            WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=true").unwrap();
2666        assert!(cfg.binary_payload);
2667    }
2668
2669    #[test]
2670    fn endpoint_config_default_binary_payload_false() {
2671        let cfg = WsEndpointConfig::default();
2672        assert!(!cfg.binary_payload);
2673    }
2674
2675    #[test]
2676    fn endpoint_config_rejects_invalid_binary_payload() {
2677        let err =
2678            WsEndpointConfig::from_uri("ws://localhost:9001/chat?binaryPayload=yes").unwrap_err();
2679        assert!(err.to_string().contains("binaryPayload"));
2680    }
2681
2682    /// Regression: max_attempts=N → exactly N invocations (caught OpenSearch off-by-one 1f5c4c2a).
2683    /// Replicates the exact retry loop from the WebSocket producer connect (lib.rs:~1228-1275):
2684    ///   attempts starts at 0, should_retry(attempts+1), delay_for(attempts), attempts += 1
2685    #[tokio::test]
2686    async fn retry_loop_invokes_operation_exactly_max_attempts_times() {
2687        use camel_component_api::NetworkRetryPolicy;
2688        use std::sync::Arc;
2689        use std::sync::atomic::{AtomicU32, Ordering};
2690
2691        let policy = NetworkRetryPolicy {
2692            max_attempts: 3,
2693            initial_delay: Duration::from_millis(1),
2694            max_delay: Duration::from_millis(1),
2695            multiplier: 1.0,
2696            ..NetworkRetryPolicy::default()
2697        };
2698
2699        let calls = Arc::new(AtomicU32::new(0));
2700        let calls_clone = Arc::clone(&calls);
2701        let mut attempts: u32 = 0;
2702
2703        let _result: Result<(), ()> = loop {
2704            calls_clone.fetch_add(1, Ordering::SeqCst);
2705            let op_result: Result<(), ()> = Err(());
2706            match op_result {
2707                Ok(_) => unreachable!(),
2708                Err(_) if policy.should_retry(attempts + 1) => {
2709                    let delay = policy.delay_for(attempts);
2710                    tokio::time::sleep(delay).await;
2711                    attempts += 1;
2712                    continue;
2713                }
2714                Err(_) => break Err(()),
2715            }
2716        };
2717
2718        assert_eq!(
2719            calls.load(Ordering::SeqCst),
2720            3,
2721            "max_attempts=3 must yield exactly 3 invocations"
2722        );
2723    }
2724
2725    /// Edge case: max_attempts=1 → exactly 1 invocation (initial attempt only, no retry).
2726    /// Locks the edge that originally broke OpenSearch.
2727    #[tokio::test]
2728    async fn retry_loop_with_max_attempts_1_invokes_operation_once() {
2729        use camel_component_api::NetworkRetryPolicy;
2730        use std::sync::Arc;
2731        use std::sync::atomic::{AtomicU32, Ordering};
2732
2733        let policy = NetworkRetryPolicy {
2734            max_attempts: 1,
2735            initial_delay: Duration::from_millis(1),
2736            max_delay: Duration::from_millis(1),
2737            multiplier: 1.0,
2738            ..NetworkRetryPolicy::default()
2739        };
2740
2741        let calls = Arc::new(AtomicU32::new(0));
2742        let calls_clone = Arc::clone(&calls);
2743        let mut attempts: u32 = 0;
2744
2745        let _result: Result<(), ()> = loop {
2746            calls_clone.fetch_add(1, Ordering::SeqCst);
2747            let op_result: Result<(), ()> = Err(());
2748            match op_result {
2749                Ok(_) => unreachable!(),
2750                Err(_) if policy.should_retry(attempts + 1) => {
2751                    let delay = policy.delay_for(attempts);
2752                    tokio::time::sleep(delay).await;
2753                    attempts += 1;
2754                    continue;
2755                }
2756                Err(_) => break Err(()),
2757            }
2758        };
2759
2760        assert_eq!(
2761            calls.load(Ordering::SeqCst),
2762            1,
2763            "max_attempts=1 must yield exactly 1 invocation"
2764        );
2765    }
2766
2767    // ── rc-1nm regression: WS producer retry emits component=ws-producer ──
2768
2769    use std::fmt::Write as _;
2770    use std::sync::{Arc, Mutex};
2771    use tracing::Subscriber;
2772    use tracing_subscriber::Layer;
2773    use tracing_subscriber::layer::SubscriberExt;
2774
2775    struct CollectingLayer {
2776        events: Arc<Mutex<Vec<String>>>,
2777    }
2778
2779    impl<S: Subscriber> Layer<S> for CollectingLayer {
2780        fn on_event(
2781            &self,
2782            event: &tracing::Event<'_>,
2783            _ctx: tracing_subscriber::layer::Context<'_, S>,
2784        ) {
2785            let mut buf = String::new();
2786            let mut visitor = CollectingVisitor { fields: &mut buf };
2787            event.record(&mut visitor);
2788            if let Ok(mut events) = self.events.lock() {
2789                events.push(buf);
2790            }
2791        }
2792    }
2793
2794    struct CollectingVisitor<'a> {
2795        fields: &'a mut String,
2796    }
2797
2798    impl CollectingVisitor<'_> {
2799        fn record_field(&mut self, name: &str, value: &str) {
2800            write!(self.fields, " {name}={value}").ok();
2801        }
2802    }
2803
2804    impl tracing::field::Visit for CollectingVisitor<'_> {
2805        fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
2806            self.record_field(field.name(), value);
2807        }
2808        fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
2809            self.record_field(field.name(), &format!("{value:?}"));
2810        }
2811        fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
2812            self.record_field(field.name(), &value.to_string());
2813        }
2814    }
2815
2816    /// Regression for rc-1nm: the WS producer connect path must emit
2817    /// `component=ws-producer` in retry log events so operators can
2818    /// identify which component is retrying. Drives through the production
2819    /// `connect_ws_with_retry` helper with a guaranteed-fail URL instead of
2820    /// calling `retry_async` with a synthetic op.
2821    #[tokio::test]
2822    async fn ws_producer_retry_log_emits_component_ws_producer() {
2823        let events = Arc::new(Mutex::new(Vec::new()));
2824        let layer = CollectingLayer {
2825            events: events.clone(),
2826        };
2827        let subscriber = tracing_subscriber::registry().with(layer);
2828        let _guard = tracing::subscriber::set_default(subscriber);
2829
2830        let policy = NetworkRetryPolicy {
2831            max_attempts: 2,
2832            initial_delay: Duration::from_millis(1),
2833            max_delay: Duration::from_millis(5),
2834            ..NetworkRetryPolicy::default()
2835        };
2836
2837        // Drive through the production connect path with a guaranteed-fail
2838        // URL. Port 1 is reserved on Linux; connect immediately fails with
2839        // ECONNREFUSED.
2840        let request = "ws://127.0.0.1:1".into_client_request().unwrap();
2841
2842        let result: Result<_, CamelError> = connect_ws_with_retry(
2843            request,
2844            "ws://127.0.0.1:1",
2845            Duration::from_millis(100),
2846            &policy,
2847        )
2848        .await;
2849
2850        assert!(result.is_err(), "expected exhausted-retries error");
2851        let captured = events.lock().unwrap();
2852        assert!(
2853            !captured.is_empty(),
2854            "expected at least one retry log event, got none"
2855        );
2856        let first = &captured[0];
2857        assert!(
2858            first.contains("component=ws-producer"),
2859            "rc-1nm regression: expected 'component=ws-producer' in WS retry log, got: {first}"
2860        );
2861    }
2862}