pulseengine_mcp_auth/middleware/
mcp_auth.rs1use crate::{models::Role, security::RequestSecurityValidator, AuthContext, AuthenticationManager};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
10use std::collections::HashMap;
11use std::sync::Arc;
12use thiserror::Error;
13use tracing::{debug, error, warn};
14
15#[derive(Debug, Error)]
17pub enum AuthExtractionError {
18 #[error("No authentication provided")]
19 NoAuth,
20
21 #[error("Invalid authentication format: {0}")]
22 InvalidFormat(String),
23
24 #[error("Authentication method not supported: {0}")]
25 UnsupportedMethod(String),
26
27 #[error("Missing required header: {0}")]
28 MissingHeader(String),
29}
30
31#[derive(Debug, Clone)]
33pub struct McpAuthConfig {
34 pub require_auth: bool,
36
37 pub anonymous_methods: Vec<String>,
39
40 pub method_role_requirements: HashMap<String, Vec<Role>>,
42
43 pub enable_permission_checking: bool,
45
46 pub auth_header_name: String,
48
49 pub enable_audit_logging: bool,
51
52 pub client_ip_header: Option<String>,
54}
55
56impl Default for McpAuthConfig {
57 fn default() -> Self {
58 Self {
59 require_auth: true,
60 anonymous_methods: vec!["initialize".to_string(), "ping".to_string()],
61 method_role_requirements: HashMap::new(),
62 enable_permission_checking: true,
63 auth_header_name: "Authorization".to_string(),
64 enable_audit_logging: true,
65 client_ip_header: Some("X-Forwarded-For".to_string()),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct McpAuthContext {
73 pub auth_context: Option<AuthContext>,
75
76 pub client_ip: Option<String>,
78
79 pub auth_method: Option<String>,
81
82 pub is_anonymous: bool,
84}
85
86#[derive(Debug, Clone)]
88pub struct McpRequestContext {
89 pub request_id: String,
91
92 pub auth: McpAuthContext,
94
95 pub timestamp: chrono::DateTime<chrono::Utc>,
97
98 pub metadata: HashMap<String, String>,
100}
101
102impl McpRequestContext {
103 pub fn new(request_id: String) -> Self {
104 Self {
105 request_id,
106 auth: McpAuthContext {
107 auth_context: None,
108 client_ip: None,
109 auth_method: None,
110 is_anonymous: true,
111 },
112 timestamp: chrono::Utc::now(),
113 metadata: HashMap::new(),
114 }
115 }
116
117 pub fn with_auth(mut self, auth_context: AuthContext, auth_method: String) -> Self {
118 self.auth.auth_context = Some(auth_context);
119 self.auth.auth_method = Some(auth_method);
120 self.auth.is_anonymous = false;
121 self
122 }
123
124 pub fn with_client_ip(mut self, client_ip: String) -> Self {
125 self.auth.client_ip = Some(client_ip);
126 self
127 }
128}
129
130pub struct McpAuthMiddleware {
132 auth_manager: Arc<AuthenticationManager>,
134
135 config: McpAuthConfig,
137
138 security_validator: Arc<RequestSecurityValidator>,
140}
141
142impl McpAuthMiddleware {
143 pub fn new(auth_manager: Arc<AuthenticationManager>, config: McpAuthConfig) -> Self {
145 Self {
146 auth_manager,
147 config,
148 security_validator: Arc::new(RequestSecurityValidator::default()),
149 }
150 }
151
152 pub fn with_security_validator(
154 auth_manager: Arc<AuthenticationManager>,
155 config: McpAuthConfig,
156 security_validator: Arc<RequestSecurityValidator>,
157 ) -> Self {
158 Self {
159 auth_manager,
160 config,
161 security_validator,
162 }
163 }
164
165 pub fn with_default_config(auth_manager: Arc<AuthenticationManager>) -> Self {
167 Self::new(auth_manager, McpAuthConfig::default())
168 }
169
170 pub fn security_validator(&self) -> &RequestSecurityValidator {
172 &self.security_validator
173 }
174
175 pub async fn process_request(
177 &self,
178 request: Request,
179 headers: Option<&HashMap<String, String>>,
180 ) -> Result<(Request, McpRequestContext), McpError> {
181 if let Err(security_error) = self
183 .security_validator
184 .validate_request(&request, None)
185 .await
186 {
187 error!("Request security validation failed: {}", security_error);
188 return Err(McpError::invalid_request(&format!(
189 "Security validation failed: {}",
190 security_error
191 )));
192 }
193
194 let sanitized_request = self.security_validator.sanitize_request(request).await;
196
197 let request_id = match &sanitized_request.id {
198 serde_json::Value::String(s) => s.clone(),
199 serde_json::Value::Number(n) => n.to_string(),
200 serde_json::Value::Null => uuid::Uuid::new_v4().to_string(),
201 _ => uuid::Uuid::new_v4().to_string(),
202 };
203 let mut context = McpRequestContext::new(request_id);
204
205 if let Some(headers) = headers {
207 if let Some(ip_header) = &self.config.client_ip_header {
208 if let Some(client_ip) = headers.get(ip_header) {
209 context = context.with_client_ip(client_ip.clone());
210 }
211 }
212 }
213
214 if self.should_skip_auth(&sanitized_request.method) {
216 debug!(
217 "Skipping authentication for method: {}",
218 sanitized_request.method
219 );
220 return Ok((sanitized_request, context));
221 }
222
223 let auth_result = if let Some(headers) = headers {
225 self.extract_authentication(headers).await
226 } else {
227 Err(AuthExtractionError::NoAuth)
228 };
229
230 match auth_result {
231 Ok((auth_context, auth_method)) => {
232 context = context.with_auth(auth_context, auth_method);
234
235 if let Err(e) = self
237 .check_method_permissions(&sanitized_request.method, &context)
238 .await
239 {
240 error!("Method permission check failed: {}", e);
241 return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
242 }
243
244 debug!("Request authenticated successfully");
245 Ok((sanitized_request, context))
246 }
247 Err(e) => {
248 if self.config.require_auth {
249 warn!("Authentication failed: {}", e);
250 Err(McpError::invalid_request(&format!(
251 "Authentication required: {}",
252 e
253 )))
254 } else {
255 debug!("Authentication failed but not required: {}", e);
256 Ok((sanitized_request, context))
257 }
258 }
259 }
260 }
261
262 pub async fn process_response(
264 &self,
265 response: Response,
266 _context: &McpRequestContext,
267 ) -> Result<Response, McpError> {
268 Ok(response)
271 }
272
273 async fn extract_authentication(
275 &self,
276 headers: &HashMap<String, String>,
277 ) -> Result<(AuthContext, String), AuthExtractionError> {
278 if let Some(auth_header) = headers.get(&self.config.auth_header_name) {
280 return self.parse_auth_header(auth_header).await;
281 }
282
283 if let Some(api_key) = headers.get("X-API-Key") {
285 return self.validate_api_key(api_key, "X-API-Key").await;
286 }
287
288 Err(AuthExtractionError::NoAuth)
289 }
290
291 async fn parse_auth_header(
293 &self,
294 auth_header: &str,
295 ) -> Result<(AuthContext, String), AuthExtractionError> {
296 let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
297 if parts.len() != 2 {
298 return Err(AuthExtractionError::InvalidFormat(
299 "Authorization header must be in format 'Type Token'".to_string(),
300 ));
301 }
302
303 let auth_type = parts[0].to_lowercase();
304 let token = parts[1];
305
306 match auth_type.as_str() {
307 "bearer" => self.validate_api_key(token, "Bearer").await,
308 "apikey" => self.validate_api_key(token, "ApiKey").await,
309 _ => Err(AuthExtractionError::UnsupportedMethod(auth_type)),
310 }
311 }
312
313 async fn validate_api_key(
315 &self,
316 api_key: &str,
317 method: &str,
318 ) -> Result<(AuthContext, String), AuthExtractionError> {
319 match self.auth_manager.validate_api_key(api_key, None).await {
320 Ok(Some(auth_context)) => Ok((auth_context, method.to_string())),
321 Ok(None) => Err(AuthExtractionError::InvalidFormat(
322 "Invalid API key".to_string(),
323 )),
324 Err(e) => {
325 error!("API key validation failed: {}", e);
326 Err(AuthExtractionError::InvalidFormat(
327 "Authentication failed".to_string(),
328 ))
329 }
330 }
331 }
332
333 fn should_skip_auth(&self, method: &str) -> bool {
335 if !self.config.require_auth {
336 return true;
337 }
338
339 self.config.anonymous_methods.contains(&method.to_string())
340 }
341
342 async fn check_method_permissions(
344 &self,
345 method: &str,
346 context: &McpRequestContext,
347 ) -> Result<(), String> {
348 if let Some(required_roles) = self.config.method_role_requirements.get(method) {
350 if let Some(auth_context) = &context.auth.auth_context {
351 let has_required_role = auth_context
353 .roles
354 .iter()
355 .any(|role| required_roles.contains(role));
356 if !has_required_role {
357 return Err(format!(
358 "Method '{}' requires one of these roles: {:?}, but user has roles: {:?}",
359 method, required_roles, auth_context.roles
360 ));
361 }
362 } else {
363 return Err(format!("Method '{}' requires authentication", method));
364 }
365 }
366
367 Ok(())
368 }
369}
370
371#[async_trait]
373pub trait McpMiddleware: Send + Sync {
374 async fn process_request(
376 &self,
377 request: Request,
378 context: &McpRequestContext,
379 ) -> Result<Request, McpError>;
380
381 async fn process_response(
383 &self,
384 response: Response,
385 context: &McpRequestContext,
386 ) -> Result<Response, McpError>;
387}
388
389#[async_trait]
390impl McpMiddleware for McpAuthMiddleware {
391 async fn process_request(
392 &self,
393 request: Request,
394 _context: &McpRequestContext,
395 ) -> Result<Request, McpError> {
396 Ok(request)
399 }
400
401 async fn process_response(
402 &self,
403 response: Response,
404 context: &McpRequestContext,
405 ) -> Result<Response, McpError> {
406 self.process_response(response, context).await
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::AuthConfig;
414
415 #[tokio::test]
416 async fn test_auth_middleware_creation() {
417 let config = AuthConfig::memory();
418 let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
419 let middleware = McpAuthMiddleware::with_default_config(auth_manager);
420
421 assert!(!middleware.config.anonymous_methods.is_empty());
422 assert!(middleware.config.require_auth);
423 }
424
425 #[tokio::test]
426 async fn test_anonymous_method_detection() {
427 let config = AuthConfig::memory();
428 let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
429 let middleware = McpAuthMiddleware::with_default_config(auth_manager);
430
431 assert!(middleware.should_skip_auth("initialize"));
432 assert!(middleware.should_skip_auth("ping"));
433 assert!(!middleware.should_skip_auth("tools/call"));
434 }
435
436 #[tokio::test]
437 async fn test_auth_header_parsing() {
438 let config = AuthConfig::memory();
439 let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
440 let middleware = McpAuthMiddleware::with_default_config(auth_manager);
441
442 let result = middleware.parse_auth_header("invalid").await;
444 assert!(result.is_err());
445
446 let result = middleware.parse_auth_header("Basic token123").await;
448 assert!(matches!(
449 result,
450 Err(AuthExtractionError::UnsupportedMethod(_))
451 ));
452 }
453}