use crate::config::rate_limit::RateLimitTier;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use llm_shield_core::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKey {
pub id: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<String>,
pub hashed_value: String,
pub tier: RateLimitTier,
pub created_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<DateTime<Utc>>,
pub active: bool,
}
impl ApiKey {
pub fn new(
name: String,
tier: RateLimitTier,
expires_at: Option<DateTime<Utc>>,
) -> Result<Self> {
let id = uuid::Uuid::new_v4().to_string();
let raw_value = Self::generate_key_value()?;
let hashed_value = Self::hash_key(&raw_value)?;
Ok(Self {
id,
name,
value: Some(raw_value),
hashed_value,
tier,
created_at: Utc::now(),
expires_at,
active: true,
})
}
fn generate_key_value() -> Result<String> {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng();
let key_suffix: String = (0..40)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect();
Ok(format!("llm_shield_{}", key_suffix))
}
fn hash_key(key: &str) -> Result<String> {
use argon2::{
password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
Argon2,
};
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let hash = argon2
.hash_password(key.as_bytes(), &salt)
.map_err(|e| Error::auth(format!("Failed to hash key: {}", e)))?;
Ok(hash.to_string())
}
pub fn verify(&self, raw_key: &str) -> Result<bool> {
use argon2::{password_hash::PasswordVerifier, Argon2};
let argon2 = Argon2::default();
let parsed_hash = argon2::password_hash::PasswordHash::new(&self.hashed_value)
.map_err(|e| Error::auth(format!("Invalid stored hash: {}", e)))?;
Ok(argon2
.verify_password(raw_key.as_bytes(), &parsed_hash)
.is_ok())
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
Utc::now() > expires_at
} else {
false
}
}
pub fn is_valid(&self) -> bool {
self.active && !self.is_expired()
}
pub fn clear_value(&mut self) {
self.value = None;
}
pub fn validate_format(key: &str) -> bool {
key.starts_with("llm_shield_") && key.len() == 51 }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateKeyRequest {
pub name: String,
pub tier: RateLimitTier,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in_days: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateKeyResponse {
pub key: String,
pub id: String,
pub name: String,
pub tier: RateLimitTier,
pub created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<String>,
}
impl From<ApiKey> for CreateKeyResponse {
fn from(key: ApiKey) -> Self {
Self {
key: key.value.unwrap_or_default(),
id: key.id,
name: key.name,
tier: key.tier,
created_at: key.created_at.to_rfc3339(),
expires_at: key.expires_at.map(|dt| dt.to_rfc3339()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_key_new() {
let key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
assert!(!key.id.is_empty());
assert_eq!(key.name, "test-key");
assert!(key.value.is_some());
assert!(!key.hashed_value.is_empty());
assert_eq!(key.tier, RateLimitTier::Free);
assert!(key.active);
assert!(!key.is_expired());
}
#[test]
fn test_api_key_format() {
let key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
let raw_value = key.value.as_ref().unwrap();
assert!(raw_value.starts_with("llm_shield_"));
assert_eq!(raw_value.len(), 51); assert!(ApiKey::validate_format(raw_value));
}
#[test]
fn test_api_key_verify_success() {
let key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
let raw_value = key.value.as_ref().unwrap().clone();
assert!(key.verify(&raw_value).unwrap());
}
#[test]
fn test_api_key_verify_failure() {
let key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
assert!(!key.verify("wrong_key").unwrap());
}
#[test]
fn test_api_key_expiration() {
use chrono::Duration;
let expired_key = ApiKey::new(
"expired-key".to_string(),
RateLimitTier::Free,
Some(Utc::now() - Duration::days(1)),
)
.unwrap();
assert!(expired_key.is_expired());
assert!(!expired_key.is_valid());
let valid_key = ApiKey::new(
"valid-key".to_string(),
RateLimitTier::Free,
Some(Utc::now() + Duration::days(1)),
)
.unwrap();
assert!(!valid_key.is_expired());
assert!(valid_key.is_valid());
}
#[test]
fn test_api_key_inactive() {
let mut key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
assert!(key.is_valid());
key.active = false;
assert!(!key.is_valid());
}
#[test]
fn test_api_key_clear_value() {
let mut key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
assert!(key.value.is_some());
key.clear_value();
assert!(key.value.is_none());
}
#[test]
fn test_validate_format() {
assert!(ApiKey::validate_format(
"llm_shield_abcdefghijklmnopqrstuvwxyz01234567890123"
));
assert!(!ApiKey::validate_format("invalid_key"));
assert!(!ApiKey::validate_format("llm_shield_short"));
}
#[test]
fn test_create_key_response_from_api_key() {
let key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Pro,
None,
)
.unwrap();
let response = CreateKeyResponse::from(key.clone());
assert_eq!(response.key, key.value.unwrap());
assert_eq!(response.id, key.id);
assert_eq!(response.name, key.name);
assert_eq!(response.tier, RateLimitTier::Pro);
}
#[test]
fn test_hash_is_different_from_raw() {
let key = ApiKey::new(
"test-key".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
let raw_value = key.value.as_ref().unwrap();
assert_ne!(raw_value, &key.hashed_value);
assert!(key.hashed_value.starts_with("$argon2"));
}
#[test]
fn test_different_keys_have_different_values() {
let key1 = ApiKey::new(
"key1".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
let key2 = ApiKey::new(
"key2".to_string(),
RateLimitTier::Free,
None,
)
.unwrap();
assert_ne!(key1.value, key2.value);
assert_ne!(key1.hashed_value, key2.hashed_value);
}
}