use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use crate::error::{Error, Result, StorageError, ValidationError};
use crate::random::{generate_random_alphanumeric, generate_random_bytes};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyConfig {
pub key_length: usize,
pub default_prefix: String,
pub require_expiration: bool,
pub max_expiration_days: Option<u32>,
pub allow_empty_scopes: bool,
}
impl Default for ApiKeyConfig {
fn default() -> Self {
Self {
key_length: 32,
default_prefix: "sk".to_string(),
require_expiration: false,
max_expiration_days: None,
allow_empty_scopes: true,
}
}
}
impl ApiKeyConfig {
pub fn production() -> Self {
Self {
key_length: 32,
default_prefix: "sk_live".to_string(),
require_expiration: true,
max_expiration_days: Some(365),
allow_empty_scopes: false,
}
}
pub fn test() -> Self {
Self {
key_length: 32,
default_prefix: "sk_test".to_string(),
require_expiration: false,
max_expiration_days: None,
allow_empty_scopes: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ApiKeyStatus {
#[default]
Active,
Revoked,
Expired,
Disabled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKey {
pub id: String,
pub prefix: String,
pub key_hash: String,
pub key_hint: String,
pub owner: String,
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_used_at: Option<DateTime<Utc>>,
#[serde(default)]
pub use_count: u64,
pub status: ApiKeyStatus,
#[serde(default)]
pub metadata: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rotated_from: Option<String>,
}
impl ApiKey {
pub fn is_valid(&self) -> bool {
if self.status != ApiKeyStatus::Active {
return false;
}
if let Some(expires_at) = self.expires_at {
expires_at > Utc::now()
} else {
true
}
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
expires_at <= Utc::now()
} else {
false
}
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope)
}
pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
scopes.iter().all(|s| self.has_scope(s))
}
pub fn has_any_scope(&self, scopes: &[&str]) -> bool {
scopes.iter().any(|s| self.has_scope(s))
}
pub fn remaining_lifetime(&self) -> Option<Duration> {
self.expires_at.map(|exp| {
let remaining = exp - Utc::now();
if remaining.num_seconds() > 0 {
remaining
} else {
Duration::zero()
}
})
}
pub fn display_hint(&self) -> String {
format!(
"{}...{}",
&self.key_hint,
&self.key_hint[self.key_hint.len().saturating_sub(4)..]
)
}
pub fn record_usage(&mut self) {
self.last_used_at = Some(Utc::now());
self.use_count += 1;
}
pub fn revoke(&mut self) {
self.status = ApiKeyStatus::Revoked;
}
pub fn disable(&mut self) {
self.status = ApiKeyStatus::Disabled;
}
pub fn enable(&mut self) {
if self.status == ApiKeyStatus::Disabled {
self.status = ApiKeyStatus::Active;
}
}
}
#[derive(Debug)]
pub struct ApiKeyBuilder<'a> {
manager: &'a ApiKeyConfig,
owner: String,
prefix: Option<String>,
scopes: Vec<String>,
expires_at: Option<DateTime<Utc>>,
metadata: HashMap<String, String>,
}
impl<'a> ApiKeyBuilder<'a> {
fn new(manager: &'a ApiKeyConfig, owner: String) -> Self {
Self {
manager,
owner,
prefix: None,
scopes: Vec::new(),
expires_at: None,
metadata: HashMap::new(),
}
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = Some(prefix.into());
self
}
pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn with_expires_at(mut self, expires_at: DateTime<Utc>) -> Self {
self.expires_at = Some(expires_at);
self
}
pub fn with_expires_in_days(mut self, days: u32) -> Self {
self.expires_at = Some(Utc::now() + Duration::days(days as i64));
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn build(self) -> Result<(ApiKey, String)> {
if !self.manager.allow_empty_scopes && self.scopes.is_empty() {
return Err(Error::Validation(ValidationError::Custom(
"At least one scope is required".to_string(),
)));
}
if self.manager.require_expiration && self.expires_at.is_none() {
return Err(Error::Validation(ValidationError::Custom(
"Expiration time is required".to_string(),
)));
}
if let (Some(max_days), Some(expires_at)) =
(self.manager.max_expiration_days, self.expires_at)
{
let max_expiration = Utc::now() + Duration::days(max_days as i64);
if expires_at > max_expiration {
return Err(Error::Validation(ValidationError::Custom(format!(
"Expiration cannot exceed {} days",
max_days
))));
}
}
let prefix = self
.prefix
.unwrap_or_else(|| self.manager.default_prefix.clone());
let _random_bytes = generate_random_bytes(self.manager.key_length)?;
let random_part = generate_random_alphanumeric(self.manager.key_length)?;
let plain_key = format!("{}_{}", prefix, random_part);
let key_hash = hash_api_key(&plain_key);
let id = generate_key_id()?;
let key_hint = format!("{}_{}", prefix, &random_part[..8.min(random_part.len())]);
let api_key = ApiKey {
id,
prefix,
key_hash,
key_hint,
owner: self.owner,
scopes: self.scopes,
created_at: Utc::now(),
expires_at: self.expires_at,
last_used_at: None,
use_count: 0,
status: ApiKeyStatus::Active,
metadata: self.metadata,
rotated_from: None,
};
Ok((api_key, plain_key))
}
}
#[derive(Debug)]
pub struct ApiKeyManager {
config: ApiKeyConfig,
keys_by_id: HashMap<String, ApiKey>,
hash_to_id: HashMap<String, String>,
}
impl ApiKeyManager {
pub fn new(config: ApiKeyConfig) -> Self {
Self {
config,
keys_by_id: HashMap::new(),
hash_to_id: HashMap::new(),
}
}
pub fn with_default_config() -> Self {
Self::new(ApiKeyConfig::default())
}
pub fn config(&self) -> &ApiKeyConfig {
&self.config
}
pub fn create_key(&self, owner: impl Into<String>) -> ApiKeyBuilder<'_> {
ApiKeyBuilder::new(&self.config, owner.into())
}
pub fn add_key(&mut self, key: ApiKey) {
self.hash_to_id.insert(key.key_hash.clone(), key.id.clone());
self.keys_by_id.insert(key.id.clone(), key);
}
pub fn validate(&mut self, plain_key: &str) -> Option<&ApiKey> {
let hash = hash_api_key(plain_key);
let id = self.hash_to_id.get(&hash)?;
let key = self.keys_by_id.get_mut(id)?;
if !key.is_valid() {
return None;
}
key.record_usage();
Some(key)
}
pub fn validate_with_scopes(
&mut self,
plain_key: &str,
required_scopes: &[&str],
) -> Option<&ApiKey> {
let key = self.validate(plain_key)?;
if key.has_all_scopes(required_scopes) {
Some(key)
} else {
None
}
}
pub fn get_by_id(&self, id: &str) -> Option<&ApiKey> {
self.keys_by_id.get(id)
}
pub fn get_by_id_mut(&mut self, id: &str) -> Option<&mut ApiKey> {
self.keys_by_id.get_mut(id)
}
pub fn revoke(&mut self, id: &str) -> Result<()> {
let key = self
.keys_by_id
.get_mut(id)
.ok_or_else(|| Error::Storage(StorageError::NotFound(id.to_string())))?;
key.revoke();
Ok(())
}
pub fn delete(&mut self, id: &str) -> Result<ApiKey> {
let key = self
.keys_by_id
.remove(id)
.ok_or_else(|| Error::Storage(StorageError::NotFound(id.to_string())))?;
self.hash_to_id.remove(&key.key_hash);
Ok(key)
}
pub fn rotate(&mut self, id: &str) -> Result<(ApiKey, String)> {
let old_key = self
.keys_by_id
.get(id)
.ok_or_else(|| Error::Storage(StorageError::NotFound(id.to_string())))?;
if !old_key.is_valid() {
return Err(Error::Validation(ValidationError::Custom(
"Cannot rotate an invalid key".to_string(),
)));
}
let mut builder = self
.create_key(&old_key.owner)
.with_prefix(&old_key.prefix)
.with_scopes(old_key.scopes.clone());
if let Some(expires_at) = old_key.expires_at {
builder = builder.with_expires_at(expires_at);
}
for (k, v) in &old_key.metadata {
builder = builder.with_metadata(k, v);
}
let (mut new_key, plain_key) = builder.build()?;
new_key.rotated_from = Some(id.to_string());
self.revoke(id)?;
self.add_key(new_key.clone());
Ok((new_key, plain_key))
}
pub fn list(&self) -> Vec<&ApiKey> {
self.keys_by_id.values().collect()
}
pub fn list_by_owner(&self, owner: &str) -> Vec<&ApiKey> {
self.keys_by_id
.values()
.filter(|k| k.owner == owner)
.collect()
}
pub fn list_active(&self) -> Vec<&ApiKey> {
self.keys_by_id.values().filter(|k| k.is_valid()).collect()
}
pub fn list_expiring_soon(&self, days: i64) -> Vec<&ApiKey> {
let threshold = Utc::now() + Duration::days(days);
self.keys_by_id
.values()
.filter(|k| k.is_valid() && k.expires_at.map(|exp| exp <= threshold).unwrap_or(false))
.collect()
}
pub fn cleanup_expired(&mut self) -> Vec<ApiKey> {
let expired_ids: Vec<String> = self
.keys_by_id
.values()
.filter(|k| k.is_expired())
.map(|k| k.id.clone())
.collect();
expired_ids
.into_iter()
.filter_map(|id| self.delete(&id).ok())
.collect()
}
pub fn stats(&self) -> ApiKeyStats {
let total = self.keys_by_id.len();
let active = self.keys_by_id.values().filter(|k| k.is_valid()).count();
let expired = self.keys_by_id.values().filter(|k| k.is_expired()).count();
let revoked = self
.keys_by_id
.values()
.filter(|k| k.status == ApiKeyStatus::Revoked)
.count();
ApiKeyStats {
total,
active,
expired,
revoked,
}
}
}
#[derive(Debug, Clone)]
pub struct ApiKeyStats {
pub total: usize,
pub active: usize,
pub expired: usize,
pub revoked: usize,
}
#[async_trait]
pub trait ApiKeyStore: Send + Sync {
async fn save(&mut self, key: &ApiKey) -> Result<()>;
async fn load(&self, id: &str) -> Result<Option<ApiKey>>;
async fn load_by_hash(&self, hash: &str) -> Result<Option<ApiKey>>;
async fn delete(&mut self, id: &str) -> Result<()>;
async fn list(&self) -> Result<Vec<ApiKey>>;
async fn list_by_owner(&self, owner: &str) -> Result<Vec<ApiKey>>;
}
#[derive(Debug, Default)]
pub struct InMemoryApiKeyStore {
keys: HashMap<String, ApiKey>,
hash_index: HashMap<String, String>,
}
impl InMemoryApiKeyStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl ApiKeyStore for InMemoryApiKeyStore {
async fn save(&mut self, key: &ApiKey) -> Result<()> {
self.hash_index.insert(key.key_hash.clone(), key.id.clone());
self.keys.insert(key.id.clone(), key.clone());
Ok(())
}
async fn load(&self, id: &str) -> Result<Option<ApiKey>> {
Ok(self.keys.get(id).cloned())
}
async fn load_by_hash(&self, hash: &str) -> Result<Option<ApiKey>> {
let id = match self.hash_index.get(hash) {
Some(id) => id,
None => return Ok(None),
};
self.load(id).await
}
async fn delete(&mut self, id: &str) -> Result<()> {
if let Some(key) = self.keys.remove(id) {
self.hash_index.remove(&key.key_hash);
}
Ok(())
}
async fn list(&self) -> Result<Vec<ApiKey>> {
Ok(self.keys.values().cloned().collect())
}
async fn list_by_owner(&self, owner: &str) -> Result<Vec<ApiKey>> {
Ok(self
.keys
.values()
.filter(|k| k.owner == owner)
.cloned()
.collect())
}
}
fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
let result = hasher.finalize();
result.iter().map(|b| format!("{:02x}", b)).collect()
}
fn generate_key_id() -> Result<String> {
let random = generate_random_alphanumeric(16)?;
Ok(format!("key_{}", random))
}
pub fn validate_api_key_format(key: &str) -> bool {
let parts: Vec<&str> = key.rsplitn(2, '_').collect();
if parts.len() < 2 {
return false;
}
let random_part = parts.first().unwrap_or(&"");
random_part.len() >= 16 && random_part.chars().all(|c| c.is_alphanumeric())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_api_key() {
let mut manager = ApiKeyManager::with_default_config();
let (key, plain_key) = manager
.create_key("test-service")
.with_prefix("sk_test")
.with_scope("read")
.with_scope("write")
.build()
.unwrap();
assert!(plain_key.starts_with("sk_test_"));
assert!(key.is_valid());
assert!(key.has_scope("read"));
assert!(key.has_scope("write"));
assert!(!key.has_scope("admin"));
manager.add_key(key);
let validated = manager.validate(&plain_key);
assert!(validated.is_some());
}
#[test]
fn test_key_expiration() {
let manager = ApiKeyManager::with_default_config();
let (mut key, _) = manager
.create_key("test")
.with_expires_in_days(0) .build()
.unwrap();
key.expires_at = Some(Utc::now() - Duration::hours(1));
assert!(key.is_expired());
assert!(!key.is_valid());
}
#[test]
fn test_key_revocation() {
let mut manager = ApiKeyManager::with_default_config();
let (key, plain_key) = manager.create_key("test").build().unwrap();
let id = key.id.clone();
manager.add_key(key);
assert!(manager.validate(&plain_key).is_some());
manager.revoke(&id).unwrap();
assert!(manager.validate(&plain_key).is_none());
}
#[test]
fn test_key_rotation() {
let mut manager = ApiKeyManager::with_default_config();
let (key, old_plain_key) = manager
.create_key("test")
.with_scope("read")
.build()
.unwrap();
let old_id = key.id.clone();
manager.add_key(key);
let (new_key, new_plain_key) = manager.rotate(&old_id).unwrap();
assert_ne!(old_plain_key, new_plain_key);
assert!(new_key.has_scope("read"));
assert_eq!(new_key.rotated_from, Some(old_id.clone()));
assert!(manager.validate(&old_plain_key).is_none());
assert!(manager.validate(&new_plain_key).is_some());
}
#[test]
fn test_scope_validation() {
let mut manager = ApiKeyManager::with_default_config();
let (key, plain_key) = manager
.create_key("test")
.with_scope("read")
.with_scope("write")
.build()
.unwrap();
manager.add_key(key);
assert!(
manager
.validate_with_scopes(&plain_key, &["read"])
.is_some()
);
assert!(
manager
.validate_with_scopes(&plain_key, &["read", "write"])
.is_some()
);
assert!(
manager
.validate_with_scopes(&plain_key, &["admin"])
.is_none()
);
assert!(
manager
.validate_with_scopes(&plain_key, &["read", "admin"])
.is_none()
);
}
#[test]
fn test_production_config() {
let config = ApiKeyConfig::production();
let manager = ApiKeyManager::new(config);
let result = manager.create_key("test").build();
assert!(result.is_err());
let result = manager
.create_key("test")
.with_scope("read")
.with_expires_in_days(30)
.build();
assert!(result.is_ok());
}
#[test]
fn test_usage_tracking() {
let mut manager = ApiKeyManager::with_default_config();
let (key, plain_key) = manager.create_key("test").build().unwrap();
let id = key.id.clone();
manager.add_key(key);
manager.validate(&plain_key);
manager.validate(&plain_key);
manager.validate(&plain_key);
let key = manager.get_by_id(&id).unwrap();
assert_eq!(key.use_count, 3);
assert!(key.last_used_at.is_some());
}
#[test]
fn test_list_expiring_soon() {
let mut manager = ApiKeyManager::with_default_config();
let (key1, _) = manager
.create_key("test1")
.with_expires_in_days(5)
.build()
.unwrap();
manager.add_key(key1);
let (key2, _) = manager
.create_key("test2")
.with_expires_in_days(60)
.build()
.unwrap();
manager.add_key(key2);
let expiring = manager.list_expiring_soon(7);
assert_eq!(expiring.len(), 1);
}
#[test]
fn test_stats() {
let mut manager = ApiKeyManager::with_default_config();
let (key1, _) = manager.create_key("test1").build().unwrap();
manager.add_key(key1);
let (key2, _) = manager.create_key("test2").build().unwrap();
let id = key2.id.clone();
manager.add_key(key2);
manager.revoke(&id).unwrap();
let stats = manager.stats();
assert_eq!(stats.total, 2);
assert_eq!(stats.active, 1);
assert_eq!(stats.revoked, 1);
}
#[tokio::test]
async fn test_in_memory_store() {
let mut store = InMemoryApiKeyStore::new();
let manager = ApiKeyManager::with_default_config();
let (key, _) = manager.create_key("test").build().unwrap();
let id = key.id.clone();
let hash = key.key_hash.clone();
store.save(&key).await.unwrap();
assert!(store.load(&id).await.unwrap().is_some());
assert!(store.load_by_hash(&hash).await.unwrap().is_some());
let list = store.list().await.unwrap();
assert_eq!(list.len(), 1);
store.delete(&id).await.unwrap();
assert!(store.load(&id).await.unwrap().is_none());
}
#[test]
fn test_validate_api_key_format() {
assert!(validate_api_key_format("sk_test_abcdefghijklmnop"));
assert!(validate_api_key_format("sk_live_1234567890abcdef"));
assert!(!validate_api_key_format("invalid"));
assert!(!validate_api_key_format("sk_short"));
}
#[test]
fn test_key_metadata() {
let manager = ApiKeyManager::with_default_config();
let (key, _) = manager
.create_key("test")
.with_metadata("env", "production")
.with_metadata("team", "backend")
.build()
.unwrap();
assert_eq!(key.metadata.get("env"), Some(&"production".to_string()));
assert_eq!(key.metadata.get("team"), Some(&"backend".to_string()));
}
}