pulseengine_mcp_auth/transport/
websocket_auth.rs1use super::auth_extractors::{
7 AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8 TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct WebSocketAuthConfig {
17 pub require_handshake_auth: bool,
19
20 pub allow_post_connect_auth: bool,
22
23 pub supported_methods: Vec<WebSocketAuthMethod>,
25
26 pub enable_per_message_auth: bool,
28
29 pub auth_subprotocol: Option<String>,
31
32 pub auth_timeout_secs: u64,
34}
35
36impl Default for WebSocketAuthConfig {
37 fn default() -> Self {
38 Self {
39 require_handshake_auth: true,
40 allow_post_connect_auth: true,
41 supported_methods: vec![
42 WebSocketAuthMethod::HandshakeHeaders,
43 WebSocketAuthMethod::QueryParams,
44 WebSocketAuthMethod::FirstMessage,
45 ],
46 enable_per_message_auth: false,
47 auth_subprotocol: Some("mcp-auth".to_string()),
48 auth_timeout_secs: 30,
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum WebSocketAuthMethod {
56 HandshakeHeaders,
58
59 QueryParams,
61
62 FirstMessage,
64
65 Subprotocol,
67
68 PerMessage,
70}
71
72pub struct WebSocketAuthExtractor {
74 config: WebSocketAuthConfig,
75}
76
77impl WebSocketAuthExtractor {
78 pub fn new(config: WebSocketAuthConfig) -> Self {
80 Self { config }
81 }
82
83 pub fn default() -> Self {
85 Self::new(WebSocketAuthConfig::default())
86 }
87
88 fn extract_handshake_headers(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
90 if !self
91 .config
92 .supported_methods
93 .contains(&WebSocketAuthMethod::HandshakeHeaders)
94 {
95 return Ok(None);
96 }
97
98 if let Some(auth_header) = headers
100 .get("Authorization")
101 .or_else(|| headers.get("authorization"))
102 {
103 if auth_header.starts_with("Bearer ") {
104 match AuthUtils::extract_bearer_token(auth_header) {
105 Ok(token) => {
106 AuthUtils::validate_api_key_format(&token)?;
107 let context = TransportAuthContext::new(
108 token,
109 "HandshakeHeaders".to_string(),
110 TransportType::WebSocket,
111 );
112 return Ok(Some(context));
113 }
114 Err(e) => return Err(e),
115 }
116 }
117 }
118
119 if let Some(api_key) = AuthUtils::extract_api_key_header(headers) {
121 AuthUtils::validate_api_key_format(&api_key)?;
122 let context = TransportAuthContext::new(
123 api_key,
124 "HandshakeHeaders".to_string(),
125 TransportType::WebSocket,
126 );
127 return Ok(Some(context));
128 }
129
130 if let Some(api_key) = headers.get("Sec-WebSocket-Protocol") {
132 if let Some(auth_token) = self.extract_from_subprotocol(api_key) {
133 AuthUtils::validate_api_key_format(&auth_token)?;
134 let context = TransportAuthContext::new(
135 auth_token,
136 "Subprotocol".to_string(),
137 TransportType::WebSocket,
138 );
139 return Ok(Some(context));
140 }
141 }
142
143 Ok(None)
144 }
145
146 fn extract_query_params(&self, request: &TransportRequest) -> AuthExtractionResult {
148 if !self
149 .config
150 .supported_methods
151 .contains(&WebSocketAuthMethod::QueryParams)
152 {
153 return Ok(None);
154 }
155
156 for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
158 if let Some(api_key) = request.get_query_param(param_name) {
159 AuthUtils::validate_api_key_format(api_key)?;
160 let context = TransportAuthContext::new(
161 api_key.clone(),
162 "QueryParams".to_string(),
163 TransportType::WebSocket,
164 );
165 return Ok(Some(context));
166 }
167 }
168
169 Ok(None)
170 }
171
172 fn extract_first_message(&self, request: &TransportRequest) -> AuthExtractionResult {
174 if !self
175 .config
176 .supported_methods
177 .contains(&WebSocketAuthMethod::FirstMessage)
178 {
179 return Ok(None);
180 }
181
182 if let Some(body) = &request.body {
183 if let Some(auth_data) = self.find_auth_in_message(body) {
185 AuthUtils::validate_api_key_format(&auth_data)?;
186 let context = TransportAuthContext::new(
187 auth_data,
188 "FirstMessage".to_string(),
189 TransportType::WebSocket,
190 );
191 return Ok(Some(context));
192 }
193 }
194
195 Ok(None)
196 }
197
198 fn extract_from_subprotocol(&self, subprotocol: &str) -> Option<String> {
200 if let Some(auth_protocol) = &self.config.auth_subprotocol {
202 let prefix = format!("{}.", auth_protocol);
203 if let Some(token) = subprotocol.strip_prefix(&prefix) {
204 return Some(token.to_string());
205 }
206
207 let prefix_dash = format!("{}-", auth_protocol);
208 if let Some(token) = subprotocol.strip_prefix(&prefix_dash) {
209 return Some(token.to_string());
210 }
211 }
212
213 None
214 }
215
216 fn find_auth_in_message(&self, message: &Value) -> Option<String> {
218 if let Some(auth) = message.get("auth") {
220 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
221 return Some(api_key.to_string());
222 }
223 if let Some(token) = auth.get("token").and_then(|v| v.as_str()) {
224 return Some(token.to_string());
225 }
226 }
227
228 if let Some(params) = message.get("params") {
230 if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
231 return Some(api_key.to_string());
232 }
233
234 if let Some(client_info) = params.get("clientInfo") {
236 if let Some(auth) = client_info.get("authentication") {
237 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
238 return Some(api_key.to_string());
239 }
240 }
241 }
242 }
243
244 if let Some(api_key) = message.get("api_key").and_then(|v| v.as_str()) {
246 return Some(api_key.to_string());
247 }
248
249 None
250 }
251
252 fn enrich_context(
254 &self,
255 mut context: TransportAuthContext,
256 request: &TransportRequest,
257 ) -> TransportAuthContext {
258 if let Some(client_ip) = AuthUtils::extract_client_ip(&request.headers) {
260 context = context.with_client_ip(client_ip);
261 }
262
263 if let Some(user_agent) = AuthUtils::extract_user_agent(&request.headers) {
265 context = context.with_user_agent(user_agent);
266 }
267
268 if let Some(origin) = request.get_header("Origin") {
270 context = context.with_metadata("origin".to_string(), origin.clone());
271 }
272
273 if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
274 context = context.with_metadata("protocols".to_string(), protocols.clone());
275 }
276
277 if let Some(version) = request.get_header("Sec-WebSocket-Version") {
278 context = context.with_metadata("ws_version".to_string(), version.clone());
279 }
280
281 context
282 }
283
284 pub fn has_handshake_auth(&self, request: &TransportRequest) -> bool {
286 if request.headers.contains_key("Authorization")
288 || AuthUtils::extract_api_key_header(&request.headers).is_some()
289 {
290 return true;
291 }
292
293 for param_name in &["api_key", "apikey", "key", "token", "access_token"] {
295 if request.query_params.contains_key(*param_name) {
296 return true;
297 }
298 }
299
300 if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") {
302 if let Some(auth_protocol) = &self.config.auth_subprotocol {
303 if protocols.contains(auth_protocol) {
304 return true;
305 }
306 }
307 }
308
309 false
310 }
311}
312
313#[async_trait]
314impl AuthExtractor for WebSocketAuthExtractor {
315 async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
316 if let Ok(Some(context)) = self.extract_handshake_headers(&request.headers) {
320 return Ok(Some(self.enrich_context(context, request)));
321 }
322
323 if let Ok(Some(context)) = self.extract_query_params(request) {
325 return Ok(Some(self.enrich_context(context, request)));
326 }
327
328 if let Ok(Some(context)) = self.extract_first_message(request) {
330 return Ok(Some(self.enrich_context(context, request)));
331 }
332
333 if self.config.require_handshake_auth && !self.config.allow_post_connect_auth {
335 return Err(TransportAuthError::NoAuth);
336 }
337
338 Ok(None)
339 }
340
341 fn transport_type(&self) -> TransportType {
342 TransportType::WebSocket
343 }
344
345 fn can_handle(&self, request: &TransportRequest) -> bool {
346 request.headers.contains_key("Sec-WebSocket-Key")
348 || request.headers.contains_key("Upgrade")
349 || request.metadata.contains_key("websocket")
350 }
351
352 async fn validate_auth(
353 &self,
354 context: &TransportAuthContext,
355 ) -> Result<(), TransportAuthError> {
356 if context.credential.is_empty() {
358 return Err(TransportAuthError::InvalidFormat(
359 "Empty credential".to_string(),
360 ));
361 }
362
363 if context.method == "QueryParams" {
365 tracing::warn!(
366 "WebSocket authentication via query parameters is less secure - consider using headers"
367 );
368 }
369
370 Ok(())
371 }
372}
373
374impl WebSocketAuthConfig {
376 pub fn secure() -> Self {
378 Self {
379 require_handshake_auth: true,
380 allow_post_connect_auth: false,
381 supported_methods: vec![WebSocketAuthMethod::HandshakeHeaders],
382 enable_per_message_auth: false,
383 auth_subprotocol: Some("mcp-auth".to_string()),
384 auth_timeout_secs: 10,
385 }
386 }
387
388 pub fn flexible() -> Self {
390 Self {
391 require_handshake_auth: false,
392 allow_post_connect_auth: true,
393 supported_methods: vec![
394 WebSocketAuthMethod::HandshakeHeaders,
395 WebSocketAuthMethod::QueryParams,
396 WebSocketAuthMethod::FirstMessage,
397 ],
398 enable_per_message_auth: false,
399 auth_subprotocol: Some("mcp-auth".to_string()),
400 auth_timeout_secs: 30,
401 }
402 }
403
404 pub fn development() -> Self {
406 Self {
407 require_handshake_auth: false,
408 allow_post_connect_auth: true,
409 supported_methods: vec![
410 WebSocketAuthMethod::HandshakeHeaders,
411 WebSocketAuthMethod::QueryParams,
412 WebSocketAuthMethod::FirstMessage,
413 WebSocketAuthMethod::Subprotocol,
414 ],
415 enable_per_message_auth: false,
416 auth_subprotocol: Some("mcp-auth".to_string()),
417 auth_timeout_secs: 60,
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use serde_json::json;
426
427 #[test]
428 fn test_handshake_header_extraction() {
429 let extractor = WebSocketAuthExtractor::default();
430 let mut headers = HashMap::new();
431 headers.insert(
432 "Authorization".to_string(),
433 "Bearer lmcp_test_1234567890abcdef".to_string(),
434 );
435 headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
436
437 let request = TransportRequest::from_headers(headers);
438 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
439
440 assert!(result.is_some());
441 let context = result.unwrap();
442 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
443 assert_eq!(context.method, "HandshakeHeaders");
444 assert_eq!(context.transport_type, TransportType::WebSocket);
445 }
446
447 #[test]
448 fn test_query_parameter_extraction() {
449 let extractor = WebSocketAuthExtractor::default();
450 let request = TransportRequest::new()
451 .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
452 .with_query_param(
453 "api_key".to_string(),
454 "lmcp_test_1234567890abcdef".to_string(),
455 );
456
457 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
458
459 assert!(result.is_some());
460 let context = result.unwrap();
461 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
462 assert_eq!(context.method, "QueryParams");
463 }
464
465 #[test]
466 fn test_first_message_extraction() {
467 let extractor = WebSocketAuthExtractor::default();
468
469 let auth_message = json!({
470 "auth": {
471 "api_key": "lmcp_test_1234567890abcdef"
472 }
473 });
474
475 let request = TransportRequest::new()
476 .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
477 .with_body(auth_message);
478
479 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
480
481 assert!(result.is_some());
482 let context = result.unwrap();
483 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
484 assert_eq!(context.method, "FirstMessage");
485 }
486
487 #[test]
488 fn test_subprotocol_extraction() {
489 let extractor = WebSocketAuthExtractor::default();
490 let mut headers = HashMap::new();
491 headers.insert(
492 "Sec-WebSocket-Protocol".to_string(),
493 "mcp-auth.lmcp_test_1234567890abcdef".to_string(),
494 );
495 headers.insert("Sec-WebSocket-Key".to_string(), "test-key".to_string());
496
497 let request = TransportRequest::from_headers(headers);
498 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
499
500 assert!(result.is_some());
501 let context = result.unwrap();
502 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
503 assert_eq!(context.method, "Subprotocol");
504 }
505
506 #[test]
507 fn test_mcp_initialize_message() {
508 let extractor = WebSocketAuthExtractor::default();
509
510 let init_message = json!({
511 "method": "initialize",
512 "params": {
513 "clientInfo": {
514 "name": "test-client",
515 "authentication": {
516 "api_key": "lmcp_test_1234567890abcdef"
517 }
518 }
519 }
520 });
521
522 let request = TransportRequest::new()
523 .with_header("Sec-WebSocket-Key".to_string(), "test-key".to_string())
524 .with_body(init_message);
525
526 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
527
528 assert!(result.is_some());
529 let context = result.unwrap();
530 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
531 assert_eq!(context.method, "FirstMessage");
532 }
533
534 #[test]
535 fn test_has_handshake_auth() {
536 let extractor = WebSocketAuthExtractor::default();
537
538 let request1 = TransportRequest::new()
540 .with_header("Authorization".to_string(), "Bearer token123".to_string());
541 assert!(extractor.has_handshake_auth(&request1));
542
543 let request2 =
545 TransportRequest::new().with_query_param("api_key".to_string(), "token123".to_string());
546 assert!(extractor.has_handshake_auth(&request2));
547
548 let request3 = TransportRequest::new();
550 assert!(!extractor.has_handshake_auth(&request3));
551 }
552
553 #[test]
554 fn test_configuration_presets() {
555 let secure_config = WebSocketAuthConfig::secure();
556 assert!(secure_config.require_handshake_auth);
557 assert!(!secure_config.allow_post_connect_auth);
558 assert_eq!(secure_config.auth_timeout_secs, 10);
559
560 let flexible_config = WebSocketAuthConfig::flexible();
561 assert!(!flexible_config.require_handshake_auth);
562 assert!(flexible_config.allow_post_connect_auth);
563
564 let dev_config = WebSocketAuthConfig::development();
565 assert!(!dev_config.require_handshake_auth);
566 assert_eq!(dev_config.auth_timeout_secs, 60);
567 }
568}