pulseengine_mcp_auth/transport/
http_auth.rs1use super::auth_extractors::{
7 AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8 TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct HttpAuthConfig {
16 pub supported_methods: Vec<HttpAuthMethod>,
18
19 pub require_https: bool,
21
22 pub allow_query_auth: bool,
24
25 pub custom_auth_headers: Vec<String>,
27
28 pub enable_cors_auth: bool,
30
31 pub trusted_proxies: Vec<String>,
33}
34
35impl Default for HttpAuthConfig {
36 fn default() -> Self {
37 Self {
38 supported_methods: vec![HttpAuthMethod::Bearer, HttpAuthMethod::ApiKeyHeader],
39 require_https: false, allow_query_auth: false, custom_auth_headers: vec![],
42 enable_cors_auth: true,
43 trusted_proxies: vec![],
44 }
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum HttpAuthMethod {
51 Bearer,
53
54 ApiKeyHeader,
56
57 ApiKeyQuery,
59
60 Basic,
62
63 Custom(String),
65}
66
67impl HttpAuthMethod {
68 pub fn name(&self) -> String {
70 match self {
71 Self::Bearer => "Bearer".to_string(),
72 Self::ApiKeyHeader => "X-API-Key".to_string(),
73 Self::ApiKeyQuery => "Query".to_string(),
74 Self::Basic => "Basic".to_string(),
75 Self::Custom(name) => name.clone(),
76 }
77 }
78}
79
80pub struct HttpAuthExtractor {
82 config: HttpAuthConfig,
83}
84
85impl HttpAuthExtractor {
86 pub fn new(config: HttpAuthConfig) -> Self {
88 Self { config }
89 }
90
91 pub fn default() -> Self {
93 Self::new(HttpAuthConfig::default())
94 }
95
96 fn extract_authorization_header(
98 &self,
99 headers: &HashMap<String, String>,
100 ) -> AuthExtractionResult {
101 let auth_header = match headers
102 .get("Authorization")
103 .or_else(|| headers.get("authorization"))
104 {
105 Some(header) => header,
106 None => return Ok(None),
107 };
108
109 if auth_header.starts_with("Bearer ")
111 && self
112 .config
113 .supported_methods
114 .contains(&HttpAuthMethod::Bearer)
115 {
116 match AuthUtils::extract_bearer_token(auth_header) {
117 Ok(token) => {
118 AuthUtils::validate_api_key_format(&token)?;
119 let context =
120 TransportAuthContext::new(token, "Bearer".to_string(), TransportType::Http);
121 return Ok(Some(context));
122 }
123 Err(e) => return Err(e),
124 }
125 }
126
127 if auth_header.starts_with("Basic ")
129 && self
130 .config
131 .supported_methods
132 .contains(&HttpAuthMethod::Basic)
133 {
134 return self.extract_basic_auth(auth_header);
135 }
136
137 Err(TransportAuthError::InvalidFormat(format!(
138 "Unsupported Authorization header format: {}",
139 auth_header
140 )))
141 }
142
143 fn extract_basic_auth(&self, auth_header: &str) -> AuthExtractionResult {
145 if !auth_header.starts_with("Basic ") {
146 return Err(TransportAuthError::InvalidFormat(
147 "Invalid Basic auth format".to_string(),
148 ));
149 }
150
151 let encoded = &auth_header[6..]; use base64::{Engine as _, engine::general_purpose};
153 let decoded = match general_purpose::STANDARD.decode(encoded) {
154 Ok(bytes) => match String::from_utf8(bytes) {
155 Ok(string) => string,
156 Err(_) => {
157 return Err(TransportAuthError::InvalidFormat(
158 "Invalid UTF-8 in Basic auth".to_string(),
159 ));
160 }
161 },
162 Err(_) => {
163 return Err(TransportAuthError::InvalidFormat(
164 "Invalid Base64 in Basic auth".to_string(),
165 ));
166 }
167 };
168
169 let parts: Vec<&str> = decoded.splitn(2, ':').collect();
170 if parts.len() != 2 {
171 return Err(TransportAuthError::InvalidFormat(
172 "Basic auth must be username:password".to_string(),
173 ));
174 }
175
176 let api_key = parts[0];
178 AuthUtils::validate_api_key_format(api_key)?;
179
180 let context = TransportAuthContext::new(
181 api_key.to_string(),
182 "Basic".to_string(),
183 TransportType::Http,
184 );
185 Ok(Some(context))
186 }
187
188 fn extract_api_key_header(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
190 if !self
191 .config
192 .supported_methods
193 .contains(&HttpAuthMethod::ApiKeyHeader)
194 {
195 return Ok(None);
196 }
197
198 if let Some(api_key) = AuthUtils::extract_api_key_header(headers) {
199 AuthUtils::validate_api_key_format(&api_key)?;
200 let context =
201 TransportAuthContext::new(api_key, "X-API-Key".to_string(), TransportType::Http);
202 return Ok(Some(context));
203 }
204
205 Ok(None)
206 }
207
208 fn extract_query_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
210 if !self.config.allow_query_auth
211 || !self
212 .config
213 .supported_methods
214 .contains(&HttpAuthMethod::ApiKeyQuery)
215 {
216 return Ok(None);
217 }
218
219 for param_name in &["api_key", "apikey", "key", "token"] {
221 if let Some(api_key) = request.get_query_param(param_name) {
222 AuthUtils::validate_api_key_format(api_key)?;
223 let context = TransportAuthContext::new(
224 api_key.clone(),
225 "Query".to_string(),
226 TransportType::Http,
227 );
228 return Ok(Some(context));
229 }
230 }
231
232 Ok(None)
233 }
234
235 fn extract_custom_headers(&self, headers: &HashMap<String, String>) -> AuthExtractionResult {
237 for header_name in &self.config.custom_auth_headers {
238 if let Some(value) = headers.get(header_name) {
239 AuthUtils::validate_api_key_format(value)?;
240 let context = TransportAuthContext::new(
241 value.clone(),
242 format!("Custom({})", header_name),
243 TransportType::Http,
244 );
245 return Ok(Some(context));
246 }
247 }
248
249 Ok(None)
250 }
251
252 fn enrich_context(
254 &self,
255 mut context: TransportAuthContext,
256 request: &TransportRequest,
257 ) -> TransportAuthContext {
258 if let Some(client_ip) = AuthUtils::extract_client_ip(&request.headers) {
260 context = context.with_client_ip(client_ip);
261 }
262
263 if let Some(user_agent) = AuthUtils::extract_user_agent(&request.headers) {
265 context = context.with_user_agent(user_agent);
266 }
267
268 if let Some(host) = request.get_header("Host") {
270 context = context.with_metadata("host".to_string(), host.clone());
271 }
272
273 if let Some(referer) = request.get_header("Referer") {
274 context = context.with_metadata("referer".to_string(), referer.clone());
275 }
276
277 if let Some(origin) = request.get_header("Origin") {
278 context = context.with_metadata("origin".to_string(), origin.clone());
279 }
280
281 context
282 }
283
284 fn validate_https(&self, request: &TransportRequest) -> Result<(), TransportAuthError> {
286 if !self.config.require_https {
287 return Ok(());
288 }
289
290 let is_https = request
292 .get_header("X-Forwarded-Proto")
293 .map(|proto| proto == "https")
294 .or_else(|| {
295 request
296 .get_header("X-Scheme")
297 .map(|scheme| scheme == "https")
298 })
299 .or_else(|| request.metadata.get("is_https").map(|_| true))
300 .unwrap_or(false);
301
302 if !is_https {
303 return Err(TransportAuthError::AuthFailed(
304 "HTTPS required for authentication".to_string(),
305 ));
306 }
307
308 Ok(())
309 }
310}
311
312#[async_trait]
313impl AuthExtractor for HttpAuthExtractor {
314 async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
315 self.validate_https(request)?;
317
318 match self.extract_authorization_header(&request.headers) {
322 Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
323 Ok(None) => {} Err(e) => return Err(e), }
326
327 match self.extract_api_key_header(&request.headers) {
329 Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
330 Ok(None) => {} Err(e) => return Err(e), }
333
334 match self.extract_custom_headers(&request.headers) {
336 Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
337 Ok(None) => {} Err(e) => return Err(e), }
340
341 match self.extract_query_auth(request) {
343 Ok(Some(context)) => return Ok(Some(self.enrich_context(context, request))),
344 Ok(None) => {} Err(e) => return Err(e), }
347
348 Ok(None)
350 }
351
352 fn transport_type(&self) -> TransportType {
353 TransportType::Http
354 }
355
356 fn can_handle(&self, request: &TransportRequest) -> bool {
357 !request.headers.is_empty()
359 }
360
361 async fn validate_auth(
362 &self,
363 context: &TransportAuthContext,
364 ) -> Result<(), TransportAuthError> {
365 if context.credential.is_empty() {
367 return Err(TransportAuthError::InvalidFormat(
368 "Empty credential".to_string(),
369 ));
370 }
371
372 Ok(())
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::collections::HashMap;
380
381 #[test]
382 fn test_bearer_token_extraction() {
383 let extractor = HttpAuthExtractor::default();
384 let mut headers = HashMap::new();
385 headers.insert(
386 "Authorization".to_string(),
387 "Bearer lmcp_test_1234567890abcdef".to_string(),
388 );
389
390 let request = TransportRequest::from_headers(headers);
391 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
392
393 assert!(result.is_some());
394 let context = result.unwrap();
395 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
396 assert_eq!(context.method, "Bearer");
397 assert_eq!(context.transport_type, TransportType::Http);
398 }
399
400 #[test]
401 fn test_api_key_header_extraction() {
402 let extractor = HttpAuthExtractor::default();
403 let mut headers = HashMap::new();
404 headers.insert(
405 "X-API-Key".to_string(),
406 "lmcp_test_1234567890abcdef".to_string(),
407 );
408
409 let request = TransportRequest::from_headers(headers);
410 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
411
412 assert!(result.is_some());
413 let context = result.unwrap();
414 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
415 assert_eq!(context.method, "X-API-Key");
416 }
417
418 #[test]
419 fn test_basic_auth_extraction() {
420 let extractor = HttpAuthExtractor::new(HttpAuthConfig {
421 supported_methods: vec![HttpAuthMethod::Basic],
422 ..Default::default()
423 });
424
425 let api_key = "lmcp_test_1234567890abcdef";
426 use base64::{Engine as _, engine::general_purpose};
427 let encoded = general_purpose::STANDARD.encode(format!("{}:", api_key));
428 let mut headers = HashMap::new();
429 headers.insert("Authorization".to_string(), format!("Basic {}", encoded));
430
431 let request = TransportRequest::from_headers(headers);
432 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
433
434 assert!(result.is_some());
435 let context = result.unwrap();
436 assert_eq!(context.credential, api_key);
437 assert_eq!(context.method, "Basic");
438 }
439
440 #[test]
441 fn test_query_parameter_extraction() {
442 let extractor = HttpAuthExtractor::new(HttpAuthConfig {
443 allow_query_auth: true,
444 supported_methods: vec![HttpAuthMethod::ApiKeyQuery],
445 ..Default::default()
446 });
447
448 let request = TransportRequest::new().with_query_param(
449 "api_key".to_string(),
450 "lmcp_test_1234567890abcdef".to_string(),
451 );
452
453 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
454
455 assert!(result.is_some());
456 let context = result.unwrap();
457 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
458 assert_eq!(context.method, "Query");
459 }
460
461 #[test]
462 fn test_no_authentication() {
463 let extractor = HttpAuthExtractor::default();
464 let request = TransportRequest::new();
465
466 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
467 assert!(result.is_none());
468 }
469
470 #[test]
471 fn test_invalid_api_key_format() {
472 let extractor = HttpAuthExtractor::default();
473 let mut headers = HashMap::new();
474 headers.insert("X-API-Key".to_string(), "short".to_string()); let request = TransportRequest::from_headers(headers);
477 let result = tokio_test::block_on(extractor.extract_auth(&request));
478
479 assert!(result.is_err());
481 if let Err(e) = result {
482 assert!(e.to_string().contains("too short"));
483 }
484 }
485
486 #[test]
487 fn test_context_enrichment() {
488 let extractor = HttpAuthExtractor::default();
489 let mut headers = HashMap::new();
490 headers.insert(
491 "X-API-Key".to_string(),
492 "lmcp_test_1234567890abcdef".to_string(),
493 );
494 headers.insert("X-Forwarded-For".to_string(), "192.168.1.100".to_string());
495 headers.insert("User-Agent".to_string(), "TestClient/1.0".to_string());
496 headers.insert("Host".to_string(), "api.example.com".to_string());
497
498 let request = TransportRequest::from_headers(headers);
499 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
500
501 assert!(result.is_some());
502 let context = result.unwrap();
503 assert_eq!(context.client_ip.unwrap(), "192.168.1.100");
504 assert_eq!(context.user_agent.unwrap(), "TestClient/1.0");
505 assert_eq!(context.metadata.get("host").unwrap(), "api.example.com");
506 }
507}