mcpkit_transport/websocket/
config.rs

1//! WebSocket transport configuration types.
2
3use std::time::Duration;
4
5/// Configuration for WebSocket transport.
6#[derive(Debug, Clone)]
7pub struct WebSocketConfig {
8    /// WebSocket URL (ws:// or wss://).
9    pub url: String,
10    /// Connection timeout.
11    pub connect_timeout: Duration,
12    /// Ping interval for keeping the connection alive.
13    pub ping_interval: Duration,
14    /// Pong timeout (how long to wait for pong after sending ping).
15    pub pong_timeout: Duration,
16    /// Maximum message size in bytes.
17    pub max_message_size: usize,
18    /// Whether to enable automatic reconnection.
19    pub auto_reconnect: bool,
20    /// Maximum reconnection attempts.
21    pub max_reconnect_attempts: u32,
22    /// Reconnection backoff configuration.
23    pub reconnect_backoff: ExponentialBackoff,
24    /// Additional WebSocket subprotocols.
25    pub subprotocols: Vec<String>,
26    /// Custom headers for the WebSocket handshake.
27    pub headers: Vec<(String, String)>,
28    /// Allowed origins for DNS rebinding protection (server-side).
29    /// If empty, origin validation is disabled.
30    pub allowed_origins: Vec<String>,
31}
32
33impl WebSocketConfig {
34    /// Create a new WebSocket configuration.
35    #[must_use]
36    pub fn new(url: impl Into<String>) -> Self {
37        Self {
38            url: url.into(),
39            connect_timeout: Duration::from_secs(30),
40            ping_interval: Duration::from_secs(30),
41            pong_timeout: Duration::from_secs(10),
42            max_message_size: 16 * 1024 * 1024, // 16 MB
43            auto_reconnect: true,
44            max_reconnect_attempts: 10,
45            reconnect_backoff: ExponentialBackoff::default(),
46            subprotocols: vec!["mcp".to_string()],
47            headers: Vec::new(),
48            allowed_origins: Vec::new(),
49        }
50    }
51
52    /// Add an allowed origin for DNS rebinding protection.
53    ///
54    /// When origins are configured, the server will reject WebSocket
55    /// connections from origins not in the list. This helps prevent
56    /// DNS rebinding attacks.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use mcpkit_transport::websocket::WebSocketConfig;
62    ///
63    /// let config = WebSocketConfig::new("ws://localhost:8080/mcp")
64    ///     .with_allowed_origin("https://trusted-app.com")
65    ///     .with_allowed_origin("https://another-trusted-app.com");
66    /// ```
67    #[must_use]
68    pub fn with_allowed_origin(mut self, origin: impl Into<String>) -> Self {
69        self.allowed_origins.push(origin.into());
70        self
71    }
72
73    /// Set multiple allowed origins at once.
74    #[must_use]
75    pub fn with_allowed_origins(
76        mut self,
77        origins: impl IntoIterator<Item = impl Into<String>>,
78    ) -> Self {
79        self.allowed_origins
80            .extend(origins.into_iter().map(Into::into));
81        self
82    }
83
84    /// Check if an origin is allowed.
85    ///
86    /// Returns `true` if:
87    /// - No origins are configured (origin validation disabled)
88    /// - The origin is in the allowed list
89    #[must_use]
90    pub fn is_origin_allowed(&self, origin: &str) -> bool {
91        self.allowed_origins.is_empty() || self.allowed_origins.iter().any(|o| o == origin)
92    }
93
94    /// Set the connection timeout.
95    #[must_use]
96    pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
97        self.connect_timeout = timeout;
98        self
99    }
100
101    /// Set the ping interval.
102    #[must_use]
103    pub const fn with_ping_interval(mut self, interval: Duration) -> Self {
104        self.ping_interval = interval;
105        self
106    }
107
108    /// Set the pong timeout.
109    #[must_use]
110    pub const fn with_pong_timeout(mut self, timeout: Duration) -> Self {
111        self.pong_timeout = timeout;
112        self
113    }
114
115    /// Set the maximum message size.
116    #[must_use]
117    pub const fn with_max_message_size(mut self, size: usize) -> Self {
118        self.max_message_size = size;
119        self
120    }
121
122    /// Disable automatic reconnection.
123    #[must_use]
124    pub const fn without_auto_reconnect(mut self) -> Self {
125        self.auto_reconnect = false;
126        self
127    }
128
129    /// Set maximum reconnection attempts.
130    #[must_use]
131    pub const fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
132        self.max_reconnect_attempts = attempts;
133        self
134    }
135
136    /// Add a WebSocket subprotocol.
137    #[must_use]
138    pub fn with_subprotocol(mut self, protocol: impl Into<String>) -> Self {
139        self.subprotocols.push(protocol.into());
140        self
141    }
142
143    /// Add a custom header for the WebSocket handshake.
144    #[must_use]
145    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
146        self.headers.push((name.into(), value.into()));
147        self
148    }
149}
150
151impl Default for WebSocketConfig {
152    fn default() -> Self {
153        Self::new("ws://localhost:8080/mcp")
154    }
155}
156
157/// Exponential backoff configuration for reconnection.
158#[derive(Debug, Clone)]
159pub struct ExponentialBackoff {
160    /// Initial delay.
161    pub initial_delay: Duration,
162    /// Maximum delay.
163    pub max_delay: Duration,
164    /// Multiplier for each attempt.
165    pub multiplier: f64,
166}
167
168impl ExponentialBackoff {
169    /// Create a new exponential backoff configuration.
170    #[must_use]
171    pub const fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
172        Self {
173            initial_delay,
174            max_delay,
175            multiplier,
176        }
177    }
178
179    /// Calculate the delay for a given attempt number.
180    #[must_use]
181    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
182        let delay_ms = self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
183        let delay = Duration::from_millis(delay_ms as u64);
184        std::cmp::min(delay, self.max_delay)
185    }
186}
187
188impl Default for ExponentialBackoff {
189    fn default() -> Self {
190        Self {
191            initial_delay: Duration::from_millis(100),
192            max_delay: Duration::from_secs(30),
193            multiplier: 2.0,
194        }
195    }
196}
197
198/// Connection state for WebSocket.
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum ConnectionState {
201    /// Not connected.
202    Disconnected,
203    /// Currently connecting.
204    Connecting,
205    /// Connected and ready.
206    Connected,
207    /// Reconnecting after a failure.
208    Reconnecting,
209    /// Connection closed.
210    Closed,
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_config_builder() {
219        let config = WebSocketConfig::new("ws://example.com/mcp")
220            .with_connect_timeout(Duration::from_secs(10))
221            .with_ping_interval(Duration::from_secs(15))
222            .with_pong_timeout(Duration::from_secs(5))
223            .with_max_message_size(1024 * 1024)
224            .with_subprotocol("custom")
225            .with_header("Authorization", "Bearer token");
226
227        assert_eq!(config.url, "ws://example.com/mcp");
228        assert_eq!(config.connect_timeout, Duration::from_secs(10));
229        assert_eq!(config.ping_interval, Duration::from_secs(15));
230        assert_eq!(config.pong_timeout, Duration::from_secs(5));
231        assert_eq!(config.max_message_size, 1024 * 1024);
232        assert!(config.subprotocols.contains(&"custom".to_string()));
233        assert_eq!(config.headers.len(), 1);
234    }
235
236    #[test]
237    fn test_origin_validation_empty_allows_all() {
238        let config = WebSocketConfig::new("ws://example.com/mcp");
239        // Empty allowed_origins means validation is disabled (allow all)
240        assert!(config.is_origin_allowed("https://anything.com"));
241        assert!(config.is_origin_allowed("http://malicious.com"));
242        assert!(config.is_origin_allowed(""));
243    }
244
245    #[test]
246    fn test_origin_validation_with_allowed_origins() {
247        let config = WebSocketConfig::new("ws://example.com/mcp")
248            .with_allowed_origin("https://trusted-app.com")
249            .with_allowed_origin("https://another-trusted.com");
250
251        // Allowed origins should pass
252        assert!(config.is_origin_allowed("https://trusted-app.com"));
253        assert!(config.is_origin_allowed("https://another-trusted.com"));
254
255        // Non-allowed origins should fail
256        assert!(!config.is_origin_allowed("https://malicious.com"));
257        assert!(!config.is_origin_allowed("http://trusted-app.com")); // Different scheme
258        assert!(!config.is_origin_allowed("https://trusted-app.com.evil.com")); // Subdomain attack
259    }
260
261    #[test]
262    fn test_origin_validation_with_multiple_origins() {
263        let origins = vec!["https://app1.com", "https://app2.com", "https://app3.com"];
264        let config = WebSocketConfig::new("ws://example.com/mcp").with_allowed_origins(origins);
265
266        assert!(config.is_origin_allowed("https://app1.com"));
267        assert!(config.is_origin_allowed("https://app2.com"));
268        assert!(config.is_origin_allowed("https://app3.com"));
269        assert!(!config.is_origin_allowed("https://app4.com"));
270    }
271
272    #[test]
273    fn test_exponential_backoff() {
274        let backoff =
275            ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(10), 2.0);
276
277        assert_eq!(backoff.delay_for_attempt(0), Duration::from_millis(100));
278        assert_eq!(backoff.delay_for_attempt(1), Duration::from_millis(200));
279        assert_eq!(backoff.delay_for_attempt(2), Duration::from_millis(400));
280        assert_eq!(backoff.delay_for_attempt(3), Duration::from_millis(800));
281
282        // Should be capped at max_delay
283        assert_eq!(backoff.delay_for_attempt(10), Duration::from_secs(10));
284    }
285
286    #[test]
287    fn test_default_backoff() {
288        let backoff = ExponentialBackoff::default();
289        assert_eq!(backoff.initial_delay, Duration::from_millis(100));
290        assert_eq!(backoff.max_delay, Duration::from_secs(30));
291        assert!((backoff.multiplier - 2.0).abs() < f64::EPSILON);
292    }
293}