pulseengine_mcp_security_middleware/
middleware.rs1use crate::auth::{ApiKeyValidator, AuthContext, TokenValidator};
4use crate::config::SecurityConfig;
5use crate::error::{SecurityError, SecurityResult};
6use crate::utils::generate_request_id;
7use axum::{
8 extract::Request,
9 http::{HeaderMap, HeaderValue, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tracing::{debug, info, warn};
17
18#[derive(Debug, Clone)]
20pub struct SecurityMiddleware {
21 config: SecurityConfig,
22 api_key_validator: Option<ApiKeyValidator>,
23 token_validator: Option<Arc<TokenValidator>>,
24 rate_limiter: Arc<Mutex<RateLimiter>>,
25}
26
27impl SecurityMiddleware {
28 pub fn new(
30 config: SecurityConfig,
31 api_key_validator: Option<ApiKeyValidator>,
32 token_validator: Option<TokenValidator>,
33 ) -> Self {
34 let rate_limiter = Arc::new(Mutex::new(RateLimiter::new(
35 config.settings.rate_limit.max_requests,
36 config.settings.rate_limit.window_duration,
37 )));
38
39 Self {
40 config,
41 api_key_validator,
42 token_validator: token_validator.map(Arc::new),
43 rate_limiter,
44 }
45 }
46
47 async fn authenticate(&self, headers: &HeaderMap) -> SecurityResult<Option<AuthContext>> {
49 if !self.config.settings.require_authentication {
51 return Ok(None);
52 }
53
54 if let Some(ref validator) = self.api_key_validator {
56 if let Some(api_key) = extract_api_key(headers) {
57 match validator.validate_api_key(&api_key) {
58 Ok(user_id) => {
59 let auth_context = AuthContext::new(user_id)
60 .with_api_key(api_key)
61 .with_role("api_user");
62 return Ok(Some(auth_context));
63 }
64 Err(e) => {
65 debug!("API key validation failed: {}", e);
66 }
67 }
68 }
69 }
70
71 if let Some(ref validator) = self.token_validator {
73 if let Some(token) = extract_bearer_token(headers) {
74 match validator.validate_token(&token) {
75 Ok(claims) => {
76 let auth_context =
77 AuthContext::new(claims.sub.clone()).with_jwt_claims(claims);
78 return Ok(Some(auth_context));
79 }
80 Err(e) => {
81 debug!("JWT validation failed: {}", e);
82 }
83 }
84 }
85 }
86
87 Err(SecurityError::MissingAuth)
89 }
90
91 fn check_rate_limit(&self, client_id: &str) -> SecurityResult<()> {
93 if !self.config.settings.rate_limit.enabled {
94 return Ok(());
95 }
96
97 let mut limiter = self.rate_limiter.lock().unwrap();
98 if !limiter.allow_request(client_id) {
99 return Err(SecurityError::RateLimitExceeded);
100 }
101
102 Ok(())
103 }
104
105 pub async fn process(&self, request: Request, next: Next) -> Result<Response, StatusCode> {
107 let request_id = generate_request_id();
108 let start_time = Instant::now();
109
110 let client_id = extract_client_id(&request);
112
113 debug!(
114 "Processing request {} from client {}",
115 request_id, client_id
116 );
117
118 if let Err(e) = self.check_rate_limit(&client_id) {
120 warn!("Rate limit exceeded for client {}: {}", client_id, e);
121 return Err(StatusCode::TOO_MANY_REQUESTS);
122 }
123
124 let auth_context = match self.authenticate(request.headers()).await {
126 Ok(auth_context) => auth_context,
127 Err(SecurityError::MissingAuth) => {
128 if self.config.settings.require_authentication {
129 warn!(
130 "Authentication required but not provided for request {}",
131 request_id
132 );
133 return Err(StatusCode::UNAUTHORIZED);
134 } else {
135 None
136 }
137 }
138 Err(e) => {
139 warn!("Authentication failed for request {}: {}", request_id, e);
140 return match e {
141 SecurityError::InvalidApiKey => Err(StatusCode::UNAUTHORIZED),
142 SecurityError::TokenExpired => Err(StatusCode::UNAUTHORIZED),
143 SecurityError::InvalidToken(_) => Err(StatusCode::UNAUTHORIZED),
144 _ => Err(StatusCode::INTERNAL_SERVER_ERROR),
145 };
146 }
147 };
148
149 if self.config.settings.require_https && !is_https_request(&request) {
151 warn!("HTTPS required but request {} is not secure", request_id);
152 return Err(StatusCode::FORBIDDEN);
153 }
154
155 let mut request = request;
157 if let Some(auth_context) = auth_context {
158 request.extensions_mut().insert(auth_context.clone());
159 info!(
160 "Authenticated request {} as user {} with roles {:?}",
161 request_id, auth_context.user_id, auth_context.roles
162 );
163 }
164
165 request
167 .extensions_mut()
168 .insert(RequestId(request_id.clone()));
169
170 let mut response = next.run(request).await;
172
173 add_security_headers(&mut response, &self.config);
175
176 response.headers_mut().insert(
178 "x-request-id",
179 HeaderValue::from_str(&request_id)
180 .unwrap_or_else(|_| HeaderValue::from_static("invalid")),
181 );
182
183 if self.config.settings.enable_audit_logging {
185 let duration = start_time.elapsed();
186 info!(
187 "Request {} completed in {:?} with status {}",
188 request_id,
189 duration,
190 response.status()
191 );
192 }
193
194 Ok(response)
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct RequestId(pub String);
201
202fn extract_api_key(headers: &HeaderMap) -> Option<String> {
204 if let Some(auth_header) = headers.get("authorization") {
206 if let Ok(auth_str) = auth_header.to_str() {
207 if let Some(key) = auth_str.strip_prefix("ApiKey ") {
208 return Some(key.to_string());
209 }
210 if let Some(key) = auth_str.strip_prefix("Bearer ") {
211 if key.starts_with("mcp_") {
212 return Some(key.to_string());
213 }
214 }
215 }
216 }
217
218 if let Some(key_header) = headers.get("x-api-key") {
220 if let Ok(key_str) = key_header.to_str() {
221 return Some(key_str.to_string());
222 }
223 }
224
225 None
226}
227
228fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
230 if let Some(auth_header) = headers.get("authorization") {
231 if let Ok(auth_str) = auth_header.to_str() {
232 if let Some(token) = auth_str.strip_prefix("Bearer ") {
233 if !token.starts_with("mcp_") {
235 return Some(token.to_string());
236 }
237 }
238 }
239 }
240
241 None
242}
243
244fn extract_client_id(request: &Request) -> String {
246 let headers = request.headers();
248
249 if let Some(forwarded_for) = headers.get("x-forwarded-for") {
250 if let Ok(ip_str) = forwarded_for.to_str() {
251 if let Some(first_ip) = ip_str.split(',').next() {
252 return first_ip.trim().to_string();
253 }
254 }
255 }
256
257 if let Some(real_ip) = headers.get("x-real-ip") {
258 if let Ok(ip_str) = real_ip.to_str() {
259 return ip_str.to_string();
260 }
261 }
262
263 "unknown".to_string()
266}
267
268fn is_https_request(request: &Request) -> bool {
270 if request.uri().scheme_str() == Some("https") {
272 return true;
273 }
274
275 let headers = request.headers();
277
278 if let Some(forwarded_proto) = headers.get("x-forwarded-proto") {
279 if let Ok(proto_str) = forwarded_proto.to_str() {
280 return proto_str.to_lowercase() == "https";
281 }
282 }
283
284 if let Some(forwarded_ssl) = headers.get("x-forwarded-ssl") {
285 if let Ok(ssl_str) = forwarded_ssl.to_str() {
286 return ssl_str.to_lowercase() == "on";
287 }
288 }
289
290 if let Some(host) = headers.get("host") {
292 if let Ok(host_str) = host.to_str() {
293 if host_str.starts_with("localhost") || host_str.starts_with("127.0.0.1") {
294 return true;
295 }
296 }
297 }
298
299 false
300}
301
302fn add_security_headers(response: &mut Response, config: &SecurityConfig) {
304 let headers = response.headers_mut();
305
306 headers.insert(
308 "content-security-policy",
309 HeaderValue::from_static("default-src 'self'"),
310 );
311
312 headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
314
315 headers.insert(
317 "x-content-type-options",
318 HeaderValue::from_static("nosniff"),
319 );
320
321 headers.insert(
323 "referrer-policy",
324 HeaderValue::from_static("strict-origin-when-cross-origin"),
325 );
326
327 if config.settings.require_https {
329 headers.insert(
330 "strict-transport-security",
331 HeaderValue::from_static("max-age=31536000; includeSubDomains"),
332 );
333 }
334
335 headers.insert(
337 "server",
338 HeaderValue::from_static("MCP-Security-Middleware"),
339 );
340}
341
342#[derive(Debug)]
344struct RateLimiter {
345 max_requests: u32,
346 window_duration: Duration,
347 clients: HashMap<String, ClientRateLimit>,
348}
349
350#[derive(Debug)]
351struct ClientRateLimit {
352 requests: u32,
353 window_start: Instant,
354}
355
356impl RateLimiter {
357 fn new(max_requests: u32, window_duration: Duration) -> Self {
358 Self {
359 max_requests,
360 window_duration,
361 clients: HashMap::new(),
362 }
363 }
364
365 fn allow_request(&mut self, client_id: &str) -> bool {
366 let now = Instant::now();
367
368 if self.clients.len() > 10000 {
370 self.cleanup_old_entries(now);
371 }
372
373 let client_limit = self
374 .clients
375 .entry(client_id.to_string())
376 .or_insert(ClientRateLimit {
377 requests: 0,
378 window_start: now,
379 });
380
381 if now.duration_since(client_limit.window_start) >= self.window_duration {
383 client_limit.requests = 0;
384 client_limit.window_start = now;
385 }
386
387 if client_limit.requests >= self.max_requests {
389 false
390 } else {
391 client_limit.requests += 1;
392 true
393 }
394 }
395
396 fn cleanup_old_entries(&mut self, now: Instant) {
397 self.clients.retain(|_, client_limit| {
398 now.duration_since(client_limit.window_start) < self.window_duration * 2
399 });
400 }
401}
402
403pub async fn mcp_auth_middleware(
428 middleware: SecurityMiddleware,
429) -> impl Fn(
430 Request,
431 Next,
432)
433 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>>
434+ Clone {
435 move |req, next| {
436 let middleware = middleware.clone();
437 Box::pin(async move { middleware.process(req, next).await })
438 }
439}
440
441pub async fn mcp_rate_limit_middleware(
443 config: SecurityConfig,
444) -> impl Fn(
445 Request,
446 Next,
447)
448 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>>
449+ Clone {
450 let rate_limiter = Arc::new(Mutex::new(RateLimiter::new(
451 config.settings.rate_limit.max_requests,
452 config.settings.rate_limit.window_duration,
453 )));
454
455 move |req, next| {
456 let rate_limiter = rate_limiter.clone();
457 Box::pin(async move {
458 let client_id = extract_client_id(&req);
459
460 {
461 let mut limiter = rate_limiter.lock().unwrap();
462 if !limiter.allow_request(&client_id) {
463 return Err(StatusCode::TOO_MANY_REQUESTS);
464 }
465 }
466
467 let result = next.run(req).await;
468 Ok(result)
469 })
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use axum::{
477 Router,
478 body::Body,
479 http::{Method, Request},
480 middleware::from_fn,
481 routing::get,
482 };
483 use tower::ServiceExt;
484
485 async fn test_handler() -> &'static str {
486 "Hello, World!"
487 }
488
489 #[tokio::test]
490 async fn test_development_middleware() {
491 let config = SecurityConfig::development();
492 let middleware = config.create_middleware().await.unwrap();
493
494 let app = Router::new()
495 .route("/", get(test_handler))
496 .layer(from_fn(move |req, next| {
497 let middleware = middleware.clone();
498 async move { middleware.process(req, next).await }
499 }));
500
501 let request = Request::builder()
502 .method(Method::GET)
503 .uri("/")
504 .body(Body::empty())
505 .unwrap();
506
507 let response = app.oneshot(request).await.unwrap();
508 assert_eq!(response.status(), StatusCode::OK);
509 }
510
511 #[test]
512 fn test_extract_api_key() {
513 let mut headers = HeaderMap::new();
515 headers.insert(
516 "authorization",
517 HeaderValue::from_static("ApiKey mcp_test_key"),
518 );
519 assert_eq!(extract_api_key(&headers), Some("mcp_test_key".to_string()));
520
521 let mut headers = HeaderMap::new();
523 headers.insert(
524 "authorization",
525 HeaderValue::from_static("Bearer mcp_bearer_key"),
526 );
527 assert_eq!(
528 extract_api_key(&headers),
529 Some("mcp_bearer_key".to_string())
530 );
531
532 let mut headers = HeaderMap::new();
534 headers.insert("x-api-key", HeaderValue::from_static("mcp_x_api_key"));
535 assert_eq!(extract_api_key(&headers), Some("mcp_x_api_key".to_string()));
536 }
537
538 #[test]
539 fn test_extract_bearer_token() {
540 let mut headers = HeaderMap::new();
541
542 headers.insert(
544 "authorization",
545 HeaderValue::from_static("Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9"),
546 );
547 assert!(extract_bearer_token(&headers).is_some());
548
549 headers.insert(
551 "authorization",
552 HeaderValue::from_static("Bearer mcp_not_a_jwt"),
553 );
554 assert_eq!(extract_bearer_token(&headers), None);
555 }
556
557 #[test]
558 fn test_rate_limiter() {
559 let mut limiter = RateLimiter::new(2, Duration::from_secs(1));
560
561 assert!(limiter.allow_request("client1"));
563
564 assert!(limiter.allow_request("client1"));
566
567 assert!(!limiter.allow_request("client1"));
569
570 assert!(limiter.allow_request("client2"));
572 }
573
574 #[test]
575 fn test_is_https_request() {
576 let request = Request::builder()
578 .uri("https://example.com/test")
579 .body(Body::empty())
580 .unwrap();
581 assert!(is_https_request(&request));
582
583 let request = Request::builder()
585 .uri("/test")
586 .header("x-forwarded-proto", "https")
587 .body(Body::empty())
588 .unwrap();
589 assert!(is_https_request(&request));
590
591 let request = Request::builder()
593 .uri("/test")
594 .header("host", "localhost:3000")
595 .body(Body::empty())
596 .unwrap();
597 assert!(is_https_request(&request));
598 }
599
600 #[test]
601 fn test_rate_limiter_edge_cases() {
602 let mut limiter = RateLimiter::new(1, Duration::from_millis(100));
603
604 assert!(limiter.allow_request(""));
606 assert!(!limiter.allow_request(""));
607
608 std::thread::sleep(Duration::from_millis(150));
610 assert!(limiter.allow_request("client1"));
611 }
612
613 #[test]
614 fn test_extract_bearer_token_edge_cases() {
615 use axum::http::{HeaderMap, HeaderValue};
616
617 let mut headers = HeaderMap::new();
618
619 headers.insert("Authorization", HeaderValue::from_static("Bearer token123"));
621 assert_eq!(extract_bearer_token(&headers), Some("token123".to_string()));
622
623 headers.clear();
625 headers.insert(
626 "authorization",
627 HeaderValue::from_static("Bearer token456"),
628 );
629 assert_eq!(
630 extract_bearer_token(&headers),
631 Some(" token456".to_string())
632 );
633
634 headers.clear();
636 let invalid_utf8 = HeaderValue::from_bytes(b"Bearer \xff\xfe token").unwrap();
637 headers.insert("authorization", invalid_utf8);
638 assert_eq!(extract_bearer_token(&headers), None);
639 }
640
641 #[test]
642 fn test_extract_api_key_edge_cases() {
643 use axum::http::{HeaderMap, HeaderValue};
644
645 let mut headers = HeaderMap::new();
646
647 headers.insert("x-api-key", HeaderValue::from_static(""));
649 assert_eq!(extract_api_key(&headers), Some("".to_string()));
650
651 headers.clear();
653 headers.insert("x-api-key", HeaderValue::from_static(" "));
654 assert_eq!(extract_api_key(&headers), Some(" ".to_string()));
655
656 headers.clear();
658 headers.insert(
659 "authorization",
660 HeaderValue::from_static("Bearer mcp_test12345678901234567890"),
661 );
662 assert_eq!(
663 extract_api_key(&headers),
664 Some("mcp_test12345678901234567890".to_string())
665 );
666 }
667
668 #[test]
669 fn test_is_https_request_edge_cases() {
670 let request = Request::builder()
672 .uri("http://example.com/test")
673 .body(Body::empty())
674 .unwrap();
675 assert!(!is_https_request(&request));
676
677 let request = Request::builder()
679 .uri("/test")
680 .header("x-forwarded-proto", "http")
681 .body(Body::empty())
682 .unwrap();
683 assert!(!is_https_request(&request));
684
685 let request = Request::builder()
687 .uri("/test")
688 .header("host", "127.0.0.1:3000")
689 .body(Body::empty())
690 .unwrap();
691 assert!(is_https_request(&request));
692
693 let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
695 assert!(!is_https_request(&request));
696 }
697}