use std::collections::HashSet;
use std::future::Future;
use std::sync::Arc;
use tower::Layer;
#[cfg(feature = "http")]
use tower::ServiceExt;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AuthResult {
Authenticated(Option<AuthInfo>),
Failed(AuthError),
}
#[derive(Debug, Clone)]
pub struct AuthInfo {
pub client_id: String,
pub claims: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct AuthError {
pub code: String,
pub message: String,
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.code, self.message)
}
}
impl std::error::Error for AuthError {}
pub trait Validate: Clone + Send + Sync + 'static {
fn validate(&self, credential: &str) -> impl Future<Output = AuthResult> + Send;
}
#[derive(Debug, Clone)]
pub struct ApiKeyValidator {
valid_keys: Arc<HashSet<String>>,
}
impl ApiKeyValidator {
pub fn new(keys: impl IntoIterator<Item = String>) -> Self {
Self {
valid_keys: Arc::new(keys.into_iter().collect()),
}
}
pub fn add_key(&mut self, key: String) {
Arc::make_mut(&mut self.valid_keys).insert(key);
}
pub fn is_valid(&self, key: &str) -> bool {
self.valid_keys.contains(key)
}
}
impl Validate for ApiKeyValidator {
async fn validate(&self, key: &str) -> AuthResult {
if self.valid_keys.contains(key) {
AuthResult::Authenticated(Some(AuthInfo {
client_id: format!("api_key:{}", &key[..8.min(key.len())]),
claims: None,
}))
} else {
AuthResult::Failed(AuthError {
code: "invalid_api_key".to_string(),
message: "The provided API key is not valid".to_string(),
})
}
}
}
#[derive(Debug, Clone)]
pub struct StaticBearerValidator {
valid_tokens: Arc<HashSet<String>>,
}
impl StaticBearerValidator {
pub fn new(tokens: impl IntoIterator<Item = String>) -> Self {
Self {
valid_tokens: Arc::new(tokens.into_iter().collect()),
}
}
}
impl Validate for StaticBearerValidator {
async fn validate(&self, token: &str) -> AuthResult {
if self.valid_tokens.contains(token) {
AuthResult::Authenticated(Some(AuthInfo {
client_id: format!("bearer:{}", &token[..8.min(token.len())]),
claims: None,
}))
} else {
AuthResult::Failed(AuthError {
code: "invalid_token".to_string(),
message: "The provided bearer token is not valid".to_string(),
})
}
}
}
pub fn extract_api_key(auth_header: &str) -> Option<&str> {
let auth_header = auth_header.trim();
if let Some(key) = auth_header.strip_prefix("Bearer ") {
Some(key.trim())
} else if let Some(key) = auth_header.strip_prefix("ApiKey ") {
Some(key.trim())
} else if !auth_header.contains(' ') {
Some(auth_header)
} else {
None
}
}
pub fn extract_bearer_token(auth_header: &str) -> Option<&str> {
auth_header.trim().strip_prefix("Bearer ").map(|t| t.trim())
}
#[derive(Clone)]
pub struct AuthLayer<V> {
validator: V,
header_name: String,
}
impl<V> AuthLayer<V> {
pub fn new(validator: V) -> Self {
Self {
validator,
header_name: "Authorization".to_string(),
}
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
}
impl<S, V: Clone> Layer<S> for AuthLayer<V> {
type Service = AuthService<S, V>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
validator: self.validator.clone(),
header_name: self.header_name.clone(),
}
}
}
#[derive(Clone)]
#[cfg_attr(not(feature = "http"), allow(dead_code))]
pub struct AuthService<S, V> {
inner: S,
validator: V,
header_name: String,
}
#[cfg(feature = "http")]
impl<S, V> tower_service::Service<axum::http::Request<axum::body::Body>> for AuthService<S, V>
where
S: tower_service::Service<
axum::http::Request<axum::body::Body>,
Response = axum::response::Response,
> + Clone
+ Send
+ 'static,
S::Future: Send,
S::Error: Into<crate::BoxError> + Send,
V: Validate,
{
type Response = axum::response::Response;
type Error = S::Error;
type Future =
std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
let credential = req
.headers()
.get(&self.header_name)
.and_then(|v| v.to_str().ok())
.and_then(extract_api_key)
.map(|s| s.to_owned());
let inner = self.inner.clone();
let validator = self.validator.clone();
Box::pin(async move {
let Some(credential) = credential else {
return Ok(unauthorized_response(
"Missing authentication credentials. Provide via Authorization header.",
));
};
match validator.validate(&credential).await {
AuthResult::Authenticated(info) => {
let mut req = req;
if let Some(info) = info {
req.extensions_mut().insert(info);
}
inner.oneshot(req).await
}
AuthResult::Failed(err) => Ok(unauthorized_response(&err.message)),
}
})
}
}
#[cfg(feature = "http")]
fn unauthorized_response(message: &str) -> axum::response::Response {
use axum::http::StatusCode;
use axum::response::IntoResponse;
let body = serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32001,
"message": message
},
"id": null
});
(StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
}
#[derive(Clone)]
pub struct AuthConfig {
pub allow_anonymous: bool,
pub public_paths: Vec<String>,
pub header_name: String,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
allow_anonymous: false,
public_paths: Vec::new(),
header_name: "Authorization".to_string(),
}
}
}
impl AuthConfig {
pub fn new() -> Self {
Self::default()
}
pub fn allow_anonymous(mut self, allow: bool) -> Self {
self.allow_anonymous = allow;
self
}
pub fn public_path(mut self, path: impl Into<String>) -> Self {
self.public_paths.push(path.into());
self
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
pub fn is_public(&self, path: &str) -> bool {
self.public_paths.iter().any(|p| path.starts_with(p))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_api_key_bearer() {
assert_eq!(extract_api_key("Bearer sk-123"), Some("sk-123"));
assert_eq!(extract_api_key("Bearer sk-123 "), Some("sk-123"));
}
#[test]
fn test_extract_api_key_apikey_prefix() {
assert_eq!(extract_api_key("ApiKey sk-123"), Some("sk-123"));
}
#[test]
fn test_extract_api_key_raw() {
assert_eq!(extract_api_key("sk-123"), Some("sk-123"));
}
#[test]
fn test_extract_api_key_invalid() {
assert_eq!(extract_api_key("Basic user:pass"), None);
}
#[test]
fn test_extract_bearer_token() {
assert_eq!(extract_bearer_token("Bearer abc123"), Some("abc123"));
assert_eq!(extract_bearer_token("bearer abc123"), None); assert_eq!(extract_bearer_token("abc123"), None);
}
#[tokio::test]
async fn test_api_key_validator() {
let validator = ApiKeyValidator::new(vec!["valid-key".to_string()]);
match validator.validate("valid-key").await {
AuthResult::Authenticated(info) => {
assert!(info.is_some());
}
AuthResult::Failed(_) => panic!("Expected authentication to succeed"),
}
match validator.validate("invalid-key").await {
AuthResult::Authenticated(_) => panic!("Expected authentication to fail"),
AuthResult::Failed(err) => {
assert_eq!(err.code, "invalid_api_key");
}
}
}
#[tokio::test]
async fn test_bearer_validator() {
let validator = StaticBearerValidator::new(vec!["token123".to_string()]);
match validator.validate("token123").await {
AuthResult::Authenticated(info) => {
assert!(info.is_some());
}
AuthResult::Failed(_) => panic!("Expected authentication to succeed"),
}
match validator.validate("bad-token").await {
AuthResult::Authenticated(_) => panic!("Expected authentication to fail"),
AuthResult::Failed(err) => {
assert_eq!(err.code, "invalid_token");
}
}
}
#[test]
fn test_auth_config() {
let config = AuthConfig::new()
.allow_anonymous(false)
.public_path("/health")
.public_path("/metrics")
.header_name("X-API-Key");
assert!(!config.allow_anonymous);
assert!(config.is_public("/health"));
assert!(config.is_public("/metrics/cpu"));
assert!(!config.is_public("/api/tools"));
assert_eq!(config.header_name, "X-API-Key");
}
#[test]
fn test_auth_layer_creates_service() {
let validator = ApiKeyValidator::new(vec!["key".to_string()]);
let layer = AuthLayer::new(validator);
let _service: AuthService<(), ApiKeyValidator> = layer.layer(());
}
#[cfg(feature = "http")]
mod http_tests {
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
use tower_service::Service;
#[derive(Clone)]
struct OkService;
impl Service<Request<Body>> for OkService {
type Response = axum::response::Response;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Body>) -> Self::Future {
Box::pin(async {
Ok(axum::response::Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap())
})
}
}
#[tokio::test]
async fn test_auth_service_rejects_missing_credentials() {
let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
let layer = AuthLayer::new(validator);
let mut service = layer.layer(OkService);
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_auth_service_rejects_invalid_key() {
let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
let layer = AuthLayer::new(validator);
let mut service = layer.layer(OkService);
let req = Request::builder()
.uri("/")
.header("Authorization", "Bearer sk-wrong-key")
.body(Body::empty())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_auth_service_accepts_valid_key() {
let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
let layer = AuthLayer::new(validator);
let mut service = layer.layer(OkService);
let req = Request::builder()
.uri("/")
.header("Authorization", "Bearer sk-test-123")
.body(Body::empty())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_service_injects_auth_info() {
let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
let layer = AuthLayer::new(validator);
#[derive(Clone)]
struct CheckAuthInfo;
impl Service<Request<Body>> for CheckAuthInfo {
type Response = axum::response::Response;
type Error = std::convert::Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let has_auth = req.extensions().get::<AuthInfo>().is_some();
Box::pin(async move {
let status = if has_auth {
StatusCode::OK
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
Ok(axum::response::Response::builder()
.status(status)
.body(Body::empty())
.unwrap())
})
}
}
let mut service = layer.layer(CheckAuthInfo);
let req = Request::builder()
.uri("/")
.header("Authorization", "Bearer sk-test-123")
.body(Body::empty())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_service_custom_header() {
let validator = ApiKeyValidator::new(vec!["my-key".to_string()]);
let layer = AuthLayer::new(validator).header_name("X-API-Key");
let mut service = layer.layer(OkService);
let req = Request::builder()
.uri("/")
.header("Authorization", "Bearer my-key")
.body(Body::empty())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let req = Request::builder()
.uri("/")
.header("X-API-Key", "my-key")
.body(Body::empty())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}
}