use crate::storage::database::Database;
use crate::utils::error::{GatewayError, Result};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VirtualKey {
pub key_id: String,
pub key_hash: String,
pub key_alias: Option<String>,
pub user_id: String,
pub team_id: Option<String>,
pub organization_id: Option<String>,
pub models: Vec<String>,
pub max_budget: Option<f64>,
pub spend: f64,
pub budget_duration: Option<String>,
pub budget_reset_at: Option<DateTime<Utc>>,
pub rate_limits: Option<RateLimits>,
pub permissions: Vec<Permission>,
pub metadata: HashMap<String, String>,
pub expires_at: Option<DateTime<Utc>>,
pub is_active: bool,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>,
pub usage_count: u64,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimits {
pub rpm: Option<u32>,
pub rph: Option<u32>,
pub rpd: Option<u32>,
pub tpm: Option<u32>,
pub tph: Option<u32>,
pub tpd: Option<u32>,
pub max_parallel_requests: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum Permission {
ChatCompletion,
TextCompletion,
Embedding,
ImageGeneration,
ModelAccess(String),
Admin,
KeyManagement,
ViewUsage,
TeamManagement,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateKeyRequest {
pub key_alias: Option<String>,
pub user_id: String,
pub team_id: Option<String>,
pub models: Vec<String>,
pub max_budget: Option<f64>,
pub budget_duration: Option<String>,
pub rate_limits: Option<RateLimits>,
pub permissions: Vec<Permission>,
pub metadata: HashMap<String, String>,
pub expires_at: Option<DateTime<Utc>>,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateKeyRequest {
pub key_alias: Option<String>,
pub models: Option<Vec<String>>,
pub max_budget: Option<f64>,
pub budget_duration: Option<String>,
pub rate_limits: Option<RateLimits>,
pub permissions: Option<Vec<Permission>>,
pub metadata: Option<HashMap<String, String>>,
pub expires_at: Option<DateTime<Utc>>,
pub is_active: Option<bool>,
pub tags: Option<Vec<String>>,
}
pub struct VirtualKeyManager {
database: Arc<Database>,
key_cache: Arc<RwLock<HashMap<String, VirtualKey>>>,
rate_limiter: Arc<RwLock<HashMap<String, RateLimitState>>>,
key_settings: KeyGenerationSettings,
}
#[derive(Debug, Clone)]
pub struct RateLimitState {
pub request_count: u32,
pub token_count: u32,
pub window_start: DateTime<Utc>,
pub parallel_requests: u32,
}
#[derive(Debug, Clone)]
pub struct KeyGenerationSettings {
pub key_length: usize,
pub key_prefix: String,
pub default_permissions: Vec<Permission>,
pub default_budget: Option<f64>,
pub default_rate_limits: Option<RateLimits>,
}
impl Default for KeyGenerationSettings {
fn default() -> Self {
Self {
key_length: 32,
key_prefix: "sk-".to_string(),
default_permissions: vec![
Permission::ChatCompletion,
Permission::TextCompletion,
Permission::Embedding,
],
default_budget: Some(100.0),
default_rate_limits: Some(RateLimits {
rpm: Some(60),
rph: Some(3600),
rpd: Some(86400),
tpm: Some(100000),
tph: Some(6000000),
tpd: Some(144000000),
max_parallel_requests: Some(10),
}),
}
}
}
impl VirtualKeyManager {
pub async fn new(database: Arc<Database>) -> Result<Self> {
Ok(Self {
database,
key_cache: Arc::new(RwLock::new(HashMap::new())),
rate_limiter: Arc::new(RwLock::new(HashMap::new())),
key_settings: KeyGenerationSettings::default(),
})
}
pub async fn create_key(&self, request: CreateKeyRequest) -> Result<(String, VirtualKey)> {
info!("Creating virtual key for user: {}", request.user_id);
let api_key = self.generate_api_key();
let key_hash = self.hash_key(&api_key);
let virtual_key = VirtualKey {
key_id: Uuid::new_v4().to_string(),
key_hash: key_hash.clone(),
key_alias: request.key_alias,
user_id: request.user_id,
team_id: request.team_id,
organization_id: None, models: request.models,
max_budget: request.max_budget.or(self.key_settings.default_budget),
spend: 0.0,
budget_duration: request.budget_duration,
budget_reset_at: self.calculate_budget_reset(&request.budget_duration),
rate_limits: request.rate_limits.or(self.key_settings.default_rate_limits.clone()),
permissions: if request.permissions.is_empty() {
self.key_settings.default_permissions.clone()
} else {
request.permissions
},
metadata: request.metadata,
expires_at: request.expires_at,
is_active: true,
created_at: Utc::now(),
last_used_at: None,
usage_count: 0,
tags: request.tags,
};
self.database.store_virtual_key(&virtual_key).await?;
{
let mut cache = self.key_cache.write().await;
cache.insert(key_hash, virtual_key.clone());
}
info!("Virtual key created successfully: {}", virtual_key.key_id);
Ok((api_key, virtual_key))
}
pub async fn validate_key(&self, api_key: &str) -> Result<VirtualKey> {
let key_hash = self.hash_key(api_key);
{
let cache = self.key_cache.read().await;
if let Some(key) = cache.get(&key_hash) {
if self.is_key_valid(key) {
return Ok(key.clone());
}
}
}
let mut virtual_key = self.database.get_virtual_key(&key_hash).await?
.ok_or_else(|| GatewayError::Unauthorized("Invalid API key".to_string()))?;
if !self.is_key_valid(&virtual_key) {
return Err(GatewayError::Unauthorized("API key is expired or inactive".to_string()));
}
virtual_key.last_used_at = Some(Utc::now());
virtual_key.usage_count += 1;
let db = self.database.clone();
let key_for_update = virtual_key.clone();
tokio::spawn(async move {
if let Err(e) = db.update_virtual_key_usage(&key_for_update).await {
error!("Failed to update key usage: {}", e);
}
});
{
let mut cache = self.key_cache.write().await;
cache.insert(key_hash, virtual_key.clone());
}
Ok(virtual_key)
}
pub async fn check_rate_limits(
&self,
key: &VirtualKey,
tokens_requested: u32,
) -> Result<()> {
if let Some(rate_limits) = &key.rate_limits {
let mut rate_limiter = self.rate_limiter.write().await;
let state = rate_limiter.entry(key.key_id.clone())
.or_insert_with(|| RateLimitState {
request_count: 0,
token_count: 0,
window_start: Utc::now(),
parallel_requests: 0,
});
let now = Utc::now();
if now.signed_duration_since(state.window_start) > Duration::minutes(1) {
state.request_count = 0;
state.token_count = 0;
state.window_start = now;
}
if let Some(rpm) = rate_limits.rpm {
if state.request_count >= rpm {
return Err(GatewayError::RateLimit(
format!("Rate limit exceeded: {} requests per minute", rpm)
));
}
}
if let Some(tpm) = rate_limits.tpm {
if state.token_count + tokens_requested > tpm {
return Err(GatewayError::RateLimit(
format!("Token rate limit exceeded: {} tokens per minute", tpm)
));
}
}
if let Some(max_parallel) = rate_limits.max_parallel_requests {
if state.parallel_requests >= max_parallel {
return Err(GatewayError::RateLimit(
format!("Too many parallel requests: max {}", max_parallel)
));
}
}
state.request_count += 1;
state.token_count += tokens_requested;
state.parallel_requests += 1;
}
Ok(())
}
pub async fn record_request_completion(&self, key_id: &str) {
let mut rate_limiter = self.rate_limiter.write().await;
if let Some(state) = rate_limiter.get_mut(key_id) {
if state.parallel_requests > 0 {
state.parallel_requests -= 1;
}
}
}
pub async fn check_budget(&self, key: &VirtualKey, cost: f64) -> Result<()> {
if let Some(max_budget) = key.max_budget {
if key.spend + cost > max_budget {
return Err(GatewayError::BudgetExceeded(
format!("Budget exceeded: ${:.2} + ${:.2} > ${:.2}",
key.spend, cost, max_budget)
));
}
}
Ok(())
}
pub async fn update_spend(&self, key_id: &str, cost: f64) -> Result<()> {
self.database.update_key_spend(key_id, cost).await?;
{
let mut cache = self.key_cache.write().await;
for (_, key) in cache.iter_mut() {
if key.key_id == key_id {
key.spend += cost;
break;
}
}
}
Ok(())
}
pub async fn list_user_keys(&self, user_id: &str) -> Result<Vec<VirtualKey>> {
self.database.list_user_keys(user_id).await
}
pub async fn update_key(&self, key_id: &str, request: UpdateKeyRequest) -> Result<VirtualKey> {
let mut key = self.database.get_virtual_key_by_id(key_id).await?
.ok_or_else(|| GatewayError::NotFound("Virtual key not found".to_string()))?;
if let Some(alias) = request.key_alias {
key.key_alias = Some(alias);
}
if let Some(models) = request.models {
key.models = models;
}
if let Some(budget) = request.max_budget {
key.max_budget = Some(budget);
}
if let Some(duration) = request.budget_duration {
key.budget_duration = Some(duration.clone());
key.budget_reset_at = self.calculate_budget_reset(&Some(duration));
}
if let Some(rate_limits) = request.rate_limits {
key.rate_limits = Some(rate_limits);
}
if let Some(permissions) = request.permissions {
key.permissions = permissions;
}
if let Some(metadata) = request.metadata {
key.metadata = metadata;
}
if let Some(expires_at) = request.expires_at {
key.expires_at = Some(expires_at);
}
if let Some(is_active) = request.is_active {
key.is_active = is_active;
}
if let Some(tags) = request.tags {
key.tags = tags;
}
self.database.update_virtual_key(&key).await?;
{
let mut cache = self.key_cache.write().await;
cache.insert(key.key_hash.clone(), key.clone());
}
Ok(key)
}
pub async fn delete_key(&self, key_id: &str) -> Result<()> {
let key = self.database.get_virtual_key_by_id(key_id).await?
.ok_or_else(|| GatewayError::NotFound("Virtual key not found".to_string()))?;
self.database.delete_virtual_key(key_id).await?;
{
let mut cache = self.key_cache.write().await;
cache.remove(&key.key_hash);
}
info!("Virtual key deleted: {}", key_id);
Ok(())
}
fn generate_api_key(&self) -> String {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng();
let random_string: String = (0..self.key_settings.key_length)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect();
format!("{}{}", self.key_settings.key_prefix, random_string)
}
fn hash_key(&self, key: &str) -> String {
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
fn is_key_valid(&self, key: &VirtualKey) -> bool {
if !key.is_active {
return false;
}
if let Some(expires_at) = key.expires_at {
if Utc::now() > expires_at {
return false;
}
}
true
}
fn calculate_budget_reset(&self, duration: &Option<String>) -> Option<DateTime<Utc>> {
duration.as_ref().and_then(|d| {
let now = Utc::now();
match d.as_str() {
"1d" => Some(now + Duration::days(1)),
"1w" => Some(now + Duration::weeks(1)),
"1m" => Some(now + Duration::days(30)),
_ => None,
}
})
}
pub async fn reset_expired_budgets(&self) -> Result<()> {
let keys_to_reset = self.database.get_keys_with_expired_budgets().await?;
for mut key in keys_to_reset {
key.spend = 0.0;
key.budget_reset_at = self.calculate_budget_reset(&key.budget_duration);
self.database.update_virtual_key(&key).await?;
{
let mut cache = self.key_cache.write().await;
cache.insert(key.key_hash.clone(), key);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_generation() {
let manager = VirtualKeyManager {
database: Arc::new(Database::new_mock()),
key_cache: Arc::new(RwLock::new(HashMap::new())),
rate_limiter: Arc::new(RwLock::new(HashMap::new())),
key_settings: KeyGenerationSettings::default(),
};
let key = manager.generate_api_key();
assert!(key.starts_with("sk-"));
assert_eq!(key.len(), 35); }
#[test]
fn test_key_hashing() {
let manager = VirtualKeyManager {
database: Arc::new(Database::new_mock()),
key_cache: Arc::new(RwLock::new(HashMap::new())),
rate_limiter: Arc::new(RwLock::new(HashMap::new())),
key_settings: KeyGenerationSettings::default(),
};
let key = "sk-test123";
let hash1 = manager.hash_key(key);
let hash2 = manager.hash_key(key);
assert_eq!(hash1, hash2);
assert_ne!(hash1, key);
}
#[test]
fn test_key_validation() {
let manager = VirtualKeyManager {
database: Arc::new(Database::new_mock()),
key_cache: Arc::new(RwLock::new(HashMap::new())),
rate_limiter: Arc::new(RwLock::new(HashMap::new())),
key_settings: KeyGenerationSettings::default(),
};
let active_key = VirtualKey {
key_id: "test".to_string(),
key_hash: "hash".to_string(),
key_alias: None,
user_id: "user1".to_string(),
team_id: None,
organization_id: None,
models: vec![],
max_budget: None,
spend: 0.0,
budget_duration: None,
budget_reset_at: None,
rate_limits: None,
permissions: vec![],
metadata: HashMap::new(),
expires_at: None,
is_active: true,
created_at: Utc::now(),
last_used_at: None,
usage_count: 0,
tags: vec![],
};
assert!(manager.is_key_valid(&active_key));
let inactive_key = VirtualKey {
is_active: false,
..active_key.clone()
};
assert!(!manager.is_key_valid(&inactive_key));
let expired_key = VirtualKey {
expires_at: Some(Utc::now() - Duration::hours(1)),
..active_key
};
assert!(!manager.is_key_valid(&expired_key));
}
}