Skip to main content

camel_component_ws/
config.rs

1use std::time::Duration;
2
3use camel_component_api::CamelError;
4
5#[derive(Debug, Clone, Default, serde::Deserialize)]
6pub struct WsConfig {
7    pub max_connections: Option<u32>,
8    pub max_message_size: Option<u32>,
9    pub heartbeat_interval_ms: Option<u64>,
10    pub idle_timeout_ms: Option<u64>,
11    pub connect_timeout_ms: Option<u64>,
12    pub response_timeout_ms: Option<u64>,
13    pub send_timeout_ms: Option<u64>,
14    pub binary_payload: Option<bool>,
15    pub subprotocols: Option<Vec<String>>,
16}
17
18#[derive(Debug, Clone)]
19pub struct WsEndpointConfig {
20    pub scheme: String,
21    pub host: String,
22    pub port: u16,
23    pub path: String,
24    pub max_connections: u32,
25    pub max_message_size: u32,
26    pub send_to_all: bool,
27    pub heartbeat_interval: Duration,
28    pub idle_timeout: Duration,
29    pub connect_timeout: Duration,
30    pub response_timeout: Duration,
31    pub allow_origin: String,
32    pub tls_cert: Option<String>,
33    pub tls_key: Option<String>,
34    pub reconnect: bool,
35    pub reconnect_max_attempts: u32,
36    pub reconnect_delay_ms: u64,
37    pub send_timeout: Duration,
38    pub binary_payload: bool,
39    pub subprotocols: Vec<String>,
40}
41
42impl Default for WsEndpointConfig {
43    fn default() -> Self {
44        Self {
45            scheme: "ws".into(),
46            host: "0.0.0.0".into(),
47            port: 8080,
48            path: "/".into(),
49            max_connections: 100,
50            max_message_size: 65536,
51            send_to_all: false,
52            heartbeat_interval: Duration::ZERO,
53            idle_timeout: Duration::ZERO,
54            connect_timeout: Duration::from_secs(10),
55            response_timeout: Duration::from_secs(30),
56            allow_origin: "*".into(),
57            tls_cert: None,
58            tls_key: None,
59            reconnect: true,
60            reconnect_max_attempts: 5,
61            reconnect_delay_ms: 1000,
62            send_timeout: Duration::from_secs(30),
63            binary_payload: false,
64            subprotocols: Vec::new(),
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct WsServerConfig {
71    pub inner: WsEndpointConfig,
72}
73
74#[derive(Debug, Clone)]
75pub struct WsClientConfig {
76    pub inner: WsEndpointConfig,
77}
78
79impl WsConfig {
80    /// Validate configuration values.
81    ///
82    /// Returns an error if any explicitly-set value is invalid (e.g. zero).
83    /// `None` values are valid — they mean "use the default / unlimited".
84    pub fn validate(&self) -> Result<(), CamelError> {
85        if let Some(0) = self.max_connections {
86            return Err(CamelError::Config(
87                "maxConnections must be >= 1 when specified".into(),
88            ));
89        }
90        if let Some(0) = self.max_message_size {
91            return Err(CamelError::Config(
92                "maxMessageSize must be >= 1 when specified".into(),
93            ));
94        }
95        Ok(())
96    }
97}
98
99impl WsEndpointConfig {
100    pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
101        let parsed = camel_component_api::parse_uri(uri)
102            .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;
103
104        let scheme = parsed.scheme;
105        if scheme != "ws" && scheme != "wss" {
106            return Err(CamelError::EndpointCreationFailed(format!(
107                "Invalid WebSocket scheme: {scheme}"
108            )));
109        }
110
111        let host_port_path = parsed.path;
112        let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
113        let (host_port, path) = match host_port_path.split_once('/') {
114            Some((hp, p)) => (hp, format!("/{p}")),
115            None => (host_port_path, "/".to_string()),
116        };
117
118        let (host, port) = match host_port.rsplit_once(':') {
119            Some((h, p)) if p.parse::<u16>().is_ok() => {
120                let parsed_port = p.parse::<u16>().unwrap(); // allow-unwrap
121                (h.to_string(), parsed_port)
122            }
123            _ => (
124                host_port.to_string(),
125                if scheme == "wss" { 443 } else { 80 },
126            ),
127        };
128
129        let mut cfg = Self {
130            scheme,
131            host: if host.is_empty() {
132                "0.0.0.0".to_string()
133            } else {
134                host
135            },
136            port,
137            path,
138            ..Self::default()
139        };
140
141        let params = parsed.params;
142        // Validate maxConnections >= 1 (WS-015)
143        if let Some(raw) = params.get("maxConnections") {
144            let v = raw.parse::<u32>().map_err(|_| {
145                CamelError::InvalidUri(format!(
146                    "maxConnections must be an unsigned integer, got '{raw}'"
147                ))
148            })?;
149            if v == 0 {
150                return Err(CamelError::InvalidUri("maxConnections must be >= 1".into()));
151            }
152            cfg.max_connections = v;
153        }
154        // Validate maxMessageSize > 0 (WS-019)
155        if let Some(raw) = params.get("maxMessageSize") {
156            let v = raw.parse::<u32>().map_err(|_| {
157                CamelError::InvalidUri(format!(
158                    "maxMessageSize must be an unsigned integer, got '{raw}'"
159                ))
160            })?;
161            if v == 0 {
162                return Err(CamelError::InvalidUri("maxMessageSize must be > 0".into()));
163            }
164            cfg.max_message_size = v;
165        }
166        if let Some(raw) = params.get("sendToAll") {
167            let v = raw.parse::<bool>().map_err(|_| {
168                CamelError::InvalidUri(format!(
169                    "sendToAll must be a boolean ('true' or 'false'), got '{raw}'"
170                ))
171            })?;
172            cfg.send_to_all = v;
173        }
174        if let Some(raw) = params.get("heartbeatIntervalMs") {
175            let v = raw.parse::<u64>().map_err(|_| {
176                CamelError::InvalidUri(format!(
177                    "heartbeatIntervalMs must be an unsigned integer, got '{raw}'"
178                ))
179            })?;
180            cfg.heartbeat_interval = Duration::from_millis(v);
181        }
182        if let Some(raw) = params.get("idleTimeoutMs") {
183            let v = raw.parse::<u64>().map_err(|_| {
184                CamelError::InvalidUri(format!(
185                    "idleTimeoutMs must be an unsigned integer, got '{raw}'"
186                ))
187            })?;
188            cfg.idle_timeout = Duration::from_millis(v);
189        }
190        if let Some(raw) = params.get("connectTimeoutMs") {
191            let v = raw.parse::<u64>().map_err(|_| {
192                CamelError::InvalidUri(format!(
193                    "connectTimeoutMs must be an unsigned integer, got '{raw}'"
194                ))
195            })?;
196            cfg.connect_timeout = Duration::from_millis(v);
197        }
198        if let Some(raw) = params.get("responseTimeoutMs") {
199            let v = raw.parse::<u64>().map_err(|_| {
200                CamelError::InvalidUri(format!(
201                    "responseTimeoutMs must be an unsigned integer, got '{raw}'"
202                ))
203            })?;
204            cfg.response_timeout = Duration::from_millis(v);
205        }
206        if let Some(v) = params.get("allowOrigin") {
207            if v.is_empty() {
208                return Err(CamelError::InvalidUri(
209                    "allowOrigin must not be empty when specified".into(),
210                ));
211            }
212            cfg.allow_origin = v.to_string();
213        }
214        if let Some(v) = params.get("tlsCert") {
215            cfg.tls_cert = Some(v.to_string());
216        }
217        if let Some(v) = params.get("tlsKey") {
218            cfg.tls_key = Some(v.to_string());
219        }
220        if let Some(raw) = params.get("reconnect") {
221            cfg.reconnect = raw.parse::<bool>().map_err(|_| {
222                CamelError::InvalidUri(format!(
223                    "reconnect must be a boolean ('true' or 'false'), got '{raw}'"
224                ))
225            })?;
226        }
227        if let Some(raw) = params.get("reconnectMaxAttempts") {
228            cfg.reconnect_max_attempts = raw.parse::<u32>().map_err(|_| {
229                CamelError::InvalidUri(format!(
230                    "reconnectMaxAttempts must be an unsigned integer, got '{raw}'"
231                ))
232            })?;
233        }
234        if let Some(raw) = params.get("reconnectDelayMs") {
235            cfg.reconnect_delay_ms = raw.parse::<u64>().map_err(|_| {
236                CamelError::InvalidUri(format!(
237                    "reconnectDelayMs must be an unsigned integer, got '{raw}'"
238                ))
239            })?;
240        }
241        if let Some(raw) = params.get("sendTimeoutMs") {
242            let v = raw.parse::<u64>().map_err(|_| {
243                CamelError::InvalidUri(format!(
244                    "sendTimeoutMs must be an unsigned integer, got '{raw}'"
245                ))
246            })?;
247            cfg.send_timeout = Duration::from_millis(v);
248        }
249        if let Some(raw) = params.get("binaryPayload") {
250            cfg.binary_payload = raw.parse::<bool>().map_err(|_| {
251                CamelError::InvalidUri(format!(
252                    "binaryPayload must be a boolean ('true' or 'false'), got '{raw}'"
253                ))
254            })?;
255        }
256        if let Some(raw) = params.get("subprotocols") {
257            cfg.subprotocols = raw
258                .split(',')
259                .map(|s| s.trim().to_string())
260                .filter(|s| !s.is_empty())
261                .collect();
262        }
263
264        Ok(cfg)
265    }
266
267    pub fn server_config(&self) -> WsServerConfig {
268        WsServerConfig {
269            inner: self.clone(),
270        }
271    }
272
273    pub fn client_config(&self) -> WsClientConfig {
274        WsClientConfig {
275            inner: self.clone(),
276        }
277    }
278
279    pub fn canonical_host(&self) -> String {
280        match self.host.as_str() {
281            "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
282            h => h.to_string(),
283        }
284    }
285}
286
287#[cfg(test)]
288mod config_validation_tests {
289    use super::*;
290
291    #[test]
292    fn test_rejects_zero_max_connections() {
293        let cfg = WsConfig {
294            max_connections: Some(0),
295            ..WsConfig::default()
296        };
297        assert!(cfg.validate().is_err());
298    }
299
300    #[test]
301    fn test_rejects_zero_max_message_size() {
302        let cfg = WsConfig {
303            max_message_size: Some(0),
304            ..WsConfig::default()
305        };
306        assert!(cfg.validate().is_err());
307    }
308
309    #[test]
310    fn test_accepts_valid_config() {
311        let cfg = WsConfig::default();
312        assert!(cfg.validate().is_ok());
313    }
314
315    #[test]
316    fn test_accepts_nonzero_max_connections() {
317        let cfg = WsConfig {
318            max_connections: Some(50),
319            ..WsConfig::default()
320        };
321        assert!(cfg.validate().is_ok());
322    }
323
324    #[test]
325    fn test_accepts_nonzero_max_message_size() {
326        let cfg = WsConfig {
327            max_message_size: Some(1024),
328            ..WsConfig::default()
329        };
330        assert!(cfg.validate().is_ok());
331    }
332
333    #[test]
334    fn test_from_uri_rejects_invalid_send_to_all() {
335        let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendToAll=yes").unwrap_err();
336        assert!(err.to_string().contains("sendToAll"));
337    }
338
339    #[test]
340    fn test_from_uri_rejects_invalid_max_connections_numeric() {
341        let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxConnections=abc").unwrap_err();
342        assert!(err.to_string().contains("maxConnections"));
343    }
344
345    #[test]
346    fn test_from_uri_rejects_invalid_max_message_size_numeric() {
347        let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxMessageSize=abc").unwrap_err();
348        assert!(err.to_string().contains("maxMessageSize"));
349    }
350
351    #[test]
352    fn test_from_uri_rejects_invalid_heartbeat_interval_numeric() {
353        let err =
354            WsEndpointConfig::from_uri("ws://localhost:8080?heartbeatIntervalMs=abc").unwrap_err();
355        assert!(err.to_string().contains("heartbeatIntervalMs"));
356    }
357
358    #[test]
359    fn test_from_uri_rejects_invalid_idle_timeout_numeric() {
360        let err = WsEndpointConfig::from_uri("ws://localhost:8080?idleTimeoutMs=abc").unwrap_err();
361        assert!(err.to_string().contains("idleTimeoutMs"));
362    }
363
364    #[test]
365    fn test_from_uri_rejects_invalid_connect_timeout_numeric() {
366        let err =
367            WsEndpointConfig::from_uri("ws://localhost:8080?connectTimeoutMs=abc").unwrap_err();
368        assert!(err.to_string().contains("connectTimeoutMs"));
369    }
370
371    #[test]
372    fn test_from_uri_rejects_invalid_response_timeout_numeric() {
373        let err =
374            WsEndpointConfig::from_uri("ws://localhost:8080?responseTimeoutMs=abc").unwrap_err();
375        assert!(err.to_string().contains("responseTimeoutMs"));
376    }
377
378    // WS-017: sendTimeoutMs parsing
379    #[test]
380    fn test_from_uri_parses_send_timeout_ms() {
381        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=7500").unwrap();
382        assert_eq!(cfg.send_timeout, Duration::from_millis(7500));
383    }
384
385    #[test]
386    fn test_from_uri_rejects_invalid_send_timeout_ms() {
387        let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=xyz").unwrap_err();
388        assert!(err.to_string().contains("sendTimeoutMs"));
389    }
390
391    // WS-018: binaryPayload parsing
392    #[test]
393    fn test_from_uri_parses_binary_payload_true() {
394        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=true").unwrap();
395        assert!(cfg.binary_payload);
396    }
397
398    #[test]
399    fn test_from_uri_parses_binary_payload_false() {
400        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=false").unwrap();
401        assert!(!cfg.binary_payload);
402    }
403
404    #[test]
405    fn test_from_uri_rejects_invalid_binary_payload() {
406        let err = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=sure").unwrap_err();
407        assert!(err.to_string().contains("binaryPayload"));
408    }
409
410    // WS-007: subprotocols parsing
411    #[test]
412    fn test_from_uri_parses_subprotocols() {
413        let cfg =
414            WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=json,protobuf").unwrap();
415        assert_eq!(cfg.subprotocols, vec!["json", "protobuf"]);
416    }
417
418    #[test]
419    fn test_from_uri_subprotocols_trims_whitespace() {
420        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=a, b").unwrap();
421        assert_eq!(cfg.subprotocols, vec!["a", "b"]);
422    }
423
424    #[test]
425    fn test_from_uri_subprotocols_empty_when_not_specified() {
426        let cfg = WsEndpointConfig::from_uri("ws://localhost:8080").unwrap();
427        assert!(cfg.subprotocols.is_empty());
428    }
429}