mcpkit_transport/websocket/
config.rs1use std::time::Duration;
4
5#[derive(Debug, Clone)]
7pub struct WebSocketConfig {
8 pub url: String,
10 pub connect_timeout: Duration,
12 pub ping_interval: Duration,
14 pub pong_timeout: Duration,
16 pub max_message_size: usize,
18 pub auto_reconnect: bool,
20 pub max_reconnect_attempts: u32,
22 pub reconnect_backoff: ExponentialBackoff,
24 pub subprotocols: Vec<String>,
26 pub headers: Vec<(String, String)>,
28 pub allowed_origins: Vec<String>,
31}
32
33impl WebSocketConfig {
34 #[must_use]
36 pub fn new(url: impl Into<String>) -> Self {
37 Self {
38 url: url.into(),
39 connect_timeout: Duration::from_secs(30),
40 ping_interval: Duration::from_secs(30),
41 pong_timeout: Duration::from_secs(10),
42 max_message_size: 16 * 1024 * 1024, auto_reconnect: true,
44 max_reconnect_attempts: 10,
45 reconnect_backoff: ExponentialBackoff::default(),
46 subprotocols: vec!["mcp".to_string()],
47 headers: Vec::new(),
48 allowed_origins: Vec::new(),
49 }
50 }
51
52 #[must_use]
68 pub fn with_allowed_origin(mut self, origin: impl Into<String>) -> Self {
69 self.allowed_origins.push(origin.into());
70 self
71 }
72
73 #[must_use]
75 pub fn with_allowed_origins(
76 mut self,
77 origins: impl IntoIterator<Item = impl Into<String>>,
78 ) -> Self {
79 self.allowed_origins
80 .extend(origins.into_iter().map(Into::into));
81 self
82 }
83
84 #[must_use]
90 pub fn is_origin_allowed(&self, origin: &str) -> bool {
91 self.allowed_origins.is_empty() || self.allowed_origins.iter().any(|o| o == origin)
92 }
93
94 #[must_use]
96 pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
97 self.connect_timeout = timeout;
98 self
99 }
100
101 #[must_use]
103 pub const fn with_ping_interval(mut self, interval: Duration) -> Self {
104 self.ping_interval = interval;
105 self
106 }
107
108 #[must_use]
110 pub const fn with_pong_timeout(mut self, timeout: Duration) -> Self {
111 self.pong_timeout = timeout;
112 self
113 }
114
115 #[must_use]
117 pub const fn with_max_message_size(mut self, size: usize) -> Self {
118 self.max_message_size = size;
119 self
120 }
121
122 #[must_use]
124 pub const fn without_auto_reconnect(mut self) -> Self {
125 self.auto_reconnect = false;
126 self
127 }
128
129 #[must_use]
131 pub const fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
132 self.max_reconnect_attempts = attempts;
133 self
134 }
135
136 #[must_use]
138 pub fn with_subprotocol(mut self, protocol: impl Into<String>) -> Self {
139 self.subprotocols.push(protocol.into());
140 self
141 }
142
143 #[must_use]
145 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
146 self.headers.push((name.into(), value.into()));
147 self
148 }
149}
150
151impl Default for WebSocketConfig {
152 fn default() -> Self {
153 Self::new("ws://localhost:8080/mcp")
154 }
155}
156
157#[derive(Debug, Clone)]
159pub struct ExponentialBackoff {
160 pub initial_delay: Duration,
162 pub max_delay: Duration,
164 pub multiplier: f64,
166}
167
168impl ExponentialBackoff {
169 #[must_use]
171 pub const fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
172 Self {
173 initial_delay,
174 max_delay,
175 multiplier,
176 }
177 }
178
179 #[must_use]
181 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
182 let delay_ms = self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
183 let delay = Duration::from_millis(delay_ms as u64);
184 std::cmp::min(delay, self.max_delay)
185 }
186}
187
188impl Default for ExponentialBackoff {
189 fn default() -> Self {
190 Self {
191 initial_delay: Duration::from_millis(100),
192 max_delay: Duration::from_secs(30),
193 multiplier: 2.0,
194 }
195 }
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum ConnectionState {
201 Disconnected,
203 Connecting,
205 Connected,
207 Reconnecting,
209 Closed,
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_config_builder() {
219 let config = WebSocketConfig::new("ws://example.com/mcp")
220 .with_connect_timeout(Duration::from_secs(10))
221 .with_ping_interval(Duration::from_secs(15))
222 .with_pong_timeout(Duration::from_secs(5))
223 .with_max_message_size(1024 * 1024)
224 .with_subprotocol("custom")
225 .with_header("Authorization", "Bearer token");
226
227 assert_eq!(config.url, "ws://example.com/mcp");
228 assert_eq!(config.connect_timeout, Duration::from_secs(10));
229 assert_eq!(config.ping_interval, Duration::from_secs(15));
230 assert_eq!(config.pong_timeout, Duration::from_secs(5));
231 assert_eq!(config.max_message_size, 1024 * 1024);
232 assert!(config.subprotocols.contains(&"custom".to_string()));
233 assert_eq!(config.headers.len(), 1);
234 }
235
236 #[test]
237 fn test_origin_validation_empty_allows_all() {
238 let config = WebSocketConfig::new("ws://example.com/mcp");
239 assert!(config.is_origin_allowed("https://anything.com"));
241 assert!(config.is_origin_allowed("http://malicious.com"));
242 assert!(config.is_origin_allowed(""));
243 }
244
245 #[test]
246 fn test_origin_validation_with_allowed_origins() {
247 let config = WebSocketConfig::new("ws://example.com/mcp")
248 .with_allowed_origin("https://trusted-app.com")
249 .with_allowed_origin("https://another-trusted.com");
250
251 assert!(config.is_origin_allowed("https://trusted-app.com"));
253 assert!(config.is_origin_allowed("https://another-trusted.com"));
254
255 assert!(!config.is_origin_allowed("https://malicious.com"));
257 assert!(!config.is_origin_allowed("http://trusted-app.com")); assert!(!config.is_origin_allowed("https://trusted-app.com.evil.com")); }
260
261 #[test]
262 fn test_origin_validation_with_multiple_origins() {
263 let origins = vec!["https://app1.com", "https://app2.com", "https://app3.com"];
264 let config = WebSocketConfig::new("ws://example.com/mcp").with_allowed_origins(origins);
265
266 assert!(config.is_origin_allowed("https://app1.com"));
267 assert!(config.is_origin_allowed("https://app2.com"));
268 assert!(config.is_origin_allowed("https://app3.com"));
269 assert!(!config.is_origin_allowed("https://app4.com"));
270 }
271
272 #[test]
273 fn test_exponential_backoff() {
274 let backoff =
275 ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(10), 2.0);
276
277 assert_eq!(backoff.delay_for_attempt(0), Duration::from_millis(100));
278 assert_eq!(backoff.delay_for_attempt(1), Duration::from_millis(200));
279 assert_eq!(backoff.delay_for_attempt(2), Duration::from_millis(400));
280 assert_eq!(backoff.delay_for_attempt(3), Duration::from_millis(800));
281
282 assert_eq!(backoff.delay_for_attempt(10), Duration::from_secs(10));
284 }
285
286 #[test]
287 fn test_default_backoff() {
288 let backoff = ExponentialBackoff::default();
289 assert_eq!(backoff.initial_delay, Duration::from_millis(100));
290 assert_eq!(backoff.max_delay, Duration::from_secs(30));
291 assert!((backoff.multiplier - 2.0).abs() < f64::EPSILON);
292 }
293}