use async_trait::async_trait;
use axum::http::HeaderMap;
use std::sync::Arc;
use std::time::Duration;
use super::error::{AuthError, Result};
use super::principal::{AuthMethod, Principal, PrincipalBuilder, PrincipalType};
use super::store::{extract_key_prefix, verify_api_key, ApiKey, ApiKeyStore};
pub const API_KEY_HEADER: &str = "X-API-Key";
pub const AUTHORIZATION_HEADER: &str = "Authorization";
const DUMMY_HASH_FOR_TIMING: &str =
"$argon2id$v=19$m=19456,t=2,p=1$YTJiM2M0ZDVlNmY3ZzhoOQ$0X9ULfbvJjTfCNxvkXqWJ9Y7Pz8eS6fQrKhW4mN3dA0";
#[async_trait]
pub trait Authenticator: Send + Sync {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal>;
fn auth_method(&self) -> AuthMethod;
}
pub struct AllowAllAuthenticator;
#[async_trait]
impl Authenticator for AllowAllAuthenticator {
async fn authenticate(&self, _headers: &HeaderMap) -> Result<Principal> {
Ok(Principal::anonymous())
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::None
}
}
pub struct DenyAllAuthenticator;
#[async_trait]
impl Authenticator for DenyAllAuthenticator {
async fn authenticate(&self, _headers: &HeaderMap) -> Result<Principal> {
Err(AuthError::Unauthenticated)
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::None
}
}
pub struct ApiKeyAuthenticator {
store: Arc<dyn ApiKeyStore>,
}
const MAX_API_KEY_LENGTH: usize = 64;
const MIN_API_KEY_LENGTH: usize = 10;
const API_KEY_PREFIX: &str = "rb_";
impl ApiKeyAuthenticator {
pub fn new(store: Arc<dyn ApiKeyStore>) -> Self {
Self { store }
}
fn validate_key_format(key: &str) -> std::result::Result<(), &'static str> {
if key.len() > MAX_API_KEY_LENGTH {
return Err("API key too long");
}
if key.len() < MIN_API_KEY_LENGTH {
return Err("API key too short");
}
if !key.starts_with(API_KEY_PREFIX) {
return Err("Invalid API key format");
}
let key_material = &key[API_KEY_PREFIX.len()..];
if !key_material
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err("API key contains invalid characters");
}
Ok(())
}
fn extract_key(headers: &HeaderMap) -> Option<String> {
headers
.get(API_KEY_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn key_to_principal(api_key: &ApiKey) -> Principal {
let mut builder = PrincipalBuilder::new(
api_key.id.to_string(),
api_key.name.clone(),
PrincipalType::ApiKey,
api_key.tenant_id.clone(),
AuthMethod::ApiKey,
);
for role in &api_key.roles {
builder = builder.with_role(role.clone());
}
if let Some(expires_at) = api_key.expires_at {
builder = builder.expires_at(expires_at);
}
builder.build()
}
}
#[async_trait]
impl Authenticator for ApiKeyAuthenticator {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal> {
let raw_key = Self::extract_key(headers).ok_or(AuthError::Unauthenticated)?;
if raw_key.is_empty() {
return Err(AuthError::InvalidCredentials("Empty API key".into()));
}
if let Err(reason) = Self::validate_key_format(&raw_key) {
return Err(AuthError::InvalidCredentials(reason.into()));
}
let key_prefix = extract_key_prefix(&raw_key)
.ok_or_else(|| AuthError::InvalidCredentials("Invalid key format".into()))?;
let candidates = self.store.get_by_prefix(&key_prefix).await;
let api_key = if candidates.is_empty() {
let _ = verify_api_key(&raw_key, DUMMY_HASH_FOR_TIMING);
return Err(AuthError::ApiKeyNotFound);
} else {
candidates
.into_iter()
.find(|k| verify_api_key(&raw_key, &k.key_hash))
.ok_or(AuthError::ApiKeyNotFound)?
};
if !api_key.enabled {
return Err(AuthError::ApiKeyDisabled);
}
if api_key.is_expired() {
return Err(AuthError::TokenExpired);
}
let _ = self.store.record_usage(&api_key.id).await;
Ok(Self::key_to_principal(&api_key))
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::ApiKey
}
}
#[derive(Debug, Clone)]
pub struct ApiKeyCache {
pub max_capacity: u64,
pub ttl: Duration,
}
impl Default for ApiKeyCache {
fn default() -> Self {
Self {
max_capacity: 10_000,
ttl: Duration::from_secs(300), }
}
}
impl ApiKeyCache {
pub fn new(max_capacity: u64, ttl: Duration) -> Self {
Self { max_capacity, ttl }
}
}
pub struct CachedApiKeyAuthenticator {
inner: ApiKeyAuthenticator,
cache: moka::future::Cache<[u8; 32], CachedAuthResult>,
store: Arc<dyn ApiKeyStore>,
}
#[derive(Clone)]
struct CachedAuthResult {
principal: Principal,
key_id: String,
enabled: bool,
expires_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl CachedApiKeyAuthenticator {
pub fn new(store: Arc<dyn ApiKeyStore>, config: ApiKeyCache) -> Self {
let cache = moka::future::Cache::builder()
.max_capacity(config.max_capacity)
.time_to_live(config.ttl)
.build();
Self {
inner: ApiKeyAuthenticator::new(store.clone()),
cache,
store,
}
}
pub fn with_defaults(store: Arc<dyn ApiKeyStore>) -> Self {
Self::new(store, ApiKeyCache::default())
}
fn cache_key(raw_key: &str) -> [u8; 32] {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(raw_key.as_bytes());
hasher.finalize().into()
}
pub async fn invalidate_key(&self, _key_id: &str) {
tracing::debug!(
key_id = _key_id,
"API key cache invalidation requested (TTL-based)"
);
}
pub fn clear_cache(&self) {
self.cache.invalidate_all();
}
pub fn cache_stats(&self) -> (u64, u64) {
(self.cache.entry_count(), self.cache.weighted_size())
}
}
#[async_trait]
impl Authenticator for CachedApiKeyAuthenticator {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal> {
let raw_key =
ApiKeyAuthenticator::extract_key(headers).ok_or(AuthError::Unauthenticated)?;
if raw_key.is_empty() {
return Err(AuthError::InvalidCredentials("Empty API key".into()));
}
if let Err(reason) = ApiKeyAuthenticator::validate_key_format(&raw_key) {
return Err(AuthError::InvalidCredentials(reason.into()));
}
let cache_key = Self::cache_key(&raw_key);
if let Some(cached) = self.cache.get(&cache_key).await {
if let Some(expires_at) = cached.expires_at {
if chrono::Utc::now() >= expires_at {
self.cache.invalidate(&cache_key).await;
return Err(AuthError::TokenExpired);
}
}
if !cached.enabled {
self.cache.invalidate(&cache_key).await;
return Err(AuthError::ApiKeyDisabled);
}
tracing::trace!(key_id = %cached.key_id, "API key cache hit");
return Ok(cached.principal);
}
let principal = self.inner.authenticate(headers).await?;
let key_prefix = extract_key_prefix(&raw_key)
.ok_or_else(|| AuthError::InvalidCredentials("Invalid key format".into()))?;
let candidates = self.store.get_by_prefix(&key_prefix).await;
if let Some(api_key) = candidates
.into_iter()
.find(|k| verify_api_key(&raw_key, &k.key_hash))
{
let cached_result = CachedAuthResult {
principal: principal.clone(),
key_id: api_key.id.to_string(),
enabled: api_key.enabled,
expires_at: api_key.expires_at,
};
self.cache.insert(cache_key, cached_result).await;
tracing::trace!(key_id = %api_key.id, "API key cached after verification");
}
Ok(principal)
}
fn auth_method(&self) -> AuthMethod {
AuthMethod::ApiKey
}
}
pub struct ChainAuthenticator {
authenticators: Vec<Arc<dyn Authenticator>>,
}
impl ChainAuthenticator {
pub fn new(authenticators: Vec<Arc<dyn Authenticator>>) -> Self {
Self { authenticators }
}
pub fn with(mut self, authenticator: Arc<dyn Authenticator>) -> Self {
self.authenticators.push(authenticator);
self
}
}
#[async_trait]
impl Authenticator for ChainAuthenticator {
async fn authenticate(&self, headers: &HeaderMap) -> Result<Principal> {
let mut last_error = AuthError::Unauthenticated;
for auth in &self.authenticators {
match auth.authenticate(headers).await {
Ok(principal) => return Ok(principal),
Err(AuthError::Unauthenticated) => continue,
Err(e) => {
last_error = e;
continue;
}
}
}
Err(last_error)
}
fn auth_method(&self) -> AuthMethod {
self.authenticators
.first()
.map(|a| a.auth_method())
.unwrap_or(AuthMethod::None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_allow_all_authenticator() {
let auth = AllowAllAuthenticator;
let headers = HeaderMap::new();
let result = auth.authenticate(&headers).await;
assert!(result.is_ok());
assert!(result.unwrap().is_anonymous());
}
#[tokio::test]
async fn test_deny_all_authenticator() {
let auth = DenyAllAuthenticator;
let headers = HeaderMap::new();
let result = auth.authenticate(&headers).await;
assert!(result.is_err());
}
#[test]
fn test_api_key_hashing_argon2() {
use super::super::store::{hash_api_key, verify_api_key};
let key = "rb_test-api-key-12345";
let hash1 = hash_api_key(key);
let hash2 = hash_api_key(key);
assert_ne!(hash1, hash2);
assert!(verify_api_key(key, &hash1));
assert!(verify_api_key(key, &hash2));
assert!(hash1.starts_with("$argon2id$"));
}
#[test]
fn test_api_key_verification() {
use super::super::store::{hash_api_key, verify_api_key};
let key = "rb_correct-key-12345";
let hash = hash_api_key(key);
assert!(verify_api_key(key, &hash));
assert!(!verify_api_key("rb_wrong-key-54321", &hash));
}
#[test]
fn test_validate_key_format_valid() {
assert!(ApiKeyAuthenticator::validate_key_format("rb_abcdefghij").is_ok());
assert!(ApiKeyAuthenticator::validate_key_format("rb_ABC123xyz-_").is_ok());
assert!(ApiKeyAuthenticator::validate_key_format(
"rb_0123456789abcdefghijklmnopqrstuvwxyz"
)
.is_ok());
}
#[test]
fn test_validate_key_format_too_short() {
assert!(ApiKeyAuthenticator::validate_key_format("rb_").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("rb_abc").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("short").is_err());
}
#[test]
fn test_validate_key_format_too_long() {
let long_key = format!("rb_{}", "a".repeat(100));
assert!(ApiKeyAuthenticator::validate_key_format(&long_key).is_err());
}
#[test]
fn test_validate_key_format_wrong_prefix() {
assert!(ApiKeyAuthenticator::validate_key_format("sk_abcdefghij").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("api_abcdefghij").is_err());
assert!(ApiKeyAuthenticator::validate_key_format("abcdefghijklmnop").is_err());
}
#[test]
fn test_validate_key_format_invalid_chars() {
assert!(ApiKeyAuthenticator::validate_key_format("rb_abc def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc@def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc!def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc+def").is_err()); assert!(ApiKeyAuthenticator::validate_key_format("rb_abc/def").is_err());
}
#[tokio::test]
async fn test_cached_authenticator_cache_hit() {
use super::super::store::{ApiKeyBuilder, InMemoryApiKeyStore};
let store = Arc::new(InMemoryApiKeyStore::new());
let (api_key, raw_key) = ApiKeyBuilder::new("test-key", "default").build();
store.store(api_key).await.unwrap();
let config = ApiKeyCache {
max_capacity: 100,
ttl: Duration::from_secs(60),
};
let auth = CachedApiKeyAuthenticator::new(store, config);
let mut headers = HeaderMap::new();
headers.insert(API_KEY_HEADER, raw_key.parse().unwrap());
let result1 = auth.authenticate(&headers).await;
assert!(result1.is_ok());
auth.cache.run_pending_tasks().await;
let (count_before, _) = auth.cache_stats();
assert_eq!(
count_before, 1,
"Cache should be populated after first request"
);
let result2 = auth.authenticate(&headers).await;
assert!(result2.is_ok());
let (count, _) = auth.cache_stats();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_cached_authenticator_invalid_key() {
use super::super::store::InMemoryApiKeyStore;
let store = Arc::new(InMemoryApiKeyStore::new());
let auth = CachedApiKeyAuthenticator::with_defaults(store);
let mut headers = HeaderMap::new();
headers.insert(API_KEY_HEADER, "rb_invalid-key-12345".parse().unwrap());
let result = auth.authenticate(&headers).await;
assert!(result.is_err());
let (count, _) = auth.cache_stats();
assert_eq!(count, 0);
}
#[test]
fn test_cache_key_deterministic() {
let key1 = CachedApiKeyAuthenticator::cache_key("rb_test-key-12345");
let key2 = CachedApiKeyAuthenticator::cache_key("rb_test-key-12345");
let key3 = CachedApiKeyAuthenticator::cache_key("rb_different-key-67890");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[tokio::test]
async fn test_cached_authenticator_clear_cache() {
use super::super::store::{ApiKeyBuilder, InMemoryApiKeyStore};
let store = Arc::new(InMemoryApiKeyStore::new());
let (api_key, raw_key) = ApiKeyBuilder::new("test-key", "default").build();
store.store(api_key).await.unwrap();
let auth = CachedApiKeyAuthenticator::with_defaults(store);
let mut headers = HeaderMap::new();
headers.insert(API_KEY_HEADER, raw_key.parse().unwrap());
let _ = auth.authenticate(&headers).await;
auth.cache.run_pending_tasks().await;
let (count, _) = auth.cache_stats();
assert_eq!(count, 1);
auth.clear_cache();
auth.cache.run_pending_tasks().await;
let (count_after, _) = auth.cache_stats();
assert_eq!(count_after, 0);
}
}