use super::storage::KeyStorage;
use super::types::{ApiKey, CreateKeyRequest, CreateKeyResponse, Result};
use crate::config::rate_limit::RateLimitTier;
use chrono::{Duration, Utc};
use llm_shield_core::Error;
use std::sync::Arc;
pub struct AuthService {
storage: Arc<dyn KeyStorage>,
}
impl AuthService {
pub fn new(storage: Arc<dyn KeyStorage>) -> Self {
Self { storage }
}
pub async fn create_key(
&self,
name: String,
tier: RateLimitTier,
expires_in_days: Option<u32>,
) -> Result<CreateKeyResponse> {
let expires_at = expires_in_days.map(|days| Utc::now() + Duration::days(days as i64));
let mut key = ApiKey::new(name, tier, expires_at)?;
self.storage.store(&key).await?;
let response = CreateKeyResponse::from(key.clone());
key.clear_value();
self.storage.update(&key).await?;
Ok(response)
}
pub async fn create_key_from_request(
&self,
request: CreateKeyRequest,
) -> Result<CreateKeyResponse> {
self.create_key(request.name, request.tier, request.expires_in_days)
.await
}
pub async fn validate_key(&self, raw_key: &str) -> Result<ApiKey> {
if !ApiKey::validate_format(raw_key) {
return Err(Error::unauthorized("Invalid API key format"));
}
let all_keys = self.storage.list().await?;
for key in all_keys {
if key.verify(raw_key)? {
if !key.active {
return Err(Error::unauthorized("API key has been revoked"));
}
if key.is_expired() {
return Err(Error::unauthorized("API key has expired"));
}
return Ok(key);
}
}
Err(Error::unauthorized("Invalid API key"))
}
pub async fn revoke_key(&self, id: &str) -> Result<()> {
let mut key = self
.storage
.get_by_id(id)
.await?
.ok_or_else(|| Error::not_found("API key not found"))?;
key.active = false;
self.storage.update(&key).await?;
Ok(())
}
pub async fn delete_key(&self, id: &str) -> Result<()> {
self.storage.delete(id).await
}
pub async fn list_keys(&self) -> Result<Vec<ApiKey>> {
self.storage.list().await
}
pub async fn get_key(&self, id: &str) -> Result<Option<ApiKey>> {
self.storage.get_by_id(id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::storage::MemoryKeyStorage;
fn create_service() -> AuthService {
let storage = Arc::new(MemoryKeyStorage::new());
AuthService::new(storage)
}
#[tokio::test]
async fn test_create_key() {
let service = create_service();
let response = service
.create_key("test-key".to_string(), RateLimitTier::Free, None)
.await
.unwrap();
assert!(!response.key.is_empty());
assert!(response.key.starts_with("llm_shield_"));
assert_eq!(response.name, "test-key");
assert_eq!(response.tier, RateLimitTier::Free);
}
#[tokio::test]
async fn test_create_key_with_expiration() {
let service = create_service();
let response = service
.create_key("test-key".to_string(), RateLimitTier::Pro, Some(30))
.await
.unwrap();
assert!(response.expires_at.is_some());
}
#[tokio::test]
async fn test_validate_key_success() {
let service = create_service();
let response = service
.create_key("test-key".to_string(), RateLimitTier::Free, None)
.await
.unwrap();
let validated = service.validate_key(&response.key).await.unwrap();
assert_eq!(validated.id, response.id);
assert_eq!(validated.name, "test-key");
assert!(validated.is_valid());
}
#[tokio::test]
async fn test_validate_key_invalid_format() {
let service = create_service();
let result = service.validate_key("invalid_key").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_validate_key_not_found() {
let service = create_service();
let result = service
.validate_key("llm_shield_abcdefghijklmnopqrstuvwxyz01234567890123")
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_revoke_key() {
let service = create_service();
let response = service
.create_key("test-key".to_string(), RateLimitTier::Free, None)
.await
.unwrap();
assert!(service.validate_key(&response.key).await.is_ok());
service.revoke_key(&response.id).await.unwrap();
let result = service.validate_key(&response.key).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_delete_key() {
let service = create_service();
let response = service
.create_key("test-key".to_string(), RateLimitTier::Free, None)
.await
.unwrap();
assert!(service.get_key(&response.id).await.unwrap().is_some());
service.delete_key(&response.id).await.unwrap();
assert!(service.get_key(&response.id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_list_keys() {
let service = create_service();
service
.create_key("key1".to_string(), RateLimitTier::Free, None)
.await
.unwrap();
service
.create_key("key2".to_string(), RateLimitTier::Pro, None)
.await
.unwrap();
service
.create_key("key3".to_string(), RateLimitTier::Enterprise, None)
.await
.unwrap();
let keys = service.list_keys().await.unwrap();
assert_eq!(keys.len(), 3);
}
#[tokio::test]
async fn test_validate_expired_key() {
let service = create_service();
let response = service
.create_key("expired-key".to_string(), RateLimitTier::Free, Some(0))
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let result = service.validate_key(&response.key).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_create_key_from_request() {
let service = create_service();
let request = CreateKeyRequest {
name: "test-key".to_string(),
tier: RateLimitTier::Pro,
expires_in_days: Some(30),
};
let response = service.create_key_from_request(request).await.unwrap();
assert_eq!(response.name, "test-key");
assert_eq!(response.tier, RateLimitTier::Pro);
assert!(response.expires_at.is_some());
}
#[tokio::test]
async fn test_raw_value_cleared_after_creation() {
let service = create_service();
let response = service
.create_key("test-key".to_string(), RateLimitTier::Free, None)
.await
.unwrap();
let stored_key = service.get_key(&response.id).await.unwrap().unwrap();
assert!(stored_key.value.is_none());
assert!(service.validate_key(&response.key).await.is_ok());
}
}