1use std::time::Duration;
2
3use camel_component_api::CamelError;
4
5#[derive(Debug, Clone, Default, serde::Deserialize)]
6pub struct WsConfig {
7 pub max_connections: Option<u32>,
8 pub max_message_size: Option<u32>,
9 pub heartbeat_interval_ms: Option<u64>,
10 pub idle_timeout_ms: Option<u64>,
11 pub connect_timeout_ms: Option<u64>,
12 pub response_timeout_ms: Option<u64>,
13 pub send_timeout_ms: Option<u64>,
14 pub binary_payload: Option<bool>,
15 pub subprotocols: Option<Vec<String>>,
16}
17
18#[derive(Clone)]
19pub struct WsEndpointConfig {
20 pub scheme: String,
21 pub host: String,
22 pub port: u16,
23 pub path: String,
24 pub max_connections: u32,
25 pub max_message_size: u32,
26 pub send_to_all: bool,
27 pub heartbeat_interval: Duration,
28 pub idle_timeout: Duration,
29 pub connect_timeout: Duration,
30 pub response_timeout: Duration,
31 pub allow_origin: String,
32 pub tls_cert: Option<String>,
33 pub tls_key: Option<String>,
34 pub reconnect: bool,
35 pub reconnect_max_attempts: u32,
36 pub reconnect_delay_ms: u64,
37 pub send_timeout: Duration,
38 pub binary_payload: bool,
39 pub subprotocols: Vec<String>,
40}
41
42fn redacted_opt(opt: &Option<String>) -> Option<&'static str> {
43 if opt.is_some() { Some("***") } else { None }
44}
45
46impl std::fmt::Debug for WsEndpointConfig {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("WsEndpointConfig")
49 .field("scheme", &self.scheme)
50 .field("host", &self.host)
51 .field("port", &self.port)
52 .field("path", &self.path)
53 .field("max_connections", &self.max_connections)
54 .field("max_message_size", &self.max_message_size)
55 .field("send_to_all", &self.send_to_all)
56 .field("heartbeat_interval", &self.heartbeat_interval)
57 .field("idle_timeout", &self.idle_timeout)
58 .field("connect_timeout", &self.connect_timeout)
59 .field("response_timeout", &self.response_timeout)
60 .field("allow_origin", &self.allow_origin)
61 .field("tls_cert", &redacted_opt(&self.tls_cert))
62 .field("tls_key", &redacted_opt(&self.tls_key))
63 .field("reconnect", &self.reconnect)
64 .field("reconnect_max_attempts", &self.reconnect_max_attempts)
65 .field("reconnect_delay_ms", &self.reconnect_delay_ms)
66 .field("send_timeout", &self.send_timeout)
67 .field("binary_payload", &self.binary_payload)
68 .field("subprotocols", &self.subprotocols)
69 .finish()
70 }
71}
72
73impl Default for WsEndpointConfig {
74 fn default() -> Self {
75 Self {
76 scheme: "ws".into(),
77 host: "0.0.0.0".into(),
78 port: 8080,
79 path: "/".into(),
80 max_connections: 100,
81 max_message_size: 65536,
82 send_to_all: false,
83 heartbeat_interval: Duration::ZERO,
84 idle_timeout: Duration::ZERO,
85 connect_timeout: Duration::from_secs(10),
86 response_timeout: Duration::from_secs(30),
87 allow_origin: "*".into(),
88 tls_cert: None,
89 tls_key: None,
90 reconnect: true,
91 reconnect_max_attempts: 5,
92 reconnect_delay_ms: 1000,
93 send_timeout: Duration::from_secs(30),
94 binary_payload: false,
95 subprotocols: Vec::new(),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct WsServerConfig {
102 pub inner: WsEndpointConfig,
103}
104
105#[derive(Debug, Clone)]
106pub struct WsClientConfig {
107 pub inner: WsEndpointConfig,
108}
109
110impl WsConfig {
111 pub fn validate(&self) -> Result<(), CamelError> {
116 if let Some(0) = self.max_connections {
117 return Err(CamelError::Config(
118 "maxConnections must be >= 1 when specified".into(),
119 ));
120 }
121 if let Some(0) = self.max_message_size {
122 return Err(CamelError::Config(
123 "maxMessageSize must be >= 1 when specified".into(),
124 ));
125 }
126 Ok(())
127 }
128}
129
130impl WsEndpointConfig {
131 pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
132 let parsed = camel_component_api::parse_uri(uri)
133 .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;
134
135 let scheme = parsed.scheme;
136 if scheme != "ws" && scheme != "wss" {
137 return Err(CamelError::EndpointCreationFailed(format!(
138 "Invalid WebSocket scheme: {scheme}"
139 )));
140 }
141
142 let host_port_path = parsed.path;
143 let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
144 let (host_port, path) = match host_port_path.split_once('/') {
145 Some((hp, p)) => (hp, format!("/{p}")),
146 None => (host_port_path, "/".to_string()),
147 };
148
149 let (host, port) = match host_port.rsplit_once(':') {
150 Some((h, p)) if p.parse::<u16>().is_ok() => {
151 let parsed_port = p.parse::<u16>().unwrap(); (h.to_string(), parsed_port)
153 }
154 _ => (
155 host_port.to_string(),
156 if scheme == "wss" { 443 } else { 80 },
157 ),
158 };
159
160 let mut cfg = Self {
161 scheme,
162 host: if host.is_empty() {
163 "0.0.0.0".to_string()
164 } else {
165 host
166 },
167 port,
168 path,
169 ..Self::default()
170 };
171
172 let params = parsed.params;
173 if let Some(raw) = params.get("maxConnections") {
175 let v = raw.parse::<u32>().map_err(|_| {
176 CamelError::InvalidUri(format!(
177 "maxConnections must be an unsigned integer, got '{raw}'"
178 ))
179 })?;
180 if v == 0 {
181 return Err(CamelError::InvalidUri("maxConnections must be >= 1".into()));
182 }
183 cfg.max_connections = v;
184 }
185 if let Some(raw) = params.get("maxMessageSize") {
187 let v = raw.parse::<u32>().map_err(|_| {
188 CamelError::InvalidUri(format!(
189 "maxMessageSize must be an unsigned integer, got '{raw}'"
190 ))
191 })?;
192 if v == 0 {
193 return Err(CamelError::InvalidUri("maxMessageSize must be > 0".into()));
194 }
195 cfg.max_message_size = v;
196 }
197 if let Some(raw) = params.get("sendToAll") {
198 let v = raw.parse::<bool>().map_err(|_| {
199 CamelError::InvalidUri(format!(
200 "sendToAll must be a boolean ('true' or 'false'), got '{raw}'"
201 ))
202 })?;
203 cfg.send_to_all = v;
204 }
205 if let Some(raw) = params.get("heartbeatIntervalMs") {
206 let v = raw.parse::<u64>().map_err(|_| {
207 CamelError::InvalidUri(format!(
208 "heartbeatIntervalMs must be an unsigned integer, got '{raw}'"
209 ))
210 })?;
211 cfg.heartbeat_interval = Duration::from_millis(v);
212 }
213 if let Some(raw) = params.get("idleTimeoutMs") {
214 let v = raw.parse::<u64>().map_err(|_| {
215 CamelError::InvalidUri(format!(
216 "idleTimeoutMs must be an unsigned integer, got '{raw}'"
217 ))
218 })?;
219 cfg.idle_timeout = Duration::from_millis(v);
220 }
221 if let Some(raw) = params.get("connectTimeoutMs") {
222 let v = raw.parse::<u64>().map_err(|_| {
223 CamelError::InvalidUri(format!(
224 "connectTimeoutMs must be an unsigned integer, got '{raw}'"
225 ))
226 })?;
227 cfg.connect_timeout = Duration::from_millis(v);
228 }
229 if let Some(raw) = params.get("responseTimeoutMs") {
230 let v = raw.parse::<u64>().map_err(|_| {
231 CamelError::InvalidUri(format!(
232 "responseTimeoutMs must be an unsigned integer, got '{raw}'"
233 ))
234 })?;
235 cfg.response_timeout = Duration::from_millis(v);
236 }
237 if let Some(v) = params.get("allowOrigin") {
238 if v.is_empty() {
239 return Err(CamelError::InvalidUri(
240 "allowOrigin must not be empty when specified".into(),
241 ));
242 }
243 cfg.allow_origin = v.to_string();
244 }
245 if let Some(v) = params.get("tlsCert") {
246 cfg.tls_cert = Some(v.to_string());
247 }
248 if let Some(v) = params.get("tlsKey") {
249 cfg.tls_key = Some(v.to_string());
250 }
251 if let Some(raw) = params.get("reconnect") {
252 cfg.reconnect = raw.parse::<bool>().map_err(|_| {
253 CamelError::InvalidUri(format!(
254 "reconnect must be a boolean ('true' or 'false'), got '{raw}'"
255 ))
256 })?;
257 }
258 if let Some(raw) = params.get("reconnectMaxAttempts") {
259 cfg.reconnect_max_attempts = raw.parse::<u32>().map_err(|_| {
260 CamelError::InvalidUri(format!(
261 "reconnectMaxAttempts must be an unsigned integer, got '{raw}'"
262 ))
263 })?;
264 }
265 if let Some(raw) = params.get("reconnectDelayMs") {
266 cfg.reconnect_delay_ms = raw.parse::<u64>().map_err(|_| {
267 CamelError::InvalidUri(format!(
268 "reconnectDelayMs must be an unsigned integer, got '{raw}'"
269 ))
270 })?;
271 }
272 if let Some(raw) = params.get("sendTimeoutMs") {
273 let v = raw.parse::<u64>().map_err(|_| {
274 CamelError::InvalidUri(format!(
275 "sendTimeoutMs must be an unsigned integer, got '{raw}'"
276 ))
277 })?;
278 cfg.send_timeout = Duration::from_millis(v);
279 }
280 if let Some(raw) = params.get("binaryPayload") {
281 cfg.binary_payload = raw.parse::<bool>().map_err(|_| {
282 CamelError::InvalidUri(format!(
283 "binaryPayload must be a boolean ('true' or 'false'), got '{raw}'"
284 ))
285 })?;
286 }
287 if let Some(raw) = params.get("subprotocols") {
288 cfg.subprotocols = raw
289 .split(',')
290 .map(|s| s.trim().to_string())
291 .filter(|s| !s.is_empty())
292 .collect();
293 }
294
295 Ok(cfg)
296 }
297
298 pub fn server_config(&self) -> WsServerConfig {
299 WsServerConfig {
300 inner: self.clone(),
301 }
302 }
303
304 pub fn client_config(&self) -> WsClientConfig {
305 WsClientConfig {
306 inner: self.clone(),
307 }
308 }
309
310 pub fn canonical_host(&self) -> String {
311 match self.host.as_str() {
312 "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
313 h => h.to_string(),
314 }
315 }
316}
317
318#[cfg(test)]
319mod config_validation_tests {
320 use super::*;
321
322 #[test]
323 fn test_rejects_zero_max_connections() {
324 let cfg = WsConfig {
325 max_connections: Some(0),
326 ..WsConfig::default()
327 };
328 assert!(cfg.validate().is_err());
329 }
330
331 #[test]
332 fn test_rejects_zero_max_message_size() {
333 let cfg = WsConfig {
334 max_message_size: Some(0),
335 ..WsConfig::default()
336 };
337 assert!(cfg.validate().is_err());
338 }
339
340 #[test]
341 fn test_accepts_valid_config() {
342 let cfg = WsConfig::default();
343 assert!(cfg.validate().is_ok());
344 }
345
346 #[test]
347 fn test_accepts_nonzero_max_connections() {
348 let cfg = WsConfig {
349 max_connections: Some(50),
350 ..WsConfig::default()
351 };
352 assert!(cfg.validate().is_ok());
353 }
354
355 #[test]
356 fn test_accepts_nonzero_max_message_size() {
357 let cfg = WsConfig {
358 max_message_size: Some(1024),
359 ..WsConfig::default()
360 };
361 assert!(cfg.validate().is_ok());
362 }
363
364 #[test]
365 fn test_from_uri_rejects_invalid_send_to_all() {
366 let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendToAll=yes").unwrap_err();
367 assert!(err.to_string().contains("sendToAll"));
368 }
369
370 #[test]
371 fn test_from_uri_rejects_invalid_max_connections_numeric() {
372 let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxConnections=abc").unwrap_err();
373 assert!(err.to_string().contains("maxConnections"));
374 }
375
376 #[test]
377 fn test_from_uri_rejects_invalid_max_message_size_numeric() {
378 let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxMessageSize=abc").unwrap_err();
379 assert!(err.to_string().contains("maxMessageSize"));
380 }
381
382 #[test]
383 fn test_from_uri_rejects_invalid_heartbeat_interval_numeric() {
384 let err =
385 WsEndpointConfig::from_uri("ws://localhost:8080?heartbeatIntervalMs=abc").unwrap_err();
386 assert!(err.to_string().contains("heartbeatIntervalMs"));
387 }
388
389 #[test]
390 fn test_from_uri_rejects_invalid_idle_timeout_numeric() {
391 let err = WsEndpointConfig::from_uri("ws://localhost:8080?idleTimeoutMs=abc").unwrap_err();
392 assert!(err.to_string().contains("idleTimeoutMs"));
393 }
394
395 #[test]
396 fn test_from_uri_rejects_invalid_connect_timeout_numeric() {
397 let err =
398 WsEndpointConfig::from_uri("ws://localhost:8080?connectTimeoutMs=abc").unwrap_err();
399 assert!(err.to_string().contains("connectTimeoutMs"));
400 }
401
402 #[test]
403 fn test_from_uri_rejects_invalid_response_timeout_numeric() {
404 let err =
405 WsEndpointConfig::from_uri("ws://localhost:8080?responseTimeoutMs=abc").unwrap_err();
406 assert!(err.to_string().contains("responseTimeoutMs"));
407 }
408
409 #[test]
411 fn test_from_uri_parses_send_timeout_ms() {
412 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=7500").unwrap();
413 assert_eq!(cfg.send_timeout, Duration::from_millis(7500));
414 }
415
416 #[test]
417 fn test_from_uri_rejects_invalid_send_timeout_ms() {
418 let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=xyz").unwrap_err();
419 assert!(err.to_string().contains("sendTimeoutMs"));
420 }
421
422 #[test]
424 fn test_from_uri_parses_binary_payload_true() {
425 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=true").unwrap();
426 assert!(cfg.binary_payload);
427 }
428
429 #[test]
430 fn test_from_uri_parses_binary_payload_false() {
431 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=false").unwrap();
432 assert!(!cfg.binary_payload);
433 }
434
435 #[test]
436 fn test_from_uri_rejects_invalid_binary_payload() {
437 let err = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=sure").unwrap_err();
438 assert!(err.to_string().contains("binaryPayload"));
439 }
440
441 #[test]
443 fn test_from_uri_parses_subprotocols() {
444 let cfg =
445 WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=json,protobuf").unwrap();
446 assert_eq!(cfg.subprotocols, vec!["json", "protobuf"]);
447 }
448
449 #[test]
450 fn test_from_uri_subprotocols_trims_whitespace() {
451 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=a, b").unwrap();
452 assert_eq!(cfg.subprotocols, vec!["a", "b"]);
453 }
454
455 #[test]
456 fn test_from_uri_subprotocols_empty_when_not_specified() {
457 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080").unwrap();
458 assert!(cfg.subprotocols.is_empty());
459 }
460
461 #[test]
462 fn ws_endpoint_config_debug_redacts_tls() {
463 let config = WsEndpointConfig {
464 tls_cert: Some("/secret/cert.pem".to_string()),
465 tls_key: Some("/secret/key.pem".to_string()),
466 ..WsEndpointConfig::default()
467 };
468 let debug = format!("{:?}", config);
469 assert!(
470 !debug.contains("/secret/"),
471 "TLS paths must be redacted: {debug}"
472 );
473 assert!(
474 debug.contains("tls_cert"),
475 "field name should appear: {debug}"
476 );
477 assert!(
478 debug.contains("tls_key"),
479 "field name should appear: {debug}"
480 );
481 }
482}