pulseengine_mcp_auth/middleware/
mcp_auth.rs1use crate::{AuthContext, AuthenticationManager, models::Role, security::RequestSecurityValidator};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
10use std::collections::HashMap;
11use std::sync::Arc;
12use thiserror::Error;
13use tracing::{debug, error, warn};
14
15#[derive(Debug, Error)]
17pub enum AuthExtractionError {
18 #[error("No authentication provided")]
19 NoAuth,
20
21 #[error("Invalid authentication format: {0}")]
22 InvalidFormat(String),
23
24 #[error("Authentication method not supported: {0}")]
25 UnsupportedMethod(String),
26
27 #[error("Missing required header: {0}")]
28 MissingHeader(String),
29}
30
31#[derive(Debug, Clone)]
33pub struct McpAuthConfig {
34 pub require_auth: bool,
36
37 pub anonymous_methods: Vec<String>,
39
40 pub method_role_requirements: HashMap<String, Vec<Role>>,
42
43 pub enable_permission_checking: bool,
45
46 pub auth_header_name: String,
48
49 pub enable_audit_logging: bool,
51
52 pub client_ip_header: Option<String>,
54}
55
56impl Default for McpAuthConfig {
57 fn default() -> Self {
58 Self {
59 require_auth: true,
60 anonymous_methods: vec!["initialize".to_string(), "ping".to_string()],
61 method_role_requirements: HashMap::new(),
62 enable_permission_checking: true,
63 auth_header_name: "Authorization".to_string(),
64 enable_audit_logging: true,
65 client_ip_header: Some("X-Forwarded-For".to_string()),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct McpAuthContext {
73 pub auth_context: Option<AuthContext>,
75
76 pub client_ip: Option<String>,
78
79 pub auth_method: Option<String>,
81
82 pub is_anonymous: bool,
84}
85
86#[derive(Debug, Clone)]
88pub struct McpRequestContext {
89 pub request_id: String,
91
92 pub auth: McpAuthContext,
94
95 pub timestamp: chrono::DateTime<chrono::Utc>,
97
98 pub metadata: HashMap<String, String>,
100}
101
102impl McpRequestContext {
103 pub fn new(request_id: String) -> Self {
104 Self {
105 request_id,
106 auth: McpAuthContext {
107 auth_context: None,
108 client_ip: None,
109 auth_method: None,
110 is_anonymous: true,
111 },
112 timestamp: chrono::Utc::now(),
113 metadata: HashMap::new(),
114 }
115 }
116
117 pub fn with_auth(mut self, auth_context: AuthContext, auth_method: String) -> Self {
118 self.auth.auth_context = Some(auth_context);
119 self.auth.auth_method = Some(auth_method);
120 self.auth.is_anonymous = false;
121 self
122 }
123
124 pub fn with_client_ip(mut self, client_ip: String) -> Self {
125 self.auth.client_ip = Some(client_ip);
126 self
127 }
128}
129
130pub struct McpAuthMiddleware {
132 auth_manager: Arc<AuthenticationManager>,
134
135 config: McpAuthConfig,
137
138 security_validator: Arc<RequestSecurityValidator>,
140}
141
142impl McpAuthMiddleware {
143 pub fn new(auth_manager: Arc<AuthenticationManager>, config: McpAuthConfig) -> Self {
145 Self {
146 auth_manager,
147 config,
148 security_validator: Arc::new(RequestSecurityValidator::default()),
149 }
150 }
151
152 pub fn with_security_validator(
154 auth_manager: Arc<AuthenticationManager>,
155 config: McpAuthConfig,
156 security_validator: Arc<RequestSecurityValidator>,
157 ) -> Self {
158 Self {
159 auth_manager,
160 config,
161 security_validator,
162 }
163 }
164
165 pub fn with_default_config(auth_manager: Arc<AuthenticationManager>) -> Self {
167 Self::new(auth_manager, McpAuthConfig::default())
168 }
169
170 pub fn security_validator(&self) -> &RequestSecurityValidator {
172 &self.security_validator
173 }
174
175 pub async fn process_request(
177 &self,
178 request: Request,
179 headers: Option<&HashMap<String, String>>,
180 ) -> Result<(Request, McpRequestContext), McpError> {
181 if let Err(security_error) = self
183 .security_validator
184 .validate_request(&request, None)
185 .await
186 {
187 error!("Request security validation failed: {}", security_error);
188 return Err(McpError::invalid_request(&format!(
189 "Security validation failed: {}",
190 security_error
191 )));
192 }
193
194 let sanitized_request = self.security_validator.sanitize_request(request).await;
196
197 let request_id = match &sanitized_request.id {
198 Some(id) => id.to_string(),
199 None => uuid::Uuid::new_v4().to_string(),
200 };
201 let mut context = McpRequestContext::new(request_id);
202
203 if let Some(headers) = headers {
205 if let Some(ip_header) = &self.config.client_ip_header {
206 if let Some(client_ip) = headers.get(ip_header) {
207 context = context.with_client_ip(client_ip.clone());
208 }
209 }
210 }
211
212 if self.should_skip_auth(&sanitized_request.method) {
214 debug!(
215 "Skipping authentication for method: {}",
216 sanitized_request.method
217 );
218 return Ok((sanitized_request, context));
219 }
220
221 let auth_result = if let Some(headers) = headers {
223 self.extract_authentication(headers).await
224 } else {
225 Err(AuthExtractionError::NoAuth)
226 };
227
228 match auth_result {
229 Ok((auth_context, auth_method)) => {
230 context = context.with_auth(auth_context, auth_method);
232
233 if let Err(e) = self
235 .check_method_permissions(&sanitized_request.method, &context)
236 .await
237 {
238 error!("Method permission check failed: {}", e);
239 return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
240 }
241
242 debug!("Request authenticated successfully");
243 Ok((sanitized_request, context))
244 }
245 Err(e) => {
246 if self.config.require_auth {
247 warn!("Authentication failed: {}", e);
248 Err(McpError::invalid_request(&format!(
249 "Authentication required: {}",
250 e
251 )))
252 } else {
253 debug!("Authentication failed but not required: {}", e);
254 Ok((sanitized_request, context))
255 }
256 }
257 }
258 }
259
260 pub async fn process_response(
262 &self,
263 response: Response,
264 _context: &McpRequestContext,
265 ) -> Result<Response, McpError> {
266 Ok(response)
269 }
270
271 async fn extract_authentication(
273 &self,
274 headers: &HashMap<String, String>,
275 ) -> Result<(AuthContext, String), AuthExtractionError> {
276 if let Some(auth_header) = headers.get(&self.config.auth_header_name) {
278 return self.parse_auth_header(auth_header).await;
279 }
280
281 if let Some(api_key) = headers.get("X-API-Key") {
283 return self.validate_api_key(api_key, "X-API-Key").await;
284 }
285
286 Err(AuthExtractionError::NoAuth)
287 }
288
289 async fn parse_auth_header(
291 &self,
292 auth_header: &str,
293 ) -> Result<(AuthContext, String), AuthExtractionError> {
294 let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
295 if parts.len() != 2 {
296 return Err(AuthExtractionError::InvalidFormat(
297 "Authorization header must be in format 'Type Token'".to_string(),
298 ));
299 }
300
301 let auth_type = parts[0].to_lowercase();
302 let token = parts[1];
303
304 match auth_type.as_str() {
305 "bearer" => self.validate_api_key(token, "Bearer").await,
306 "apikey" => self.validate_api_key(token, "ApiKey").await,
307 _ => Err(AuthExtractionError::UnsupportedMethod(auth_type)),
308 }
309 }
310
311 async fn validate_api_key(
313 &self,
314 api_key: &str,
315 method: &str,
316 ) -> Result<(AuthContext, String), AuthExtractionError> {
317 match self.auth_manager.validate_api_key(api_key, None).await {
318 Ok(Some(auth_context)) => Ok((auth_context, method.to_string())),
319 Ok(None) => Err(AuthExtractionError::InvalidFormat(
320 "Invalid API key".to_string(),
321 )),
322 Err(e) => {
323 error!("API key validation failed: {}", e);
324 Err(AuthExtractionError::InvalidFormat(
325 "Authentication failed".to_string(),
326 ))
327 }
328 }
329 }
330
331 fn should_skip_auth(&self, method: &str) -> bool {
333 if !self.config.require_auth {
334 return true;
335 }
336
337 self.config.anonymous_methods.contains(&method.to_string())
338 }
339
340 async fn check_method_permissions(
342 &self,
343 method: &str,
344 context: &McpRequestContext,
345 ) -> Result<(), String> {
346 if let Some(required_roles) = self.config.method_role_requirements.get(method) {
348 if let Some(auth_context) = &context.auth.auth_context {
349 let has_required_role = auth_context
351 .roles
352 .iter()
353 .any(|role| required_roles.contains(role));
354 if !has_required_role {
355 return Err(format!(
356 "Method '{}' requires one of these roles: {:?}, but user has roles: {:?}",
357 method, required_roles, auth_context.roles
358 ));
359 }
360 } else {
361 return Err(format!("Method '{}' requires authentication", method));
362 }
363 }
364
365 Ok(())
366 }
367}
368
369#[async_trait]
371pub trait McpMiddleware: Send + Sync {
372 async fn process_request(
374 &self,
375 request: Request,
376 context: &McpRequestContext,
377 ) -> Result<Request, McpError>;
378
379 async fn process_response(
381 &self,
382 response: Response,
383 context: &McpRequestContext,
384 ) -> Result<Response, McpError>;
385}
386
387#[async_trait]
388impl McpMiddleware for McpAuthMiddleware {
389 async fn process_request(
390 &self,
391 request: Request,
392 _context: &McpRequestContext,
393 ) -> Result<Request, McpError> {
394 Ok(request)
397 }
398
399 async fn process_response(
400 &self,
401 response: Response,
402 context: &McpRequestContext,
403 ) -> Result<Response, McpError> {
404 self.process_response(response, context).await
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::AuthConfig;
412
413 #[tokio::test]
414 async fn test_auth_middleware_creation() {
415 let config = AuthConfig::memory();
416 let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
417 let middleware = McpAuthMiddleware::with_default_config(auth_manager);
418
419 assert!(!middleware.config.anonymous_methods.is_empty());
420 assert!(middleware.config.require_auth);
421 }
422
423 #[tokio::test]
424 async fn test_anonymous_method_detection() {
425 let config = AuthConfig::memory();
426 let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
427 let middleware = McpAuthMiddleware::with_default_config(auth_manager);
428
429 assert!(middleware.should_skip_auth("initialize"));
430 assert!(middleware.should_skip_auth("ping"));
431 assert!(!middleware.should_skip_auth("tools/call"));
432 }
433
434 #[tokio::test]
435 async fn test_auth_header_parsing() {
436 let config = AuthConfig::memory();
437 let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
438 let middleware = McpAuthMiddleware::with_default_config(auth_manager);
439
440 let result = middleware.parse_auth_header("invalid").await;
442 assert!(result.is_err());
443
444 let result = middleware.parse_auth_header("Basic token123").await;
446 assert!(matches!(
447 result,
448 Err(AuthExtractionError::UnsupportedMethod(_))
449 ));
450 }
451}