Skip to main content

camel_component_ws/
config.rs

1use std::time::Duration;
2
3use camel_component_api::CamelError;
4
5#[derive(Debug, Clone, Default, serde::Deserialize)]
6pub struct WsConfig {
7    pub max_connections: Option<u32>,
8    pub max_message_size: Option<u32>,
9    pub heartbeat_interval_ms: Option<u64>,
10    pub idle_timeout_ms: Option<u64>,
11    pub connect_timeout_ms: Option<u64>,
12    pub response_timeout_ms: Option<u64>,
13}
14
15#[derive(Debug, Clone)]
16pub struct WsEndpointConfig {
17    pub scheme: String,
18    pub host: String,
19    pub port: u16,
20    pub path: String,
21    pub max_connections: u32,
22    pub max_message_size: u32,
23    pub send_to_all: bool,
24    pub heartbeat_interval: Duration,
25    pub idle_timeout: Duration,
26    pub connect_timeout: Duration,
27    pub response_timeout: Duration,
28    pub allow_origin: String,
29    pub tls_cert: Option<String>,
30    pub tls_key: Option<String>,
31}
32
33impl Default for WsEndpointConfig {
34    fn default() -> Self {
35        Self {
36            scheme: "ws".into(),
37            host: "0.0.0.0".into(),
38            port: 8080,
39            path: "/".into(),
40            max_connections: 100,
41            max_message_size: 65536,
42            send_to_all: false,
43            heartbeat_interval: Duration::ZERO,
44            idle_timeout: Duration::ZERO,
45            connect_timeout: Duration::from_secs(10),
46            response_timeout: Duration::from_secs(30),
47            allow_origin: "*".into(),
48            tls_cert: None,
49            tls_key: None,
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct WsServerConfig {
56    pub inner: WsEndpointConfig,
57}
58
59#[derive(Debug, Clone)]
60pub struct WsClientConfig {
61    pub inner: WsEndpointConfig,
62}
63
64impl WsEndpointConfig {
65    pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
66        let parsed = camel_component_api::parse_uri(uri)
67            .map_err(|e| CamelError::EndpointCreationFailed(e.to_string()))?;
68
69        let scheme = parsed.scheme;
70        if scheme != "ws" && scheme != "wss" {
71            return Err(CamelError::EndpointCreationFailed(format!(
72                "Invalid WebSocket scheme: {scheme}"
73            )));
74        }
75
76        let host_port_path = parsed.path;
77        let host_port_path = host_port_path.strip_prefix("//").unwrap_or(&host_port_path);
78        let (host_port, path) = match host_port_path.split_once('/') {
79            Some((hp, p)) => (hp, format!("/{p}")),
80            None => (host_port_path, "/".to_string()),
81        };
82
83        let (host, port) = match host_port.rsplit_once(':') {
84            Some((h, p)) if p.parse::<u16>().is_ok() => {
85                let parsed_port = p.parse::<u16>().unwrap();
86                (h.to_string(), parsed_port)
87            }
88            _ => (
89                host_port.to_string(),
90                if scheme == "wss" { 443 } else { 80 },
91            ),
92        };
93
94        let mut cfg = Self {
95            scheme,
96            host: if host.is_empty() {
97                "0.0.0.0".to_string()
98            } else {
99                host
100            },
101            port,
102            path,
103            ..Self::default()
104        };
105
106        let params = parsed.params;
107        if let Some(v) = params
108            .get("maxConnections")
109            .and_then(|v| v.parse::<u32>().ok())
110        {
111            cfg.max_connections = v;
112        }
113        if let Some(v) = params
114            .get("maxMessageSize")
115            .and_then(|v| v.parse::<u32>().ok())
116        {
117            cfg.max_message_size = v;
118        }
119        if let Some(v) = params.get("sendToAll").and_then(|v| v.parse::<bool>().ok()) {
120            cfg.send_to_all = v;
121        }
122        if let Some(v) = params
123            .get("heartbeatIntervalMs")
124            .and_then(|v| v.parse::<u64>().ok())
125        {
126            cfg.heartbeat_interval = Duration::from_millis(v);
127        }
128        if let Some(v) = params
129            .get("idleTimeoutMs")
130            .and_then(|v| v.parse::<u64>().ok())
131        {
132            cfg.idle_timeout = Duration::from_millis(v);
133        }
134        if let Some(v) = params
135            .get("connectTimeoutMs")
136            .and_then(|v| v.parse::<u64>().ok())
137        {
138            cfg.connect_timeout = Duration::from_millis(v);
139        }
140        if let Some(v) = params
141            .get("responseTimeoutMs")
142            .and_then(|v| v.parse::<u64>().ok())
143        {
144            cfg.response_timeout = Duration::from_millis(v);
145        }
146        if let Some(v) = params.get("allowOrigin") {
147            cfg.allow_origin = v.to_string();
148        }
149        if let Some(v) = params.get("tlsCert") {
150            cfg.tls_cert = Some(v.to_string());
151        }
152        if let Some(v) = params.get("tlsKey") {
153            cfg.tls_key = Some(v.to_string());
154        }
155
156        Ok(cfg)
157    }
158
159    pub fn server_config(&self) -> WsServerConfig {
160        WsServerConfig {
161            inner: self.clone(),
162        }
163    }
164
165    pub fn client_config(&self) -> WsClientConfig {
166        WsClientConfig {
167            inner: self.clone(),
168        }
169    }
170
171    pub fn canonical_host(&self) -> String {
172        match self.host.as_str() {
173            "0.0.0.0" | "localhost" => "127.0.0.1".to_string(),
174            h => h.to_string(),
175        }
176    }
177}