use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum ApiKeyError {
#[error("Storage error: {0}")]
Storage(String),
#[error("Key not found")]
NotFound,
#[error("Key expired")]
Expired,
#[error("Key revoked")]
Revoked,
#[error("Invalid key format")]
InvalidFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyRecord {
pub id: Uuid,
pub key_hash: String,
pub key_prefix: String,
pub user_id: String,
pub name: String,
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub last_used_at: Option<DateTime<Utc>>,
pub revoked: bool,
}
impl ApiKeyRecord {
pub fn new(
user_id: &str,
name: &str,
scopes: Vec<String>,
expires_in_days: Option<u32>,
) -> (Self, String) {
let id = Uuid::new_v4();
let random_part: String = (0..32)
.map(|_| {
let idx = rand::random::<usize>() % 62;
let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
chars.chars().nth(idx).unwrap()
})
.collect();
let plaintext_key = format!("vex_{}_{}", id.to_string().replace("-", ""), random_part);
let key_hash = Self::hash_key(&plaintext_key);
let key_prefix = plaintext_key.chars().take(12).collect();
let expires_at =
expires_in_days.map(|days| Utc::now() + chrono::Duration::days(days as i64));
let record = Self {
id,
key_hash,
key_prefix,
user_id: user_id.to_string(),
name: name.to_string(),
scopes,
created_at: Utc::now(),
expires_at,
last_used_at: None,
revoked: false,
};
(record, plaintext_key)
}
pub fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
hex::encode(hasher.finalize())
}
pub fn is_valid(&self) -> bool {
if self.revoked {
return false;
}
if let Some(expires) = self.expires_at {
if Utc::now() > expires {
return false;
}
}
true
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope || s == "*")
}
}
#[async_trait]
pub trait ApiKeyStore: Send + Sync {
async fn create(&self, record: &ApiKeyRecord) -> Result<(), ApiKeyError>;
async fn find_by_hash(&self, hash: &str) -> Result<Option<ApiKeyRecord>, ApiKeyError>;
async fn find_by_user(&self, user_id: &str) -> Result<Vec<ApiKeyRecord>, ApiKeyError>;
async fn record_usage(&self, id: Uuid) -> Result<(), ApiKeyError>;
async fn revoke(&self, id: Uuid) -> Result<(), ApiKeyError>;
async fn delete(&self, id: Uuid) -> Result<(), ApiKeyError>;
}
#[derive(Debug, Default)]
pub struct MemoryApiKeyStore {
keys: tokio::sync::RwLock<std::collections::HashMap<Uuid, ApiKeyRecord>>,
}
impl MemoryApiKeyStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl ApiKeyStore for MemoryApiKeyStore {
async fn create(&self, record: &ApiKeyRecord) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
keys.insert(record.id, record.clone());
Ok(())
}
async fn find_by_hash(&self, hash: &str) -> Result<Option<ApiKeyRecord>, ApiKeyError> {
let keys = self.keys.read().await;
Ok(keys.values().find(|r| r.key_hash == hash).cloned())
}
async fn find_by_user(&self, user_id: &str) -> Result<Vec<ApiKeyRecord>, ApiKeyError> {
let keys = self.keys.read().await;
Ok(keys
.values()
.filter(|r| r.user_id == user_id)
.cloned()
.collect())
}
async fn record_usage(&self, id: Uuid) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
if let Some(record) = keys.get_mut(&id) {
record.last_used_at = Some(Utc::now());
Ok(())
} else {
Err(ApiKeyError::NotFound)
}
}
async fn revoke(&self, id: Uuid) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
if let Some(record) = keys.get_mut(&id) {
record.revoked = true;
Ok(())
} else {
Err(ApiKeyError::NotFound)
}
}
async fn delete(&self, id: Uuid) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
keys.remove(&id).ok_or(ApiKeyError::NotFound)?;
Ok(())
}
}
pub async fn validate_api_key<S: ApiKeyStore>(
store: &S,
plaintext_key: &str,
) -> Result<ApiKeyRecord, ApiKeyError> {
if !plaintext_key.starts_with("vex_") || plaintext_key.len() < 40 {
return Err(ApiKeyError::InvalidFormat);
}
let hash = ApiKeyRecord::hash_key(plaintext_key);
let record = store
.find_by_hash(&hash)
.await?
.ok_or(ApiKeyError::NotFound)?;
if record.revoked {
return Err(ApiKeyError::Revoked);
}
if let Some(expires) = record.expires_at {
if Utc::now() > expires {
return Err(ApiKeyError::Expired);
}
}
let _ = store.record_usage(record.id).await;
Ok(record)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_api_key_creation() {
let (record, key) =
ApiKeyRecord::new("user123", "My API Key", vec!["read".to_string()], None);
assert!(key.starts_with("vex_"));
assert!(key.len() > 40);
assert_eq!(record.user_id, "user123");
assert_eq!(record.name, "My API Key");
assert!(record.is_valid());
}
#[tokio::test]
async fn test_api_key_hash_consistency() {
let key = "vex_test123_abcdefgh";
let hash1 = ApiKeyRecord::hash_key(key);
let hash2 = ApiKeyRecord::hash_key(key);
assert_eq!(hash1, hash2);
}
#[tokio::test]
async fn test_memory_store_crud() {
let store = MemoryApiKeyStore::new();
let (record, key) = ApiKeyRecord::new("user1", "Test Key", vec![], None);
store.create(&record).await.unwrap();
let hash = ApiKeyRecord::hash_key(&key);
let found = store.find_by_hash(&hash).await.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().id, record.id);
store.revoke(record.id).await.unwrap();
let revoked = store.find_by_hash(&hash).await.unwrap().unwrap();
assert!(revoked.revoked);
}
#[tokio::test]
async fn test_validate_api_key() {
let store = MemoryApiKeyStore::new();
let (record, key) = ApiKeyRecord::new("user1", "Test Key", vec!["admin".to_string()], None);
store.create(&record).await.unwrap();
let validated = validate_api_key(&store, &key).await.unwrap();
assert_eq!(validated.id, record.id);
assert!(validated.has_scope("admin"));
let result = validate_api_key(&store, "invalid").await;
assert!(matches!(result, Err(ApiKeyError::InvalidFormat)));
let result =
validate_api_key(&store, "vex_00000000000000000000000000000000_wrongkey").await;
assert!(matches!(result, Err(ApiKeyError::NotFound)));
}
}