use std::sync::Once;
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)] pub struct LogConfig {
pub default_level: LogLevel,
pub with_timestamps: bool,
pub with_target: bool,
pub with_file_line: bool,
pub with_ansi: bool,
}
impl Default for LogConfig {
fn default() -> Self {
Self {
default_level: LogLevel::Info,
with_timestamps: true,
with_target: true,
with_file_line: false,
with_ansi: true,
}
}
}
impl LogConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_level(mut self, level: LogLevel) -> Self {
self.default_level = level;
self
}
#[must_use]
pub fn with_timestamps(mut self, enable: bool) -> Self {
self.with_timestamps = enable;
self
}
#[must_use]
pub fn with_ansi(mut self, enable: bool) -> Self {
self.with_ansi = enable;
self
}
#[must_use]
pub fn development() -> Self {
Self {
default_level: LogLevel::Debug,
with_timestamps: true,
with_target: true,
with_file_line: true,
with_ansi: true,
}
}
#[must_use]
pub fn production() -> Self {
Self {
default_level: LogLevel::Info,
with_timestamps: true,
with_target: false,
with_file_line: false,
with_ansi: false,
}
}
#[must_use]
pub fn testing() -> Self {
Self {
default_level: LogLevel::Warn,
with_timestamps: false,
with_target: false,
with_file_line: false,
with_ansi: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LogLevel {
Error,
Warn,
#[default]
Info,
Debug,
Trace,
}
impl LogLevel {
fn as_filter_str(self) -> &'static str {
match self {
Self::Error => "error",
Self::Warn => "warn",
Self::Info => "info",
Self::Debug => "debug",
Self::Trace => "trace",
}
}
}
static INIT_LOGGING: Once = Once::new();
pub fn init_logging(config: &LogConfig) {
INIT_LOGGING.call_once(|| {
let filter = std::env::var("RUST_LOG")
.unwrap_or_else(|_| config.default_level.as_filter_str().to_string());
let builder = tracing_subscriber::fmt()
.with_env_filter(filter)
.with_ansi(config.with_ansi)
.with_target(config.with_target)
.with_file(config.with_file_line)
.with_line_number(config.with_file_line);
if config.with_timestamps {
builder.init();
} else {
builder.without_time().init();
}
});
}
#[macro_export]
macro_rules! log_metric {
($($field:ident = $value:expr),+ $(,)?) => {
tracing::info!(
target: "rust_ai::metrics",
$($field = $value),+
);
};
}
#[allow(clippy::cast_precision_loss)] pub fn log_training_step(step: usize, total_steps: usize, loss: f64, lr: f64) {
let progress_pct = if total_steps > 0 {
(step as f64 / total_steps as f64) * 100.0
} else {
0.0
};
tracing::info!(
target: "rust_ai::training",
step,
total_steps,
progress_pct = format!("{progress_pct:.1}"),
loss = format!("{loss:.6}"),
lr = format!("{lr:.2e}"),
"Training step"
);
}
#[allow(clippy::cast_precision_loss)] pub fn log_memory_usage(allocated_bytes: usize, peak_bytes: usize, context: &str) {
let allocated_mb = allocated_bytes as f64 / (1024.0 * 1024.0);
let peak_mb = peak_bytes as f64 / (1024.0 * 1024.0);
tracing::debug!(
target: "rust_ai::memory",
allocated_mb = format!("{allocated_mb:.2}"),
peak_mb = format!("{peak_mb:.2}"),
context,
"Memory usage"
);
}
pub use tracing::{debug, error, info, trace, warn};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_config_default() {
let config = LogConfig::default();
assert!(matches!(config.default_level, LogLevel::Info));
assert!(config.with_timestamps);
assert!(config.with_ansi);
}
#[test]
fn test_log_config_builder() {
let config = LogConfig::new()
.with_level(LogLevel::Debug)
.with_timestamps(false)
.with_ansi(false);
assert!(matches!(config.default_level, LogLevel::Debug));
assert!(!config.with_timestamps);
assert!(!config.with_ansi);
}
#[test]
fn test_log_config_presets() {
let dev = LogConfig::development();
assert!(matches!(dev.default_level, LogLevel::Debug));
assert!(dev.with_file_line);
let prod = LogConfig::production();
assert!(matches!(prod.default_level, LogLevel::Info));
assert!(!prod.with_ansi);
let test = LogConfig::testing();
assert!(matches!(test.default_level, LogLevel::Warn));
assert!(!test.with_timestamps);
}
#[test]
fn test_log_level_filter_str() {
assert_eq!(LogLevel::Error.as_filter_str(), "error");
assert_eq!(LogLevel::Warn.as_filter_str(), "warn");
assert_eq!(LogLevel::Info.as_filter_str(), "info");
assert_eq!(LogLevel::Debug.as_filter_str(), "debug");
assert_eq!(LogLevel::Trace.as_filter_str(), "trace");
}
}