use axum::http::HeaderMap;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub enum AuthError {
Missing,
Invalid,
}
#[derive(Debug, Clone, Default)]
pub struct Identity {
pub subject: String,
pub claims: HashMap<String, serde_json::Value>,
}
#[async_trait::async_trait]
pub trait AuthValidator: Send + Sync {
async fn validate(&self, headers: &HeaderMap) -> Result<Option<Identity>, AuthError>;
fn challenge(&self) -> String {
"Bearer".to_string()
}
}
pub struct NoAuth;
#[async_trait::async_trait]
impl AuthValidator for NoAuth {
async fn validate(&self, _: &HeaderMap) -> Result<Option<Identity>, AuthError> {
Ok(None)
}
}
pub struct BearerKeyAuth {
allowed: HashSet<String>,
}
impl BearerKeyAuth {
pub fn new<I, S>(keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
allowed: keys.into_iter().map(Into::into).collect(),
}
}
pub fn single(key: impl Into<String>) -> Self {
Self::new(std::iter::once(key.into()))
}
}
#[async_trait::async_trait]
impl AuthValidator for BearerKeyAuth {
async fn validate(&self, headers: &HeaderMap) -> Result<Option<Identity>, AuthError> {
let header = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or(AuthError::Missing)?;
let key = header
.strip_prefix("Bearer ")
.or_else(|| header.strip_prefix("bearer "))
.ok_or(AuthError::Invalid)?;
if self.allowed.contains(key) {
Ok(None)
} else {
Err(AuthError::Invalid)
}
}
}
pub struct ApiKeyHeaderAuth {
header_name: String,
allowed: HashSet<String>,
}
impl ApiKeyHeaderAuth {
pub fn new<I, S>(header_name: impl Into<String>, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
header_name: header_name.into(),
allowed: keys.into_iter().map(Into::into).collect(),
}
}
}
#[async_trait::async_trait]
impl AuthValidator for ApiKeyHeaderAuth {
async fn validate(&self, headers: &HeaderMap) -> Result<Option<Identity>, AuthError> {
let key = headers
.get(self.header_name.as_str())
.and_then(|v| v.to_str().ok())
.ok_or(AuthError::Missing)?;
if self.allowed.contains(key) {
Ok(None)
} else {
Err(AuthError::Invalid)
}
}
fn challenge(&self) -> String {
format!("ApiKey realm=\"{}\"", self.header_name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
fn hdrs(pairs: &[(&str, &str)]) -> HeaderMap {
use axum::http::HeaderName;
let mut h = HeaderMap::new();
for (k, v) in pairs {
let name = HeaderName::from_bytes(k.as_bytes()).unwrap();
h.insert(name, HeaderValue::from_str(v).unwrap());
}
h
}
#[tokio::test]
async fn no_auth_passes_everything() {
let v = NoAuth;
assert!(matches!(v.validate(&hdrs(&[])).await, Ok(None)));
}
#[tokio::test]
async fn bearer_accepts_known_key() {
let v = BearerKeyAuth::single("alpha");
assert!(matches!(
v.validate(&hdrs(&[("authorization", "Bearer alpha")]))
.await,
Ok(None)
));
}
#[tokio::test]
async fn bearer_rejects_unknown_and_missing() {
let v = BearerKeyAuth::single("alpha");
assert!(matches!(
v.validate(&hdrs(&[("authorization", "Bearer beta")])).await,
Err(AuthError::Invalid)
));
assert!(matches!(
v.validate(&hdrs(&[])).await,
Err(AuthError::Missing)
));
}
#[tokio::test]
async fn api_key_header_accepts_known_key() {
let v = ApiKeyHeaderAuth::new("X-API-Key", ["secret"]);
assert!(matches!(
v.validate(&hdrs(&[("x-api-key", "secret")])).await,
Ok(None)
));
}
#[tokio::test]
async fn custom_validator_can_surface_identity() {
struct TokenTableAuth(HashMap<&'static str, &'static str>);
#[async_trait::async_trait]
impl AuthValidator for TokenTableAuth {
async fn validate(&self, headers: &HeaderMap) -> Result<Option<Identity>, AuthError> {
let token = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|h| {
h.strip_prefix("Bearer ")
.or_else(|| h.strip_prefix("bearer "))
})
.ok_or(AuthError::Missing)?;
let subject = self.0.get(token).copied().ok_or(AuthError::Invalid)?;
Ok(Some(Identity {
subject: subject.into(),
claims: HashMap::new(),
}))
}
}
let mut table = HashMap::new();
table.insert("tok-1", "alice@example.com");
let v = TokenTableAuth(table);
let id = v
.validate(&hdrs(&[("authorization", "Bearer tok-1")]))
.await
.expect("accepted")
.expect("identity returned");
assert_eq!(id.subject, "alice@example.com");
}
}