Skip to main content

camel_component_ws/
config.rs

1use std::time::Duration;
2
3use camel_component_api::CamelError;
4
5#[derive(Debug, Clone, Default, serde::Deserialize)]
6pub struct WsConfig {
7    pub max_connections: Option<u32>,
8    pub max_message_size: Option<u32>,
9    pub heartbeat_interval_ms: Option<u64>,
10    pub idle_timeout_ms: Option<u64>,
11    pub connect_timeout_ms: Option<u64>,
12    pub response_timeout_ms: Option<u64>,
13    pub send_timeout_ms: Option<u64>,
14    pub binary_payload: Option<bool>,
15    pub subprotocols: Option<Vec<String>>,
16}
17
18#[derive(Clone)]
19pub struct WsEndpointConfig {
20    pub scheme: String,
21    pub host: String,
22    pub port: u16,
23    pub path: String,
24    pub max_connections: u32,
25    pub max_message_size: u32,
26    pub send_to_all: bool,
27    pub heartbeat_interval: Duration,
28    pub idle_timeout: Duration,
29    pub connect_timeout: Duration,
30    pub response_timeout: Duration,
31    pub allow_origin: String,
32    pub tls_cert: Option<String>,
33    pub tls_key: Option<String>,
34    pub reconnect: bool,
35    pub reconnect_max_attempts: u32,
36    pub reconnect_delay_ms: u64,
37    pub send_timeout: Duration,
38    pub binary_payload: bool,
39    pub subprotocols: Vec<String>,
40}
41
42fn redacted_opt(opt: &Option<String>) -> Option<&'static str> {
43    if opt.is_some() { Some("***") } else { None }
44}
45
46impl std::fmt::Debug for WsEndpointConfig {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("WsEndpointConfig")
49            .field("scheme", &self.scheme)
50            .field("host", &self.host)
51            .field("port", &self.port)
52            .field("path", &self.path)
53            .field("max_connections", &self.max_connections)
54            .field("max_message_size", &self.max_message_size)
55            .field("send_to_all", &self.send_to_all)
56            .field("heartbeat_interval", &self.heartbeat_interval)
57            .field("idle_timeout", &self.idle_timeout)
58            .field("connect_timeout", &self.connect_timeout)
59            .field("response_timeout", &self.response_timeout)
60            .field("allow_origin", &self.allow_origin)
61            .field("tls_cert", &redacted_opt(&self.tls_cert))
62            .field("tls_key", &redacted_opt(&self.tls_key))
63            .field("reconnect", &self.reconnect)
64            .field("reconnect_max_attempts", &self.reconnect_max_attempts)
65            .field("reconnect_delay_ms", &self.reconnect_delay_ms)
66            .field("send_timeout", &self.send_timeout)
67            .field("binary_payload", &self.binary_payload)
68            .field("subprotocols", &self.subprotocols)
69            .finish()
70    }
71}
72
73impl Default for WsEndpointConfig {
74    fn default() -> Self {
75        Self {
76            scheme: "ws".into(),
77            host: "0.0.0.0".into(),
78            port: 8080,
79            path: "/".into(),
80            max_connections: 100,
81            max_message_size: 65536,
82            send_to_all: false,
83            heartbeat_interval: Duration::ZERO,
84            idle_timeout: Duration::ZERO,
85            connect_timeout: Duration::from_secs(10),
86            response_timeout: Duration::from_secs(30),
87            allow_origin: "*".into(),
88            tls_cert: None,
89            tls_key: None,
90            reconnect: true,
91            reconnect_max_attempts: 5,
92            reconnect_delay_ms: 1000,
93            send_timeout: Duration::from_secs(30),
94            binary_payload: false,
95            subprotocols: Vec::new(),
96        }
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct WsServerConfig {
102    pub inner: WsEndpointConfig,
103}
104
105#[derive(Debug, Clone)]
106pub struct WsClientConfig {
107    pub inner: WsEndpointConfig,
108}
109
110impl WsConfig {
111    /// Validate configuration values.
112    ///
113    /// Returns an error if any explicitly-set value is invalid (e.g. zero).
114    /// `None` values are valid — they mean "use the default / unlimited".
115    pub fn validate(&self) -> Result<(), CamelError> {
116        if let Some(0) = self.max_connections {
117            return Err(CamelError::Config(
118                "maxConnections must be >= 1 when specified".into(),
119            ));
120        }
121        if let Some(0) = self.max_message_size {
122            return Err(CamelError::Config(
123                "maxMessageSize must be >= 1 when specified".into(),
124            ));
125        }
126        Ok(())
127    }
128}
129
130impl WsEndpointConfig {
131    pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
132        let parsed = camel_component_api::parse_uri(uri)
133            .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;
134
135        let scheme = parsed.scheme;
136        if scheme != "ws" && scheme != "wss" {
137            return Err(CamelError::EndpointCreationFailed(format!(
138                "Invalid WebSocket scheme: {scheme}"
139            )));
140        }
141
142        let host_port_path = parsed.path;
143        let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
144        let (host_port, path) = match host_port_path.split_once('/') {
145            Some((hp, p)) => (hp, format!("/{p}")),
146            None => (host_port_path, "/".to_string()),
147        };
148
149        let (host, port) = match host_port.rsplit_once(':') {
150            Some((h, p)) if p.parse::<u16>().is_ok() => {
151                let parsed_port = p.parse::<u16>().unwrap(); // allow-unwrap
152                (h.to_string(), parsed_port)
153            }
154            _ => (
155                host_port.to_string(),
156                if scheme == "wss" { 443 } else { 80 },
157            ),
158        };
159
160        let mut cfg = Self {
161            scheme,
162            host: if host.is_empty() {
163                "0.0.0.0".to_string()
164            } else {
165                host
166            },
167            port,
168            path,
169            ..Self::default()
170        };
171
172        let params = parsed.params;
173        // Validate maxConnections >= 1 (WS-015)
174        if let Some(raw) = params.get("maxConnections") {
175            let v = raw.parse::<u32>().map_err(|_| {
176                CamelError::InvalidUri(format!(
177                    "maxConnections must be an unsigned integer, got '{raw}'"
178                ))
179            })?;
180            if v == 0 {
181                return Err(CamelError::InvalidUri("maxConnections must be >= 1".into()));
182            }
183            cfg.max_connections = v;
184        }
185        // Validate maxMessageSize > 0 (WS-019)
186        if let Some(raw) = params.get("maxMessageSize") {
187            let v = raw.parse::<u32>().map_err(|_| {
188                CamelError::InvalidUri(format!(
189                    "maxMessageSize must be an unsigned integer, got '{raw}'"
190                ))
191            })?;
192            if v == 0 {
193                return Err(CamelError::InvalidUri("maxMessageSize must be > 0".into()));
194            }
195            cfg.max_message_size = v;
196        }
197        if let Some(raw) = params.get("sendToAll") {
198            let v = raw.parse::<bool>().map_err(|_| {
199                CamelError::InvalidUri(format!(
200                    "sendToAll must be a boolean ('true' or 'false'), got '{raw}'"
201                ))
202            })?;
203            cfg.send_to_all = v;
204        }
205        if let Some(raw) = params.get("heartbeatIntervalMs") {
206            let v = raw.parse::<u64>().map_err(|_| {
207                CamelError::InvalidUri(format!(
208                    "heartbeatIntervalMs must be an unsigned integer, got '{raw}'"
209                ))
210            })?;
211            cfg.heartbeat_interval = Duration::from_millis(v);
212        }
213        if let Some(raw) = params.get("idleTimeoutMs") {
214            let v = raw.parse::<u64>().map_err(|_| {
215                CamelError::InvalidUri(format!(
216                    "idleTimeoutMs must be an unsigned integer, got '{raw}'"
217                ))
218            })?;
219            cfg.idle_timeout = Duration::from_millis(v);
220        }
221        if let Some(raw) = params.get("connectTimeoutMs") {
222            let v = raw.parse::<u64>().map_err(|_| {
223                CamelError::InvalidUri(format!(
224                    "connectTimeoutMs must be an unsigned integer, got '{raw}'"
225                ))
226            })?;
227            cfg.connect_timeout = Duration::from_millis(v);
228        }
229        if let Some(raw) = params.get("responseTimeoutMs") {
230            let v = raw.parse::<u64>().map_err(|_| {
231                CamelError::InvalidUri(format!(
232                    "responseTimeoutMs must be an unsigned integer, got '{raw}'"
233                ))
234            })?;
235            cfg.response_timeout = Duration::from_millis(v);
236        }
237        if let Some(v) = params.get("allowOrigin") {
238            if v.is_empty() {
239                return Err(CamelError::InvalidUri(
240                    "allowOrigin must not be empty when specified".into(),
241                ));
242            }
243            cfg.allow_origin = v.to_string();
244        }
245        if let Some(v) = params.get("tlsCert") {
246            cfg.tls_cert = Some(v.to_string());
247        }
248        if let Some(v) = params.get("tlsKey") {
249            cfg.tls_key = Some(v.to_string());
250        }
251        if let Some(raw) = params.get("reconnect") {
252            cfg.reconnect = raw.parse::<bool>().map_err(|_| {
253                CamelError::InvalidUri(format!(
254                    "reconnect must be a boolean ('true' or 'false'), got '{raw}'"
255                ))
256            })?;
257        }
258        if let Some(raw) = params.get("reconnectMaxAttempts") {
259            cfg.reconnect_max_attempts = raw.parse::<u32>().map_err(|_| {
260                CamelError::InvalidUri(format!(
261                    "reconnectMaxAttempts must be an unsigned integer, got '{raw}'"
262                ))
263            })?;
264        }
265        if let Some(raw) = params.get("reconnectDelayMs") {
266            cfg.reconnect_delay_ms = raw.parse::<u64>().map_err(|_| {
267                CamelError::InvalidUri(format!(
268                    "reconnectDelayMs must be an unsigned integer, got '{raw}'"
269                ))
270            })?;
271        }
272        if let Some(raw) = params.get("sendTimeoutMs") {
273            let v = raw.parse::<u64>().map_err(|_| {
274                CamelError::InvalidUri(format!(
275                    "sendTimeoutMs must be an unsigned integer, got '{raw}'"
276                ))
277            })?;
278            cfg.send_timeout = Duration::from_millis(v);
279        }
280        if let Some(raw) = params.get("binaryPayload") {
281            cfg.binary_payload = raw.parse::<bool>().map_err(|_| {
282                CamelError::InvalidUri(format!(
283                    "binaryPayload must be a boolean ('true' or 'false'), got '{raw}'"
284                ))
285            })?;
286        }
287        if let Some(raw) = params.get("subprotocols") {
288            cfg.subprotocols = raw
289                .split(',')
290                .map(|s| s.trim().to_string())
291                .filter(|s| !s.is_empty())
292                .collect();
293        }
294
295        Ok(cfg)
296    }
297
298    pub fn server_config(&self) -> WsServerConfig {
299        WsServerConfig {
300            inner: self.clone(),
301        }
302    }
303
304    pub fn client_config(&self) -> WsClientConfig {
305        WsClientConfig {
306            inner: self.clone(),
307        }
308    }
309
310    pub fn canonical_host(&self) -> String {
311        match self.host.as_str() {
312            "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
313            h => h.to_string(),
314        }
315    }
316}
317
318#[cfg(test)]
319mod config_validation_tests {
320    use super::*;
321
322    #[test]
323    fn test_rejects_zero_max_connections() {
324        let cfg = WsConfig {
325            max_connections: Some(0),
326            ..WsConfig::default()
327        };
328        assert!(cfg.validate().is_err());
329    }
330
331    #[test]
332    fn test_rejects_zero_max_message_size() {
333        let cfg = WsConfig {
334            max_message_size: Some(0),
335            ..WsConfig::default()
336        };
337        assert!(cfg.validate().is_err());
338    }
339
340    #[test]
341    fn test_accepts_valid_config() {
342        let cfg = WsConfig::default();
343        assert!(cfg.validate().is_ok());
344    }
345
346    #[test]
347    fn test_accepts_nonzero_max_connections() {
348        let cfg = WsConfig {
349            max_connections: Some(50),
350            ..WsConfig::default()
351        };
352        assert!(cfg.validate().is_ok());
353    }
354
355    #[test]
356    fn test_accepts_nonzero_max_message_size() {
357        let cfg = WsConfig {
358            max_message_size: Some(1024),
359            ..WsConfig::default()
360        };
361        assert!(cfg.validate().is_ok());
362    }
363
364    #[test]
365    fn test_from_uri_rejects_invalid_send_to_all() {
366        let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendToAll=yes").unwrap_err();
367        assert!(err.to_string().contains("sendToAll"));
368    }
369
370    #[test]
371    fn test_from_uri_rejects_invalid_max_connections_numeric() {
372        let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxConnections=abc").unwrap_err();
373        assert!(err.to_string().contains("maxConnections"));
374    }
375
376    #[test]
377    fn test_from_uri_rejects_invalid_max_message_size_numeric() {
378        let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxMessageSize=abc").unwrap_err();
379        assert!(err.to_string().contains("maxMessageSize"));
380    }
381
382    #[test]
383    fn test_from_uri_rejects_invalid_heartbeat_interval_numeric() {
384        let err =
385            WsEndpointConfig::from_uri("ws://localhost:8080?heartbeatIntervalMs=abc").unwrap_err();
386        assert!(err.to_string().contains("heartbeatIntervalMs"));
387    }
388
389    #[test]
390    fn test_from_uri_rejects_invalid_idle_timeout_numeric() {
391        let err = WsEndpointConfig::from_uri("ws://localhost:8080?idleTimeoutMs=abc").unwrap_err();
392        assert!(err.to_string().contains("idleTimeoutMs"));
393    }
394
395    #[test]
396    fn test_from_uri_rejects_invalid_connect_timeout_numeric() {
397        let err =
398            WsEndpointConfig::from_uri("ws://localhost:8080?connectTimeoutMs=abc").unwrap_err();
399        assert!(err.to_string().contains("connectTimeoutMs"));
400    }
401
402    #[test]
403    fn test_from_uri_rejects_invalid_response_timeout_numeric() {
404        let err =
405            WsEndpointConfig::from_uri("ws://localhost:8080?responseTimeoutMs=abc").unwrap_err();
406        assert!(err.to_string().contains("responseTimeoutMs"));
407    }
408
409    // WS-017: sendTimeoutMs parsing
410    #[test]
411    fn test_from_uri_parses_send_timeout_ms() {
412        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=7500").unwrap();
413        assert_eq!(cfg.send_timeout, Duration::from_millis(7500));
414    }
415
416    #[test]
417    fn test_from_uri_rejects_invalid_send_timeout_ms() {
418        let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=xyz").unwrap_err();
419        assert!(err.to_string().contains("sendTimeoutMs"));
420    }
421
422    // WS-018: binaryPayload parsing
423    #[test]
424    fn test_from_uri_parses_binary_payload_true() {
425        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=true").unwrap();
426        assert!(cfg.binary_payload);
427    }
428
429    #[test]
430    fn test_from_uri_parses_binary_payload_false() {
431        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=false").unwrap();
432        assert!(!cfg.binary_payload);
433    }
434
435    #[test]
436    fn test_from_uri_rejects_invalid_binary_payload() {
437        let err = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=sure").unwrap_err();
438        assert!(err.to_string().contains("binaryPayload"));
439    }
440
441    // WS-007: subprotocols parsing
442    #[test]
443    fn test_from_uri_parses_subprotocols() {
444        let cfg =
445            WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=json,protobuf").unwrap();
446        assert_eq!(cfg.subprotocols, vec!["json", "protobuf"]);
447    }
448
449    #[test]
450    fn test_from_uri_subprotocols_trims_whitespace() {
451        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=a, b").unwrap();
452        assert_eq!(cfg.subprotocols, vec!["a", "b"]);
453    }
454
455    #[test]
456    fn test_from_uri_subprotocols_empty_when_not_specified() {
457        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080").unwrap();
458        assert!(cfg.subprotocols.is_empty());
459    }
460
461    #[test]
462    fn ws_endpoint_config_debug_redacts_tls() {
463        let config = WsEndpointConfig {
464            tls_cert: Some("/secret/cert.pem".to_string()),
465            tls_key: Some("/secret/key.pem".to_string()),
466            ..WsEndpointConfig::default()
467        };
468        let debug = format!("{:?}", config);
469        assert!(
470            !debug.contains("/secret/"),
471            "TLS paths must be redacted: {debug}"
472        );
473        assert!(
474            debug.contains("tls_cert"),
475            "field name should appear: {debug}"
476        );
477        assert!(
478            debug.contains("tls_key"),
479            "field name should appear: {debug}"
480        );
481    }
482}