pulseengine_mcp_auth/transport/
websocket_auth.rs

1//! WebSocket Transport Authentication
2//!
3//! This module provides authentication for WebSocket-based MCP servers,
4//! handling both connection-time and per-message authentication.
5
6use super::auth_extractors::{
7    AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8    TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14/// Configuration for WebSocket authentication
15#[derive(Debug, Clone)]
16pub struct WebSocketAuthConfig {
17    /// Require authentication during WebSocket handshake
18    pub require_handshake_auth: bool,
19
20    /// Allow authentication after connection (first message)
21    pub allow_post_connect_auth: bool,
22
23    /// Supported authentication methods
24    pub supported_methods: Vec<WebSocketAuthMethod>,
25
26    /// Enable per-message authentication
27    pub enable_per_message_auth: bool,
28
29    /// WebSocket subprotocol for authentication
30    pub auth_subprotocol: Option<String>,
31
32    /// Connection timeout for authentication (seconds)
33    pub auth_timeout_secs: u64,
34}
35
36impl Default for WebSocketAuthConfig {
37    fn default() -> Self {
38        Self {
39            require_handshake_auth: true,
40            allow_post_connect_auth: true,
41            supported_methods: vec![
42                WebSocketAuthMethod::HandshakeHeaders,
43                WebSocketAuthMethod::QueryParams,
44                WebSocketAuthMethod::FirstMessage,
45            ],
46            enable_per_message_auth: false,
47            auth_subprotocol: Some("mcp-auth".to_string()),
48            auth_timeout_secs: 30,
49        }
50    }
51}
52
53/// WebSocket authentication methods
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum WebSocketAuthMethod {
56    /// Authentication via handshake headers
57    HandshakeHeaders,
58
59    /// Authentication via query parameters
60    QueryParams,
61
62    /// Authentication via first message
63    FirstMessage,
64
65    /// Authentication via subprotocol
66    Subprotocol,
67
68    /// Per-message authentication
69    PerMessage,
70}
71
72/// WebSocket authentication extractor
73pub struct WebSocketAuthExtractor {
74    config: WebSocketAuthConfig,
75}
76
77impl WebSocketAuthExtractor {
78    /// Create a new WebSocket authentication extractor
79    pub fn new(config: WebSocketAuthConfig) -> Self {
80        Self { config }
81    }
82
83    /// Create with default configuration
84    pub fn default() -> Self {
85        Self::new(WebSocketAuthConfig::default())
86    }
87
88    /// Extract authentication from WebSocket handshake headers
89    fn extract_handshake_headers(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
90        if !self
91            .config
92            .supported_methods
93            .contains(&WebSocketAuthMethod::HandshakeHeaders)
94        {
95            return Ok(None);
96        }
97
98        // Try Authorization header
99        if let Some(auth_header) = headers
100            .get("Authorization")
101            .or_else(|| headers.get("authorization"))
102        {
103            if auth_header.starts_with("Bearer ") {
104                match AuthUtils::extract_bearer_token(auth_header) {
105                    Ok(token) => {
106                        AuthUtils::validate_api_key_format(&token)?;
107                        let context = TransportAuthContext::new(
108                            token,
109                            "HandshakeHeaders".to_string(),
110                            TransportType::WebSocket,
111                        );
112                        return Ok(Some(context));
113                    }
114                    Err(e) => return Err(e),
115                }
116            }
117        }
118
119        // Try X-API-Key header
120        if let Some(api_key) = AuthUtils::extract_api_key_header(headers) {
121            AuthUtils::validate_api_key_format(&api_key)?;
122            let context = TransportAuthContext::new(
123                api_key,
124                "HandshakeHeaders".to_string(),
125                TransportType::WebSocket,
126            );
127            return Ok(Some(context));
128        }
129
130        // Try WebSocket-specific headers
131        if let Some(api_key) = headers.get("Sec-WebSocket-Protocol") {
132            if let Some(auth_token) = self.extract_from_subprotocol(api_key) {
133                AuthUtils::validate_api_key_format(&auth_token)?;
134                let context = TransportAuthContext::new(
135                    auth_token,
136                    "Subprotocol".to_string(),
137                    TransportType::WebSocket,
138                );
139                return Ok(Some(context));
140            }
141        }
142
143        Ok(None)
144    }
145
146    /// Extract authentication from query parameters (during handshake)
147    fn extract_query_params(&self, request: &TransportRequest) -> AuthExtractionResult {
148        if !self
149            .config
150            .supported_methods
151            .contains(&WebSocketAuthMethod::QueryParams)
152        {
153            return Ok(None);
154        }
155
156        // Try common query parameter names
157        for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
158            if let Some(api_key) = request.get_query_param(param_name) {
159                AuthUtils::validate_api_key_format(api_key)?;
160                let context = TransportAuthContext::new(
161                    api_key.clone(),
162                    "QueryParams".to_string(),
163                    TransportType::WebSocket,
164                );
165                return Ok(Some(context));
166            }
167        }
168
169        Ok(None)
170    }
171
172    /// Extract authentication from first WebSocket message
173    fn extract_first_message(&self, request: &TransportRequest) -> AuthExtractionResult {
174        if !self
175            .config
176            .supported_methods
177            .contains(&WebSocketAuthMethod::FirstMessage)
178        {
179            return Ok(None);
180        }
181
182        if let Some(body) = &request.body {
183            // Look for authentication in message
184            if let Some(auth_data) = self.find_auth_in_message(body) {
185                AuthUtils::validate_api_key_format(&auth_data)?;
186                let context = TransportAuthContext::new(
187                    auth_data,
188                    "FirstMessage".to_string(),
189                    TransportType::WebSocket,
190                );
191                return Ok(Some(context));
192            }
193        }
194
195        Ok(None)
196    }
197
198    /// Extract authentication token from WebSocket subprotocol
199    fn extract_from_subprotocol(&self, subprotocol: &str) -> Option<String> {
200        // Format: "mcp-auth.TOKEN" or "mcp-auth-TOKEN"
201        if let Some(auth_protocol) = &self.config.auth_subprotocol {
202            let prefix = format!("{}.", auth_protocol);
203            if let Some(token) = subprotocol.strip_prefix(&prefix) {
204                return Some(token.to_string());
205            }
206
207            let prefix_dash = format!("{}-", auth_protocol);
208            if let Some(token) = subprotocol.strip_prefix(&prefix_dash) {
209                return Some(token.to_string());
210            }
211        }
212
213        None
214    }
215
216    /// Find authentication data in WebSocket message
217    fn find_auth_in_message(&self, message: &Value) -> Option<String> {
218        // Try direct auth field
219        if let Some(auth) = message.get("auth") {
220            if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
221                return Some(api_key.to_string());
222            }
223            if let Some(token) = auth.get("token").and_then(|v| v.as_str()) {
224                return Some(token.to_string());
225            }
226        }
227
228        // Try in params (for MCP initialize)
229        if let Some(params) = message.get("params") {
230            if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
231                return Some(api_key.to_string());
232            }
233
234            // Try nested in clientInfo
235            if let Some(client_info) = params.get("clientInfo") {
236                if let Some(auth) = client_info.get("authentication") {
237                    if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
238                        return Some(api_key.to_string());
239                    }
240                }
241            }
242        }
243
244        // Try root level for simple auth messages
245        if let Some(api_key) = message.get("api_key").and_then(|v| v.as_str()) {
246            return Some(api_key.to_string());
247        }
248
249        None
250    }
251
252    /// Add WebSocket-specific context information
253    fn enrich_context(
254        &self,
255        mut context: TransportAuthContext,
256        request: &TransportRequest,
257    ) -> TransportAuthContext {
258        // Add client IP
259        if let Some(client_ip) = AuthUtils::extract_client_ip(&request.headers) {
260            context = context.with_client_ip(client_ip);
261        }
262
263        // Add user agent
264        if let Some(user_agent) = AuthUtils::extract_user_agent(&request.headers) {
265            context = context.with_user_agent(user_agent);
266        }
267
268        // Add WebSocket-specific metadata
269        if let Some(origin) = request.get_header("Origin") {
270            context = context.with_metadata("origin".to_string(), origin.clone());
271        }
272
273        if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
274            context = context.with_metadata("protocols".to_string(), protocols.clone());
275        }
276
277        if let Some(version) = request.get_header("Sec-WebSocket-Version") {
278            context = context.with_metadata("ws_version".to_string(), version.clone());
279        }
280
281        context
282    }
283
284    /// Check if WebSocket handshake contains authentication
285    pub fn has_handshake_auth(&self, request: &TransportRequest) -> bool {
286        // Check headers for auth
287        if request.headers.contains_key("Authorization")
288            || AuthUtils::extract_api_key_header(&request.headers).is_some()
289        {
290            return true;
291        }
292
293        // Check query params for auth
294        for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
295            if request.query_params.contains_key(*param_name) {
296                return true;
297            }
298        }
299
300        // Check subprotocol for auth
301        if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
302            if let Some(auth_protocol) = &self.config.auth_subprotocol {
303                if protocols.contains(auth_protocol) {
304                    return true;
305                }
306            }
307        }
308
309        false
310    }
311}
312
313#[async_trait]
314impl AuthExtractor for WebSocketAuthExtractor {
315    async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
316        // Try different authentication methods in order of preference
317
318        // 1. Handshake headers
319        if let Ok(Some(context)) = self.extract_handshake_headers(&request.headers) {
320            return Ok(Some(self.enrich_context(context, request)));
321        }
322
323        // 2. Query parameters
324        if let Ok(Some(context)) = self.extract_query_params(request) {
325            return Ok(Some(self.enrich_context(context, request)));
326        }
327
328        // 3. First message (if body is present)
329        if let Ok(Some(context)) = self.extract_first_message(request) {
330            return Ok(Some(self.enrich_context(context, request)));
331        }
332
333        // No authentication found
334        if self.config.require_handshake_auth && !self.config.allow_post_connect_auth {
335            return Err(TransportAuthError::NoAuth);
336        }
337
338        Ok(None)
339    }
340
341    fn transport_type(&self) -> TransportType {
342        TransportType::WebSocket
343    }
344
345    fn can_handle(&self, request: &TransportRequest) -> bool {
346        // Check for WebSocket-specific headers
347        request.headers.contains_key("Sec-WebSocket-Key")
348            || request.headers.contains_key("Upgrade")
349            || request.metadata.contains_key("websocket")
350    }
351
352    async fn validate_auth(
353        &self,
354        context: &TransportAuthContext,
355    ) -> Result<(), TransportAuthError> {
356        // WebSocket-specific validation
357        if context.credential.is_empty() {
358            return Err(TransportAuthError::InvalidFormat(
359                "Empty credential".to_string(),
360            ));
361        }
362
363        // Warn about insecure authentication methods
364        if context.method == "QueryParams" {
365            tracing::warn!("WebSocket authentication via query parameters is less secure - consider using headers");
366        }
367
368        Ok(())
369    }
370}
371
372/// Helper for creating WebSocket authentication configuration
373impl WebSocketAuthConfig {
374    /// Create a secure configuration
375    pub fn secure() -> Self {
376        Self {
377            require_handshake_auth: true,
378            allow_post_connect_auth: false,
379            supported_methods: vec![WebSocketAuthMethod::HandshakeHeaders],
380            enable_per_message_auth: false,
381            auth_subprotocol: Some("mcp-auth".to_string()),
382            auth_timeout_secs: 10,
383        }
384    }
385
386    /// Create a flexible configuration
387    pub fn flexible() -> Self {
388        Self {
389            require_handshake_auth: false,
390            allow_post_connect_auth: true,
391            supported_methods: vec![
392                WebSocketAuthMethod::HandshakeHeaders,
393                WebSocketAuthMethod::QueryParams,
394                WebSocketAuthMethod::FirstMessage,
395            ],
396            enable_per_message_auth: false,
397            auth_subprotocol: Some("mcp-auth".to_string()),
398            auth_timeout_secs: 30,
399        }
400    }
401
402    /// Create a development-friendly configuration
403    pub fn development() -> Self {
404        Self {
405            require_handshake_auth: false,
406            allow_post_connect_auth: true,
407            supported_methods: vec![
408                WebSocketAuthMethod::HandshakeHeaders,
409                WebSocketAuthMethod::QueryParams,
410                WebSocketAuthMethod::FirstMessage,
411                WebSocketAuthMethod::Subprotocol,
412            ],
413            enable_per_message_auth: false,
414            auth_subprotocol: Some("mcp-auth".to_string()),
415            auth_timeout_secs: 60,
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use serde_json::json;
424
425    #[test]
426    fn test_handshake_header_extraction() {
427        let extractor = WebSocketAuthExtractor::default();
428        let mut headers = HashMap::new();
429        headers.insert(
430            "Authorization".to_string(),
431            "Bearer lmcp_test_1234567890abcdef".to_string(),
432        );
433        headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
434
435        let request = TransportRequest::from_headers(headers);
436        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
437
438        assert!(result.is_some());
439        let context = result.unwrap();
440        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
441        assert_eq!(context.method, "HandshakeHeaders");
442        assert_eq!(context.transport_type, TransportType::WebSocket);
443    }
444
445    #[test]
446    fn test_query_parameter_extraction() {
447        let extractor = WebSocketAuthExtractor::default();
448        let request = TransportRequest::new()
449            .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
450            .with_query_param(
451                "api_key".to_string(),
452                "lmcp_test_1234567890abcdef".to_string(),
453            );
454
455        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
456
457        assert!(result.is_some());
458        let context = result.unwrap();
459        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
460        assert_eq!(context.method, "QueryParams");
461    }
462
463    #[test]
464    fn test_first_message_extraction() {
465        let extractor = WebSocketAuthExtractor::default();
466
467        let auth_message = json!({
468            "auth": {
469                "api_key": "lmcp_test_1234567890abcdef"
470            }
471        });
472
473        let request = TransportRequest::new()
474            .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
475            .with_body(auth_message);
476
477        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
478
479        assert!(result.is_some());
480        let context = result.unwrap();
481        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
482        assert_eq!(context.method, "FirstMessage");
483    }
484
485    #[test]
486    fn test_subprotocol_extraction() {
487        let extractor = WebSocketAuthExtractor::default();
488        let mut headers = HashMap::new();
489        headers.insert(
490            "Sec-WebSocket-Protocol".to_string(),
491            "mcp-auth.lmcp_test_1234567890abcdef".to_string(),
492        );
493        headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
494
495        let request = TransportRequest::from_headers(headers);
496        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
497
498        assert!(result.is_some());
499        let context = result.unwrap();
500        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
501        assert_eq!(context.method, "Subprotocol");
502    }
503
504    #[test]
505    fn test_mcp_initialize_message() {
506        let extractor = WebSocketAuthExtractor::default();
507
508        let init_message = json!({
509            "method": "initialize",
510            "params": {
511                "clientInfo": {
512                    "name": "test-client",
513                    "authentication": {
514                        "api_key": "lmcp_test_1234567890abcdef"
515                    }
516                }
517            }
518        });
519
520        let request = TransportRequest::new()
521            .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
522            .with_body(init_message);
523
524        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
525
526        assert!(result.is_some());
527        let context = result.unwrap();
528        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
529        assert_eq!(context.method, "FirstMessage");
530    }
531
532    #[test]
533    fn test_has_handshake_auth() {
534        let extractor = WebSocketAuthExtractor::default();
535
536        // Test with Authorization header
537        let request1 = TransportRequest::new()
538            .with_header("Authorization".to_string(), "Bearer token123".to_string());
539        assert!(extractor.has_handshake_auth(&request1));
540
541        // Test with query parameter
542        let request2 =
543            TransportRequest::new().with_query_param("api_key".to_string(), "token123".to_string());
544        assert!(extractor.has_handshake_auth(&request2));
545
546        // Test without auth
547        let request3 = TransportRequest::new();
548        assert!(!extractor.has_handshake_auth(&request3));
549    }
550
551    #[test]
552    fn test_configuration_presets() {
553        let secure_config = WebSocketAuthConfig::secure();
554        assert!(secure_config.require_handshake_auth);
555        assert!(!secure_config.allow_post_connect_auth);
556        assert_eq!(secure_config.auth_timeout_secs, 10);
557
558        let flexible_config = WebSocketAuthConfig::flexible();
559        assert!(!flexible_config.require_handshake_auth);
560        assert!(flexible_config.allow_post_connect_auth);
561
562        let dev_config = WebSocketAuthConfig::development();
563        assert!(!dev_config.require_handshake_auth);
564        assert_eq!(dev_config.auth_timeout_secs, 60);
565    }
566}