use super::types::{ApiKeyVerification, CreateApiKeyRequest};
use crate::core::models::user::types::User;
use crate::core::models::{ApiKey, Metadata, UsageStats};
use crate::storage::StorageLayer;
use crate::utils::auth::crypto::keys::{extract_api_key_prefix, generate_api_key, hash_api_key};
use crate::utils::error::gateway_error::{GatewayError, Result};
use chrono::Utc;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use uuid::Uuid;
const VALID_PERMISSIONS: &[&str] = &[
"*",
"users.read",
"users.write",
"users.delete",
"teams.read",
"teams.write",
"teams.delete",
"api.chat",
"api.embeddings",
"api.images",
"api_keys.read",
"api_keys.write",
"api_keys.delete",
"analytics.read",
"system.admin",
];
fn validate_create_key_input(name: &str, permissions: &[String]) -> Result<()> {
if name.is_empty() {
return Err(GatewayError::Validation(
"API key name must not be empty".to_string(),
));
}
if name.len() > 255 {
return Err(GatewayError::Validation(
"API key name must not exceed 255 characters".to_string(),
));
}
if name.chars().any(|c| c.is_control()) {
return Err(GatewayError::Validation(
"API key name must not contain control characters".to_string(),
));
}
for perm in permissions {
if !VALID_PERMISSIONS.contains(&perm.as_str()) {
return Err(GatewayError::Validation(format!(
"Unknown permission: '{}'. Valid permissions: {}",
perm,
VALID_PERMISSIONS[1..].join(", "),
)));
}
}
Ok(())
}
const LAST_USED_THROTTLE: Duration = Duration::from_secs(5 * 60);
const API_KEY_CACHE_TTL: u64 = 300;
fn api_key_cache_key(key_hash: &str) -> String {
format!("api_key:hash:{}", key_hash)
}
#[derive(Debug, Clone)]
pub struct ApiKeyHandler {
pub(super) storage: Arc<StorageLayer>,
last_used_cache: Arc<DashMap<Uuid, Instant>>,
hmac_secret: Option<String>,
}
impl ApiKeyHandler {
pub async fn new(storage: Arc<StorageLayer>, hmac_secret: Option<String>) -> Result<Self> {
Ok(Self {
storage,
hmac_secret,
last_used_cache: Arc::new(DashMap::new()),
})
}
fn hmac_secret(&self) -> Option<&str> {
self.hmac_secret.as_deref()
}
pub async fn create_key(
&self,
user_id: Option<Uuid>,
team_id: Option<Uuid>,
name: String,
permissions: Vec<String>,
) -> Result<(ApiKey, String)> {
validate_create_key_input(&name, &permissions)?;
info!("Creating API key: {}", name);
let raw_key = generate_api_key();
let key_hash = hash_api_key(&raw_key, self.hmac_secret());
let key_prefix = extract_api_key_prefix(&raw_key);
let api_key = ApiKey {
metadata: Metadata::new(),
name,
key_hash,
key_prefix,
user_id,
team_id,
permissions,
rate_limits: None,
expires_at: None,
is_active: true,
last_used_at: None,
usage_stats: UsageStats::default(),
};
let stored_key = self.storage.db().create_api_key(&api_key).await?;
info!("API key created successfully: {}", stored_key.metadata.id);
Ok((stored_key, raw_key))
}
pub async fn create_key_with_options(
&self,
request: CreateApiKeyRequest,
) -> Result<(ApiKey, String)> {
validate_create_key_input(&request.name, &request.permissions)?;
info!("Creating API key with options: {}", request.name);
let raw_key = generate_api_key();
let key_hash = hash_api_key(&raw_key, self.hmac_secret());
let key_prefix = extract_api_key_prefix(&raw_key);
let api_key = ApiKey {
metadata: Metadata::new(),
name: request.name,
key_hash,
key_prefix,
user_id: request.user_id,
team_id: request.team_id,
permissions: request.permissions,
rate_limits: request.rate_limits,
expires_at: request.expires_at,
is_active: true,
last_used_at: None,
usage_stats: UsageStats::default(),
};
let stored_key = self.storage.db().create_api_key(&api_key).await?;
info!("API key created successfully: {}", stored_key.metadata.id);
Ok((stored_key, raw_key))
}
async fn find_api_key_cached(&self, key_hash: &str) -> Result<Option<ApiKey>> {
let cache_key = api_key_cache_key(key_hash);
match self.storage.cache_get(&cache_key).await {
Ok(Some(cached)) => {
debug!("API key cache hit");
match serde_json::from_str::<ApiKey>(&cached) {
Ok(api_key) => return Ok(Some(api_key)),
Err(e) => {
warn!(
"Failed to deserialize cached API key, falling back to DB: {}",
e
);
if let Err(del_err) = self.storage.cache_delete(&cache_key).await {
warn!("Failed to delete corrupt API key cache entry: {}", del_err);
}
}
}
}
Ok(None) => {
debug!("API key cache miss");
}
Err(e) => {
warn!("Redis cache_get failed, falling back to DB: {}", e);
}
}
let api_key = self.storage.db().find_api_key_by_hash(key_hash).await?;
if let Some(ref key) = api_key
&& let Ok(serialized) = serde_json::to_string(key)
&& let Err(e) = self
.storage
.cache_set(&cache_key, &serialized, Some(API_KEY_CACHE_TTL))
.await
{
warn!("Failed to populate API key cache: {}", e);
}
Ok(api_key)
}
pub(super) async fn invalidate_api_key_cache(&self, key_hash: &str) {
let cache_key = api_key_cache_key(key_hash);
if let Err(e) = self.storage.cache_delete(&cache_key).await {
warn!(
"Failed to invalidate API key cache for {}: {}",
cache_key, e
);
}
}
pub async fn verify_key(&self, raw_key: &str) -> Result<Option<(ApiKey, Option<User>)>> {
debug!("Verifying API key");
let key_hash = hash_api_key(raw_key, self.hmac_secret());
let api_key = match self.find_api_key_cached(&key_hash).await? {
Some(key) => key,
None => {
debug!("API key not found");
return Ok(None);
}
};
if !api_key.is_active {
debug!("API key is inactive");
return Ok(None);
}
if let Some(expires_at) = api_key.expires_at
&& Utc::now() > expires_at
{
debug!("API key is expired");
return Ok(None);
}
let user = if let Some(user_id) = api_key.user_id {
self.storage.db().find_user_by_id(user_id).await?
} else {
None
};
self.update_last_used(api_key.metadata.id).await?;
debug!("API key verified successfully");
Ok(Some((api_key, user)))
}
pub async fn verify_key_detailed(&self, raw_key: &str) -> Result<ApiKeyVerification> {
let key_hash = hash_api_key(raw_key, self.hmac_secret());
let api_key = match self.find_api_key_cached(&key_hash).await? {
Some(key) => key,
None => {
return Ok(ApiKeyVerification {
api_key: ApiKey {
metadata: Metadata::new(),
name: "".to_string(),
key_hash: "".to_string(),
key_prefix: "".to_string(),
user_id: None,
team_id: None,
permissions: vec![],
rate_limits: None,
expires_at: None,
is_active: false,
last_used_at: None,
usage_stats: UsageStats::default(),
},
user: None,
is_valid: false,
invalid_reason: Some("API key not found".to_string()),
});
}
};
if !api_key.is_active {
return Ok(ApiKeyVerification {
api_key,
user: None,
is_valid: false,
invalid_reason: Some("API key is inactive".to_string()),
});
}
if let Some(expires_at) = api_key.expires_at
&& Utc::now() > expires_at
{
return Ok(ApiKeyVerification {
api_key,
user: None,
is_valid: false,
invalid_reason: Some("API key is expired".to_string()),
});
}
let user = if let Some(user_id) = api_key.user_id {
self.storage.db().find_user_by_id(user_id).await?
} else {
None
};
if let Some(ref user) = user
&& !user.is_active()
{
return Ok(ApiKeyVerification {
api_key,
user: Some(user.clone()),
is_valid: false,
invalid_reason: Some("Associated user is inactive".to_string()),
});
}
self.update_last_used(api_key.metadata.id).await?;
Ok(ApiKeyVerification {
api_key,
user,
is_valid: true,
invalid_reason: None,
})
}
pub(super) async fn update_last_used(&self, key_id: Uuid) -> Result<()> {
let now = Instant::now();
if let Some(last_persisted) = self.last_used_cache.get(&key_id)
&& now.duration_since(*last_persisted) < LAST_USED_THROTTLE
{
return Ok(());
}
self.last_used_cache.insert(key_id, now);
let storage = self.storage.clone();
tokio::spawn(async move {
if let Err(e) = storage.db().update_api_key_last_used(key_id).await {
warn!("Failed to update API key last used timestamp: {}", e);
}
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_name_and_permissions() {
let result = validate_create_key_input("My API Key", &["api.chat".to_string()]);
assert!(result.is_ok());
}
#[test]
fn test_empty_name_is_rejected() {
let result = validate_create_key_input("", &[]);
assert!(matches!(result, Err(GatewayError::Validation(_))));
if let Err(GatewayError::Validation(msg)) = result {
assert!(msg.contains("empty"), "unexpected message: {msg}");
}
}
#[test]
fn test_name_at_max_length_is_accepted() {
let name = "a".repeat(255);
let result = validate_create_key_input(&name, &[]);
assert!(result.is_ok());
}
#[test]
fn test_name_exceeding_max_length_is_rejected() {
let name = "a".repeat(256);
let result = validate_create_key_input(&name, &[]);
assert!(matches!(result, Err(GatewayError::Validation(_))));
if let Err(GatewayError::Validation(msg)) = result {
assert!(msg.contains("255"), "expected '255' in message: {msg}");
}
}
#[test]
fn test_name_with_control_character_is_rejected() {
let name = "bad\x00name";
let result = validate_create_key_input(name, &[]);
assert!(matches!(result, Err(GatewayError::Validation(_))));
if let Err(GatewayError::Validation(msg)) = result {
assert!(
msg.contains("control"),
"expected 'control' in message: {msg}"
);
}
}
#[test]
fn test_name_with_newline_is_rejected() {
let result = validate_create_key_input("bad\nname", &[]);
assert!(matches!(result, Err(GatewayError::Validation(_))));
}
#[test]
fn test_empty_permissions_are_accepted() {
let result = validate_create_key_input("My Key", &[]);
assert!(result.is_ok());
}
#[test]
fn test_wildcard_permission_is_accepted() {
let result = validate_create_key_input("My Key", &["*".to_string()]);
assert!(result.is_ok());
}
#[test]
fn test_all_known_permissions_are_accepted() {
let perms: Vec<String> = VALID_PERMISSIONS.iter().map(|p| p.to_string()).collect();
let result = validate_create_key_input("My Key", &perms);
assert!(result.is_ok());
}
#[test]
fn test_unknown_permission_is_rejected() {
let result = validate_create_key_input("My Key", &["invalid.perm".to_string()]);
assert!(matches!(result, Err(GatewayError::Validation(_))));
if let Err(GatewayError::Validation(msg)) = result {
assert!(
msg.contains("invalid.perm"),
"expected permission name in message: {msg}"
);
}
}
#[test]
fn test_mixed_valid_and_invalid_permissions_are_rejected() {
let perms = vec!["api.chat".to_string(), "not.a.perm".to_string()];
let result = validate_create_key_input("My Key", &perms);
assert!(matches!(result, Err(GatewayError::Validation(_))));
}
}