pulseengine_mcp_security_middleware/
middleware.rs

1//! Axum middleware implementation for MCP security
2
3use crate::auth::{ApiKeyValidator, AuthContext, TokenValidator};
4use crate::config::SecurityConfig;
5use crate::error::{SecurityError, SecurityResult};
6use crate::utils::generate_request_id;
7use axum::{
8    extract::Request,
9    http::{HeaderMap, HeaderValue, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tracing::{debug, info, warn};
17
18/// Main security middleware
19#[derive(Debug, Clone)]
20pub struct SecurityMiddleware {
21    config: SecurityConfig,
22    api_key_validator: Option<ApiKeyValidator>,
23    token_validator: Option<Arc<TokenValidator>>,
24    rate_limiter: Arc<Mutex<RateLimiter>>,
25}
26
27impl SecurityMiddleware {
28    /// Create a new security middleware
29    pub fn new(
30        config: SecurityConfig,
31        api_key_validator: Option<ApiKeyValidator>,
32        token_validator: Option<TokenValidator>,
33    ) -> Self {
34        let rate_limiter = Arc::new(Mutex::new(RateLimiter::new(
35            config.settings.rate_limit.max_requests,
36            config.settings.rate_limit.window_duration,
37        )));
38
39        Self {
40            config,
41            api_key_validator,
42            token_validator: token_validator.map(Arc::new),
43            rate_limiter,
44        }
45    }
46
47    /// Authenticate a request
48    async fn authenticate(&self, headers: &HeaderMap) -> SecurityResult<Option<AuthContext>> {
49        // If authentication is not required, return None
50        if !self.config.settings.require_authentication {
51            return Ok(None);
52        }
53
54        // Try API key authentication first
55        if let Some(ref validator) = self.api_key_validator {
56            if let Some(api_key) = extract_api_key(headers) {
57                match validator.validate_api_key(&api_key) {
58                    Ok(user_id) => {
59                        let auth_context = AuthContext::new(user_id)
60                            .with_api_key(api_key)
61                            .with_role("api_user");
62                        return Ok(Some(auth_context));
63                    }
64                    Err(e) => {
65                        debug!("API key validation failed: {}", e);
66                    }
67                }
68            }
69        }
70
71        // Try JWT token authentication
72        if let Some(ref validator) = self.token_validator {
73            if let Some(token) = extract_bearer_token(headers) {
74                match validator.validate_token(&token) {
75                    Ok(claims) => {
76                        let auth_context =
77                            AuthContext::new(claims.sub.clone()).with_jwt_claims(claims);
78                        return Ok(Some(auth_context));
79                    }
80                    Err(e) => {
81                        debug!("JWT validation failed: {}", e);
82                    }
83                }
84            }
85        }
86
87        // No valid authentication found
88        Err(SecurityError::MissingAuth)
89    }
90
91    /// Check rate limiting
92    fn check_rate_limit(&self, client_id: &str) -> SecurityResult<()> {
93        if !self.config.settings.rate_limit.enabled {
94            return Ok(());
95        }
96
97        let mut limiter = self.rate_limiter.lock().unwrap();
98        if !limiter.allow_request(client_id) {
99            return Err(SecurityError::RateLimitExceeded);
100        }
101
102        Ok(())
103    }
104
105    /// Process the request
106    pub async fn process(&self, request: Request, next: Next) -> Result<Response, StatusCode> {
107        let request_id = generate_request_id();
108        let start_time = Instant::now();
109
110        // Extract client identifier for rate limiting
111        let client_id = extract_client_id(&request);
112
113        debug!(
114            "Processing request {} from client {}",
115            request_id, client_id
116        );
117
118        // Check rate limiting
119        if let Err(e) = self.check_rate_limit(&client_id) {
120            warn!("Rate limit exceeded for client {}: {}", client_id, e);
121            return Err(StatusCode::TOO_MANY_REQUESTS);
122        }
123
124        // Authenticate the request
125        let auth_context = match self.authenticate(request.headers()).await {
126            Ok(auth_context) => auth_context,
127            Err(SecurityError::MissingAuth) => {
128                if self.config.settings.require_authentication {
129                    warn!(
130                        "Authentication required but not provided for request {}",
131                        request_id
132                    );
133                    return Err(StatusCode::UNAUTHORIZED);
134                } else {
135                    None
136                }
137            }
138            Err(e) => {
139                warn!("Authentication failed for request {}: {}", request_id, e);
140                return match e {
141                    SecurityError::InvalidApiKey => Err(StatusCode::UNAUTHORIZED),
142                    SecurityError::TokenExpired => Err(StatusCode::UNAUTHORIZED),
143                    SecurityError::InvalidToken(_) => Err(StatusCode::UNAUTHORIZED),
144                    _ => Err(StatusCode::INTERNAL_SERVER_ERROR),
145                };
146            }
147        };
148
149        // HTTPS enforcement
150        if self.config.settings.require_https && !is_https_request(&request) {
151            warn!("HTTPS required but request {} is not secure", request_id);
152            return Err(StatusCode::FORBIDDEN);
153        }
154
155        // Add auth context to request extensions if available
156        let mut request = request;
157        if let Some(auth_context) = auth_context {
158            request.extensions_mut().insert(auth_context.clone());
159            info!(
160                "Authenticated request {} as user {} with roles {:?}",
161                request_id, auth_context.user_id, auth_context.roles
162            );
163        }
164
165        // Add request ID to extensions
166        request
167            .extensions_mut()
168            .insert(RequestId(request_id.clone()));
169
170        // Process the request
171        let mut response = next.run(request).await;
172
173        // Add security headers
174        add_security_headers(&mut response, &self.config);
175
176        // Add request ID to response headers
177        response.headers_mut().insert(
178            "x-request-id",
179            HeaderValue::from_str(&request_id)
180                .unwrap_or_else(|_| HeaderValue::from_static("invalid")),
181        );
182
183        // Audit logging
184        if self.config.settings.enable_audit_logging {
185            let duration = start_time.elapsed();
186            info!(
187                "Request {} completed in {:?} with status {}",
188                request_id,
189                duration,
190                response.status()
191            );
192        }
193
194        Ok(response)
195    }
196}
197
198/// Request ID wrapper for extensions
199#[derive(Debug, Clone)]
200pub struct RequestId(pub String);
201
202/// Extract API key from request headers
203fn extract_api_key(headers: &HeaderMap) -> Option<String> {
204    // Try Authorization header first
205    if let Some(auth_header) = headers.get("authorization") {
206        if let Ok(auth_str) = auth_header.to_str() {
207            if let Some(key) = auth_str.strip_prefix("ApiKey ") {
208                return Some(key.to_string());
209            }
210            if let Some(key) = auth_str.strip_prefix("Bearer ") {
211                if key.starts_with("mcp_") {
212                    return Some(key.to_string());
213                }
214            }
215        }
216    }
217
218    // Try X-API-Key header
219    if let Some(key_header) = headers.get("x-api-key") {
220        if let Ok(key_str) = key_header.to_str() {
221            return Some(key_str.to_string());
222        }
223    }
224
225    None
226}
227
228/// Extract Bearer token from request headers
229fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
230    if let Some(auth_header) = headers.get("authorization") {
231        if let Ok(auth_str) = auth_header.to_str() {
232            if let Some(token) = auth_str.strip_prefix("Bearer ") {
233                // Make sure it's not an API key
234                if !token.starts_with("mcp_") {
235                    return Some(token.to_string());
236                }
237            }
238        }
239    }
240
241    None
242}
243
244/// Extract client identifier for rate limiting
245fn extract_client_id(request: &Request) -> String {
246    // Try to get client IP from headers (proxy headers)
247    let headers = request.headers();
248
249    if let Some(forwarded_for) = headers.get("x-forwarded-for") {
250        if let Ok(ip_str) = forwarded_for.to_str() {
251            if let Some(first_ip) = ip_str.split(',').next() {
252                return first_ip.trim().to_string();
253            }
254        }
255    }
256
257    if let Some(real_ip) = headers.get("x-real-ip") {
258        if let Ok(ip_str) = real_ip.to_str() {
259            return ip_str.to_string();
260        }
261    }
262
263    // Fallback to connection info (if available)
264    // This is simplified - in a real implementation you'd extract from connection info
265    "unknown".to_string()
266}
267
268/// Check if request is HTTPS
269fn is_https_request(request: &Request) -> bool {
270    // Check scheme if available in URI
271    if request.uri().scheme_str() == Some("https") {
272        return true;
273    }
274
275    // Check forwarded protocol headers (common in proxy setups)
276    let headers = request.headers();
277
278    if let Some(forwarded_proto) = headers.get("x-forwarded-proto") {
279        if let Ok(proto_str) = forwarded_proto.to_str() {
280            return proto_str.to_lowercase() == "https";
281        }
282    }
283
284    if let Some(forwarded_ssl) = headers.get("x-forwarded-ssl") {
285        if let Ok(ssl_str) = forwarded_ssl.to_str() {
286            return ssl_str.to_lowercase() == "on";
287        }
288    }
289
290    // For development, assume localhost connections are acceptable
291    if let Some(host) = headers.get("host") {
292        if let Ok(host_str) = host.to_str() {
293            if host_str.starts_with("localhost") || host_str.starts_with("127.0.0.1") {
294                return true;
295            }
296        }
297    }
298
299    false
300}
301
302/// Add security headers to response
303fn add_security_headers(response: &mut Response, config: &SecurityConfig) {
304    let headers = response.headers_mut();
305
306    // Content Security Policy
307    headers.insert(
308        "content-security-policy",
309        HeaderValue::from_static("default-src 'self'"),
310    );
311
312    // X-Frame-Options
313    headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
314
315    // X-Content-Type-Options
316    headers.insert(
317        "x-content-type-options",
318        HeaderValue::from_static("nosniff"),
319    );
320
321    // Referrer Policy
322    headers.insert(
323        "referrer-policy",
324        HeaderValue::from_static("strict-origin-when-cross-origin"),
325    );
326
327    // HTTPS enforcement
328    if config.settings.require_https {
329        headers.insert(
330            "strict-transport-security",
331            HeaderValue::from_static("max-age=31536000; includeSubDomains"),
332        );
333    }
334
335    // Server identification
336    headers.insert(
337        "server",
338        HeaderValue::from_static("MCP-Security-Middleware"),
339    );
340}
341
342/// Simple rate limiter implementation
343#[derive(Debug)]
344struct RateLimiter {
345    max_requests: u32,
346    window_duration: Duration,
347    clients: HashMap<String, ClientRateLimit>,
348}
349
350#[derive(Debug)]
351struct ClientRateLimit {
352    requests: u32,
353    window_start: Instant,
354}
355
356impl RateLimiter {
357    fn new(max_requests: u32, window_duration: Duration) -> Self {
358        Self {
359            max_requests,
360            window_duration,
361            clients: HashMap::new(),
362        }
363    }
364
365    fn allow_request(&mut self, client_id: &str) -> bool {
366        let now = Instant::now();
367
368        // Clean up old entries periodically
369        if self.clients.len() > 10000 {
370            self.cleanup_old_entries(now);
371        }
372
373        let client_limit = self
374            .clients
375            .entry(client_id.to_string())
376            .or_insert(ClientRateLimit {
377                requests: 0,
378                window_start: now,
379            });
380
381        // Check if we're in a new window
382        if now.duration_since(client_limit.window_start) >= self.window_duration {
383            client_limit.requests = 0;
384            client_limit.window_start = now;
385        }
386
387        // Check if request is allowed
388        if client_limit.requests >= self.max_requests {
389            false
390        } else {
391            client_limit.requests += 1;
392            true
393        }
394    }
395
396    fn cleanup_old_entries(&mut self, now: Instant) {
397        self.clients.retain(|_, client_limit| {
398            now.duration_since(client_limit.window_start) < self.window_duration * 2
399        });
400    }
401}
402
403/// Main MCP authentication middleware function for use with Axum
404///
405/// This is the primary entry point for integrating MCP security into an Axum application.
406///
407/// # Example
408/// ```rust,no_run
409/// use axum::{Router, routing::get, middleware::from_fn};
410/// use pulseengine_mcp_security_middleware::*;
411///
412/// #[tokio::main]
413/// async fn main() {
414///     let security_config = SecurityConfig::development();
415///     let middleware = security_config.create_middleware().await.unwrap();
416///     
417///     let app: Router = Router::new()
418///         .route("/", get(|| async { "Hello, secure world!" }))
419///         .layer(from_fn(move |req, next| {
420///             let middleware = middleware.clone();
421///             async move { middleware.process(req, next).await }
422///         }));
423///     
424///     // Start server...
425/// }
426/// ```
427pub async fn mcp_auth_middleware(
428    middleware: SecurityMiddleware,
429) -> impl Fn(
430    Request,
431    Next,
432)
433    -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>>
434+ Clone {
435    move |req, next| {
436        let middleware = middleware.clone();
437        Box::pin(async move { middleware.process(req, next).await })
438    }
439}
440
441/// Rate limiting middleware function
442pub async fn mcp_rate_limit_middleware(
443    config: SecurityConfig,
444) -> impl Fn(
445    Request,
446    Next,
447)
448    -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>>
449+ Clone {
450    let rate_limiter = Arc::new(Mutex::new(RateLimiter::new(
451        config.settings.rate_limit.max_requests,
452        config.settings.rate_limit.window_duration,
453    )));
454
455    move |req, next| {
456        let rate_limiter = rate_limiter.clone();
457        Box::pin(async move {
458            let client_id = extract_client_id(&req);
459
460            {
461                let mut limiter = rate_limiter.lock().unwrap();
462                if !limiter.allow_request(&client_id) {
463                    return Err(StatusCode::TOO_MANY_REQUESTS);
464                }
465            }
466
467            let result = next.run(req).await;
468            Ok(result)
469        })
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use axum::{
477        Router,
478        body::Body,
479        http::{Method, Request},
480        middleware::from_fn,
481        routing::get,
482    };
483    use tower::ServiceExt;
484
485    async fn test_handler() -> &'static str {
486        "Hello, World!"
487    }
488
489    #[tokio::test]
490    async fn test_development_middleware() {
491        let config = SecurityConfig::development();
492        let middleware = config.create_middleware().await.unwrap();
493
494        let app = Router::new()
495            .route("/", get(test_handler))
496            .layer(from_fn(move |req, next| {
497                let middleware = middleware.clone();
498                async move { middleware.process(req, next).await }
499            }));
500
501        let request = Request::builder()
502            .method(Method::GET)
503            .uri("/")
504            .body(Body::empty())
505            .unwrap();
506
507        let response = app.oneshot(request).await.unwrap();
508        assert_eq!(response.status(), StatusCode::OK);
509    }
510
511    #[test]
512    fn test_extract_api_key() {
513        // Test Authorization: ApiKey format
514        let mut headers = HeaderMap::new();
515        headers.insert(
516            "authorization",
517            HeaderValue::from_static("ApiKey mcp_test_key"),
518        );
519        assert_eq!(extract_api_key(&headers), Some("mcp_test_key".to_string()));
520
521        // Test Authorization: Bearer format (for API keys) - clear previous header first
522        let mut headers = HeaderMap::new();
523        headers.insert(
524            "authorization",
525            HeaderValue::from_static("Bearer mcp_bearer_key"),
526        );
527        assert_eq!(
528            extract_api_key(&headers),
529            Some("mcp_bearer_key".to_string())
530        );
531
532        // Test X-API-Key header - clear previous headers first
533        let mut headers = HeaderMap::new();
534        headers.insert("x-api-key", HeaderValue::from_static("mcp_x_api_key"));
535        assert_eq!(extract_api_key(&headers), Some("mcp_x_api_key".to_string()));
536    }
537
538    #[test]
539    fn test_extract_bearer_token() {
540        let mut headers = HeaderMap::new();
541
542        // Test JWT Bearer token (not starting with mcp_)
543        headers.insert(
544            "authorization",
545            HeaderValue::from_static("Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9"),
546        );
547        assert!(extract_bearer_token(&headers).is_some());
548
549        // Test API key in Bearer (should be None for JWT extraction)
550        headers.insert(
551            "authorization",
552            HeaderValue::from_static("Bearer mcp_not_a_jwt"),
553        );
554        assert_eq!(extract_bearer_token(&headers), None);
555    }
556
557    #[test]
558    fn test_rate_limiter() {
559        let mut limiter = RateLimiter::new(2, Duration::from_secs(1));
560
561        // First request should be allowed
562        assert!(limiter.allow_request("client1"));
563
564        // Second request should be allowed
565        assert!(limiter.allow_request("client1"));
566
567        // Third request should be denied
568        assert!(!limiter.allow_request("client1"));
569
570        // Different client should be allowed
571        assert!(limiter.allow_request("client2"));
572    }
573
574    #[test]
575    fn test_is_https_request() {
576        // Test with HTTPS URI
577        let request = Request::builder()
578            .uri("https://example.com/test")
579            .body(Body::empty())
580            .unwrap();
581        assert!(is_https_request(&request));
582
583        // Test with X-Forwarded-Proto header
584        let request = Request::builder()
585            .uri("/test")
586            .header("x-forwarded-proto", "https")
587            .body(Body::empty())
588            .unwrap();
589        assert!(is_https_request(&request));
590
591        // Test with localhost (should be accepted)
592        let request = Request::builder()
593            .uri("/test")
594            .header("host", "localhost:3000")
595            .body(Body::empty())
596            .unwrap();
597        assert!(is_https_request(&request));
598    }
599
600    #[test]
601    fn test_rate_limiter_edge_cases() {
602        let mut limiter = RateLimiter::new(1, Duration::from_millis(100));
603
604        // Test with empty client identifier
605        assert!(limiter.allow_request(""));
606        assert!(!limiter.allow_request(""));
607
608        // Test that limit resets after window
609        std::thread::sleep(Duration::from_millis(150));
610        assert!(limiter.allow_request("client1"));
611    }
612
613    #[test]
614    fn test_extract_bearer_token_edge_cases() {
615        use axum::http::{HeaderMap, HeaderValue};
616
617        let mut headers = HeaderMap::new();
618
619        // Test case-insensitive header names
620        headers.insert("Authorization", HeaderValue::from_static("Bearer token123"));
621        assert_eq!(extract_bearer_token(&headers), Some("token123".to_string()));
622
623        // Test with spaces after Bearer - actual behavior preserves spaces
624        headers.clear();
625        headers.insert(
626            "authorization",
627            HeaderValue::from_static("Bearer    token456"),
628        );
629        assert_eq!(
630            extract_bearer_token(&headers),
631            Some("   token456".to_string())
632        );
633
634        // Test with non-UTF8 header value (should return None)
635        headers.clear();
636        let invalid_utf8 = HeaderValue::from_bytes(b"Bearer \xff\xfe token").unwrap();
637        headers.insert("authorization", invalid_utf8);
638        assert_eq!(extract_bearer_token(&headers), None);
639    }
640
641    #[test]
642    fn test_extract_api_key_edge_cases() {
643        use axum::http::{HeaderMap, HeaderValue};
644
645        let mut headers = HeaderMap::new();
646
647        // Test empty API key - function returns the empty string
648        headers.insert("x-api-key", HeaderValue::from_static(""));
649        assert_eq!(extract_api_key(&headers), Some("".to_string()));
650
651        // Test whitespace-only API key - function returns the whitespace
652        headers.clear();
653        headers.insert("x-api-key", HeaderValue::from_static("   "));
654        assert_eq!(extract_api_key(&headers), Some("   ".to_string()));
655
656        // Test valid mcp_ API key via Bearer
657        headers.clear();
658        headers.insert(
659            "authorization",
660            HeaderValue::from_static("Bearer mcp_test12345678901234567890"),
661        );
662        assert_eq!(
663            extract_api_key(&headers),
664            Some("mcp_test12345678901234567890".to_string())
665        );
666    }
667
668    #[test]
669    fn test_is_https_request_edge_cases() {
670        // Test HTTP URI (should fail)
671        let request = Request::builder()
672            .uri("http://example.com/test")
673            .body(Body::empty())
674            .unwrap();
675        assert!(!is_https_request(&request));
676
677        // Test with X-Forwarded-Proto: http
678        let request = Request::builder()
679            .uri("/test")
680            .header("x-forwarded-proto", "http")
681            .body(Body::empty())
682            .unwrap();
683        assert!(!is_https_request(&request));
684
685        // Test with 127.0.0.1 (localhost variant)
686        let request = Request::builder()
687            .uri("/test")
688            .header("host", "127.0.0.1:3000")
689            .body(Body::empty())
690            .unwrap();
691        assert!(is_https_request(&request));
692
693        // Test with no host header and HTTP URI
694        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
695        assert!(!is_https_request(&request));
696    }
697}