use argon2::password_hash::{rand_core::OsRng, SaltString};
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use thiserror::Error;
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum ApiKeyError {
#[error("Storage error: {0}")]
Storage(String),
#[error("Key not found")]
NotFound,
#[error("Key expired")]
Expired,
#[error("Key revoked")]
Revoked,
#[error("Invalid key format")]
InvalidFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyRecord {
pub id: Uuid,
pub key_hash: String,
pub key_prefix: String,
pub user_id: String,
pub name: String,
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub last_used_at: Option<DateTime<Utc>>,
pub revoked: bool,
}
impl ApiKeyRecord {
pub fn new(
user_id: &str,
name: &str,
scopes: Vec<String>,
expires_in_days: Option<u32>,
) -> (Self, String) {
let id = Uuid::new_v4();
use rand::distributions::{Alphanumeric, DistString};
let random_part = Alphanumeric.sample_string(&mut rand::thread_rng(), 32);
let plaintext_key = format!("vex_{}_{}", id.to_string().replace("-", ""), random_part);
let key_hash = Self::hash_key(&plaintext_key);
let key_prefix = plaintext_key.chars().take(12).collect();
let expires_at =
expires_in_days.map(|days| Utc::now() + chrono::Duration::days(days as i64));
let record = Self {
id,
key_hash,
key_prefix,
user_id: user_id.to_string(),
name: name.to_string(),
scopes,
created_at: Utc::now(),
expires_at,
last_used_at: None,
revoked: false,
};
(record, plaintext_key)
}
pub fn hash_key(key: &str) -> String {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
argon2
.hash_password(key.as_bytes(), &salt)
.expect("Argon2 hashing should not fail")
.to_string()
}
pub fn verify_key(plaintext_key: &str, stored_hash: &str) -> bool {
match PasswordHash::new(stored_hash) {
Ok(parsed_hash) => Argon2::default()
.verify_password(plaintext_key.as_bytes(), &parsed_hash)
.is_ok(),
Err(_) => {
let legacy_hash = {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(plaintext_key.as_bytes());
hex::encode(hasher.finalize())
};
legacy_hash.as_bytes().ct_eq(stored_hash.as_bytes()).into()
}
}
}
pub fn is_valid(&self) -> bool {
if self.revoked {
return false;
}
if let Some(expires) = self.expires_at {
if Utc::now() > expires {
return false;
}
}
true
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope || s == "*")
}
}
#[async_trait]
pub trait ApiKeyStore: Send + Sync {
async fn create(&self, record: &ApiKeyRecord) -> Result<(), ApiKeyError>;
async fn find_by_hash(&self, hash: &str) -> Result<Option<ApiKeyRecord>, ApiKeyError>;
async fn find_and_verify_key(
&self,
plaintext_key: &str,
) -> Result<Option<ApiKeyRecord>, ApiKeyError>;
async fn find_by_user(&self, user_id: &str) -> Result<Vec<ApiKeyRecord>, ApiKeyError>;
async fn record_usage(&self, id: Uuid) -> Result<(), ApiKeyError>;
async fn revoke(&self, id: Uuid) -> Result<(), ApiKeyError>;
async fn delete(&self, id: Uuid) -> Result<(), ApiKeyError>;
async fn rotate(
&self,
old_key_id: Uuid,
expires_in_days: Option<u32>,
) -> Result<(ApiKeyRecord, String), ApiKeyError>;
}
#[derive(Debug, Default)]
pub struct MemoryApiKeyStore {
keys: tokio::sync::RwLock<std::collections::HashMap<Uuid, ApiKeyRecord>>,
}
impl MemoryApiKeyStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl ApiKeyStore for MemoryApiKeyStore {
async fn create(&self, record: &ApiKeyRecord) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
keys.insert(record.id, record.clone());
Ok(())
}
async fn find_by_hash(&self, hash: &str) -> Result<Option<ApiKeyRecord>, ApiKeyError> {
let keys = self.keys.read().await;
Ok(keys.values().find(|r| r.key_hash == hash).cloned())
}
async fn find_and_verify_key(
&self,
plaintext_key: &str,
) -> Result<Option<ApiKeyRecord>, ApiKeyError> {
let parts: Vec<&str> = plaintext_key.split('_').collect();
if parts.len() < 3 {
return Err(ApiKeyError::InvalidFormat);
}
let uuid_str = parts[1];
if uuid_str.len() != 32 {
return Err(ApiKeyError::InvalidFormat);
}
let formatted_uuid = format!(
"{}-{}-{}-{}-{}",
&uuid_str[0..8],
&uuid_str[8..12],
&uuid_str[12..16],
&uuid_str[16..20],
&uuid_str[20..32]
);
let id = Uuid::parse_str(&formatted_uuid).map_err(|_| ApiKeyError::InvalidFormat)?;
let keys = self.keys.read().await;
if let Some(record) = keys.get(&id) {
if ApiKeyRecord::verify_key(plaintext_key, &record.key_hash) {
return Ok(Some(record.clone()));
}
}
Ok(None)
}
async fn find_by_user(&self, user_id: &str) -> Result<Vec<ApiKeyRecord>, ApiKeyError> {
let keys = self.keys.read().await;
Ok(keys
.values()
.filter(|r| r.user_id == user_id)
.cloned()
.collect())
}
async fn record_usage(&self, id: Uuid) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
if let Some(record) = keys.get_mut(&id) {
record.last_used_at = Some(Utc::now());
Ok(())
} else {
Err(ApiKeyError::NotFound)
}
}
async fn revoke(&self, id: Uuid) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
if let Some(record) = keys.get_mut(&id) {
record.revoked = true;
Ok(())
} else {
Err(ApiKeyError::NotFound)
}
}
async fn delete(&self, id: Uuid) -> Result<(), ApiKeyError> {
let mut keys = self.keys.write().await;
keys.remove(&id).ok_or(ApiKeyError::NotFound)?;
Ok(())
}
async fn rotate(
&self,
old_key_id: Uuid,
expires_in_days: Option<u32>,
) -> Result<(ApiKeyRecord, String), ApiKeyError> {
let old_key = {
let keys = self.keys.read().await;
keys.get(&old_key_id)
.cloned()
.ok_or(ApiKeyError::NotFound)?
};
self.revoke(old_key_id).await?;
let ttl = expires_in_days.unwrap_or(90); let (new_record, plaintext) = ApiKeyRecord::new(
&old_key.user_id,
&format!("{} (rotated)", old_key.name),
old_key.scopes.clone(),
Some(ttl),
);
self.create(&new_record).await?;
tracing::info!(
old_key_id = %old_key_id,
new_key_id = %new_record.id,
user_id = %old_key.user_id,
"API key rotated successfully"
);
Ok((new_record, plaintext))
}
}
pub async fn validate_api_key<S: ApiKeyStore>(
store: &S,
plaintext_key: &str,
) -> Result<ApiKeyRecord, ApiKeyError> {
if !plaintext_key.starts_with("vex_") || plaintext_key.len() < 40 {
return Err(ApiKeyError::InvalidFormat);
}
let record = store
.find_and_verify_key(plaintext_key)
.await?
.ok_or(ApiKeyError::NotFound)?;
if record.revoked {
return Err(ApiKeyError::Revoked);
}
if let Some(expires) = record.expires_at {
if Utc::now() > expires {
return Err(ApiKeyError::Expired);
}
}
let _ = store.record_usage(record.id).await;
Ok(record)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_api_key_creation() {
let (record, key) =
ApiKeyRecord::new("user123", "My API Key", vec!["read".to_string()], None);
assert!(key.starts_with("vex_"));
assert!(key.len() > 40);
assert_eq!(record.user_id, "user123");
assert_eq!(record.name, "My API Key");
assert!(record.is_valid());
}
#[tokio::test]
async fn test_api_key_hash_verification() {
let key = "vex_test123456789_abcdefghijklmnopqrst";
let hash = ApiKeyRecord::hash_key(key);
assert!(ApiKeyRecord::verify_key(key, &hash));
assert!(!ApiKeyRecord::verify_key(
"vex_wrong_key_12345678901234567890",
&hash
));
}
#[tokio::test]
async fn test_memory_store_crud() {
let store = MemoryApiKeyStore::new();
let (record, key) = ApiKeyRecord::new("user1", "Test Key", vec![], None);
store.create(&record).await.unwrap();
let found = store.find_and_verify_key(&key).await.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().id, record.id);
store.revoke(record.id).await.unwrap();
let revoked = store.find_and_verify_key(&key).await.unwrap().unwrap();
assert!(revoked.revoked);
}
#[tokio::test]
async fn test_validate_api_key() {
let store = MemoryApiKeyStore::new();
let (record, key) = ApiKeyRecord::new("user1", "Test Key", vec!["admin".to_string()], None);
store.create(&record).await.unwrap();
let validated = validate_api_key(&store, &key).await.unwrap();
assert_eq!(validated.id, record.id);
assert!(validated.has_scope("admin"));
let result = validate_api_key(&store, "invalid").await;
assert!(matches!(result, Err(ApiKeyError::InvalidFormat)));
let result =
validate_api_key(&store, "vex_00000000000000000000000000000000_wrongkey").await;
assert!(matches!(result, Err(ApiKeyError::NotFound)));
}
}