use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
pub mod bounds {
pub const MAX_SUBJECT_LEN: usize = 256;
pub const MAX_OPERATION_LEN: usize = 50;
pub const MAX_ERROR_MESSAGE_LEN: usize = 1024;
pub const MAX_CONTEXT_LEN: usize = 2048;
pub const MAX_ENTRIES_IN_MEMORY: usize = 10_000;
pub const BYTES_PER_ENTRY: usize = 4096;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuditEventType {
JwtValidation,
JwtRefresh,
OidcCredentialAccess,
OidcTokenExchange,
SessionTokenCreated,
SessionTokenValidation,
SessionTokenRevoked,
CsrfStateGenerated,
CsrfStateValidated,
OauthStart,
OauthCallback,
AuthSuccess,
AuthFailure,
}
impl AuditEventType {
pub fn as_str(&self) -> &'static str {
match self {
AuditEventType::JwtValidation => "jwt_validation",
AuditEventType::JwtRefresh => "jwt_refresh",
AuditEventType::OidcCredentialAccess => "oidc_credential_access",
AuditEventType::OidcTokenExchange => "oidc_token_exchange",
AuditEventType::SessionTokenCreated => "session_token_created",
AuditEventType::SessionTokenValidation => "session_token_validation",
AuditEventType::SessionTokenRevoked => "session_token_revoked",
AuditEventType::CsrfStateGenerated => "csrf_state_generated",
AuditEventType::CsrfStateValidated => "csrf_state_validated",
AuditEventType::OauthStart => "oauth_start",
AuditEventType::OauthCallback => "oauth_callback",
AuditEventType::AuthSuccess => "auth_success",
AuditEventType::AuthFailure => "auth_failure",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SecretType {
JwtToken,
SessionToken,
ClientSecret,
RefreshToken,
AuthorizationCode,
StateToken,
CsrfToken,
}
impl SecretType {
pub fn as_str(&self) -> &'static str {
match self {
SecretType::JwtToken => "jwt_token",
SecretType::SessionToken => "session_token",
SecretType::ClientSecret => "client_secret",
SecretType::RefreshToken => "refresh_token",
SecretType::AuthorizationCode => "authorization_code",
SecretType::StateToken => "state_token",
SecretType::CsrfToken => "csrf_token",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub event_type: AuditEventType,
pub secret_type: SecretType,
pub subject: Option<String>,
pub operation: String,
pub success: bool,
pub error_message: Option<String>,
pub context: Option<String>,
}
pub trait AuditLogger: Send + Sync {
fn log_entry(&self, entry: AuditEntry);
fn log_success(
&self,
event_type: AuditEventType,
secret_type: SecretType,
subject: Option<String>,
operation: &str,
) {
self.log_entry(AuditEntry {
event_type,
secret_type,
subject,
operation: operation.to_string(),
success: true,
error_message: None,
context: None,
});
}
fn log_failure(
&self,
event_type: AuditEventType,
secret_type: SecretType,
subject: Option<String>,
operation: &str,
error: &str,
) {
self.log_entry(AuditEntry {
event_type,
secret_type,
subject,
operation: operation.to_string(),
success: false,
error_message: Some(error.to_string()),
context: None,
});
}
}
pub struct StructuredAuditLogger;
impl StructuredAuditLogger {
pub fn new() -> Self {
Self
}
}
impl Default for StructuredAuditLogger {
fn default() -> Self {
Self::new()
}
}
impl AuditLogger for StructuredAuditLogger {
fn log_entry(&self, entry: AuditEntry) {
if entry.success {
info!(
event_type = entry.event_type.as_str(),
secret_type = entry.secret_type.as_str(),
subject = ?entry.subject,
operation = entry.operation,
context = ?entry.context,
"Security event: successful operation"
);
} else {
warn!(
event_type = entry.event_type.as_str(),
secret_type = entry.secret_type.as_str(),
subject = ?entry.subject,
operation = entry.operation,
error = ?entry.error_message,
context = ?entry.context,
"Security event: failed operation"
);
}
}
}
pub static AUDIT_LOGGER: std::sync::OnceLock<Arc<dyn AuditLogger>> = std::sync::OnceLock::new();
pub fn init_audit_logger(logger: Arc<dyn AuditLogger>) {
let _ = AUDIT_LOGGER.set(logger);
}
pub fn get_audit_logger() -> Arc<dyn AuditLogger> {
AUDIT_LOGGER.get_or_init(|| Arc::new(StructuredAuditLogger::new())).clone()
}
pub trait AuditableResult<T, E> {
fn audit_log(
self,
event_type: AuditEventType,
secret_type: SecretType,
subject: Option<String>,
operation: &str,
) -> Result<T, E>;
}
impl<T, E: std::fmt::Display> AuditableResult<T, E> for Result<T, E> {
fn audit_log(
self,
event_type: AuditEventType,
secret_type: SecretType,
subject: Option<String>,
operation: &str,
) -> Result<T, E> {
let logger = get_audit_logger();
match &self {
Ok(_) => logger.log_success(event_type, secret_type, subject, operation),
Err(e) => {
logger.log_failure(event_type, secret_type, subject, operation, &e.to_string());
},
}
self
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
struct TestAuditLogger {
entries: Mutex<Vec<AuditEntry>>,
}
impl TestAuditLogger {
fn new() -> Self {
Self {
entries: Mutex::new(Vec::new()),
}
}
fn get_entries(&self) -> Vec<AuditEntry> {
self.entries.lock().unwrap().clone()
}
}
impl AuditLogger for TestAuditLogger {
fn log_entry(&self, entry: AuditEntry) {
self.entries.lock().unwrap().push(entry);
}
}
#[test]
fn test_audit_entry_creation() {
let entry = AuditEntry {
event_type: AuditEventType::JwtValidation,
secret_type: SecretType::JwtToken,
subject: Some("user123".to_string()),
operation: "validate".to_string(),
success: true,
error_message: None,
context: None,
};
assert_eq!(entry.event_type, AuditEventType::JwtValidation);
assert_eq!(entry.subject, Some("user123".to_string()));
assert!(entry.success);
}
#[test]
fn test_audit_logger_logs_entry() {
let logger = TestAuditLogger::new();
logger.log_success(
AuditEventType::JwtValidation,
SecretType::JwtToken,
Some("user123".to_string()),
"validate",
);
let entries = logger.get_entries();
assert_eq!(entries.len(), 1);
assert!(entries[0].success);
}
#[test]
fn test_audit_logger_logs_failure() {
let logger = TestAuditLogger::new();
logger.log_failure(
AuditEventType::JwtValidation,
SecretType::JwtToken,
Some("user123".to_string()),
"validate",
"Invalid signature",
);
let entries = logger.get_entries();
assert_eq!(entries.len(), 1);
assert!(!entries[0].success);
assert_eq!(entries[0].error_message, Some("Invalid signature".to_string()));
}
#[test]
fn test_event_type_strings() {
assert_eq!(AuditEventType::JwtValidation.as_str(), "jwt_validation");
assert_eq!(AuditEventType::OidcTokenExchange.as_str(), "oidc_token_exchange");
}
#[test]
fn test_secret_type_strings() {
assert_eq!(SecretType::JwtToken.as_str(), "jwt_token");
assert_eq!(SecretType::ClientSecret.as_str(), "client_secret");
}
#[test]
fn test_bounds_constants_are_reasonable() {
use crate::audit_logger::bounds;
let max_subject = bounds::MAX_SUBJECT_LEN;
assert!(max_subject >= 128, "Subject length too small");
let max_operation = bounds::MAX_OPERATION_LEN;
assert!(max_operation >= 20, "Operation length too small");
let max_error = bounds::MAX_ERROR_MESSAGE_LEN;
assert!(max_error >= 512, "Error message length too small");
let max_context = bounds::MAX_CONTEXT_LEN;
assert!(max_context >= 1024, "Context length too small");
let max_entries = bounds::MAX_ENTRIES_IN_MEMORY;
assert!(max_entries >= 1000, "Max entries in memory too small");
assert!(max_entries <= 100_000, "Max entries in memory too large");
}
#[test]
fn test_bounds_constants_match_documentation() {
use crate::audit_logger::bounds;
assert_eq!(bounds::MAX_SUBJECT_LEN, 256, "Subject length bound mismatch");
assert_eq!(bounds::MAX_OPERATION_LEN, 50, "Operation length bound mismatch");
assert_eq!(bounds::MAX_ERROR_MESSAGE_LEN, 1024, "Error message length bound mismatch");
assert_eq!(bounds::MAX_CONTEXT_LEN, 2048, "Context length bound mismatch");
assert_eq!(bounds::MAX_ENTRIES_IN_MEMORY, 10_000, "Max entries in memory bound mismatch");
}
#[test]
fn test_memory_per_entry_constant_is_reasonable() {
use crate::audit_logger::bounds;
let bytes_per_entry = bounds::BYTES_PER_ENTRY;
let max_entries = bounds::MAX_ENTRIES_IN_MEMORY;
let total_memory_mb = (bytes_per_entry * max_entries) / (1024 * 1024);
assert!(
total_memory_mb < 100,
"Total memory for full buffer too large: {} MB",
total_memory_mb
);
assert!(
total_memory_mb > 10,
"Total memory for full buffer too small: {} MB",
total_memory_mb
);
}
#[test]
fn test_audit_entry_field_sizes_within_bounds() {
use crate::audit_logger::bounds;
let entry = AuditEntry {
event_type: AuditEventType::JwtValidation,
secret_type: SecretType::JwtToken,
subject: Some("a".repeat(bounds::MAX_SUBJECT_LEN)),
operation: "validate".to_string(),
success: true,
error_message: None,
context: None,
};
assert!(entry.subject.as_ref().unwrap().len() <= bounds::MAX_SUBJECT_LEN);
}
#[test]
fn test_error_message_bound_accommodates_typical_errors() {
use crate::audit_logger::bounds;
let error_messages = vec![
"Invalid signature",
"Token expired",
"User not authorized",
"Failed to decrypt payload: AES-256-GCM decryption returned InvalidTag",
"Database connection timeout after 30 seconds waiting for available connection",
];
for msg in error_messages {
assert!(
msg.len() <= bounds::MAX_ERROR_MESSAGE_LEN,
"Error message too long: {} bytes for: {}",
msg.len(),
msg
);
}
}
#[test]
fn test_operation_bound_covers_all_audit_operations() {
use crate::audit_logger::bounds;
let operations = vec![
"validate", "create", "revoke", "refresh", "exchange", "logout",
];
for op in operations {
assert!(
op.len() <= bounds::MAX_OPERATION_LEN,
"Operation name too long: {} bytes for: {}",
op.len(),
op
);
}
}
#[test]
fn test_global_audit_logger_is_singleton() {
let logger1 = get_audit_logger();
let logger2 = get_audit_logger();
assert_eq!(
Arc::as_ptr(&logger1),
Arc::as_ptr(&logger2),
"Audit loggers are not the same singleton instance"
);
}
#[test]
fn test_audit_entry_sizes_reasonable_for_serialization() {
use crate::audit_logger::bounds;
let max_entry = AuditEntry {
event_type: AuditEventType::JwtValidation,
secret_type: SecretType::JwtToken,
subject: Some("a".repeat(bounds::MAX_SUBJECT_LEN)),
operation: "validate".to_string(),
success: false,
error_message: Some("e".repeat(bounds::MAX_ERROR_MESSAGE_LEN)),
context: Some("c".repeat(bounds::MAX_CONTEXT_LEN)),
};
let json = serde_json::to_string(&max_entry);
assert!(json.is_ok(), "Failed to serialize maximum-size entry");
let json_size = json.unwrap().len();
assert!(
json_size < bounds::BYTES_PER_ENTRY * 2,
"JSON serialization too large: {} bytes",
json_size
);
}
}