pulseengine_mcp_auth/transport/
websocket_auth.rs1use super::auth_extractors::{
7 AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8 TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct WebSocketAuthConfig {
17 pub require_handshake_auth: bool,
19
20 pub allow_post_connect_auth: bool,
22
23 pub supported_methods: Vec<WebSocketAuthMethod>,
25
26 pub enable_per_message_auth: bool,
28
29 pub auth_subprotocol: Option<String>,
31
32 pub auth_timeout_secs: u64,
34}
35
36impl Default for WebSocketAuthConfig {
37 fn default() -> Self {
38 Self {
39 require_handshake_auth: true,
40 allow_post_connect_auth: true,
41 supported_methods: vec![
42 WebSocketAuthMethod::HandshakeHeaders,
43 WebSocketAuthMethod::QueryParams,
44 WebSocketAuthMethod::FirstMessage,
45 ],
46 enable_per_message_auth: false,
47 auth_subprotocol: Some("mcp-auth".to_string()),
48 auth_timeout_secs: 30,
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum WebSocketAuthMethod {
56 HandshakeHeaders,
58
59 QueryParams,
61
62 FirstMessage,
64
65 Subprotocol,
67
68 PerMessage,
70}
71
72pub struct WebSocketAuthExtractor {
74 config: WebSocketAuthConfig,
75}
76
77impl WebSocketAuthExtractor {
78 pub fn new(config: WebSocketAuthConfig) -> Self {
80 Self { config }
81 }
82
83 pub fn default() -> Self {
85 Self::new(WebSocketAuthConfig::default())
86 }
87
88 fn extract_handshake_headers(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
90 if !self
91 .config
92 .supported_methods
93 .contains(&WebSocketAuthMethod::HandshakeHeaders)
94 {
95 return Ok(None);
96 }
97
98 if let Some(auth_header) = headers
100 .get("Authorization")
101 .or_else(|| headers.get("authorization"))
102 {
103 if auth_header.starts_with("Bearer ") {
104 match AuthUtils::extract_bearer_token(auth_header) {
105 Ok(token) => {
106 AuthUtils::validate_api_key_format(&token)?;
107 let context = TransportAuthContext::new(
108 token,
109 "HandshakeHeaders".to_string(),
110 TransportType::WebSocket,
111 );
112 return Ok(Some(context));
113 }
114 Err(e) => return Err(e),
115 }
116 }
117 }
118
119 if let Some(api_key) = AuthUtils::extract_api_key_header(headers) {
121 AuthUtils::validate_api_key_format(&api_key)?;
122 let context = TransportAuthContext::new(
123 api_key,
124 "HandshakeHeaders".to_string(),
125 TransportType::WebSocket,
126 );
127 return Ok(Some(context));
128 }
129
130 if let Some(api_key) = headers.get("Sec-WebSocket-Protocol") {
132 if let Some(auth_token) = self.extract_from_subprotocol(api_key) {
133 AuthUtils::validate_api_key_format(&auth_token)?;
134 let context = TransportAuthContext::new(
135 auth_token,
136 "Subprotocol".to_string(),
137 TransportType::WebSocket,
138 );
139 return Ok(Some(context));
140 }
141 }
142
143 Ok(None)
144 }
145
146 fn extract_query_params(&self, request: &TransportRequest) -> AuthExtractionResult {
148 if !self
149 .config
150 .supported_methods
151 .contains(&WebSocketAuthMethod::QueryParams)
152 {
153 return Ok(None);
154 }
155
156 for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
158 if let Some(api_key) = request.get_query_param(param_name) {
159 AuthUtils::validate_api_key_format(api_key)?;
160 let context = TransportAuthContext::new(
161 api_key.clone(),
162 "QueryParams".to_string(),
163 TransportType::WebSocket,
164 );
165 return Ok(Some(context));
166 }
167 }
168
169 Ok(None)
170 }
171
172 fn extract_first_message(&self, request: &TransportRequest) -> AuthExtractionResult {
174 if !self
175 .config
176 .supported_methods
177 .contains(&WebSocketAuthMethod::FirstMessage)
178 {
179 return Ok(None);
180 }
181
182 if let Some(body) = &request.body {
183 if let Some(auth_data) = self.find_auth_in_message(body) {
185 AuthUtils::validate_api_key_format(&auth_data)?;
186 let context = TransportAuthContext::new(
187 auth_data,
188 "FirstMessage".to_string(),
189 TransportType::WebSocket,
190 );
191 return Ok(Some(context));
192 }
193 }
194
195 Ok(None)
196 }
197
198 fn extract_from_subprotocol(&self, subprotocol: &str) -> Option<String> {
200 if let Some(auth_protocol) = &self.config.auth_subprotocol {
202 let prefix = format!("{}.", auth_protocol);
203 if let Some(token) = subprotocol.strip_prefix(&prefix) {
204 return Some(token.to_string());
205 }
206
207 let prefix_dash = format!("{}-", auth_protocol);
208 if let Some(token) = subprotocol.strip_prefix(&prefix_dash) {
209 return Some(token.to_string());
210 }
211 }
212
213 None
214 }
215
216 fn find_auth_in_message(&self, message: &Value) -> Option<String> {
218 if let Some(auth) = message.get("auth") {
220 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
221 return Some(api_key.to_string());
222 }
223 if let Some(token) = auth.get("token").and_then(|v| v.as_str()) {
224 return Some(token.to_string());
225 }
226 }
227
228 if let Some(params) = message.get("params") {
230 if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
231 return Some(api_key.to_string());
232 }
233
234 if let Some(client_info) = params.get("clientInfo") {
236 if let Some(auth) = client_info.get("authentication") {
237 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
238 return Some(api_key.to_string());
239 }
240 }
241 }
242 }
243
244 if let Some(api_key) = message.get("api_key").and_then(|v| v.as_str()) {
246 return Some(api_key.to_string());
247 }
248
249 None
250 }
251
252 fn enrich_context(
254 &self,
255 mut context: TransportAuthContext,
256 request: &TransportRequest,
257 ) -> TransportAuthContext {
258 if let Some(client_ip) = AuthUtils::extract_client_ip(&request.headers) {
260 context = context.with_client_ip(client_ip);
261 }
262
263 if let Some(user_agent) = AuthUtils::extract_user_agent(&request.headers) {
265 context = context.with_user_agent(user_agent);
266 }
267
268 if let Some(origin) = request.get_header("Origin") {
270 context = context.with_metadata("origin".to_string(), origin.clone());
271 }
272
273 if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
274 context = context.with_metadata("protocols".to_string(), protocols.clone());
275 }
276
277 if let Some(version) = request.get_header("Sec-WebSocket-Version") {
278 context = context.with_metadata("ws_version".to_string(), version.clone());
279 }
280
281 context
282 }
283
284 pub fn has_handshake_auth(&self, request: &TransportRequest) -> bool {
286 if request.headers.contains_key("Authorization")
288 || AuthUtils::extract_api_key_header(&request.headers).is_some()
289 {
290 return true;
291 }
292
293 for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
295 if request.query_params.contains_key(*param_name) {
296 return true;
297 }
298 }
299
300 if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
302 if let Some(auth_protocol) = &self.config.auth_subprotocol {
303 if protocols.contains(auth_protocol) {
304 return true;
305 }
306 }
307 }
308
309 false
310 }
311}
312
313#[async_trait]
314impl AuthExtractor for WebSocketAuthExtractor {
315 async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
316 if let Ok(Some(context)) = self.extract_handshake_headers(&request.headers) {
320 return Ok(Some(self.enrich_context(context, request)));
321 }
322
323 if let Ok(Some(context)) = self.extract_query_params(request) {
325 return Ok(Some(self.enrich_context(context, request)));
326 }
327
328 if let Ok(Some(context)) = self.extract_first_message(request) {
330 return Ok(Some(self.enrich_context(context, request)));
331 }
332
333 if self.config.require_handshake_auth && !self.config.allow_post_connect_auth {
335 return Err(TransportAuthError::NoAuth);
336 }
337
338 Ok(None)
339 }
340
341 fn transport_type(&self) -> TransportType {
342 TransportType::WebSocket
343 }
344
345 fn can_handle(&self, request: &TransportRequest) -> bool {
346 request.headers.contains_key("Sec-WebSocket-Key")
348 || request.headers.contains_key("Upgrade")
349 || request.metadata.contains_key("websocket")
350 }
351
352 async fn validate_auth(
353 &self,
354 context: &TransportAuthContext,
355 ) -> Result<(), TransportAuthError> {
356 if context.credential.is_empty() {
358 return Err(TransportAuthError::InvalidFormat(
359 "Empty credential".to_string(),
360 ));
361 }
362
363 if context.method == "QueryParams" {
365 tracing::warn!("WebSocket authentication via query parameters is less secure - consider using headers");
366 }
367
368 Ok(())
369 }
370}
371
372impl WebSocketAuthConfig {
374 pub fn secure() -> Self {
376 Self {
377 require_handshake_auth: true,
378 allow_post_connect_auth: false,
379 supported_methods: vec![WebSocketAuthMethod::HandshakeHeaders],
380 enable_per_message_auth: false,
381 auth_subprotocol: Some("mcp-auth".to_string()),
382 auth_timeout_secs: 10,
383 }
384 }
385
386 pub fn flexible() -> Self {
388 Self {
389 require_handshake_auth: false,
390 allow_post_connect_auth: true,
391 supported_methods: vec![
392 WebSocketAuthMethod::HandshakeHeaders,
393 WebSocketAuthMethod::QueryParams,
394 WebSocketAuthMethod::FirstMessage,
395 ],
396 enable_per_message_auth: false,
397 auth_subprotocol: Some("mcp-auth".to_string()),
398 auth_timeout_secs: 30,
399 }
400 }
401
402 pub fn development() -> Self {
404 Self {
405 require_handshake_auth: false,
406 allow_post_connect_auth: true,
407 supported_methods: vec![
408 WebSocketAuthMethod::HandshakeHeaders,
409 WebSocketAuthMethod::QueryParams,
410 WebSocketAuthMethod::FirstMessage,
411 WebSocketAuthMethod::Subprotocol,
412 ],
413 enable_per_message_auth: false,
414 auth_subprotocol: Some("mcp-auth".to_string()),
415 auth_timeout_secs: 60,
416 }
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use serde_json::json;
424
425 #[test]
426 fn test_handshake_header_extraction() {
427 let extractor = WebSocketAuthExtractor::default();
428 let mut headers = HashMap::new();
429 headers.insert(
430 "Authorization".to_string(),
431 "Bearer lmcp_test_1234567890abcdef".to_string(),
432 );
433 headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
434
435 let request = TransportRequest::from_headers(headers);
436 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
437
438 assert!(result.is_some());
439 let context = result.unwrap();
440 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
441 assert_eq!(context.method, "HandshakeHeaders");
442 assert_eq!(context.transport_type, TransportType::WebSocket);
443 }
444
445 #[test]
446 fn test_query_parameter_extraction() {
447 let extractor = WebSocketAuthExtractor::default();
448 let request = TransportRequest::new()
449 .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
450 .with_query_param(
451 "api_key".to_string(),
452 "lmcp_test_1234567890abcdef".to_string(),
453 );
454
455 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
456
457 assert!(result.is_some());
458 let context = result.unwrap();
459 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
460 assert_eq!(context.method, "QueryParams");
461 }
462
463 #[test]
464 fn test_first_message_extraction() {
465 let extractor = WebSocketAuthExtractor::default();
466
467 let auth_message = json!({
468 "auth": {
469 "api_key": "lmcp_test_1234567890abcdef"
470 }
471 });
472
473 let request = TransportRequest::new()
474 .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
475 .with_body(auth_message);
476
477 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
478
479 assert!(result.is_some());
480 let context = result.unwrap();
481 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
482 assert_eq!(context.method, "FirstMessage");
483 }
484
485 #[test]
486 fn test_subprotocol_extraction() {
487 let extractor = WebSocketAuthExtractor::default();
488 let mut headers = HashMap::new();
489 headers.insert(
490 "Sec-WebSocket-Protocol".to_string(),
491 "mcp-auth.lmcp_test_1234567890abcdef".to_string(),
492 );
493 headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
494
495 let request = TransportRequest::from_headers(headers);
496 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
497
498 assert!(result.is_some());
499 let context = result.unwrap();
500 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
501 assert_eq!(context.method, "Subprotocol");
502 }
503
504 #[test]
505 fn test_mcp_initialize_message() {
506 let extractor = WebSocketAuthExtractor::default();
507
508 let init_message = json!({
509 "method": "initialize",
510 "params": {
511 "clientInfo": {
512 "name": "test-client",
513 "authentication": {
514 "api_key": "lmcp_test_1234567890abcdef"
515 }
516 }
517 }
518 });
519
520 let request = TransportRequest::new()
521 .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
522 .with_body(init_message);
523
524 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
525
526 assert!(result.is_some());
527 let context = result.unwrap();
528 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
529 assert_eq!(context.method, "FirstMessage");
530 }
531
532 #[test]
533 fn test_has_handshake_auth() {
534 let extractor = WebSocketAuthExtractor::default();
535
536 let request1 = TransportRequest::new()
538 .with_header("Authorization".to_string(), "Bearer token123".to_string());
539 assert!(extractor.has_handshake_auth(&request1));
540
541 let request2 =
543 TransportRequest::new().with_query_param("api_key".to_string(), "token123".to_string());
544 assert!(extractor.has_handshake_auth(&request2));
545
546 let request3 = TransportRequest::new();
548 assert!(!extractor.has_handshake_auth(&request3));
549 }
550
551 #[test]
552 fn test_configuration_presets() {
553 let secure_config = WebSocketAuthConfig::secure();
554 assert!(secure_config.require_handshake_auth);
555 assert!(!secure_config.allow_post_connect_auth);
556 assert_eq!(secure_config.auth_timeout_secs, 10);
557
558 let flexible_config = WebSocketAuthConfig::flexible();
559 assert!(!flexible_config.require_handshake_auth);
560 assert!(flexible_config.allow_post_connect_auth);
561
562 let dev_config = WebSocketAuthConfig::development();
563 assert!(!dev_config.require_handshake_auth);
564 assert_eq!(dev_config.auth_timeout_secs, 60);
565 }
566}