1use std::time::Duration;
2
3use camel_component_api::CamelError;
4
5#[derive(Debug, Clone, Default, serde::Deserialize)]
6pub struct WsConfig {
7 pub max_connections: Option<u32>,
8 pub max_message_size: Option<u32>,
9 pub heartbeat_interval_ms: Option<u64>,
10 pub idle_timeout_ms: Option<u64>,
11 pub connect_timeout_ms: Option<u64>,
12 pub response_timeout_ms: Option<u64>,
13 pub send_timeout_ms: Option<u64>,
14 pub binary_payload: Option<bool>,
15 pub subprotocols: Option<Vec<String>>,
16}
17
18#[derive(Debug, Clone)]
19pub struct WsEndpointConfig {
20 pub scheme: String,
21 pub host: String,
22 pub port: u16,
23 pub path: String,
24 pub max_connections: u32,
25 pub max_message_size: u32,
26 pub send_to_all: bool,
27 pub heartbeat_interval: Duration,
28 pub idle_timeout: Duration,
29 pub connect_timeout: Duration,
30 pub response_timeout: Duration,
31 pub allow_origin: String,
32 pub tls_cert: Option<String>,
33 pub tls_key: Option<String>,
34 pub reconnect: bool,
35 pub reconnect_max_attempts: u32,
36 pub reconnect_delay_ms: u64,
37 pub send_timeout: Duration,
38 pub binary_payload: bool,
39 pub subprotocols: Vec<String>,
40}
41
42impl Default for WsEndpointConfig {
43 fn default() -> Self {
44 Self {
45 scheme: "ws".into(),
46 host: "0.0.0.0".into(),
47 port: 8080,
48 path: "/".into(),
49 max_connections: 100,
50 max_message_size: 65536,
51 send_to_all: false,
52 heartbeat_interval: Duration::ZERO,
53 idle_timeout: Duration::ZERO,
54 connect_timeout: Duration::from_secs(10),
55 response_timeout: Duration::from_secs(30),
56 allow_origin: "*".into(),
57 tls_cert: None,
58 tls_key: None,
59 reconnect: true,
60 reconnect_max_attempts: 5,
61 reconnect_delay_ms: 1000,
62 send_timeout: Duration::from_secs(30),
63 binary_payload: false,
64 subprotocols: Vec::new(),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct WsServerConfig {
71 pub inner: WsEndpointConfig,
72}
73
74#[derive(Debug, Clone)]
75pub struct WsClientConfig {
76 pub inner: WsEndpointConfig,
77}
78
79impl WsConfig {
80 pub fn validate(&self) -> Result<(), CamelError> {
85 if let Some(0) = self.max_connections {
86 return Err(CamelError::Config(
87 "maxConnections must be >= 1 when specified".into(),
88 ));
89 }
90 if let Some(0) = self.max_message_size {
91 return Err(CamelError::Config(
92 "maxMessageSize must be >= 1 when specified".into(),
93 ));
94 }
95 Ok(())
96 }
97}
98
99impl WsEndpointConfig {
100 pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
101 let parsed = camel_component_api::parse_uri(uri)
102 .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;
103
104 let scheme = parsed.scheme;
105 if scheme != "ws" && scheme != "wss" {
106 return Err(CamelError::EndpointCreationFailed(format!(
107 "Invalid WebSocket scheme: {scheme}"
108 )));
109 }
110
111 let host_port_path = parsed.path;
112 let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
113 let (host_port, path) = match host_port_path.split_once('/') {
114 Some((hp, p)) => (hp, format!("/{p}")),
115 None => (host_port_path, "/".to_string()),
116 };
117
118 let (host, port) = match host_port.rsplit_once(':') {
119 Some((h, p)) if p.parse::<u16>().is_ok() => {
120 let parsed_port = p.parse::<u16>().unwrap(); (h.to_string(), parsed_port)
122 }
123 _ => (
124 host_port.to_string(),
125 if scheme == "wss" { 443 } else { 80 },
126 ),
127 };
128
129 let mut cfg = Self {
130 scheme,
131 host: if host.is_empty() {
132 "0.0.0.0".to_string()
133 } else {
134 host
135 },
136 port,
137 path,
138 ..Self::default()
139 };
140
141 let params = parsed.params;
142 if let Some(raw) = params.get("maxConnections") {
144 let v = raw.parse::<u32>().map_err(|_| {
145 CamelError::InvalidUri(format!(
146 "maxConnections must be an unsigned integer, got '{raw}'"
147 ))
148 })?;
149 if v == 0 {
150 return Err(CamelError::InvalidUri("maxConnections must be >= 1".into()));
151 }
152 cfg.max_connections = v;
153 }
154 if let Some(raw) = params.get("maxMessageSize") {
156 let v = raw.parse::<u32>().map_err(|_| {
157 CamelError::InvalidUri(format!(
158 "maxMessageSize must be an unsigned integer, got '{raw}'"
159 ))
160 })?;
161 if v == 0 {
162 return Err(CamelError::InvalidUri("maxMessageSize must be > 0".into()));
163 }
164 cfg.max_message_size = v;
165 }
166 if let Some(raw) = params.get("sendToAll") {
167 let v = raw.parse::<bool>().map_err(|_| {
168 CamelError::InvalidUri(format!(
169 "sendToAll must be a boolean ('true' or 'false'), got '{raw}'"
170 ))
171 })?;
172 cfg.send_to_all = v;
173 }
174 if let Some(raw) = params.get("heartbeatIntervalMs") {
175 let v = raw.parse::<u64>().map_err(|_| {
176 CamelError::InvalidUri(format!(
177 "heartbeatIntervalMs must be an unsigned integer, got '{raw}'"
178 ))
179 })?;
180 cfg.heartbeat_interval = Duration::from_millis(v);
181 }
182 if let Some(raw) = params.get("idleTimeoutMs") {
183 let v = raw.parse::<u64>().map_err(|_| {
184 CamelError::InvalidUri(format!(
185 "idleTimeoutMs must be an unsigned integer, got '{raw}'"
186 ))
187 })?;
188 cfg.idle_timeout = Duration::from_millis(v);
189 }
190 if let Some(raw) = params.get("connectTimeoutMs") {
191 let v = raw.parse::<u64>().map_err(|_| {
192 CamelError::InvalidUri(format!(
193 "connectTimeoutMs must be an unsigned integer, got '{raw}'"
194 ))
195 })?;
196 cfg.connect_timeout = Duration::from_millis(v);
197 }
198 if let Some(raw) = params.get("responseTimeoutMs") {
199 let v = raw.parse::<u64>().map_err(|_| {
200 CamelError::InvalidUri(format!(
201 "responseTimeoutMs must be an unsigned integer, got '{raw}'"
202 ))
203 })?;
204 cfg.response_timeout = Duration::from_millis(v);
205 }
206 if let Some(v) = params.get("allowOrigin") {
207 if v.is_empty() {
208 return Err(CamelError::InvalidUri(
209 "allowOrigin must not be empty when specified".into(),
210 ));
211 }
212 cfg.allow_origin = v.to_string();
213 }
214 if let Some(v) = params.get("tlsCert") {
215 cfg.tls_cert = Some(v.to_string());
216 }
217 if let Some(v) = params.get("tlsKey") {
218 cfg.tls_key = Some(v.to_string());
219 }
220 if let Some(raw) = params.get("reconnect") {
221 cfg.reconnect = raw.parse::<bool>().map_err(|_| {
222 CamelError::InvalidUri(format!(
223 "reconnect must be a boolean ('true' or 'false'), got '{raw}'"
224 ))
225 })?;
226 }
227 if let Some(raw) = params.get("reconnectMaxAttempts") {
228 cfg.reconnect_max_attempts = raw.parse::<u32>().map_err(|_| {
229 CamelError::InvalidUri(format!(
230 "reconnectMaxAttempts must be an unsigned integer, got '{raw}'"
231 ))
232 })?;
233 }
234 if let Some(raw) = params.get("reconnectDelayMs") {
235 cfg.reconnect_delay_ms = raw.parse::<u64>().map_err(|_| {
236 CamelError::InvalidUri(format!(
237 "reconnectDelayMs must be an unsigned integer, got '{raw}'"
238 ))
239 })?;
240 }
241 if let Some(raw) = params.get("sendTimeoutMs") {
242 let v = raw.parse::<u64>().map_err(|_| {
243 CamelError::InvalidUri(format!(
244 "sendTimeoutMs must be an unsigned integer, got '{raw}'"
245 ))
246 })?;
247 cfg.send_timeout = Duration::from_millis(v);
248 }
249 if let Some(raw) = params.get("binaryPayload") {
250 cfg.binary_payload = raw.parse::<bool>().map_err(|_| {
251 CamelError::InvalidUri(format!(
252 "binaryPayload must be a boolean ('true' or 'false'), got '{raw}'"
253 ))
254 })?;
255 }
256 if let Some(raw) = params.get("subprotocols") {
257 cfg.subprotocols = raw
258 .split(',')
259 .map(|s| s.trim().to_string())
260 .filter(|s| !s.is_empty())
261 .collect();
262 }
263
264 Ok(cfg)
265 }
266
267 pub fn server_config(&self) -> WsServerConfig {
268 WsServerConfig {
269 inner: self.clone(),
270 }
271 }
272
273 pub fn client_config(&self) -> WsClientConfig {
274 WsClientConfig {
275 inner: self.clone(),
276 }
277 }
278
279 pub fn canonical_host(&self) -> String {
280 match self.host.as_str() {
281 "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
282 h => h.to_string(),
283 }
284 }
285}
286
287#[cfg(test)]
288mod config_validation_tests {
289 use super::*;
290
291 #[test]
292 fn test_rejects_zero_max_connections() {
293 let cfg = WsConfig {
294 max_connections: Some(0),
295 ..WsConfig::default()
296 };
297 assert!(cfg.validate().is_err());
298 }
299
300 #[test]
301 fn test_rejects_zero_max_message_size() {
302 let cfg = WsConfig {
303 max_message_size: Some(0),
304 ..WsConfig::default()
305 };
306 assert!(cfg.validate().is_err());
307 }
308
309 #[test]
310 fn test_accepts_valid_config() {
311 let cfg = WsConfig::default();
312 assert!(cfg.validate().is_ok());
313 }
314
315 #[test]
316 fn test_accepts_nonzero_max_connections() {
317 let cfg = WsConfig {
318 max_connections: Some(50),
319 ..WsConfig::default()
320 };
321 assert!(cfg.validate().is_ok());
322 }
323
324 #[test]
325 fn test_accepts_nonzero_max_message_size() {
326 let cfg = WsConfig {
327 max_message_size: Some(1024),
328 ..WsConfig::default()
329 };
330 assert!(cfg.validate().is_ok());
331 }
332
333 #[test]
334 fn test_from_uri_rejects_invalid_send_to_all() {
335 let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendToAll=yes").unwrap_err();
336 assert!(err.to_string().contains("sendToAll"));
337 }
338
339 #[test]
340 fn test_from_uri_rejects_invalid_max_connections_numeric() {
341 let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxConnections=abc").unwrap_err();
342 assert!(err.to_string().contains("maxConnections"));
343 }
344
345 #[test]
346 fn test_from_uri_rejects_invalid_max_message_size_numeric() {
347 let err = WsEndpointConfig::from_uri("ws://localhost:8080?maxMessageSize=abc").unwrap_err();
348 assert!(err.to_string().contains("maxMessageSize"));
349 }
350
351 #[test]
352 fn test_from_uri_rejects_invalid_heartbeat_interval_numeric() {
353 let err =
354 WsEndpointConfig::from_uri("ws://localhost:8080?heartbeatIntervalMs=abc").unwrap_err();
355 assert!(err.to_string().contains("heartbeatIntervalMs"));
356 }
357
358 #[test]
359 fn test_from_uri_rejects_invalid_idle_timeout_numeric() {
360 let err = WsEndpointConfig::from_uri("ws://localhost:8080?idleTimeoutMs=abc").unwrap_err();
361 assert!(err.to_string().contains("idleTimeoutMs"));
362 }
363
364 #[test]
365 fn test_from_uri_rejects_invalid_connect_timeout_numeric() {
366 let err =
367 WsEndpointConfig::from_uri("ws://localhost:8080?connectTimeoutMs=abc").unwrap_err();
368 assert!(err.to_string().contains("connectTimeoutMs"));
369 }
370
371 #[test]
372 fn test_from_uri_rejects_invalid_response_timeout_numeric() {
373 let err =
374 WsEndpointConfig::from_uri("ws://localhost:8080?responseTimeoutMs=abc").unwrap_err();
375 assert!(err.to_string().contains("responseTimeoutMs"));
376 }
377
378 #[test]
380 fn test_from_uri_parses_send_timeout_ms() {
381 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=7500").unwrap();
382 assert_eq!(cfg.send_timeout, Duration::from_millis(7500));
383 }
384
385 #[test]
386 fn test_from_uri_rejects_invalid_send_timeout_ms() {
387 let err = WsEndpointConfig::from_uri("ws://localhost:8080?sendTimeoutMs=xyz").unwrap_err();
388 assert!(err.to_string().contains("sendTimeoutMs"));
389 }
390
391 #[test]
393 fn test_from_uri_parses_binary_payload_true() {
394 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=true").unwrap();
395 assert!(cfg.binary_payload);
396 }
397
398 #[test]
399 fn test_from_uri_parses_binary_payload_false() {
400 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=false").unwrap();
401 assert!(!cfg.binary_payload);
402 }
403
404 #[test]
405 fn test_from_uri_rejects_invalid_binary_payload() {
406 let err = WsEndpointConfig::from_uri("ws://localhost:8080?binaryPayload=sure").unwrap_err();
407 assert!(err.to_string().contains("binaryPayload"));
408 }
409
410 #[test]
412 fn test_from_uri_parses_subprotocols() {
413 let cfg =
414 WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=json,protobuf").unwrap();
415 assert_eq!(cfg.subprotocols, vec!["json", "protobuf"]);
416 }
417
418 #[test]
419 fn test_from_uri_subprotocols_trims_whitespace() {
420 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080?subprotocols=a, b").unwrap();
421 assert_eq!(cfg.subprotocols, vec!["a", "b"]);
422 }
423
424 #[test]
425 fn test_from_uri_subprotocols_empty_when_not_specified() {
426 let cfg = WsEndpointConfig::from_uri("ws://localhost:8080").unwrap();
427 assert!(cfg.subprotocols.is_empty());
428 }
429}