use crate::security::audit::{AuditEvent, AuditEventType, AuditLogger, AuditActor, AuditOutcome};
use chrono::{DateTime, Utc};
use secrecy::{ExposeSecret, SecretString, SecretVec};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ApiKeyRecord {
pub key_id: Uuid,
pub tenant_id: Uuid,
pub key_prefix: String,
#[serde(with = "hex_array")]
pub key_hash: [u8; 32],
#[serde(with = "base64_vec")]
pub encrypted_key: Vec<u8>,
#[serde(with = "hex_nonce")]
pub encryption_nonce: [u8; 12],
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub last_used_at: Option<DateTime<Utc>>,
pub status: KeyStatus,
pub rotation: Option<KeyRotationInfo>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum KeyStatus {
Active,
Rotating,
Revoked,
Expired,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct KeyRotationInfo {
pub previous_key_id: Option<Uuid>,
pub rotation_started: DateTime<Utc>,
pub rotation_completed: Option<DateTime<Utc>>,
pub reason: RotationReason,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RotationReason {
Scheduled,
SecurityIncident,
PolicyViolation,
CustomerRequest,
KeyCompromise,
}
pub struct GeneratedKey {
pub key_id: Uuid,
pub api_key: SecretString,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone)]
pub struct ValidatedKey {
pub key_id: Uuid,
pub tenant_id: Uuid,
pub scopes: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum KeyValidationError {
#[error("Invalid API key")]
InvalidKey,
#[error("API key has been revoked")]
KeyRevoked,
#[error("API key has expired")]
KeyExpired,
#[error("Insufficient scope: missing {0}")]
InsufficientScope(String),
#[error("Key not found")]
NotFound,
}
#[derive(Debug, thiserror::Error)]
pub enum KeyError {
#[error("Encryption failed")]
EncryptionFailed,
#[error("Decryption failed")]
DecryptionFailed,
#[error("Random generation failed")]
RandomGenerationFailed,
#[error("Invalid key format")]
InvalidKeyFormat,
#[error("Database error: {0}")]
Database(String),
#[error("Audit error: {0}")]
Audit(String),
}
#[async_trait::async_trait]
pub trait KeyStore: Send + Sync {
async fn store_key(&self, record: &ApiKeyRecord) -> Result<(), KeyError>;
async fn find_keys_by_prefix(&self, prefix: &str) -> Result<Vec<ApiKeyRecord>, KeyError>;
async fn get_key(&self, key_id: Uuid) -> Result<Option<ApiKeyRecord>, KeyError>;
async fn update_status(&self, key_id: Uuid, status: KeyStatus) -> Result<(), KeyError>;
async fn update_last_used(&self, key_id: Uuid) -> Result<(), KeyError>;
async fn list_active_keys(&self) -> Result<Vec<ApiKeyRecord>, KeyError>;
async fn set_rotation_info(&self, key_id: Uuid, info: KeyRotationInfo) -> Result<(), KeyError>;
}
pub struct ApiKeyManager {
kek: SecretVec<u8>,
db: Arc<dyn KeyStore>,
audit: Arc<AuditLogger>,
key_prefix: String,
}
impl ApiKeyManager {
pub fn new(
kek: SecretVec<u8>,
db: Arc<dyn KeyStore>,
audit: Arc<AuditLogger>,
is_production: bool,
) -> Self {
Self {
kek,
db,
audit,
key_prefix: if is_production { "rk_live_" } else { "rk_test_" }.to_string(),
}
}
pub async fn generate_key(
&self,
tenant_id: Uuid,
scopes: Vec<String>,
expires_in: Option<Duration>,
) -> Result<GeneratedKey, KeyError> {
let mut key_bytes = generate_secure_random(32)?;
let key_id = Uuid::new_v4();
let key_string = format!("{}{}", self.key_prefix, bs58::encode(&key_bytes).into_string());
let key_hash = hash_key(&key_bytes);
let (encrypted_key, nonce) = self.encrypt_key_material(&key_bytes)?;
zeroize_memory(&mut key_bytes);
let expires_at = expires_in.map(|d| {
Utc::now() + chrono::Duration::from_std(d).unwrap_or(chrono::Duration::days(365))
});
let record = ApiKeyRecord {
key_id,
tenant_id,
key_prefix: key_string[..16.min(key_string.len())].to_string(),
key_hash,
encrypted_key,
encryption_nonce: nonce,
scopes: scopes.clone(),
created_at: Utc::now(),
expires_at,
last_used_at: None,
status: KeyStatus::Active,
rotation: None,
};
self.db.store_key(&record).await?;
self.audit
.log(AuditEvent::new(
AuditEventType::ApiKeyCreated,
AuditActor::System,
)
.with_tenant(tenant_id)
.with_key(key_id)
.with_details(serde_json::json!({
"scopes": scopes,
"expires_at": expires_at,
"key_prefix": record.key_prefix,
}))
.with_outcome(AuditOutcome::Success))
.await
.map_err(|e| KeyError::Audit(e.to_string()))?;
Ok(GeneratedKey {
key_id,
api_key: SecretString::from(key_string),
expires_at,
})
}
pub async fn validate_key(
&self,
provided_key: &str,
required_scopes: &[String],
client_ip: std::net::IpAddr,
) -> Result<ValidatedKey, KeyValidationError> {
let key_bytes = self.parse_key_format(provided_key)?;
let provided_hash = hash_key(&key_bytes);
let prefix = &provided_key[..16.min(provided_key.len())];
let candidates = self
.db
.find_keys_by_prefix(prefix)
.await
.map_err(|_| KeyValidationError::NotFound)?;
let mut matched_record: Option<ApiKeyRecord> = None;
for record in candidates {
if constant_time_eq(&provided_hash, &record.key_hash) {
matched_record = Some(record);
break;
}
}
let record = matched_record.ok_or(KeyValidationError::InvalidKey)?;
match record.status {
KeyStatus::Revoked => {
self.audit_failed_auth(&record, client_ip, "Key revoked")
.await;
return Err(KeyValidationError::KeyRevoked);
}
KeyStatus::Expired => {
self.audit_failed_auth(&record, client_ip, "Key expired")
.await;
return Err(KeyValidationError::KeyExpired);
}
KeyStatus::Active | KeyStatus::Rotating => {}
}
if let Some(expires_at) = record.expires_at {
if Utc::now() > expires_at {
self.audit_failed_auth(&record, client_ip, "Key past expiration")
.await;
return Err(KeyValidationError::KeyExpired);
}
}
for required in required_scopes {
if !record.scopes.contains(required) {
self.audit_failed_auth(
&record,
client_ip,
&format!("Missing scope: {}", required),
)
.await;
return Err(KeyValidationError::InsufficientScope(required.clone()));
}
}
let _ = self.db.update_last_used(record.key_id).await;
let _ = self
.audit
.log(
AuditEvent::new(AuditEventType::ApiKeyValidated, AuditActor::ApiKey(record.key_id))
.with_tenant(record.tenant_id)
.with_key(record.key_id)
.with_details(serde_json::json!({
"scopes_checked": required_scopes,
}))
.with_ip(client_ip)
.with_outcome(AuditOutcome::Success),
)
.await;
Ok(ValidatedKey {
key_id: record.key_id,
tenant_id: record.tenant_id,
scopes: record.scopes,
})
}
pub async fn revoke_key(
&self,
key_id: Uuid,
reason: &str,
actor: AuditActor,
) -> Result<(), KeyError> {
let record = self
.db
.get_key(key_id)
.await?
.ok_or(KeyError::Database("Key not found".to_string()))?;
self.db.update_status(key_id, KeyStatus::Revoked).await?;
self.audit
.log(
AuditEvent::new(AuditEventType::ApiKeyRevoked, actor)
.with_tenant(record.tenant_id)
.with_key(key_id)
.with_details(serde_json::json!({
"reason": reason,
"key_prefix": record.key_prefix,
}))
.with_outcome(AuditOutcome::Success),
)
.await
.map_err(|e| KeyError::Audit(e.to_string()))?;
Ok(())
}
pub async fn get_key(&self, key_id: Uuid) -> Result<ApiKeyRecord, KeyError> {
self.db
.get_key(key_id)
.await?
.ok_or(KeyError::Database("Key not found".to_string()))
}
pub async fn update_status(&self, key_id: Uuid, status: KeyStatus) -> Result<(), KeyError> {
self.db.update_status(key_id, status).await
}
pub async fn set_rotation_info(
&self,
key_id: Uuid,
info: KeyRotationInfo,
) -> Result<(), KeyError> {
self.db.set_rotation_info(key_id, info).await
}
pub async fn list_active_keys(&self) -> Result<Vec<ApiKeyRecord>, KeyError> {
self.db.list_active_keys().await
}
fn parse_key_format(&self, key: &str) -> Result<Vec<u8>, KeyValidationError> {
if !key.starts_with("rk_live_") && !key.starts_with("rk_test_") {
return Err(KeyValidationError::InvalidKey);
}
let encoded = &key[8..];
bs58::decode(encoded)
.into_vec()
.map_err(|_| KeyValidationError::InvalidKey)
}
fn encrypt_key_material(&self, plaintext: &[u8]) -> Result<(Vec<u8>, [u8; 12]), KeyError> {
use aes_gcm::{aead::Aead, Aes256Gcm, KeyInit, Nonce};
let cipher = Aes256Gcm::new_from_slice(self.kek.expose_secret())
.map_err(|_| KeyError::EncryptionFailed)?;
let nonce_bytes = generate_secure_random(12)?;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|_| KeyError::EncryptionFailed)?;
let mut nonce_arr = [0u8; 12];
nonce_arr.copy_from_slice(&nonce_bytes);
Ok((ciphertext, nonce_arr))
}
async fn audit_failed_auth(&self, record: &ApiKeyRecord, ip: std::net::IpAddr, reason: &str) {
let _ = self
.audit
.log(
AuditEvent::new(
AuditEventType::ApiKeyValidationFailed,
AuditActor::Anonymous,
)
.with_tenant(record.tenant_id)
.with_key(record.key_id)
.with_details(serde_json::json!({
"reason": reason,
"key_prefix": record.key_prefix,
}))
.with_ip(ip)
.with_outcome(AuditOutcome::Failure {
error_code: "AUTH_FAILED".to_string(),
error_message: reason.to_string(),
}),
)
.await;
}
}
#[inline(never)]
fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
let mut result: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
result == 0
}
fn generate_secure_random(len: usize) -> Result<Vec<u8>, KeyError> {
use rand::RngCore;
let mut bytes = vec![0u8; len];
rand::rngs::OsRng
.try_fill_bytes(&mut bytes)
.map_err(|_| KeyError::RandomGenerationFailed)?;
Ok(bytes)
}
fn zeroize_memory(data: &mut Vec<u8>) {
use zeroize::Zeroize;
data.zeroize();
}
fn hash_key(key: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(key);
hasher.finalize().into()
}
mod hex_array {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(data: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
hex::encode(data).serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let bytes = hex::decode(&s).map_err(serde::de::Error::custom)?;
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(arr)
}
}
mod hex_nonce {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(data: &[u8; 12], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
hex::encode(data).serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 12], D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let bytes = hex::decode(&s).map_err(serde::de::Error::custom)?;
let mut arr = [0u8; 12];
arr.copy_from_slice(&bytes);
Ok(arr)
}
}
mod base64_vec {
use base64::{engine::general_purpose::STANDARD, Engine};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(data: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
STANDARD.encode(data).serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
STANDARD.decode(&s).map_err(serde::de::Error::custom)
}
}
pub struct InMemoryKeyStore {
keys: RwLock<HashMap<Uuid, ApiKeyRecord>>,
}
impl InMemoryKeyStore {
pub fn new() -> Self {
Self {
keys: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryKeyStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl KeyStore for InMemoryKeyStore {
async fn store_key(&self, record: &ApiKeyRecord) -> Result<(), KeyError> {
let mut keys = self.keys.write().await;
keys.insert(record.key_id, record.clone());
Ok(())
}
async fn find_keys_by_prefix(&self, prefix: &str) -> Result<Vec<ApiKeyRecord>, KeyError> {
let keys = self.keys.read().await;
Ok(keys
.values()
.filter(|k| k.key_prefix.starts_with(prefix))
.cloned()
.collect())
}
async fn get_key(&self, key_id: Uuid) -> Result<Option<ApiKeyRecord>, KeyError> {
let keys = self.keys.read().await;
Ok(keys.get(&key_id).cloned())
}
async fn update_status(&self, key_id: Uuid, status: KeyStatus) -> Result<(), KeyError> {
let mut keys = self.keys.write().await;
if let Some(key) = keys.get_mut(&key_id) {
key.status = status;
}
Ok(())
}
async fn update_last_used(&self, key_id: Uuid) -> Result<(), KeyError> {
let mut keys = self.keys.write().await;
if let Some(key) = keys.get_mut(&key_id) {
key.last_used_at = Some(Utc::now());
}
Ok(())
}
async fn list_active_keys(&self) -> Result<Vec<ApiKeyRecord>, KeyError> {
let keys = self.keys.read().await;
Ok(keys
.values()
.filter(|k| k.status == KeyStatus::Active)
.cloned()
.collect())
}
async fn set_rotation_info(&self, key_id: Uuid, info: KeyRotationInfo) -> Result<(), KeyError> {
let mut keys = self.keys.write().await;
if let Some(key) = keys.get_mut(&key_id) {
key.rotation = Some(info);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_eq_equal() {
let a = [1u8; 32];
let b = [1u8; 32];
assert!(constant_time_eq(&a, &b));
}
#[test]
fn test_constant_time_eq_unequal() {
let a = [1u8; 32];
let mut b = [1u8; 32];
b[31] = 2;
assert!(!constant_time_eq(&a, &b));
}
#[test]
fn test_hash_key_deterministic() {
let key = b"test_key_material_here";
let hash1 = hash_key(key);
let hash2 = hash_key(key);
assert_eq!(hash1, hash2);
}
#[test]
fn test_hash_key_different_inputs() {
let key1 = b"test_key_1_material_";
let key2 = b"test_key_2_material_";
let hash1 = hash_key(key1);
let hash2 = hash_key(key2);
assert_ne!(hash1, hash2);
}
#[test]
fn test_secure_random_length() {
let bytes = generate_secure_random(32).unwrap();
assert_eq!(bytes.len(), 32);
}
#[test]
fn test_secure_random_uniqueness() {
let bytes1 = generate_secure_random(32).unwrap();
let bytes2 = generate_secure_random(32).unwrap();
assert_ne!(bytes1, bytes2);
}
}