use async_trait::async_trait;
use axum::body::Body;
use axum::http::{HeaderMap, Request};
use std::sync::Arc;
use super::error::{AuthError, Result};
use super::principal::{AuthMethod, Principal, PrincipalBuilder, PrincipalType};
use super::store::{extract_key_prefix, verify_api_key, ApiKey, ApiKeyStore};
pub const API_KEY_HEADER: &str = "X-API-Key";
pub const AUTHORIZATION_HEADER: &str = "Authorization";
const DUMMY_HASH_FOR_TIMING: &str =
"$argon2id$v=19$m=19456,t=2,p=1$YTJiM2M0ZDVlNmY3ZzhoOQ$0X9ULfbvJjTfCNxvkXqWJ9Y7Pz8eS6fQrKhW4mN3dA0";
#[async_trait]
pub trait Authenticator: Send + Sync {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal>;
fn auth_method(&self) -> AuthMethod;
}
#[allow(dead_code)]
pub trait RequestAuthExt {
fn headers_for_auth(&self) -> &HeaderMap;
}
impl RequestAuthExt for Request<Body> {
fn headers_for_auth(&self) -> &HeaderMap {
self.headers()
}
}
pub struct AllowAllAuthenticator;
#[async_trait]
impl Authenticator for AllowAllAuthenticator {
async fn authenticate(&self, _headers: &HeaderMap) -> Result<Principal> {
Ok(Principal::anonymous())
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::None
}
}
pub struct DenyAllAuthenticator;
#[async_trait]
impl Authenticator for DenyAllAuthenticator {
async fn authenticate(&self, _headers: &HeaderMap) -> Result<Principal> {
Err(AuthError::Unauthenticated)
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::None
}
}
pub struct ApiKeyAuthenticator {
store: Arc<dyn ApiKeyStore>,
}
const MAX_API_KEY_LENGTH: usize = 64;
const MIN_API_KEY_LENGTH: usize = 10;
const API_KEY_PREFIX: &str = "rb_";
impl ApiKeyAuthenticator {
pub fn new(store: Arc<dyn ApiKeyStore>) -> Self {
Self { store }
}
fn validate_key_format(key: &str) -> std::result::Result<(), &'static str> {
if key.len() > MAX_API_KEY_LENGTH {
return Err("API key too long");
}
if key.len() < MIN_API_KEY_LENGTH {
return Err("API key too short");
}
if !key.starts_with(API_KEY_PREFIX) {
return Err("Invalid API key format");
}
let key_material = &key[API_KEY_PREFIX.len()..];
if !key_material
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err("API key contains invalid characters");
}
Ok(())
}
fn extract_key(headers: &HeaderMap) -> Option<String> {
headers
.get(API_KEY_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn key_to_principal(api_key: &ApiKey) -> Principal {
let mut builder = PrincipalBuilder::new(
api_key.id.to_string(),
api_key.name.clone(),
PrincipalType::ApiKey,
api_key.tenant_id.clone(),
AuthMethod::ApiKey,
);
for role in &api_key.roles {
builder = builder.with_role(role.clone());
}
if let Some(expires_at) = api_key.expires_at {
builder = builder.expires_at(expires_at);
}
builder.build()
}
}
#[async_trait]
impl Authenticator for ApiKeyAuthenticator {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal> {
let raw_key = Self::extract_key(headers).ok_or(AuthError::Unauthenticated)?;
if raw_key.is_empty() {
return Err(AuthError::InvalidCredentials("Empty API key".into()));
}
if let Err(reason) = Self::validate_key_format(&raw_key) {
return Err(AuthError::InvalidCredentials(reason.into()));
}
let key_prefix = extract_key_prefix(&raw_key)
.ok_or_else(|| AuthError::InvalidCredentials("Invalid key format".into()))?;
let candidates = self.store.get_by_prefix(&key_prefix).await;
let api_key = if candidates.is_empty() {
let _ = verify_api_key(&raw_key, DUMMY_HASH_FOR_TIMING);
return Err(AuthError::ApiKeyNotFound);
} else {
candidates
.into_iter()
.find(|k| verify_api_key(&raw_key, &k.key_hash))
.ok_or(AuthError::ApiKeyNotFound)?
};
if !api_key.enabled {
return Err(AuthError::ApiKeyDisabled);
}
if api_key.is_expired() {
return Err(AuthError::TokenExpired);
}
let _ = self.store.record_usage(&api_key.id).await;
Ok(Self::key_to_principal(&api_key))
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::ApiKey
}
}
pub struct ChainAuthenticator {
authenticators: Vec<Arc<dyn Authenticator>>,
}
impl ChainAuthenticator {
pub fn new(authenticators: Vec<Arc<dyn Authenticator>>) -> Self {
Self { authenticators }
}
pub fn with(mut self, authenticator: Arc<dyn Authenticator>) -> Self {
self.authenticators.push(authenticator);
self
}
}
#[async_trait]
impl Authenticator for ChainAuthenticator {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal> {
let mut last_error = AuthError::Unauthenticated;
for auth in &self.authenticators {
match auth.authenticate(headers).await {
Ok(principal) => return Ok(principal),
Err(AuthError::Unauthenticated) => continue,
Err(e) => {
last_error = e;
continue;
}
}
}
Err(last_error)
}
fn auth_method(&self) -> AuthMethod {
self.authenticators
.first()
.map(|a| a.auth_method())
.unwrap_or(AuthMethod::None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_allow_all_authenticator() {
let auth = AllowAllAuthenticator;
let headers = HeaderMap::new();
let result = auth.authenticate(&headers).await;
assert!(result.is_ok());
assert!(result.unwrap().is_anonymous());
}
#[tokio::test]
async fn test_deny_all_authenticator() {
let auth = DenyAllAuthenticator;
let headers = HeaderMap::new();
let result = auth.authenticate(&headers).await;
assert!(result.is_err());
}
#[test]
fn test_api_key_hashing_argon2() {
use super::super::store::{hash_api_key, verify_api_key};
let key = "rb_test-api-key-12345";
let hash1 = hash_api_key(key);
let hash2 = hash_api_key(key);
assert_ne!(hash1, hash2);
assert!(verify_api_key(key, &hash1));
assert!(verify_api_key(key, &hash2));
assert!(hash1.starts_with("$argon2id$"));
}
#[test]
fn test_api_key_verification() {
use super::super::store::{hash_api_key, verify_api_key};
let key = "rb_correct-key-12345";
let hash = hash_api_key(key);
assert!(verify_api_key(key, &hash));
assert!(!verify_api_key("rb_wrong-key-54321", &hash));
}
#[test]
fn test_validate_key_format_valid() {
assert!(ApiKeyAuthenticator::validate_key_format("rb_abcdefghij").is_ok());
assert!(ApiKeyAuthenticator::validate_key_format("rb_ABC123xyz-_").is_ok());
assert!(ApiKeyAuthenticator::validate_key_format(
"rb_0123456789abcdefghijklmnopqrstuvwxyz"
)
.is_ok());
}
#[test]
fn test_validate_key_format_too_short() {
assert!(ApiKeyAuthenticator::validate_key_format("rb_").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("rb_abc").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("short").is_err());
}
#[test]
fn test_validate_key_format_too_long() {
let long_key = format!("rb_{}", "a".repeat(100));
assert!(ApiKeyAuthenticator::validate_key_format(&long_key).is_err());
}
#[test]
fn test_validate_key_format_wrong_prefix() {
assert!(ApiKeyAuthenticator::validate_key_format("sk_abcdefghij").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("api_abcdefghij").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("abcdefghijklmnop").is_err());
}
#[test]
fn test_validate_key_format_invalid_chars() {
assert!(ApiKeyAuthenticator::validate_key_format("rb_abc def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc@def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc!def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc+def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc/def").is_err());
}
}