use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Duration, Utc};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiKeyError {
#[error("Invalid API key")]
Invalid,
#[error("API key expired")]
Expired,
#[error("API key revoked")]
Revoked,
#[error("Insufficient permissions: {0}")]
InsufficientPermissions(String),
#[error("Storage error: {0}")]
Storage(String),
#[error("Rate limit exceeded")]
RateLimitExceeded,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKey {
pub id: String,
pub key: String,
pub user_id: String,
pub name: Option<String>,
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub last_used_at: Option<DateTime<Utc>>,
pub revoked: bool,
pub rate_limit: Option<u32>,
}
#[async_trait]
pub trait ApiKeyStore: Send + Sync {
async fn save(&self, key: &ApiKey) -> Result<(), ApiKeyError>;
async fn find_by_key(&self, key: &str) -> Result<Option<ApiKey>, ApiKeyError>;
async fn find_by_id(&self, id: &str) -> Result<Option<ApiKey>, ApiKeyError>;
async fn list_by_user(&self, user_id: &str) -> Result<Vec<ApiKey>, ApiKeyError>;
async fn revoke(&self, key_id: &str) -> Result<(), ApiKeyError>;
async fn update_last_used(
&self,
key_id: &str,
timestamp: DateTime<Utc>,
) -> Result<(), ApiKeyError>;
}
pub struct ApiKeyManager {
store: Arc<dyn ApiKeyStore>,
key_prefix: String,
default_expiration: Option<Duration>,
}
impl ApiKeyManager {
pub fn new(store: Arc<dyn ApiKeyStore>) -> Self {
Self {
store,
key_prefix: "ak".to_string(),
default_expiration: Some(Duration::days(365)),
}
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.key_prefix = prefix.into();
self
}
pub fn with_expiration(mut self, duration: Option<Duration>) -> Self {
self.default_expiration = duration;
self
}
pub async fn generate(
&self,
user_id: impl Into<String>,
scopes: Vec<String>,
) -> Result<ApiKey, ApiKeyError> {
let key_id = uuid::Uuid::new_v4().to_string();
let raw_key = self.generate_random_key();
let key_string = format!("{}_{}", self.key_prefix, raw_key);
let expires_at = self.default_expiration.map(|d| Utc::now() + d);
let api_key = ApiKey {
id: key_id,
key: key_string.clone(),
user_id: user_id.into(),
name: None,
scopes,
created_at: Utc::now(),
expires_at,
last_used_at: None,
revoked: false,
rate_limit: None,
};
self.store.save(&api_key).await?;
Ok(api_key)
}
pub async fn validate(&self, key: &str) -> Result<Option<ApiKey>, ApiKeyError> {
let api_key = match self.store.find_by_key(key).await? {
Some(k) => k,
None => return Ok(None),
};
if api_key.revoked {
return Err(ApiKeyError::Revoked);
}
if let Some(expires_at) = api_key.expires_at
&& Utc::now() > expires_at
{
return Err(ApiKeyError::Expired);
}
self.store.update_last_used(&api_key.id, Utc::now()).await?;
Ok(Some(api_key))
}
pub fn has_scope(&self, api_key: &ApiKey, required_scope: &str) -> bool {
api_key
.scopes
.iter()
.any(|s| s == required_scope || s == "*")
}
pub async fn revoke(&self, key_id: &str) -> Result<(), ApiKeyError> {
self.store.revoke(key_id).await
}
pub async fn list_user_keys(&self, user_id: &str) -> Result<Vec<ApiKey>, ApiKeyError> {
self.store.list_by_user(user_id).await
}
pub async fn rotate(&self, old_key_id: &str) -> Result<ApiKey, ApiKeyError> {
let old_key = self
.store
.find_by_id(old_key_id)
.await?
.ok_or(ApiKeyError::Invalid)?;
self.revoke(old_key_id).await?;
self.generate(&old_key.user_id, old_key.scopes).await
}
fn generate_random_key(&self) -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..32).map(|_| rng.random()).collect();
URL_SAFE_NO_PAD.encode(bytes)
}
pub fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
hex::encode(hasher.finalize())
}
}
#[cfg(test)]
mod tests {
use super::*;
struct InMemoryStore {
keys: std::sync::Mutex<Vec<ApiKey>>,
}
impl InMemoryStore {
fn new() -> Self {
Self {
keys: std::sync::Mutex::new(Vec::new()),
}
}
}
#[async_trait]
impl ApiKeyStore for InMemoryStore {
async fn save(&self, key: &ApiKey) -> Result<(), ApiKeyError> {
let mut keys = self.keys.lock().unwrap();
keys.push(key.clone());
Ok(())
}
async fn find_by_key(&self, key: &str) -> Result<Option<ApiKey>, ApiKeyError> {
let keys = self.keys.lock().unwrap();
Ok(keys.iter().find(|k| k.key == key).cloned())
}
async fn find_by_id(&self, id: &str) -> Result<Option<ApiKey>, ApiKeyError> {
let keys = self.keys.lock().unwrap();
Ok(keys.iter().find(|k| k.id == id).cloned())
}
async fn list_by_user(&self, user_id: &str) -> Result<Vec<ApiKey>, ApiKeyError> {
let keys = self.keys.lock().unwrap();
Ok(keys
.iter()
.filter(|k| k.user_id == user_id)
.cloned()
.collect())
}
async fn revoke(&self, key_id: &str) -> Result<(), ApiKeyError> {
let mut keys = self.keys.lock().unwrap();
if let Some(key) = keys.iter_mut().find(|k| k.id == key_id) {
key.revoked = true;
}
Ok(())
}
async fn update_last_used(
&self,
key_id: &str,
timestamp: DateTime<Utc>,
) -> Result<(), ApiKeyError> {
let mut keys = self.keys.lock().unwrap();
if let Some(key) = keys.iter_mut().find(|k| k.id == key_id) {
key.last_used_at = Some(timestamp);
}
Ok(())
}
}
#[tokio::test]
async fn test_generate_api_key() {
let store = Arc::new(InMemoryStore::new());
let manager = ApiKeyManager::new(store);
let key = manager
.generate("user_123", vec!["read".to_string()])
.await
.unwrap();
assert!(key.key.starts_with("ak_"));
assert_eq!(key.user_id, "user_123");
assert_eq!(key.scopes, vec!["read"]);
}
#[tokio::test]
async fn test_validate_api_key() {
let store = Arc::new(InMemoryStore::new());
let manager = ApiKeyManager::new(store);
let key = manager
.generate("user_123", vec!["read".to_string()])
.await
.unwrap();
let validated = manager.validate(&key.key).await.unwrap();
assert!(validated.is_some());
assert_eq!(validated.unwrap().user_id, "user_123");
}
#[tokio::test]
async fn test_revoke_api_key() {
let store = Arc::new(InMemoryStore::new());
let manager = ApiKeyManager::new(store);
let key = manager
.generate("user_123", vec!["read".to_string()])
.await
.unwrap();
manager.revoke(&key.id).await.unwrap();
let result = manager.validate(&key.key).await;
assert!(matches!(result, Err(ApiKeyError::Revoked)));
}
#[tokio::test]
async fn test_has_scope() {
let store = Arc::new(InMemoryStore::new());
let manager = ApiKeyManager::new(store);
let key = manager
.generate("user_123", vec!["read".to_string(), "write".to_string()])
.await
.unwrap();
assert!(manager.has_scope(&key, "read"));
assert!(manager.has_scope(&key, "write"));
assert!(!manager.has_scope(&key, "admin"));
}
}