Skip to main content

turbomcp_websocket/
config.rs

1//! Configuration types for WebSocket bidirectional transport
2//!
3//! This module provides configuration structures for WebSocket transport
4//! including connection settings, reconnection policies, and TLS configuration.
5
6use std::time::Duration;
7
8/// Configuration for WebSocket bidirectional transport
9#[derive(Clone, Debug)]
10pub struct WebSocketBidirectionalConfig {
11    /// WebSocket URL to connect to (client mode)
12    pub url: Option<String>,
13
14    /// Bind address for server mode
15    pub bind_addr: Option<String>,
16
17    /// Maximum message size (default: 16MB)
18    pub max_message_size: usize,
19
20    /// Keep-alive interval
21    pub keep_alive_interval: Duration,
22
23    /// Reconnection configuration
24    pub reconnect: ReconnectConfig,
25
26    /// Elicitation timeout
27    pub elicitation_timeout: Duration,
28
29    /// Maximum concurrent elicitations
30    pub max_concurrent_elicitations: usize,
31
32    /// Enable compression
33    pub enable_compression: bool,
34
35    /// TLS configuration
36    pub tls_config: Option<TlsConfig>,
37}
38
39impl Default for WebSocketBidirectionalConfig {
40    fn default() -> Self {
41        Self {
42            url: None,
43            bind_addr: None,
44            max_message_size: 16 * 1024 * 1024, // 16MB
45            keep_alive_interval: Duration::from_secs(30),
46            reconnect: ReconnectConfig::default(),
47            elicitation_timeout: Duration::from_secs(30),
48            max_concurrent_elicitations: 10,
49            enable_compression: false,
50            tls_config: None,
51        }
52    }
53}
54
55impl WebSocketBidirectionalConfig {
56    /// Create a new configuration with default values
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Create client configuration with URL
62    pub fn client(url: String) -> Self {
63        Self {
64            url: Some(url),
65            ..Self::default()
66        }
67    }
68
69    /// Create server configuration with bind address
70    pub fn server(bind_addr: String) -> Self {
71        Self {
72            bind_addr: Some(bind_addr),
73            ..Self::default()
74        }
75    }
76
77    /// Set maximum message size
78    pub fn with_max_message_size(mut self, size: usize) -> Self {
79        self.max_message_size = size;
80        self
81    }
82
83    /// Set keep-alive interval
84    pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self {
85        self.keep_alive_interval = interval;
86        self
87    }
88
89    /// Set reconnection configuration
90    pub fn with_reconnect_config(mut self, config: ReconnectConfig) -> Self {
91        self.reconnect = config;
92        self
93    }
94
95    /// Set elicitation timeout
96    pub fn with_elicitation_timeout(mut self, timeout: Duration) -> Self {
97        self.elicitation_timeout = timeout;
98        self
99    }
100
101    /// Set maximum concurrent elicitations
102    pub fn with_max_concurrent_elicitations(mut self, max: usize) -> Self {
103        self.max_concurrent_elicitations = max;
104        self
105    }
106
107    /// Enable compression
108    pub fn with_compression(mut self, enable: bool) -> Self {
109        self.enable_compression = enable;
110        self
111    }
112
113    /// Set TLS configuration
114    pub fn with_tls_config(mut self, tls_config: TlsConfig) -> Self {
115        self.tls_config = Some(tls_config);
116        self
117    }
118}
119
120/// Reconnection configuration
121#[derive(Clone, Debug)]
122pub struct ReconnectConfig {
123    /// Enable automatic reconnection
124    pub enabled: bool,
125
126    /// Initial retry delay
127    pub initial_delay: Duration,
128
129    /// Maximum retry delay
130    pub max_delay: Duration,
131
132    /// Exponential backoff factor
133    pub backoff_factor: f64,
134
135    /// Maximum number of retries
136    pub max_retries: u32,
137}
138
139impl Default for ReconnectConfig {
140    fn default() -> Self {
141        Self {
142            enabled: true,
143            initial_delay: Duration::from_millis(500),
144            max_delay: Duration::from_secs(30),
145            backoff_factor: 2.0,
146            max_retries: 10,
147        }
148    }
149}
150
151impl ReconnectConfig {
152    /// Create new reconnection configuration
153    pub fn new() -> Self {
154        Self::default()
155    }
156
157    /// Set whether reconnection is enabled
158    pub fn with_enabled(mut self, enabled: bool) -> Self {
159        self.enabled = enabled;
160        self
161    }
162
163    /// Set initial delay
164    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
165        self.initial_delay = delay;
166        self
167    }
168
169    /// Set maximum delay
170    pub fn with_max_delay(mut self, delay: Duration) -> Self {
171        self.max_delay = delay;
172        self
173    }
174
175    /// Set backoff factor
176    pub fn with_backoff_factor(mut self, factor: f64) -> Self {
177        self.backoff_factor = factor;
178        self
179    }
180
181    /// Set maximum retries
182    pub fn with_max_retries(mut self, retries: u32) -> Self {
183        self.max_retries = retries;
184        self
185    }
186}
187
188/// TLS configuration
189#[derive(Clone, Debug, Default)]
190pub struct TlsConfig {
191    /// Client certificate path
192    pub cert_path: Option<String>,
193
194    /// Client key path
195    pub key_path: Option<String>,
196
197    /// CA certificate path
198    pub ca_path: Option<String>,
199
200    /// Skip certificate verification (dangerous!)
201    pub skip_verify: bool,
202}
203
204impl TlsConfig {
205    /// Create new TLS configuration
206    pub fn new() -> Self {
207        Self::default()
208    }
209
210    /// Create TLS configuration with certificate and key
211    pub fn with_client_cert(cert_path: String, key_path: String) -> Self {
212        Self {
213            cert_path: Some(cert_path),
214            key_path: Some(key_path),
215            ..Self::default()
216        }
217    }
218
219    /// Create TLS configuration with CA certificate
220    pub fn with_ca_cert(ca_path: String) -> Self {
221        Self {
222            ca_path: Some(ca_path),
223            ..Self::default()
224        }
225    }
226
227    /// Create insecure TLS configuration (skip verification)
228    pub fn insecure() -> Self {
229        Self {
230            skip_verify: true,
231            ..Self::default()
232        }
233    }
234
235    /// Set certificate path
236    pub fn with_cert_path(mut self, path: String) -> Self {
237        self.cert_path = Some(path);
238        self
239    }
240
241    /// Set key path
242    pub fn with_key_path(mut self, path: String) -> Self {
243        self.key_path = Some(path);
244        self
245    }
246
247    /// Set CA certificate path
248    pub fn with_ca_path(mut self, path: String) -> Self {
249        self.ca_path = Some(path);
250        self
251    }
252
253    /// Set skip verification flag
254    pub fn with_skip_verify(mut self, skip: bool) -> Self {
255        self.skip_verify = skip;
256        self
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_websocket_config_default() {
266        let config = WebSocketBidirectionalConfig::default();
267        assert_eq!(config.max_message_size, 16 * 1024 * 1024);
268        assert_eq!(config.keep_alive_interval, Duration::from_secs(30));
269        assert_eq!(config.max_concurrent_elicitations, 10);
270        assert!(!config.enable_compression);
271    }
272
273    #[test]
274    fn test_websocket_config_client() {
275        let config = WebSocketBidirectionalConfig::client("ws://example.com".to_string());
276        assert_eq!(config.url, Some("ws://example.com".to_string()));
277        assert_eq!(config.bind_addr, None);
278    }
279
280    #[test]
281    fn test_websocket_config_server() {
282        let config = WebSocketBidirectionalConfig::server("0.0.0.0:8080".to_string());
283        assert_eq!(config.bind_addr, Some("0.0.0.0:8080".to_string()));
284        assert_eq!(config.url, None);
285    }
286
287    #[test]
288    fn test_websocket_config_builder() {
289        let config = WebSocketBidirectionalConfig::new()
290            .with_max_message_size(1024)
291            .with_keep_alive_interval(Duration::from_secs(60))
292            .with_compression(true)
293            .with_max_concurrent_elicitations(5);
294
295        assert_eq!(config.max_message_size, 1024);
296        assert_eq!(config.keep_alive_interval, Duration::from_secs(60));
297        assert!(config.enable_compression);
298        assert_eq!(config.max_concurrent_elicitations, 5);
299    }
300
301    #[test]
302    fn test_tls_config_presets() {
303        let client_cert =
304            TlsConfig::with_client_cert("cert.pem".to_string(), "key.pem".to_string());
305        assert_eq!(client_cert.cert_path, Some("cert.pem".to_string()));
306        assert_eq!(client_cert.key_path, Some("key.pem".to_string()));
307
308        let ca_cert = TlsConfig::with_ca_cert("ca.pem".to_string());
309        assert_eq!(ca_cert.ca_path, Some("ca.pem".to_string()));
310
311        let insecure = TlsConfig::insecure();
312        assert!(insecure.skip_verify);
313    }
314}