pulseengine_mcp_auth/transport/
http_auth.rs

1//! HTTP Transport Authentication
2//!
3//! This module provides authentication extraction for HTTP-based transports
4//! including REST APIs and Server-Sent Events.
5
6use super::auth_extractors::{
7    AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8    TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use std::collections::HashMap;
12
13/// Configuration for HTTP authentication
14#[derive(Debug, Clone)]
15pub struct HttpAuthConfig {
16    /// Supported authentication methods
17    pub supported_methods: Vec<HttpAuthMethod>,
18
19    /// Require HTTPS for authentication
20    pub require_https: bool,
21
22    /// Allow authentication in query parameters
23    pub allow_query_auth: bool,
24
25    /// Custom header names for authentication
26    pub custom_auth_headers: Vec<String>,
27
28    /// Enable CORS preflight authentication
29    pub enable_cors_auth: bool,
30
31    /// Trusted proxy IPs for X-Forwarded-For
32    pub trusted_proxies: Vec<String>,
33}
34
35impl Default for HttpAuthConfig {
36    fn default() -> Self {
37        Self {
38            supported_methods: vec![HttpAuthMethod::Bearer, HttpAuthMethod::ApiKeyHeader],
39            require_https: false,    // Allow HTTP for development
40            allow_query_auth: false, // Discourage query auth for security
41            custom_auth_headers: vec![],
42            enable_cors_auth: true,
43            trusted_proxies: vec![],
44        }
45    }
46}
47
48/// HTTP authentication methods
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum HttpAuthMethod {
51    /// Bearer token in Authorization header
52    Bearer,
53
54    /// API key in X-API-Key header
55    ApiKeyHeader,
56
57    /// API key in query parameter
58    ApiKeyQuery,
59
60    /// Basic authentication
61    Basic,
62
63    /// Custom header authentication
64    Custom(String),
65}
66
67impl HttpAuthMethod {
68    /// Get the method name as string
69    pub fn name(&self) -> String {
70        match self {
71            Self::Bearer => "Bearer".to_string(),
72            Self::ApiKeyHeader => "X-API-Key".to_string(),
73            Self::ApiKeyQuery => "Query".to_string(),
74            Self::Basic => "Basic".to_string(),
75            Self::Custom(name) => name.clone(),
76        }
77    }
78}
79
80/// HTTP authentication extractor
81pub struct HttpAuthExtractor {
82    config: HttpAuthConfig,
83}
84
85impl HttpAuthExtractor {
86    /// Create a new HTTP authentication extractor
87    pub fn new(config: HttpAuthConfig) -> Self {
88        Self { config }
89    }
90
91    /// Create with default configuration
92    pub fn default() -> Self {
93        Self::new(HttpAuthConfig::default())
94    }
95
96    /// Extract authentication from Authorization header
97    fn extract_authorization_header(
98        &self,
99        headers: &HashMap<String, String>,
100    ) -> AuthExtractionResult {
101        let auth_header = match headers
102            .get("Authorization")
103            .or_else(|| headers.get("authorization"))
104        {
105            Some(header) => header,
106            None => return Ok(None),
107        };
108
109        // Try Bearer token
110        if auth_header.starts_with("Bearer ")
111            && self
112                .config
113                .supported_methods
114                .contains(&HttpAuthMethod::Bearer)
115        {
116            match AuthUtils::extract_bearer_token(auth_header) {
117                Ok(token) => {
118                    AuthUtils::validate_api_key_format(&token)?;
119                    let context =
120                        TransportAuthContext::new(token, "Bearer".to_string(), TransportType::Http);
121                    return Ok(Some(context));
122                }
123                Err(e) => return Err(e),
124            }
125        }
126
127        // Try Basic authentication
128        if auth_header.starts_with("Basic ")
129            && self
130                .config
131                .supported_methods
132                .contains(&HttpAuthMethod::Basic)
133        {
134            return self.extract_basic_auth(auth_header);
135        }
136
137        Err(TransportAuthError::InvalidFormat(format!(
138            "Unsupported Authorization header format: {}",
139            auth_header
140        )))
141    }
142
143    /// Extract Basic authentication
144    fn extract_basic_auth(&self, auth_header: &str) -> AuthExtractionResult {
145        if !auth_header.starts_with("Basic ") {
146            return Err(TransportAuthError::InvalidFormat(
147                "Invalid Basic auth format".to_string(),
148            ));
149        }
150
151        let encoded = &auth_header[6..]; // Skip "Basic "
152        use base64::{Engine as _, engine::general_purpose};
153        let decoded = match general_purpose::STANDARD.decode(encoded) {
154            Ok(bytes) => match String::from_utf8(bytes) {
155                Ok(string) => string,
156                Err(_) => {
157                    return Err(TransportAuthError::InvalidFormat(
158                        "Invalid UTF-8 in Basic auth".to_string(),
159                    ));
160                }
161            },
162            Err(_) => {
163                return Err(TransportAuthError::InvalidFormat(
164                    "Invalid Base64 in Basic auth".to_string(),
165                ));
166            }
167        };
168
169        let parts: Vec<&str> = decoded.splitn(2, ':').collect();
170        if parts.len() != 2 {
171            return Err(TransportAuthError::InvalidFormat(
172                "Basic auth must be username:password".to_string(),
173            ));
174        }
175
176        // For API key auth, we expect username to be the API key and password to be empty or a specific value
177        let api_key = parts[0];
178        AuthUtils::validate_api_key_format(api_key)?;
179
180        let context = TransportAuthContext::new(
181            api_key.to_string(),
182            "Basic".to_string(),
183            TransportType::Http,
184        );
185        Ok(Some(context))
186    }
187
188    /// Extract authentication from X-API-Key header
189    fn extract_api_key_header(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
190        if !self
191            .config
192            .supported_methods
193            .contains(&HttpAuthMethod::ApiKeyHeader)
194        {
195            return Ok(None);
196        }
197
198        if let Some(api_key) = AuthUtils::extract_api_key_header(headers) {
199            AuthUtils::validate_api_key_format(&api_key)?;
200            let context =
201                TransportAuthContext::new(api_key, "X-API-Key".to_string(), TransportType::Http);
202            return Ok(Some(context));
203        }
204
205        Ok(None)
206    }
207
208    /// Extract authentication from query parameters
209    fn extract_query_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
210        if !self.config.allow_query_auth
211            || !self
212                .config
213                .supported_methods
214                .contains(&HttpAuthMethod::ApiKeyQuery)
215        {
216            return Ok(None);
217        }
218
219        // Try common query parameter names
220        for param_name in &["api_key", "apikey", "key", "token"] {
221            if let Some(api_key) = request.get_query_param(param_name) {
222                AuthUtils::validate_api_key_format(api_key)?;
223                let context = TransportAuthContext::new(
224                    api_key.clone(),
225                    "Query".to_string(),
226                    TransportType::Http,
227                );
228                return Ok(Some(context));
229            }
230        }
231
232        Ok(None)
233    }
234
235    /// Extract authentication from custom headers
236    fn extract_custom_headers(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
237        for header_name in &self.config.custom_auth_headers {
238            if let Some(value) = headers.get(header_name) {
239                AuthUtils::validate_api_key_format(value)?;
240                let context = TransportAuthContext::new(
241                    value.clone(),
242                    format!("Custom({})", header_name),
243                    TransportType::Http,
244                );
245                return Ok(Some(context));
246            }
247        }
248
249        Ok(None)
250    }
251
252    /// Add HTTP-specific context information
253    fn enrich_context(
254        &self,
255        mut context: TransportAuthContext,
256        request: &TransportRequest,
257    ) -> TransportAuthContext {
258        // Add client IP
259        if let Some(client_ip) = AuthUtils::extract_client_ip(&request.headers) {
260            context = context.with_client_ip(client_ip);
261        }
262
263        // Add user agent
264        if let Some(user_agent) = AuthUtils::extract_user_agent(&request.headers) {
265            context = context.with_user_agent(user_agent);
266        }
267
268        // Add HTTP-specific metadata
269        if let Some(host) = request.get_header("Host") {
270            context = context.with_metadata("host".to_string(), host.clone());
271        }
272
273        if let Some(referer) = request.get_header("Referer") {
274            context = context.with_metadata("referer".to_string(), referer.clone());
275        }
276
277        if let Some(origin) = request.get_header("Origin") {
278            context = context.with_metadata("origin".to_string(), origin.clone());
279        }
280
281        context
282    }
283
284    /// Check if request is HTTPS (when required)
285    fn validate_https(&self, request: &TransportRequest) -> Result<(), TransportAuthError> {
286        if !self.config.require_https {
287            return Ok(());
288        }
289
290        // Check various headers that indicate HTTPS
291        let is_https = request
292            .get_header("X-Forwarded-Proto")
293            .map(|proto| proto == "https")
294            .or_else(|| {
295                request
296                    .get_header("X-Scheme")
297                    .map(|scheme| scheme == "https")
298            })
299            .or_else(|| request.metadata.get("is_https").map(|_| true))
300            .unwrap_or(false);
301
302        if !is_https {
303            return Err(TransportAuthError::AuthFailed(
304                "HTTPS required for authentication".to_string(),
305            ));
306        }
307
308        Ok(())
309    }
310}
311
312#[async_trait]
313impl AuthExtractor for HttpAuthExtractor {
314    async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
315        // Validate HTTPS requirement
316        self.validate_https(request)?;
317
318        // Try different authentication methods in order of preference
319
320        // 1. Authorization header (Bearer, Basic)
321        match self.extract_authorization_header(&request.headers) {
322            Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
323            Ok(None) => {}           // Try next method
324            Err(e) => return Err(e), // Propagate validation errors
325        }
326
327        // 2. X-API-Key header
328        match self.extract_api_key_header(&request.headers) {
329            Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
330            Ok(None) => {}           // Try next method
331            Err(e) => return Err(e), // Propagate validation errors
332        }
333
334        // 3. Custom headers
335        match self.extract_custom_headers(&request.headers) {
336            Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
337            Ok(None) => {}           // Try next method
338            Err(e) => return Err(e), // Propagate validation errors
339        }
340
341        // 4. Query parameters (if allowed)
342        match self.extract_query_auth(request) {
343            Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
344            Ok(None) => {}           // Try next method
345            Err(e) => return Err(e), // Propagate validation errors
346        }
347
348        // No authentication found
349        Ok(None)
350    }
351
352    fn transport_type(&self) -> TransportType {
353        TransportType::Http
354    }
355
356    fn can_handle(&self, request: &TransportRequest) -> bool {
357        // HTTP extractor can handle any request with headers
358        !request.headers.is_empty()
359    }
360
361    async fn validate_auth(
362        &self,
363        context: &TransportAuthContext,
364    ) -> Result<(), TransportAuthError> {
365        // Additional HTTP-specific validation can go here
366        if context.credential.is_empty() {
367            return Err(TransportAuthError::InvalidFormat(
368                "Empty credential".to_string(),
369            ));
370        }
371
372        Ok(())
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use std::collections::HashMap;
380
381    #[test]
382    fn test_bearer_token_extraction() {
383        let extractor = HttpAuthExtractor::default();
384        let mut headers = HashMap::new();
385        headers.insert(
386            "Authorization".to_string(),
387            "Bearer lmcp_test_1234567890abcdef".to_string(),
388        );
389
390        let request = TransportRequest::from_headers(headers);
391        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
392
393        assert!(result.is_some());
394        let context = result.unwrap();
395        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
396        assert_eq!(context.method, "Bearer");
397        assert_eq!(context.transport_type, TransportType::Http);
398    }
399
400    #[test]
401    fn test_api_key_header_extraction() {
402        let extractor = HttpAuthExtractor::default();
403        let mut headers = HashMap::new();
404        headers.insert(
405            "X-API-Key".to_string(),
406            "lmcp_test_1234567890abcdef".to_string(),
407        );
408
409        let request = TransportRequest::from_headers(headers);
410        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
411
412        assert!(result.is_some());
413        let context = result.unwrap();
414        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
415        assert_eq!(context.method, "X-API-Key");
416    }
417
418    #[test]
419    fn test_basic_auth_extraction() {
420        let extractor = HttpAuthExtractor::new(HttpAuthConfig {
421            supported_methods: vec![HttpAuthMethod::Basic],
422            ..Default::default()
423        });
424
425        let api_key = "lmcp_test_1234567890abcdef";
426        use base64::{Engine as _, engine::general_purpose};
427        let encoded = general_purpose::STANDARD.encode(format!("{}:", api_key));
428        let mut headers = HashMap::new();
429        headers.insert("Authorization".to_string(), format!("Basic {}", encoded));
430
431        let request = TransportRequest::from_headers(headers);
432        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
433
434        assert!(result.is_some());
435        let context = result.unwrap();
436        assert_eq!(context.credential, api_key);
437        assert_eq!(context.method, "Basic");
438    }
439
440    #[test]
441    fn test_query_parameter_extraction() {
442        let extractor = HttpAuthExtractor::new(HttpAuthConfig {
443            allow_query_auth: true,
444            supported_methods: vec![HttpAuthMethod::ApiKeyQuery],
445            ..Default::default()
446        });
447
448        let request = TransportRequest::new().with_query_param(
449            "api_key".to_string(),
450            "lmcp_test_1234567890abcdef".to_string(),
451        );
452
453        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
454
455        assert!(result.is_some());
456        let context = result.unwrap();
457        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
458        assert_eq!(context.method, "Query");
459    }
460
461    #[test]
462    fn test_no_authentication() {
463        let extractor = HttpAuthExtractor::default();
464        let request = TransportRequest::new();
465
466        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
467        assert!(result.is_none());
468    }
469
470    #[test]
471    fn test_invalid_api_key_format() {
472        let extractor = HttpAuthExtractor::default();
473        let mut headers = HashMap::new();
474        headers.insert("X-API-Key".to_string(), "short".to_string()); // Too short
475
476        let request = TransportRequest::from_headers(headers);
477        let result = tokio_test::block_on(extractor.extract_auth(&request));
478
479        // The API key should be found and validation should fail
480        assert!(result.is_err());
481        if let Err(e) = result {
482            assert!(e.to_string().contains("too short"));
483        }
484    }
485
486    #[test]
487    fn test_context_enrichment() {
488        let extractor = HttpAuthExtractor::default();
489        let mut headers = HashMap::new();
490        headers.insert(
491            "X-API-Key".to_string(),
492            "lmcp_test_1234567890abcdef".to_string(),
493        );
494        headers.insert("X-Forwarded-For".to_string(), "192.168.1.100".to_string());
495        headers.insert("User-Agent".to_string(), "TestClient/1.0".to_string());
496        headers.insert("Host".to_string(), "api.example.com".to_string());
497
498        let request = TransportRequest::from_headers(headers);
499        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
500
501        assert!(result.is_some());
502        let context = result.unwrap();
503        assert_eq!(context.client_ip.unwrap(), "192.168.1.100");
504        assert_eq!(context.user_agent.unwrap(), "TestClient/1.0");
505        assert_eq!(context.metadata.get("host").unwrap(), "api.example.com");
506    }
507}