#[cfg(feature = "structured-logging")]
use crate::{TrainError, TrainResult};
#[cfg(feature = "structured-logging")]
use tracing_subscriber::{
fmt::{self, format::FmtSpan},
layer::SubscriberExt,
util::SubscriberInitExt,
EnvFilter,
};
#[cfg(feature = "structured-logging")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogFormat {
Pretty,
Compact,
Json,
}
#[cfg(feature = "structured-logging")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
#[cfg(feature = "structured-logging")]
impl LogLevel {
fn as_str(&self) -> &'static str {
match self {
LogLevel::Trace => "trace",
LogLevel::Debug => "debug",
LogLevel::Info => "info",
LogLevel::Warn => "warn",
LogLevel::Error => "error",
}
}
}
#[cfg(feature = "structured-logging")]
#[derive(Debug, Clone)]
pub struct TracingLoggerBuilder {
format: LogFormat,
level: LogLevel,
env_filter: Option<String>,
with_targets: bool,
with_file_location: bool,
with_thread_ids: bool,
with_span_events: bool,
}
#[cfg(feature = "structured-logging")]
impl Default for TracingLoggerBuilder {
fn default() -> Self {
Self {
format: LogFormat::Pretty,
level: LogLevel::Info,
env_filter: None,
with_targets: true,
with_file_location: false,
with_thread_ids: false,
with_span_events: false,
}
}
}
#[cfg(feature = "structured-logging")]
impl TracingLoggerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_format(mut self, format: LogFormat) -> Self {
self.format = format;
self
}
pub fn with_level(mut self, level: LogLevel) -> Self {
self.level = level;
self
}
pub fn with_env_filter(mut self, filter: impl Into<String>) -> Self {
self.env_filter = Some(filter.into());
self
}
pub fn with_targets(mut self, enabled: bool) -> Self {
self.with_targets = enabled;
self
}
pub fn with_file_location(mut self, enabled: bool) -> Self {
self.with_file_location = enabled;
self
}
pub fn with_thread_ids(mut self, enabled: bool) -> Self {
self.with_thread_ids = enabled;
self
}
pub fn with_span_events(mut self, enabled: bool) -> Self {
self.with_span_events = enabled;
self
}
pub fn build(self) -> TrainResult<TracingLogger> {
let env_filter = if let Some(custom_filter) = self.env_filter {
EnvFilter::try_new(custom_filter)
.map_err(|e| TrainError::Other(format!("Invalid env filter: {}", e)))?
} else {
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(self.level.as_str()))
};
let span_events = if self.with_span_events {
FmtSpan::NEW | FmtSpan::CLOSE
} else {
FmtSpan::NONE
};
match self.format {
LogFormat::Pretty => {
let layer = fmt::layer()
.with_target(self.with_targets)
.with_file(self.with_file_location)
.with_line_number(self.with_file_location)
.with_thread_ids(self.with_thread_ids)
.with_span_events(span_events)
.pretty();
tracing_subscriber::registry()
.with(env_filter)
.with(layer)
.try_init()
.map_err(|e| {
TrainError::Other(format!("Failed to initialize tracing: {}", e))
})?;
}
LogFormat::Compact => {
let layer = fmt::layer()
.with_target(self.with_targets)
.with_file(self.with_file_location)
.with_line_number(self.with_file_location)
.with_thread_ids(self.with_thread_ids)
.with_span_events(span_events)
.with_ansi(false)
.compact();
tracing_subscriber::registry()
.with(env_filter)
.with(layer)
.try_init()
.map_err(|e| {
TrainError::Other(format!("Failed to initialize tracing: {}", e))
})?;
}
LogFormat::Json => {
let layer = fmt::layer()
.with_target(self.with_targets)
.with_file(self.with_file_location)
.with_line_number(self.with_file_location)
.with_thread_ids(self.with_thread_ids)
.with_span_events(span_events)
.json();
tracing_subscriber::registry()
.with(env_filter)
.with(layer)
.try_init()
.map_err(|e| {
TrainError::Other(format!("Failed to initialize tracing: {}", e))
})?;
}
}
Ok(TracingLogger {
_format: self.format,
})
}
}
#[cfg(feature = "structured-logging")]
#[derive(Debug)]
pub struct TracingLogger {
_format: LogFormat,
}
#[cfg(feature = "structured-logging")]
impl TracingLogger {
pub fn builder() -> TracingLoggerBuilder {
TracingLoggerBuilder::new()
}
pub fn init() -> TrainResult<Self> {
Self::builder().build()
}
pub fn init_production() -> TrainResult<Self> {
Self::builder()
.with_format(LogFormat::Json)
.with_level(LogLevel::Info)
.with_targets(false)
.build()
}
pub fn init_development() -> TrainResult<Self> {
Self::builder()
.with_format(LogFormat::Pretty)
.with_level(LogLevel::Debug)
.with_file_location(true)
.build()
}
}
#[cfg(feature = "structured-logging")]
pub mod training {
#[macro_export]
macro_rules! log_epoch {
($epoch:expr, $loss:expr, $($key:ident = $value:expr),* $(,)?) => {
tracing::info!(
epoch = $epoch,
loss = $loss,
$($key = $value,)*
"Epoch completed"
);
};
}
#[macro_export]
macro_rules! log_batch {
($batch:expr, $loss:expr, $($key:ident = $value:expr),* $(,)?) => {
tracing::debug!(
batch = $batch,
loss = $loss,
$($key = $value,)*
"Batch processed"
);
};
}
#[macro_export]
macro_rules! log_gradients {
($norm:expr, $($key:ident = $value:expr),* $(,)?) => {
tracing::trace!(
gradient_norm = $norm,
$($key = $value,)*
"Gradient statistics"
);
};
}
#[macro_export]
macro_rules! training_span {
($name:expr, $($key:ident = $value:expr),* $(,)?) => {
tracing::info_span!($name, $($key = $value,)*)
};
}
}
#[cfg(all(test, feature = "structured-logging"))]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let builder = TracingLoggerBuilder::new();
assert_eq!(builder.format, LogFormat::Pretty);
assert_eq!(builder.level, LogLevel::Info);
}
#[test]
fn test_builder_configuration() {
let builder = TracingLoggerBuilder::new()
.with_format(LogFormat::Json)
.with_level(LogLevel::Debug)
.with_targets(false)
.with_file_location(true)
.with_thread_ids(true)
.with_span_events(true);
assert_eq!(builder.format, LogFormat::Json);
assert_eq!(builder.level, LogLevel::Debug);
assert!(!builder.with_targets);
assert!(builder.with_file_location);
assert!(builder.with_thread_ids);
assert!(builder.with_span_events);
}
#[test]
fn test_log_level_as_str() {
assert_eq!(LogLevel::Trace.as_str(), "trace");
assert_eq!(LogLevel::Debug.as_str(), "debug");
assert_eq!(LogLevel::Info.as_str(), "info");
assert_eq!(LogLevel::Warn.as_str(), "warn");
assert_eq!(LogLevel::Error.as_str(), "error");
}
#[test]
fn test_custom_env_filter() {
let builder = TracingLoggerBuilder::new().with_env_filter("tensorlogic=debug,scirs2=info");
assert!(builder.env_filter.is_some());
assert_eq!(
builder.env_filter.expect("unwrap"),
"tensorlogic=debug,scirs2=info"
);
}
}