pulseengine_mcp_auth/transport/
websocket_auth.rs

1//! WebSocket Transport Authentication
2//!
3//! This module provides authentication for WebSocket-based MCP servers,
4//! handling both connection-time and per-message authentication.
5
6use super::auth_extractors::{
7    AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8    TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14/// Configuration for WebSocket authentication
15#[derive(Debug, Clone)]
16pub struct WebSocketAuthConfig {
17    /// Require authentication during WebSocket handshake
18    pub require_handshake_auth: bool,
19
20    /// Allow authentication after connection (first message)
21    pub allow_post_connect_auth: bool,
22
23    /// Supported authentication methods
24    pub supported_methods: Vec<WebSocketAuthMethod>,
25
26    /// Enable per-message authentication
27    pub enable_per_message_auth: bool,
28
29    /// WebSocket subprotocol for authentication
30    pub auth_subprotocol: Option<String>,
31
32    /// Connection timeout for authentication (seconds)
33    pub auth_timeout_secs: u64,
34}
35
36impl Default for WebSocketAuthConfig {
37    fn default() -> Self {
38        Self {
39            require_handshake_auth: true,
40            allow_post_connect_auth: true,
41            supported_methods: vec![
42                WebSocketAuthMethod::HandshakeHeaders,
43                WebSocketAuthMethod::QueryParams,
44                WebSocketAuthMethod::FirstMessage,
45            ],
46            enable_per_message_auth: false,
47            auth_subprotocol: Some("mcp-auth".to_string()),
48            auth_timeout_secs: 30,
49        }
50    }
51}
52
53/// WebSocket authentication methods
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum WebSocketAuthMethod {
56    /// Authentication via handshake headers
57    HandshakeHeaders,
58
59    /// Authentication via query parameters
60    QueryParams,
61
62    /// Authentication via first message
63    FirstMessage,
64
65    /// Authentication via subprotocol
66    Subprotocol,
67
68    /// Per-message authentication
69    PerMessage,
70}
71
72/// WebSocket authentication extractor
73pub struct WebSocketAuthExtractor {
74    config: WebSocketAuthConfig,
75}
76
77impl WebSocketAuthExtractor {
78    /// Create a new WebSocket authentication extractor
79    pub fn new(config: WebSocketAuthConfig) -> Self {
80        Self { config }
81    }
82
83    /// Create with default configuration
84    pub fn default() -> Self {
85        Self::new(WebSocketAuthConfig::default())
86    }
87
88    /// Extract authentication from WebSocket handshake headers
89    fn extract_handshake_headers(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
90        if !self
91            .config
92            .supported_methods
93            .contains(&WebSocketAuthMethod::HandshakeHeaders)
94        {
95            return Ok(None);
96        }
97
98        // Try Authorization header
99        if let Some(auth_header) = headers
100            .get("Authorization")
101            .or_else(|| headers.get("authorization"))
102        {
103            if auth_header.starts_with("Bearer ") {
104                match AuthUtils::extract_bearer_token(auth_header) {
105                    Ok(token) => {
106                        AuthUtils::validate_api_key_format(&token)?;
107                        let context = TransportAuthContext::new(
108                            token,
109                            "HandshakeHeaders".to_string(),
110                            TransportType::WebSocket,
111                        );
112                        return Ok(Some(context));
113                    }
114                    Err(e) => return Err(e),
115                }
116            }
117        }
118
119        // Try X-API-Key header
120        if let Some(api_key) = AuthUtils::extract_api_key_header(headers) {
121            AuthUtils::validate_api_key_format(&api_key)?;
122            let context = TransportAuthContext::new(
123                api_key,
124                "HandshakeHeaders".to_string(),
125                TransportType::WebSocket,
126            );
127            return Ok(Some(context));
128        }
129
130        // Try WebSocket-specific headers
131        if let Some(api_key) = headers.get("Sec-WebSocket-Protocol") {
132            if let Some(auth_token) = self.extract_from_subprotocol(api_key) {
133                AuthUtils::validate_api_key_format(&auth_token)?;
134                let context = TransportAuthContext::new(
135                    auth_token,
136                    "Subprotocol".to_string(),
137                    TransportType::WebSocket,
138                );
139                return Ok(Some(context));
140            }
141        }
142
143        Ok(None)
144    }
145
146    /// Extract authentication from query parameters (during handshake)
147    fn extract_query_params(&self, request: &TransportRequest) -> AuthExtractionResult {
148        if !self
149            .config
150            .supported_methods
151            .contains(&WebSocketAuthMethod::QueryParams)
152        {
153            return Ok(None);
154        }
155
156        // Try common query parameter names
157        for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
158            if let Some(api_key) = request.get_query_param(param_name) {
159                AuthUtils::validate_api_key_format(api_key)?;
160                let context = TransportAuthContext::new(
161                    api_key.clone(),
162                    "QueryParams".to_string(),
163                    TransportType::WebSocket,
164                );
165                return Ok(Some(context));
166            }
167        }
168
169        Ok(None)
170    }
171
172    /// Extract authentication from first WebSocket message
173    fn extract_first_message(&self, request: &TransportRequest) -> AuthExtractionResult {
174        if !self
175            .config
176            .supported_methods
177            .contains(&WebSocketAuthMethod::FirstMessage)
178        {
179            return Ok(None);
180        }
181
182        if let Some(body) = &request.body {
183            // Look for authentication in message
184            if let Some(auth_data) = self.find_auth_in_message(body) {
185                AuthUtils::validate_api_key_format(&auth_data)?;
186                let context = TransportAuthContext::new(
187                    auth_data,
188                    "FirstMessage".to_string(),
189                    TransportType::WebSocket,
190                );
191                return Ok(Some(context));
192            }
193        }
194
195        Ok(None)
196    }
197
198    /// Extract authentication token from WebSocket subprotocol
199    fn extract_from_subprotocol(&self, subprotocol: &str) -> Option<String> {
200        // Format: "mcp-auth.TOKEN" or "mcp-auth-TOKEN"
201        if let Some(auth_protocol) = &self.config.auth_subprotocol {
202            let prefix = format!("{}.", auth_protocol);
203            if let Some(token) = subprotocol.strip_prefix(&prefix) {
204                return Some(token.to_string());
205            }
206
207            let prefix_dash = format!("{}-", auth_protocol);
208            if let Some(token) = subprotocol.strip_prefix(&prefix_dash) {
209                return Some(token.to_string());
210            }
211        }
212
213        None
214    }
215
216    /// Find authentication data in WebSocket message
217    fn find_auth_in_message(&self, message: &Value) -> Option<String> {
218        // Try direct auth field
219        if let Some(auth) = message.get("auth") {
220            if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
221                return Some(api_key.to_string());
222            }
223            if let Some(token) = auth.get("token").and_then(|v| v.as_str()) {
224                return Some(token.to_string());
225            }
226        }
227
228        // Try in params (for MCP initialize)
229        if let Some(params) = message.get("params") {
230            if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
231                return Some(api_key.to_string());
232            }
233
234            // Try nested in clientInfo
235            if let Some(client_info) = params.get("clientInfo") {
236                if let Some(auth) = client_info.get("authentication") {
237                    if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
238                        return Some(api_key.to_string());
239                    }
240                }
241            }
242        }
243
244        // Try root level for simple auth messages
245        if let Some(api_key) = message.get("api_key").and_then(|v| v.as_str()) {
246            return Some(api_key.to_string());
247        }
248
249        None
250    }
251
252    /// Add WebSocket-specific context information
253    fn enrich_context(
254        &self,
255        mut context: TransportAuthContext,
256        request: &TransportRequest,
257    ) -> TransportAuthContext {
258        // Add client IP
259        if let Some(client_ip) = AuthUtils::extract_client_ip(&request.headers) {
260            context = context.with_client_ip(client_ip);
261        }
262
263        // Add user agent
264        if let Some(user_agent) = AuthUtils::extract_user_agent(&request.headers) {
265            context = context.with_user_agent(user_agent);
266        }
267
268        // Add WebSocket-specific metadata
269        if let Some(origin) = request.get_header("Origin") {
270            context = context.with_metadata("origin".to_string(), origin.clone());
271        }
272
273        if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
274            context = context.with_metadata("protocols".to_string(), protocols.clone());
275        }
276
277        if let Some(version) = request.get_header("Sec-WebSocket-Version") {
278            context = context.with_metadata("ws_version".to_string(), version.clone());
279        }
280
281        context
282    }
283
284    /// Check if WebSocket handshake contains authentication
285    pub fn has_handshake_auth(&self, request: &TransportRequest) -> bool {
286        // Check headers for auth
287        if request.headers.contains_key("Authorization")
288            || AuthUtils::extract_api_key_header(&request.headers).is_some()
289        {
290            return true;
291        }
292
293        // Check query params for auth
294        for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
295            if request.query_params.contains_key(*param_name) {
296                return true;
297            }
298        }
299
300        // Check subprotocol for auth
301        if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
302            if let Some(auth_protocol) = &self.config.auth_subprotocol {
303                if protocols.contains(auth_protocol) {
304                    return true;
305                }
306            }
307        }
308
309        false
310    }
311}
312
313#[async_trait]
314impl AuthExtractor for WebSocketAuthExtractor {
315    async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
316        // Try different authentication methods in order of preference
317
318        // 1. Handshake headers
319        if let Ok(Some(context)) = self.extract_handshake_headers(&request.headers) {
320            return Ok(Some(self.enrich_context(context, request)));
321        }
322
323        // 2. Query parameters
324        if let Ok(Some(context)) = self.extract_query_params(request) {
325            return Ok(Some(self.enrich_context(context, request)));
326        }
327
328        // 3. First message (if body is present)
329        if let Ok(Some(context)) = self.extract_first_message(request) {
330            return Ok(Some(self.enrich_context(context, request)));
331        }
332
333        // No authentication found
334        if self.config.require_handshake_auth && !self.config.allow_post_connect_auth {
335            return Err(TransportAuthError::NoAuth);
336        }
337
338        Ok(None)
339    }
340
341    fn transport_type(&self) -> TransportType {
342        TransportType::WebSocket
343    }
344
345    fn can_handle(&self, request: &TransportRequest) -> bool {
346        // Check for WebSocket-specific headers
347        request.headers.contains_key("Sec-WebSocket-Key")
348            || request.headers.contains_key("Upgrade")
349            || request.metadata.contains_key("websocket")
350    }
351
352    async fn validate_auth(
353        &self,
354        context: &TransportAuthContext,
355    ) -> Result<(), TransportAuthError> {
356        // WebSocket-specific validation
357        if context.credential.is_empty() {
358            return Err(TransportAuthError::InvalidFormat(
359                "Empty credential".to_string(),
360            ));
361        }
362
363        // Warn about insecure authentication methods
364        if context.method == "QueryParams" {
365            tracing::warn!(
366                "WebSocket authentication via query parameters is less secure - consider using headers"
367            );
368        }
369
370        Ok(())
371    }
372}
373
374/// Helper for creating WebSocket authentication configuration
375impl WebSocketAuthConfig {
376    /// Create a secure configuration
377    pub fn secure() -> Self {
378        Self {
379            require_handshake_auth: true,
380            allow_post_connect_auth: false,
381            supported_methods: vec![WebSocketAuthMethod::HandshakeHeaders],
382            enable_per_message_auth: false,
383            auth_subprotocol: Some("mcp-auth".to_string()),
384            auth_timeout_secs: 10,
385        }
386    }
387
388    /// Create a flexible configuration
389    pub fn flexible() -> Self {
390        Self {
391            require_handshake_auth: false,
392            allow_post_connect_auth: true,
393            supported_methods: vec![
394                WebSocketAuthMethod::HandshakeHeaders,
395                WebSocketAuthMethod::QueryParams,
396                WebSocketAuthMethod::FirstMessage,
397            ],
398            enable_per_message_auth: false,
399            auth_subprotocol: Some("mcp-auth".to_string()),
400            auth_timeout_secs: 30,
401        }
402    }
403
404    /// Create a development-friendly configuration
405    pub fn development() -> Self {
406        Self {
407            require_handshake_auth: false,
408            allow_post_connect_auth: true,
409            supported_methods: vec![
410                WebSocketAuthMethod::HandshakeHeaders,
411                WebSocketAuthMethod::QueryParams,
412                WebSocketAuthMethod::FirstMessage,
413                WebSocketAuthMethod::Subprotocol,
414            ],
415            enable_per_message_auth: false,
416            auth_subprotocol: Some("mcp-auth".to_string()),
417            auth_timeout_secs: 60,
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use serde_json::json;
426
427    #[test]
428    fn test_handshake_header_extraction() {
429        let extractor = WebSocketAuthExtractor::default();
430        let mut headers = HashMap::new();
431        headers.insert(
432            "Authorization".to_string(),
433            "Bearer lmcp_test_1234567890abcdef".to_string(),
434        );
435        headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
436
437        let request = TransportRequest::from_headers(headers);
438        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
439
440        assert!(result.is_some());
441        let context = result.unwrap();
442        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
443        assert_eq!(context.method, "HandshakeHeaders");
444        assert_eq!(context.transport_type, TransportType::WebSocket);
445    }
446
447    #[test]
448    fn test_query_parameter_extraction() {
449        let extractor = WebSocketAuthExtractor::default();
450        let request = TransportRequest::new()
451            .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
452            .with_query_param(
453                "api_key".to_string(),
454                "lmcp_test_1234567890abcdef".to_string(),
455            );
456
457        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
458
459        assert!(result.is_some());
460        let context = result.unwrap();
461        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
462        assert_eq!(context.method, "QueryParams");
463    }
464
465    #[test]
466    fn test_first_message_extraction() {
467        let extractor = WebSocketAuthExtractor::default();
468
469        let auth_message = json!({
470            "auth": {
471                "api_key": "lmcp_test_1234567890abcdef"
472            }
473        });
474
475        let request = TransportRequest::new()
476            .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
477            .with_body(auth_message);
478
479        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
480
481        assert!(result.is_some());
482        let context = result.unwrap();
483        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
484        assert_eq!(context.method, "FirstMessage");
485    }
486
487    #[test]
488    fn test_subprotocol_extraction() {
489        let extractor = WebSocketAuthExtractor::default();
490        let mut headers = HashMap::new();
491        headers.insert(
492            "Sec-WebSocket-Protocol".to_string(),
493            "mcp-auth.lmcp_test_1234567890abcdef".to_string(),
494        );
495        headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
496
497        let request = TransportRequest::from_headers(headers);
498        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
499
500        assert!(result.is_some());
501        let context = result.unwrap();
502        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
503        assert_eq!(context.method, "Subprotocol");
504    }
505
506    #[test]
507    fn test_mcp_initialize_message() {
508        let extractor = WebSocketAuthExtractor::default();
509
510        let init_message = json!({
511            "method": "initialize",
512            "params": {
513                "clientInfo": {
514                    "name": "test-client",
515                    "authentication": {
516                        "api_key": "lmcp_test_1234567890abcdef"
517                    }
518                }
519            }
520        });
521
522        let request = TransportRequest::new()
523            .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
524            .with_body(init_message);
525
526        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
527
528        assert!(result.is_some());
529        let context = result.unwrap();
530        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
531        assert_eq!(context.method, "FirstMessage");
532    }
533
534    #[test]
535    fn test_has_handshake_auth() {
536        let extractor = WebSocketAuthExtractor::default();
537
538        // Test with Authorization header
539        let request1 = TransportRequest::new()
540            .with_header("Authorization".to_string(), "Bearer token123".to_string());
541        assert!(extractor.has_handshake_auth(&request1));
542
543        // Test with query parameter
544        let request2 =
545            TransportRequest::new().with_query_param("api_key".to_string(), "token123".to_string());
546        assert!(extractor.has_handshake_auth(&request2));
547
548        // Test without auth
549        let request3 = TransportRequest::new();
550        assert!(!extractor.has_handshake_auth(&request3));
551    }
552
553    #[test]
554    fn test_configuration_presets() {
555        let secure_config = WebSocketAuthConfig::secure();
556        assert!(secure_config.require_handshake_auth);
557        assert!(!secure_config.allow_post_connect_auth);
558        assert_eq!(secure_config.auth_timeout_secs, 10);
559
560        let flexible_config = WebSocketAuthConfig::flexible();
561        assert!(!flexible_config.require_handshake_auth);
562        assert!(flexible_config.allow_post_connect_auth);
563
564        let dev_config = WebSocketAuthConfig::development();
565        assert!(!dev_config.require_handshake_auth);
566        assert_eq!(dev_config.auth_timeout_secs, 60);
567    }
568}