1use std::time::Duration;
2
3use camel_component_api::{CamelError, NetworkRetryPolicy};
4
5#[derive(Debug, Clone, Default, serde::Deserialize)]
6pub struct WsConfig {
7 pub max_connections: Option<u32>,
8 pub max_message_size: Option<u32>,
9 pub heartbeat_interval_ms: Option<u64>,
10 pub idle_timeout_ms: Option<u64>,
11 pub connect_timeout_ms: Option<u64>,
12 pub response_timeout_ms: Option<u64>,
13 pub send_timeout_ms: Option<u64>,
14 pub binary_payload: Option<bool>,
15 pub subprotocols: Option<Vec<String>>,
16}
17
18#[derive(Clone)]
19pub struct WsEndpointConfig {
20 pub scheme: String,
21 pub host: String,
22 pub port: u16,
23 pub path: String,
24 pub max_connections: u32,
25 pub max_message_size: u32,
26 pub send_to_all: bool,
27 pub heartbeat_interval: Duration,
28 pub idle_timeout: Duration,
29 pub connect_timeout: Duration,
30 pub response_timeout: Duration,
31 pub allow_origin: String,
32 pub tls_cert: Option<String>,
33 pub tls_key: Option<String>,
34 pub reconnect: bool,
35 pub reconnect_max_attempts: u32,
36 pub reconnect_delay_ms: u64,
37 pub send_timeout: Duration,
38 pub binary_payload: bool,
39 pub subprotocols: Vec<String>,
40 pub reconnect_policy: NetworkRetryPolicy,
45}
46
47fn redacted_opt(opt: &Option<String>) -> Option<&'static str> {
48 if opt.is_some() { Some("***") } else { None }
49}
50
51impl std::fmt::Debug for WsEndpointConfig {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("WsEndpointConfig")
54 .field("scheme", &self.scheme)
55 .field("host", &self.host)
56 .field("port", &self.port)
57 .field("path", &self.path)
58 .field("max_connections", &self.max_connections)
59 .field("max_message_size", &self.max_message_size)
60 .field("send_to_all", &self.send_to_all)
61 .field("heartbeat_interval", &self.heartbeat_interval)
62 .field("idle_timeout", &self.idle_timeout)
63 .field("connect_timeout", &self.connect_timeout)
64 .field("response_timeout", &self.response_timeout)
65 .field("allow_origin", &self.allow_origin)
66 .field("tls_cert", &redacted_opt(&self.tls_cert))
67 .field("tls_key", &redacted_opt(&self.tls_key))
68 .field("reconnect", &self.reconnect)
69 .field("reconnect_max_attempts", &self.reconnect_max_attempts)
70 .field("reconnect_delay_ms", &self.reconnect_delay_ms)
71 .field("reconnect_policy", &self.reconnect_policy)
72 .field("send_timeout", &self.send_timeout)
73 .field("binary_payload", &self.binary_payload)
74 .field("subprotocols", &self.subprotocols)
75 .finish()
76 }
77}
78
79impl Default for WsEndpointConfig {
80 fn default() -> Self {
81 Self {
82 scheme: "ws".into(),
83 host: "0.0.0.0".into(),
84 port: 8080,
85 path: "/".into(),
86 max_connections: 100,
87 max_message_size: 65536,
88 send_to_all: false,
89 heartbeat_interval: Duration::ZERO,
90 idle_timeout: Duration::ZERO,
91 connect_timeout: Duration::from_secs(10),
92 response_timeout: Duration::from_secs(30),
93 allow_origin: "*".into(),
94 tls_cert: None,
95 tls_key: None,
96 reconnect: true,
97 reconnect_max_attempts: 5,
98 reconnect_delay_ms: 1000,
99 send_timeout: Duration::from_secs(30),
100 binary_payload: false,
101 subprotocols: Vec::new(),
102 reconnect_policy: NetworkRetryPolicy {
103 enabled: true,
104 max_attempts: 5,
105 initial_delay: Duration::from_millis(1000),
106 multiplier: 2.0,
107 max_delay: Duration::from_secs(30),
108 jitter_factor: 0.0, },
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
115pub struct WsServerConfig {
116 pub inner: WsEndpointConfig,
117}
118
119#[derive(Debug, Clone)]
120pub struct WsClientConfig {
121 pub inner: WsEndpointConfig,
122}
123
124impl WsConfig {
125 pub fn validate(&self) -> Result<(), CamelError> {
130 if let Some(0) = self.max_connections {
131 return Err(CamelError::Config(
132 "maxConnections must be >= 1 when specified".into(),
133 ));
134 }
135 if let Some(0) = self.max_message_size {
136 return Err(CamelError::Config(
137 "maxMessageSize must be >= 1 when specified".into(),
138 ));
139 }
140 Ok(())
141 }
142}
143
144impl WsEndpointConfig {
145 pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
146 let parsed = camel_component_api::parse_uri(uri)
147 .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;
148
149 let scheme = parsed.scheme;
150 if scheme != "ws" && scheme != "wss" {
151 return Err(CamelError::EndpointCreationFailed(format!(
152 "Invalid WebSocket scheme: {scheme}"
153 )));
154 }
155
156 let host_port_path = parsed.path;
157 let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
158 let (host_port, path) = match host_port_path.split_once('/') {
159 Some((hp, p)) => (hp, format!("/{p}")),
160 None => (host_port_path, "/".to_string()),
161 };
162
163 let (host, port) = match host_port.rsplit_once(':') {
164 Some((h, p)) if p.parse::<u16>().is_ok() => {
165 let parsed_port = p.parse::<u16>().unwrap(); (h.to_string(), parsed_port)
167 }
168 _ => (
169 host_port.to_string(),
170 if scheme == "wss" { 443 } else { 80 },
171 ),
172 };
173
174 let mut cfg = Self {
175 scheme,
176 host: if host.is_empty() {
177 "0.0.0.0".to_string()
178 } else {
179 host
180 },
181 port,
182 path,
183 ..Self::default()
184 };
185
186 let params = parsed.params;
187 if let Some(raw) = params.get("maxConnections") {
189 let v = raw.parse::<u32>().map_err(|_| {
190 CamelError::InvalidUri(format!(
191 "maxConnections must be an unsigned integer, got '{raw}'"
192 ))
193 })?;
194 if v == 0 {
195 return Err(CamelError::InvalidUri("maxConnections must be >= 1".into()));
196 }
197 cfg.max_connections = v;
198 }
199 if let Some(raw) = params.get("maxMessageSize") {
201 let v = raw.parse::<u32>().map_err(|_| {
202 CamelError::InvalidUri(format!(
203 "maxMessageSize must be an unsigned integer, got '{raw}'"
204 ))
205 })?;
206 if v == 0 {
207 return Err(CamelError::InvalidUri("maxMessageSize must be > 0".into()));
208 }
209 cfg.max_message_size = v;
210 }
211 if let Some(raw) = params.get("sendToAll") {
212 let v = raw.parse::<bool>().map_err(|_| {
213 CamelError::InvalidUri(format!(
214 "sendToAll must be a boolean ('true' or 'false'), got '{raw}'"
215 ))
216 })?;
217 cfg.send_to_all = v;
218 }
219 if let Some(raw) = params.get("heartbeatIntervalMs") {
220 let v = raw.parse::<u64>().map_err(|_| {
221 CamelError::InvalidUri(format!(
222 "heartbeatIntervalMs must be an unsigned integer, got '{raw}'"
223 ))
224 })?;
225 cfg.heartbeat_interval = Duration::from_millis(v);
226 }
227 if let Some(raw) = params.get("idleTimeoutMs") {
228 let v = raw.parse::<u64>().map_err(|_| {
229 CamelError::InvalidUri(format!(
230 "idleTimeoutMs must be an unsigned integer, got '{raw}'"
231 ))
232 })?;
233 cfg.idle_timeout = Duration::from_millis(v);
234 }
235 if let Some(raw) = params.get("connectTimeoutMs") {
236 let v = raw.parse::<u64>().map_err(|_| {
237 CamelError::InvalidUri(format!(
238 "connectTimeoutMs must be an unsigned integer, got '{raw}'"
239 ))
240 })?;
241 cfg.connect_timeout = Duration::from_millis(v);
242 }
243 if let Some(raw) = params.get("responseTimeoutMs") {
244 let v = raw.parse::<u64>().map_err(|_| {
245 CamelError::InvalidUri(format!(
246 "responseTimeoutMs must be an unsigned integer, got '{raw}'"
247 ))
248 })?;
249 cfg.response_timeout = Duration::from_millis(v);
250 }
251 if let Some(v) = params.get("allowOrigin") {
252 if v.is_empty() {
253 return Err(CamelError::InvalidUri(
254 "allowOrigin must not be empty when specified".into(),
255 ));
256 }
257 cfg.allow_origin = v.to_string();
258 }
259 if let Some(v) = params.get("tlsCert") {
260 cfg.tls_cert = Some(v.to_string());
261 }
262 if let Some(v) = params.get("tlsKey") {
263 cfg.tls_key = Some(v.to_string());
264 }
265 let mut reconnect_explicit = false;
270 let mut reconnect_max_attempts_explicit = false;
271 let mut reconnect_delay_ms_explicit = false;
272
273 if let Some(raw) = params.get("reconnect") {
274 cfg.reconnect = raw.parse::<bool>().map_err(|_| {
275 CamelError::InvalidUri(format!(
276 "reconnect must be a boolean ('true' or 'false'), got '{raw}'"
277 ))
278 })?;
279 reconnect_explicit = true;
280 }
281 if let Some(raw) = params.get("reconnectMaxAttempts") {
282 cfg.reconnect_max_attempts = raw.parse::<u32>().map_err(|_| {
283 CamelError::InvalidUri(format!(
284 "reconnectMaxAttempts must be an unsigned integer, got '{raw}'"
285 ))
286 })?;
287 reconnect_max_attempts_explicit = true;
288 }
289 if let Some(raw) = params.get("reconnectDelayMs") {
290 cfg.reconnect_delay_ms = raw.parse::<u64>().map_err(|_| {
291 CamelError::InvalidUri(format!(
292 "reconnectDelayMs must be an unsigned integer, got '{raw}'"
293 ))
294 })?;
295 reconnect_delay_ms_explicit = true;
296 }
297 if let Some(raw) = params.get("sendTimeoutMs") {
298 let v = raw.parse::<u64>().map_err(|_| {
299 CamelError::InvalidUri(format!(
300 "sendTimeoutMs must be an unsigned integer, got '{raw}'"
301 ))
302 })?;
303 cfg.send_timeout = Duration::from_millis(v);
304 }
305 if let Some(raw) = params.get("binaryPayload") {
306 cfg.binary_payload = raw.parse::<bool>().map_err(|_| {
307 CamelError::InvalidUri(format!(
308 "binaryPayload must be a boolean ('true' or 'false'), got '{raw}'"
309 ))
310 })?;
311 }
312 if let Some(raw) = params.get("subprotocols") {
313 cfg.subprotocols = raw
314 .split(',')
315 .map(|s| s.trim().to_string())
316 .filter(|s| !s.is_empty())
317 .collect();
318 }
319
320 if reconnect_explicit {
329 cfg.reconnect_policy.enabled = cfg.reconnect;
330 }
331 if reconnect_max_attempts_explicit {
332 cfg.reconnect_policy.max_attempts = cfg.reconnect_max_attempts;
333 }
334 if reconnect_delay_ms_explicit {
335 cfg.reconnect_policy.initial_delay = Duration::from_millis(cfg.reconnect_delay_ms);
336 }
337
338 Ok(cfg)
339 }
340
341 pub fn server_config(&self) -> WsServerConfig {
342 WsServerConfig {
343 inner: self.clone(),
344 }
345 }
346
347 pub fn client_config(&self) -> WsClientConfig {
348 WsClientConfig {
349 inner: self.clone(),
350 }
351 }
352
353 pub fn canonical_host(&self) -> String {
354 match self.host.as_str() {
355 "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
356 h => h.to_string(),
357 }
358 }
359}
360
361#[cfg(test)]
362mod config_validation_tests {
363 use super::*;
364
365 #[test]
366 fn test_rejects_zero_max_connections() {
367 let cfg = WsConfig {
368 max_connections: Some(0),
369 ..WsConfig::default()
370 };
371 assert!(cfg.validate().is_err());
372 }
373
374 #[test]
375 fn test_rejects_zero_max_message_size() {
376 let cfg = WsConfig {
377 max_message_size: Some(0),
378 ..WsConfig::default()
379 };
380 assert!(cfg.validate().is_err());
381 }
382
383 #[test]
384 fn test_accepts_valid_config() {
385 let cfg = WsConfig::default();
386 assert!(cfg.validate().is_ok());
387 }
388
389 #[test]
390 fn test_accepts_nonzero_max_connections() {
391 let cfg = WsConfig {
392 max_connections: Some(50),
393 ..WsConfig::default()
394 };
395 assert!(cfg.validate().is_ok());
396 }
397
398 #[test]
399 fn test_accepts_nonzero_max_message_size() {
400 let cfg = WsConfig {
401 max_message_size: Some(1024),
402 ..WsConfig::default()
403 };
404 assert!(cfg.validate().is_ok());
405 }
406
407 #[test]
408 fn test_from_uri_rejects_invalid_send_to_all() {
409 let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendToAll=yes").unwrap_err();
410 assert!(err.to_string().contains("sendToAll"));
411 }
412
413 #[test]
414 fn test_from_uri_rejects_invalid_max_connections_numeric() {
415 let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxConnections=abc").unwrap_err();
416 assert!(err.to_string().contains("maxConnections"));
417 }
418
419 #[test]
420 fn test_from_uri_rejects_invalid_max_message_size_numeric() {
421 let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxMessageSize=abc").unwrap_err();
422 assert!(err.to_string().contains("maxMessageSize"));
423 }
424
425 #[test]
426 fn test_from_uri_rejects_invalid_heartbeat_interval_numeric() {
427 let err =
428 WsEndpointConfig::from_uri("ws://localhost:8080?heartbeatIntervalMs=abc").unwrap_err();
429 assert!(err.to_string().contains("heartbeatIntervalMs"));
430 }
431
432 #[test]
433 fn test_from_uri_rejects_invalid_idle_timeout_numeric() {
434 let err = WsEndpointConfig::from_uri("ws://localhost:8080?idleTimeoutMs=abc").unwrap_err();
435 assert!(err.to_string().contains("idleTimeoutMs"));
436 }
437
438 #[test]
439 fn test_from_uri_rejects_invalid_connect_timeout_numeric() {
440 let err =
441 WsEndpointConfig::from_uri("ws://localhost:8080?connectTimeoutMs=abc").unwrap_err();
442 assert!(err.to_string().contains("connectTimeoutMs"));
443 }
444
445 #[test]
446 fn test_from_uri_rejects_invalid_response_timeout_numeric() {
447 let err =
448 WsEndpointConfig::from_uri("ws://localhost:8080?responseTimeoutMs=abc").unwrap_err();
449 assert!(err.to_string().contains("responseTimeoutMs"));
450 }
451
452 #[test]
454 fn test_from_uri_parses_send_timeout_ms() {
455 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=7500").unwrap();
456 assert_eq!(cfg.send_timeout, Duration::from_millis(7500));
457 }
458
459 #[test]
460 fn test_from_uri_rejects_invalid_send_timeout_ms() {
461 let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=xyz").unwrap_err();
462 assert!(err.to_string().contains("sendTimeoutMs"));
463 }
464
465 #[test]
467 fn test_from_uri_parses_binary_payload_true() {
468 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=true").unwrap();
469 assert!(cfg.binary_payload);
470 }
471
472 #[test]
473 fn test_from_uri_parses_binary_payload_false() {
474 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=false").unwrap();
475 assert!(!cfg.binary_payload);
476 }
477
478 #[test]
479 fn test_from_uri_rejects_invalid_binary_payload() {
480 let err = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=sure").unwrap_err();
481 assert!(err.to_string().contains("binaryPayload"));
482 }
483
484 #[test]
486 fn test_from_uri_parses_subprotocols() {
487 let cfg =
488 WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=json,protobuf").unwrap();
489 assert_eq!(cfg.subprotocols, vec!["json", "protobuf"]);
490 }
491
492 #[test]
493 fn test_from_uri_subprotocols_trims_whitespace() {
494 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=a, b").unwrap();
495 assert_eq!(cfg.subprotocols, vec!["a", "b"]);
496 }
497
498 #[test]
499 fn test_from_uri_subprotocols_empty_when_not_specified() {
500 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080").unwrap();
501 assert!(cfg.subprotocols.is_empty());
502 }
503
504 #[test]
505 fn ws_endpoint_config_debug_redacts_tls() {
506 let config = WsEndpointConfig {
507 tls_cert: Some("/secret/cert.pem".to_string()),
508 tls_key: Some("/secret/key.pem".to_string()),
509 ..WsEndpointConfig::default()
510 };
511 let debug = format!("{:?}", config);
512 assert!(
513 !debug.contains("/secret/"),
514 "TLS paths must be redacted: {debug}"
515 );
516 assert!(
517 debug.contains("tls_cert"),
518 "field name should appear: {debug}"
519 );
520 assert!(
521 debug.contains("tls_key"),
522 "field name should appear: {debug}"
523 );
524 }
525
526 #[test]
527 fn ws_endpoint_config_has_reconnect_policy_field() {
528 let cfg = WsEndpointConfig::default();
529 assert!(cfg.reconnect_policy.enabled);
530 assert_eq!(cfg.reconnect_policy.max_attempts, 5);
531 assert_eq!(
532 cfg.reconnect_policy.initial_delay,
533 std::time::Duration::from_millis(1000)
534 );
535 }
536
537 #[test]
538 fn ws_endpoint_uri_bridges_flat_fields_to_policy() {
539 let uri =
541 "ws://localhost:9001/test?reconnect=false&reconnectMaxAttempts=7&reconnectDelayMs=250";
542 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
543 assert!(!cfg.reconnect);
544 assert_eq!(cfg.reconnect_max_attempts, 7);
545 assert_eq!(cfg.reconnect_delay_ms, 250);
546 assert!(!cfg.reconnect_policy.enabled);
548 assert_eq!(cfg.reconnect_policy.max_attempts, 7);
549 assert_eq!(
550 cfg.reconnect_policy.initial_delay,
551 std::time::Duration::from_millis(250)
552 );
553 }
554
555 #[test]
556 fn ws_endpoint_policy_defaults_match_old_flat_defaults() {
557 let cfg = WsEndpointConfig::default();
558 assert!(cfg.reconnect_policy.enabled);
559 assert_eq!(cfg.reconnect_policy.max_attempts, 5);
560 assert_eq!(
561 cfg.reconnect_policy.initial_delay,
562 std::time::Duration::from_millis(1000)
563 );
564 assert!((cfg.reconnect_policy.multiplier - 2.0).abs() < f64::EPSILON);
565 assert_eq!(
566 cfg.reconnect_policy.max_delay,
567 std::time::Duration::from_secs(30)
568 );
569 assert!((cfg.reconnect_policy.jitter_factor - 0.0).abs() < f64::EPSILON);
570 }
571
572 #[test]
578 fn ws_endpoint_policy_preserved_when_no_flat_fields() {
579 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080/echo").unwrap();
580 assert!(cfg.reconnect_policy.enabled);
582 assert_eq!(cfg.reconnect_policy.max_attempts, 5);
583 assert_eq!(
584 cfg.reconnect_policy.initial_delay,
585 std::time::Duration::from_millis(1000)
586 );
587 }
588
589 #[test]
599 fn ws_endpoint_policy_from_toml_preserved_across_from_uri() {
600 let toml_policy = NetworkRetryPolicy {
603 enabled: true,
604 max_attempts: 10,
605 initial_delay: std::time::Duration::from_millis(250),
606 multiplier: 3.0,
607 max_delay: std::time::Duration::from_secs(60),
608 jitter_factor: 0.1,
609 };
610 let mut cfg = WsEndpointConfig::from_uri("ws://localhost:8080/echo").unwrap();
615 cfg.reconnect_policy = toml_policy.clone();
616 let cfg2 = WsEndpointConfig::from_uri("ws://localhost:8080/echo").unwrap();
619 let default_policy = WsEndpointConfig::default().reconnect_policy;
624 assert_eq!(cfg2.reconnect_policy, default_policy);
625 cfg.reconnect_policy = toml_policy.clone();
628 assert_eq!(cfg.reconnect_policy.max_attempts, 10);
629 assert_eq!(
630 cfg.reconnect_policy.initial_delay,
631 std::time::Duration::from_millis(250)
632 );
633 }
634
635 #[test]
638 fn ws_endpoint_policy_partial_bridge() {
639 let uri = "ws://localhost:9001/test?reconnectMaxAttempts=10";
640 let cfg = WsEndpointConfig::from_uri(uri).unwrap();
641 assert_eq!(cfg.reconnect_policy.max_attempts, 10);
643 assert!(cfg.reconnect_policy.enabled);
645 assert_eq!(
646 cfg.reconnect_policy.initial_delay,
647 std::time::Duration::from_millis(1000)
648 );
649 }
650
651 #[test]
654 fn network_retry_policy_from_toml() {
655 let toml_str = r#"
656 enabled = true
657 max_attempts = 10
658 initial_delay_ms = 250
659 multiplier = 3.0
660 max_delay_ms = 60000
661 jitter_factor = 0.1
662 "#;
663 let policy: NetworkRetryPolicy = toml::from_str(toml_str).unwrap();
664 assert!(policy.enabled);
665 assert_eq!(policy.max_attempts, 10);
666 assert_eq!(policy.initial_delay, std::time::Duration::from_millis(250));
667 assert!((policy.multiplier - 3.0).abs() < f64::EPSILON);
668 assert_eq!(policy.max_delay, std::time::Duration::from_millis(60_000));
669 assert!((policy.jitter_factor - 0.1).abs() < f64::EPSILON);
670 }
671}