use chrono::{DateTime, Utc};
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize)]
pub struct AuditEvent {
pub timestamp: DateTime<Utc>,
pub event_type: AuditEventType,
pub request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_ip: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub key_prefix: Option<String>,
pub endpoint: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AuditEventType {
AuthSuccess,
AuthFailure,
AuthExpired,
AuthDisabled,
ScopeViolation,
RateLimited,
ModelLoad,
ModelUnload,
KeyCreated,
KeyRevoked,
}
impl AuditEventType {
pub fn as_str(&self) -> &'static str {
match self {
Self::AuthSuccess => "auth_success",
Self::AuthFailure => "auth_failure",
Self::AuthExpired => "auth_expired",
Self::AuthDisabled => "auth_disabled",
Self::ScopeViolation => "scope_violation",
Self::RateLimited => "rate_limited",
Self::ModelLoad => "model_load",
Self::ModelUnload => "model_unload",
Self::KeyCreated => "key_created",
Self::KeyRevoked => "key_revoked",
}
}
}
#[derive(Debug, Clone)]
pub struct AuditConfig {
pub enabled: bool,
pub log_client_ip: bool,
pub log_key_prefix: bool,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enabled: true,
log_client_ip: true,
log_key_prefix: true,
}
}
}
impl AuditConfig {
pub fn enabled() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct AuditLogger {
config: Arc<RwLock<AuditConfig>>,
}
impl Default for AuditLogger {
fn default() -> Self {
Self::new(AuditConfig::default())
}
}
impl AuditLogger {
pub fn new(config: AuditConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
}
}
pub fn enabled() -> Self {
Self::new(AuditConfig::enabled())
}
pub fn disabled() -> Self {
Self::new(AuditConfig::disabled())
}
pub async fn log(&self, event: AuditEvent) {
let config = self.config.read().await;
if !config.enabled {
return;
}
let event_json = serde_json::to_string(&event).unwrap_or_else(|_| "{}".to_string());
tracing::info!(
target: "audit",
event_type = event.event_type.as_str(),
request_id = %event.request_id,
client_ip = ?event.client_ip,
key_prefix = ?event.key_prefix,
endpoint = %event.endpoint,
success = event.success,
details = ?event.details,
"{}",
event_json
);
}
pub async fn auth_success(
&self,
request_id: &str,
client_ip: Option<&str>,
key: &str,
endpoint: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::AuthSuccess,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(key).await,
endpoint: endpoint.to_string(),
success: true,
details: None,
})
.await;
}
pub async fn auth_failure(
&self,
request_id: &str,
client_ip: Option<&str>,
reason: &str,
endpoint: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::AuthFailure,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: None,
endpoint: endpoint.to_string(),
success: false,
details: Some(reason.to_string()),
})
.await;
}
pub async fn auth_expired(
&self,
request_id: &str,
client_ip: Option<&str>,
key: &str,
endpoint: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::AuthExpired,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(key).await,
endpoint: endpoint.to_string(),
success: false,
details: Some("API key has expired".to_string()),
})
.await;
}
pub async fn auth_disabled(
&self,
request_id: &str,
client_ip: Option<&str>,
key: &str,
endpoint: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::AuthDisabled,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(key).await,
endpoint: endpoint.to_string(),
success: false,
details: Some("API key is disabled".to_string()),
})
.await;
}
pub async fn scope_violation(
&self,
request_id: &str,
client_ip: Option<&str>,
key: &str,
endpoint: &str,
required_scope: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::ScopeViolation,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(key).await,
endpoint: endpoint.to_string(),
success: false,
details: Some(format!("Required scope: {required_scope}")),
})
.await;
}
pub async fn rate_limited(
&self,
request_id: &str,
client_ip: Option<&str>,
key: Option<&str>,
endpoint: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::RateLimited,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: match key {
Some(k) => self.maybe_key_prefix(k).await,
None => None,
},
endpoint: endpoint.to_string(),
success: false,
details: Some("Request rate limit exceeded".to_string()),
})
.await;
}
pub async fn model_load(
&self,
request_id: &str,
client_ip: Option<&str>,
key: &str,
model_id: &str,
success: bool,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::ModelLoad,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(key).await,
endpoint: "/api/models/load".to_string(),
success,
details: Some(format!("Model: {model_id}")),
})
.await;
}
pub async fn model_unload(
&self,
request_id: &str,
client_ip: Option<&str>,
key: &str,
model_id: &str,
success: bool,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::ModelUnload,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(key).await,
endpoint: "/api/models/unload".to_string(),
success,
details: Some(format!("Model: {model_id}")),
})
.await;
}
pub async fn key_created(
&self,
request_id: &str,
client_ip: Option<&str>,
admin_key: &str,
new_key_prefix: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::KeyCreated,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(admin_key).await,
endpoint: "/api/keys".to_string(),
success: true,
details: Some(format!("Created key: {new_key_prefix}...")),
})
.await;
}
pub async fn key_revoked(
&self,
request_id: &str,
client_ip: Option<&str>,
admin_key: &str,
revoked_key_prefix: &str,
) {
self.log(AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::KeyRevoked,
request_id: request_id.to_string(),
client_ip: self.maybe_client_ip(client_ip).await,
key_prefix: self.maybe_key_prefix(admin_key).await,
endpoint: "/api/keys".to_string(),
success: true,
details: Some(format!("Revoked key: {revoked_key_prefix}...")),
})
.await;
}
async fn maybe_key_prefix(&self, key: &str) -> Option<String> {
let config = self.config.read().await;
if config.log_key_prefix {
let end = std::cmp::min(8, key.len());
Some(key[..end].to_string())
} else {
None
}
}
async fn maybe_client_ip(&self, ip: Option<&str>) -> Option<String> {
let config = self.config.read().await;
if config.log_client_ip {
ip.map(String::from)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audit_event_type_as_str() {
assert_eq!(AuditEventType::AuthSuccess.as_str(), "auth_success");
assert_eq!(AuditEventType::AuthFailure.as_str(), "auth_failure");
assert_eq!(AuditEventType::AuthExpired.as_str(), "auth_expired");
assert_eq!(AuditEventType::AuthDisabled.as_str(), "auth_disabled");
assert_eq!(AuditEventType::ScopeViolation.as_str(), "scope_violation");
assert_eq!(AuditEventType::RateLimited.as_str(), "rate_limited");
assert_eq!(AuditEventType::ModelLoad.as_str(), "model_load");
assert_eq!(AuditEventType::ModelUnload.as_str(), "model_unload");
assert_eq!(AuditEventType::KeyCreated.as_str(), "key_created");
assert_eq!(AuditEventType::KeyRevoked.as_str(), "key_revoked");
}
#[test]
fn test_audit_config_default() {
let config = AuditConfig::default();
assert!(config.enabled);
assert!(config.log_client_ip);
assert!(config.log_key_prefix);
}
#[test]
fn test_audit_config_disabled() {
let config = AuditConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_audit_event_serialization() {
let event = AuditEvent {
timestamp: Utc::now(),
event_type: AuditEventType::AuthSuccess,
request_id: "req-123".to_string(),
client_ip: Some("192.168.1.1".to_string()),
key_prefix: Some("sk-inf-a".to_string()),
endpoint: "/v1/chat/completions".to_string(),
success: true,
details: None,
};
let json = serde_json::to_string(&event).expect("Should serialize");
assert!(json.contains("auth_success"));
assert!(json.contains("req-123"));
assert!(json.contains("192.168.1.1"));
assert!(json.contains("sk-inf-a"));
}
#[tokio::test]
async fn test_audit_logger_disabled() {
let logger = AuditLogger::disabled();
logger
.auth_success("req-1", Some("127.0.0.1"), "sk-test", "/test")
.await;
logger
.auth_failure("req-2", Some("127.0.0.1"), "bad key", "/test")
.await;
}
#[tokio::test]
async fn test_audit_logger_key_prefix() {
let logger = AuditLogger::enabled();
let prefix = logger.maybe_key_prefix("sk-inf-abcdefghijklmnop").await;
assert_eq!(prefix, Some("sk-inf-a".to_string()));
}
#[tokio::test]
async fn test_audit_logger_short_key() {
let logger = AuditLogger::enabled();
let prefix = logger.maybe_key_prefix("short").await;
assert_eq!(prefix, Some("short".to_string()));
}
}