1use std::collections::HashSet;
38use std::future::Future;
39use std::sync::Arc;
40
41use tower::Layer;
42
43#[derive(Debug, Clone)]
45pub enum AuthResult {
46 Authenticated(Option<AuthInfo>),
48 Failed(AuthError),
50}
51
52#[derive(Debug, Clone)]
54pub struct AuthInfo {
55 pub client_id: String,
57 pub claims: Option<serde_json::Value>,
59}
60
61#[derive(Debug, Clone)]
63pub struct AuthError {
64 pub code: String,
66 pub message: String,
68}
69
70impl std::fmt::Display for AuthError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(f, "{}: {}", self.code, self.message)
73 }
74}
75
76impl std::error::Error for AuthError {}
77
78pub trait Validate: Clone + Send + Sync + 'static {
116 fn validate(&self, credential: &str) -> impl Future<Output = AuthResult> + Send;
118}
119
120#[derive(Debug, Clone)]
131pub struct ApiKeyValidator {
132 valid_keys: Arc<HashSet<String>>,
133}
134
135impl ApiKeyValidator {
136 pub fn new(keys: impl IntoIterator<Item = String>) -> Self {
138 Self {
139 valid_keys: Arc::new(keys.into_iter().collect()),
140 }
141 }
142
143 pub fn add_key(&mut self, key: String) {
145 Arc::make_mut(&mut self.valid_keys).insert(key);
146 }
147
148 pub fn is_valid(&self, key: &str) -> bool {
150 self.valid_keys.contains(key)
151 }
152}
153
154impl Validate for ApiKeyValidator {
155 async fn validate(&self, key: &str) -> AuthResult {
156 if self.valid_keys.contains(key) {
157 AuthResult::Authenticated(Some(AuthInfo {
158 client_id: format!("api_key:{}", &key[..8.min(key.len())]),
159 claims: None,
160 }))
161 } else {
162 AuthResult::Failed(AuthError {
163 code: "invalid_api_key".to_string(),
164 message: "The provided API key is not valid".to_string(),
165 })
166 }
167 }
168}
169
170#[derive(Debug, Clone)]
181pub struct StaticBearerValidator {
182 valid_tokens: Arc<HashSet<String>>,
183}
184
185impl StaticBearerValidator {
186 pub fn new(tokens: impl IntoIterator<Item = String>) -> Self {
188 Self {
189 valid_tokens: Arc::new(tokens.into_iter().collect()),
190 }
191 }
192}
193
194impl Validate for StaticBearerValidator {
195 async fn validate(&self, token: &str) -> AuthResult {
196 if self.valid_tokens.contains(token) {
197 AuthResult::Authenticated(Some(AuthInfo {
198 client_id: format!("bearer:{}", &token[..8.min(token.len())]),
199 claims: None,
200 }))
201 } else {
202 AuthResult::Failed(AuthError {
203 code: "invalid_token".to_string(),
204 message: "The provided bearer token is not valid".to_string(),
205 })
206 }
207 }
208}
209
210pub fn extract_api_key(auth_header: &str) -> Option<&str> {
221 let auth_header = auth_header.trim();
222
223 if let Some(key) = auth_header.strip_prefix("Bearer ") {
224 Some(key.trim())
225 } else if let Some(key) = auth_header.strip_prefix("ApiKey ") {
226 Some(key.trim())
227 } else if !auth_header.contains(' ') {
228 Some(auth_header)
230 } else {
231 None
232 }
233}
234
235pub fn extract_bearer_token(auth_header: &str) -> Option<&str> {
237 auth_header.trim().strip_prefix("Bearer ").map(|t| t.trim())
238}
239
240#[derive(Clone)]
249pub struct AuthLayer<V> {
250 validator: V,
251 header_name: String,
252}
253
254impl<V> AuthLayer<V> {
255 pub fn new(validator: V) -> Self {
259 Self {
260 validator,
261 header_name: "Authorization".to_string(),
262 }
263 }
264
265 pub fn header_name(mut self, name: impl Into<String>) -> Self {
267 self.header_name = name.into();
268 self
269 }
270}
271
272impl<S, V: Clone> Layer<S> for AuthLayer<V> {
273 type Service = AuthService<S, V>;
274
275 fn layer(&self, inner: S) -> Self::Service {
276 AuthService {
277 inner,
278 validator: self.validator.clone(),
279 header_name: self.header_name.clone(),
280 }
281 }
282}
283
284#[derive(Clone)]
305#[cfg_attr(not(feature = "http"), allow(dead_code))]
306pub struct AuthService<S, V> {
307 inner: S,
308 validator: V,
309 header_name: String,
310}
311
312#[cfg(feature = "http")]
313impl<S, V> tower_service::Service<axum::http::Request<axum::body::Body>> for AuthService<S, V>
314where
315 S: tower_service::Service<
316 axum::http::Request<axum::body::Body>,
317 Response = axum::response::Response,
318 > + Clone
319 + Send
320 + 'static,
321 S::Future: Send,
322 S::Error: Into<crate::BoxError> + Send,
323 V: Validate,
324{
325 type Response = axum::response::Response;
326 type Error = S::Error;
327 type Future =
328 std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
329
330 fn poll_ready(
331 &mut self,
332 cx: &mut std::task::Context<'_>,
333 ) -> std::task::Poll<Result<(), Self::Error>> {
334 self.inner.poll_ready(cx)
335 }
336
337 fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
338 let credential = req
339 .headers()
340 .get(&self.header_name)
341 .and_then(|v| v.to_str().ok())
342 .and_then(extract_api_key)
343 .map(|s| s.to_owned());
344
345 let mut inner = self.inner.clone();
346 let validator = self.validator.clone();
347
348 Box::pin(async move {
349 let Some(credential) = credential else {
350 return Ok(unauthorized_response(
351 "Missing authentication credentials. Provide via Authorization header.",
352 ));
353 };
354
355 match validator.validate(&credential).await {
356 AuthResult::Authenticated(info) => {
357 let mut req = req;
358 if let Some(info) = info {
359 req.extensions_mut().insert(info);
360 }
361 inner.call(req).await
362 }
363 AuthResult::Failed(err) => Ok(unauthorized_response(&err.message)),
364 }
365 })
366 }
367}
368
369#[cfg(feature = "http")]
371fn unauthorized_response(message: &str) -> axum::response::Response {
372 use axum::http::StatusCode;
373 use axum::response::IntoResponse;
374
375 let body = serde_json::json!({
376 "jsonrpc": "2.0",
377 "error": {
378 "code": -32001,
379 "message": message
380 },
381 "id": null
382 });
383
384 (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
385}
386
387#[derive(Clone)]
393pub struct AuthConfig {
394 pub allow_anonymous: bool,
396 pub public_paths: Vec<String>,
398 pub header_name: String,
400}
401
402impl Default for AuthConfig {
403 fn default() -> Self {
404 Self {
405 allow_anonymous: false,
406 public_paths: Vec::new(),
407 header_name: "Authorization".to_string(),
408 }
409 }
410}
411
412impl AuthConfig {
413 pub fn new() -> Self {
415 Self::default()
416 }
417
418 pub fn allow_anonymous(mut self, allow: bool) -> Self {
420 self.allow_anonymous = allow;
421 self
422 }
423
424 pub fn public_path(mut self, path: impl Into<String>) -> Self {
426 self.public_paths.push(path.into());
427 self
428 }
429
430 pub fn header_name(mut self, name: impl Into<String>) -> Self {
432 self.header_name = name.into();
433 self
434 }
435
436 pub fn is_public(&self, path: &str) -> bool {
438 self.public_paths.iter().any(|p| path.starts_with(p))
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_extract_api_key_bearer() {
448 assert_eq!(extract_api_key("Bearer sk-123"), Some("sk-123"));
449 assert_eq!(extract_api_key("Bearer sk-123 "), Some("sk-123"));
450 }
451
452 #[test]
453 fn test_extract_api_key_apikey_prefix() {
454 assert_eq!(extract_api_key("ApiKey sk-123"), Some("sk-123"));
455 }
456
457 #[test]
458 fn test_extract_api_key_raw() {
459 assert_eq!(extract_api_key("sk-123"), Some("sk-123"));
460 }
461
462 #[test]
463 fn test_extract_api_key_invalid() {
464 assert_eq!(extract_api_key("Basic user:pass"), None);
465 }
466
467 #[test]
468 fn test_extract_bearer_token() {
469 assert_eq!(extract_bearer_token("Bearer abc123"), Some("abc123"));
470 assert_eq!(extract_bearer_token("bearer abc123"), None); assert_eq!(extract_bearer_token("abc123"), None);
472 }
473
474 #[tokio::test]
475 async fn test_api_key_validator() {
476 let validator = ApiKeyValidator::new(vec!["valid-key".to_string()]);
477
478 match validator.validate("valid-key").await {
479 AuthResult::Authenticated(info) => {
480 assert!(info.is_some());
481 }
482 AuthResult::Failed(_) => panic!("Expected authentication to succeed"),
483 }
484
485 match validator.validate("invalid-key").await {
486 AuthResult::Authenticated(_) => panic!("Expected authentication to fail"),
487 AuthResult::Failed(err) => {
488 assert_eq!(err.code, "invalid_api_key");
489 }
490 }
491 }
492
493 #[tokio::test]
494 async fn test_bearer_validator() {
495 let validator = StaticBearerValidator::new(vec!["token123".to_string()]);
496
497 match validator.validate("token123").await {
498 AuthResult::Authenticated(info) => {
499 assert!(info.is_some());
500 }
501 AuthResult::Failed(_) => panic!("Expected authentication to succeed"),
502 }
503
504 match validator.validate("bad-token").await {
505 AuthResult::Authenticated(_) => panic!("Expected authentication to fail"),
506 AuthResult::Failed(err) => {
507 assert_eq!(err.code, "invalid_token");
508 }
509 }
510 }
511
512 #[test]
513 fn test_auth_config() {
514 let config = AuthConfig::new()
515 .allow_anonymous(false)
516 .public_path("/health")
517 .public_path("/metrics")
518 .header_name("X-API-Key");
519
520 assert!(!config.allow_anonymous);
521 assert!(config.is_public("/health"));
522 assert!(config.is_public("/metrics/cpu"));
523 assert!(!config.is_public("/api/tools"));
524 assert_eq!(config.header_name, "X-API-Key");
525 }
526
527 #[test]
528 fn test_auth_layer_creates_service() {
529 let validator = ApiKeyValidator::new(vec!["key".to_string()]);
530 let layer = AuthLayer::new(validator);
531 let _service: AuthService<(), ApiKeyValidator> = layer.layer(());
533 }
534
535 #[cfg(feature = "http")]
536 mod http_tests {
537 use super::*;
538 use std::pin::Pin;
539 use std::task::{Context, Poll};
540
541 use axum::body::Body;
542 use axum::http::{Request, StatusCode};
543 use tower::ServiceExt;
544 use tower_service::Service;
545
546 #[derive(Clone)]
548 struct OkService;
549
550 impl Service<Request<Body>> for OkService {
551 type Response = axum::response::Response;
552 type Error = std::convert::Infallible;
553 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
554
555 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
556 Poll::Ready(Ok(()))
557 }
558
559 fn call(&mut self, _req: Request<Body>) -> Self::Future {
560 Box::pin(async {
561 Ok(axum::response::Response::builder()
562 .status(StatusCode::OK)
563 .body(Body::empty())
564 .unwrap())
565 })
566 }
567 }
568
569 #[tokio::test]
570 async fn test_auth_service_rejects_missing_credentials() {
571 let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
572 let layer = AuthLayer::new(validator);
573 let mut service = layer.layer(OkService);
574
575 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
576
577 let resp = service.ready().await.unwrap().call(req).await.unwrap();
578 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
579 }
580
581 #[tokio::test]
582 async fn test_auth_service_rejects_invalid_key() {
583 let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
584 let layer = AuthLayer::new(validator);
585 let mut service = layer.layer(OkService);
586
587 let req = Request::builder()
588 .uri("/")
589 .header("Authorization", "Bearer sk-wrong-key")
590 .body(Body::empty())
591 .unwrap();
592
593 let resp = service.ready().await.unwrap().call(req).await.unwrap();
594 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
595 }
596
597 #[tokio::test]
598 async fn test_auth_service_accepts_valid_key() {
599 let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
600 let layer = AuthLayer::new(validator);
601 let mut service = layer.layer(OkService);
602
603 let req = Request::builder()
604 .uri("/")
605 .header("Authorization", "Bearer sk-test-123")
606 .body(Body::empty())
607 .unwrap();
608
609 let resp = service.ready().await.unwrap().call(req).await.unwrap();
610 assert_eq!(resp.status(), StatusCode::OK);
611 }
612
613 #[tokio::test]
614 async fn test_auth_service_injects_auth_info() {
615 let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
616 let layer = AuthLayer::new(validator);
617
618 #[derive(Clone)]
620 struct CheckAuthInfo;
621
622 impl Service<Request<Body>> for CheckAuthInfo {
623 type Response = axum::response::Response;
624 type Error = std::convert::Infallible;
625 type Future =
626 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
627
628 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
629 Poll::Ready(Ok(()))
630 }
631
632 fn call(&mut self, req: Request<Body>) -> Self::Future {
633 let has_auth = req.extensions().get::<AuthInfo>().is_some();
634 Box::pin(async move {
635 let status = if has_auth {
636 StatusCode::OK
637 } else {
638 StatusCode::INTERNAL_SERVER_ERROR
639 };
640 Ok(axum::response::Response::builder()
641 .status(status)
642 .body(Body::empty())
643 .unwrap())
644 })
645 }
646 }
647
648 let mut service = layer.layer(CheckAuthInfo);
649
650 let req = Request::builder()
651 .uri("/")
652 .header("Authorization", "Bearer sk-test-123")
653 .body(Body::empty())
654 .unwrap();
655
656 let resp = service.ready().await.unwrap().call(req).await.unwrap();
657 assert_eq!(resp.status(), StatusCode::OK);
658 }
659
660 #[tokio::test]
661 async fn test_auth_service_custom_header() {
662 let validator = ApiKeyValidator::new(vec!["my-key".to_string()]);
663 let layer = AuthLayer::new(validator).header_name("X-API-Key");
664 let mut service = layer.layer(OkService);
665
666 let req = Request::builder()
668 .uri("/")
669 .header("Authorization", "Bearer my-key")
670 .body(Body::empty())
671 .unwrap();
672 let resp = service.ready().await.unwrap().call(req).await.unwrap();
673 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
674
675 let req = Request::builder()
677 .uri("/")
678 .header("X-API-Key", "my-key")
679 .body(Body::empty())
680 .unwrap();
681 let resp = service.ready().await.unwrap().call(req).await.unwrap();
682 assert_eq!(resp.status(), StatusCode::OK);
683 }
684 }
685}