use serde::Serialize;
use std::collections::HashMap;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
use tracing::{Level, debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize)]
pub struct LogEntry {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub level: String,
pub logger: String,
pub message: String,
pub fields: HashMap<String, serde_json::Value>,
pub request_id: Option<String>,
pub user_id: Option<Uuid>,
pub trace_id: Option<String>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct AsyncLoggerConfig {
pub buffer_size: usize,
pub drop_on_overflow: bool,
pub sample_rate: f64,
pub max_message_length: usize,
}
impl Default for AsyncLoggerConfig {
fn default() -> Self {
Self {
buffer_size: 10000,
drop_on_overflow: false,
sample_rate: 1.0,
max_message_length: 1024,
}
}
}
#[allow(dead_code)]
pub struct AsyncLogger {
sender: mpsc::UnboundedSender<LogEntry>,
config: AsyncLoggerConfig,
sample_counter: AtomicU64,
}
#[allow(dead_code)]
impl AsyncLogger {
pub fn new(config: AsyncLoggerConfig) -> Self {
let (sender, mut receiver) = mpsc::unbounded_channel::<LogEntry>();
tokio::spawn(async move {
while let Some(entry) = receiver.recv().await {
Self::process_log_entry(entry).await;
}
});
Self {
sender,
config,
sample_counter: AtomicU64::new(0),
}
}
pub fn log_structured(
&self,
level: Level,
logger: &str,
message: &str,
fields: HashMap<String, serde_json::Value>,
request_id: Option<String>,
user_id: Option<Uuid>,
) {
if self.config.sample_rate < 1.0 {
let counter = self.sample_counter.fetch_add(1, Ordering::Relaxed);
let sample_threshold = (u64::MAX as f64 * self.config.sample_rate) as u64;
if counter % (u64::MAX / sample_threshold.max(1)) != 0 {
return;
}
}
let truncated_message = if message.len() > self.config.max_message_length {
format!("{}...", &message[..self.config.max_message_length - 3])
} else {
message.to_string()
};
let entry = LogEntry {
timestamp: chrono::Utc::now(),
level: level.to_string(),
logger: logger.to_string(),
message: truncated_message,
fields,
request_id,
user_id,
trace_id: Self::current_trace_id(),
};
if let Err(_) = self.sender.send(entry) {
eprintln!("Async logger channel closed, falling back to sync logging");
}
}
pub fn log(&self, level: Level, logger: &str, message: &str) {
self.log_structured(level, logger, message, HashMap::new(), None, None);
}
pub fn log_with_context(
&self,
level: Level,
logger: &str,
message: &str,
request_id: Option<String>,
user_id: Option<Uuid>,
) {
self.log_structured(level, logger, message, HashMap::new(), request_id, user_id);
}
async fn process_log_entry(entry: LogEntry) {
let level = match entry.level.as_str() {
"ERROR" => Level::ERROR,
"WARN" => Level::WARN,
"INFO" => Level::INFO,
"DEBUG" => Level::DEBUG,
_ => Level::INFO,
};
match level {
Level::ERROR => error!(
logger = entry.logger,
request_id = entry.request_id,
user_id = ?entry.user_id,
trace_id = entry.trace_id,
fields = ?entry.fields,
"{}",
entry.message
),
Level::WARN => warn!(
logger = entry.logger,
request_id = entry.request_id,
user_id = ?entry.user_id,
trace_id = entry.trace_id,
fields = ?entry.fields,
"{}",
entry.message
),
Level::INFO => info!(
logger = entry.logger,
request_id = entry.request_id,
user_id = ?entry.user_id,
trace_id = entry.trace_id,
fields = ?entry.fields,
"{}",
entry.message
),
Level::DEBUG => debug!(
logger = entry.logger,
request_id = entry.request_id,
user_id = ?entry.user_id,
trace_id = entry.trace_id,
fields = ?entry.fields,
"{}",
entry.message
),
_ => info!(
logger = entry.logger,
request_id = entry.request_id,
user_id = ?entry.user_id,
trace_id = entry.trace_id,
fields = ?entry.fields,
"{}",
entry.message
),
}
}
fn current_trace_id() -> Option<String> {
None
}
}
#[allow(dead_code)]
static ASYNC_LOGGER: OnceLock<AsyncLogger> = OnceLock::new();
#[allow(dead_code)]
pub fn init_async_logger(config: AsyncLoggerConfig) {
ASYNC_LOGGER.get_or_init(|| AsyncLogger::new(config));
}
#[allow(dead_code)]
pub fn async_logger() -> Option<&'static AsyncLogger> {
ASYNC_LOGGER.get()
}
#[allow(dead_code)]
pub struct LogSampler {
sample_rates: HashMap<String, f64>,
counters: HashMap<String, AtomicU64>,
}
#[allow(dead_code)]
impl LogSampler {
pub fn new() -> Self {
Self {
sample_rates: HashMap::new(),
counters: HashMap::new(),
}
}
pub fn set_sample_rate(&mut self, category: &str, rate: f64) {
self.sample_rates
.insert(category.to_string(), rate.clamp(0.0, 1.0));
self.counters
.insert(category.to_string(), AtomicU64::new(0));
}
pub fn should_log(&self, category: &str) -> bool {
if let Some(&rate) = self.sample_rates.get(category) {
if rate >= 1.0 {
return true;
}
if rate <= 0.0 {
return false;
}
if let Some(counter) = self.counters.get(category) {
let count = counter.fetch_add(1, Ordering::Relaxed);
let sample_threshold = (1.0 / rate) as u64;
count % sample_threshold == 0
} else {
true
}
} else {
true
}
}
}
#[allow(dead_code)]
pub struct SecurityLogger;
#[allow(dead_code)]
impl SecurityLogger {
pub fn log_auth_event(
event_type: &str,
user_id: Option<Uuid>,
ip_address: Option<&str>,
user_agent: Option<&str>,
success: bool,
details: Option<&str>,
) {
let mut fields = HashMap::new();
fields.insert(
"event_type".to_string(),
serde_json::Value::String(event_type.to_string()),
);
fields.insert("success".to_string(), serde_json::Value::Bool(success));
if let Some(ip) = ip_address {
fields.insert(
"ip_address".to_string(),
serde_json::Value::String(ip.to_string()),
);
}
if let Some(ua) = user_agent {
let safe_ua = ua.chars().take(200).collect::<String>();
fields.insert("user_agent".to_string(), serde_json::Value::String(safe_ua));
}
if let Some(details) = details {
fields.insert(
"details".to_string(),
serde_json::Value::String(details.to_string()),
);
}
let level = if success { Level::INFO } else { Level::WARN };
let message = format!(
"Authentication {}: {}",
if success { "success" } else { "failure" },
event_type
);
if let Some(logger) = async_logger() {
logger.log_structured(level, "security", &message, fields, None, user_id);
}
}
pub fn log_authz_event(
user_id: Uuid,
resource: &str,
action: &str,
granted: bool,
reason: Option<&str>,
) {
let mut fields = HashMap::new();
fields.insert(
"resource".to_string(),
serde_json::Value::String(resource.to_string()),
);
fields.insert(
"action".to_string(),
serde_json::Value::String(action.to_string()),
);
fields.insert("granted".to_string(), serde_json::Value::Bool(granted));
if let Some(reason) = reason {
fields.insert(
"reason".to_string(),
serde_json::Value::String(reason.to_string()),
);
}
let level = if granted { Level::DEBUG } else { Level::WARN };
let message = format!(
"Authorization {}: {} on {}",
if granted { "granted" } else { "denied" },
action,
resource
);
if let Some(logger) = async_logger() {
logger.log_structured(level, "security", &message, fields, None, Some(user_id));
}
}
pub fn log_security_violation(
violation_type: &str,
severity: &str,
description: &str,
user_id: Option<Uuid>,
ip_address: Option<&str>,
additional_data: Option<HashMap<String, serde_json::Value>>,
) {
let mut fields = HashMap::new();
fields.insert(
"violation_type".to_string(),
serde_json::Value::String(violation_type.to_string()),
);
fields.insert(
"severity".to_string(),
serde_json::Value::String(severity.to_string()),
);
if let Some(ip) = ip_address {
fields.insert(
"ip_address".to_string(),
serde_json::Value::String(ip.to_string()),
);
}
if let Some(data) = additional_data {
for (key, value) in data {
fields.insert(key, value);
}
}
let level = match severity.to_lowercase().as_str() {
"critical" | "high" => Level::ERROR,
"medium" => Level::WARN,
_ => Level::INFO,
};
if let Some(logger) = async_logger() {
logger.log_structured(level, "security", description, fields, None, user_id);
}
}
}
#[allow(dead_code)]
pub struct PerformanceLogger;
#[allow(dead_code)]
impl PerformanceLogger {
pub fn log_request_metrics(
method: &str,
path: &str,
status_code: u16,
duration_ms: u64,
request_size: u64,
response_size: u64,
user_id: Option<Uuid>,
request_id: Option<String>,
) {
let mut fields = HashMap::new();
fields.insert(
"method".to_string(),
serde_json::Value::String(method.to_string()),
);
fields.insert(
"path".to_string(),
serde_json::Value::String(path.to_string()),
);
fields.insert(
"status_code".to_string(),
serde_json::Value::Number(status_code.into()),
);
fields.insert(
"duration_ms".to_string(),
serde_json::Value::Number(duration_ms.into()),
);
fields.insert(
"request_size".to_string(),
serde_json::Value::Number(request_size.into()),
);
fields.insert(
"response_size".to_string(),
serde_json::Value::Number(response_size.into()),
);
let message = format!("{} {} {} {}ms", method, path, status_code, duration_ms);
let level = if duration_ms > 5000 {
Level::WARN } else if duration_ms > 1000 {
Level::INFO } else {
Level::DEBUG };
if let Some(logger) = async_logger() {
logger.log_structured(level, "performance", &message, fields, request_id, user_id);
}
}
pub fn log_provider_metrics(
provider: &str,
model: &str,
duration_ms: u64,
token_count: Option<u32>,
success: bool,
error: Option<&str>,
) {
let mut fields = HashMap::new();
fields.insert(
"provider".to_string(),
serde_json::Value::String(provider.to_string()),
);
fields.insert(
"model".to_string(),
serde_json::Value::String(model.to_string()),
);
fields.insert(
"duration_ms".to_string(),
serde_json::Value::Number(duration_ms.into()),
);
fields.insert("success".to_string(), serde_json::Value::Bool(success));
if let Some(tokens) = token_count {
fields.insert(
"token_count".to_string(),
serde_json::Value::Number(tokens.into()),
);
}
if let Some(err) = error {
fields.insert(
"error".to_string(),
serde_json::Value::String(err.to_string()),
);
}
let level = if success { Level::DEBUG } else { Level::WARN };
let message = format!(
"Provider {} {} {}ms {}",
provider,
model,
duration_ms,
if success { "success" } else { "failed" }
);
if let Some(logger) = async_logger() {
logger.log_structured(level, "performance", &message, fields, None, None);
}
}
}
#[macro_export]
macro_rules! log_structured {
($level:expr, $logger:expr, $message:expr, $($key:expr => $value:expr),*) => {
{
let mut fields = std::collections::HashMap::new();
$(
fields.insert($key.to_string(), serde_json::to_value($value).unwrap_or(serde_json::Value::Null));
)*
if let Some(logger) = $crate::utils::logging::async_logger() {
logger.log_structured($level, $logger, $message, fields, None, None);
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_sampler() {
let mut sampler = LogSampler::new();
sampler.set_sample_rate("test", 0.5);
let mut sampled_count = 0;
for _ in 0..1000 {
if sampler.should_log("test") {
sampled_count += 1;
}
}
assert!(sampled_count > 400 && sampled_count < 600);
}
#[test]
fn test_async_logger_config() {
let config = AsyncLoggerConfig {
buffer_size: 5000,
drop_on_overflow: true,
sample_rate: 0.8,
max_message_length: 512,
};
assert_eq!(config.buffer_size, 5000);
assert_eq!(config.drop_on_overflow, true);
assert_eq!(config.sample_rate, 0.8);
assert_eq!(config.max_message_length, 512);
}
#[tokio::test]
async fn test_async_logger_creation() {
let config = AsyncLoggerConfig::default();
let logger = AsyncLogger::new(config);
logger.log(Level::INFO, "test", "test message");
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
}