use crate::shared::context::RequestContext;
use crate::types::RequestId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{field, span, Level, Span};
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
use tracing_subscriber::layer::SubscriberExt;
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
use tracing_subscriber::util::SubscriberInitExt;
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
use tracing_subscriber::{EnvFilter, Layer};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig {
pub level: LogLevel,
pub timestamps: bool,
pub source_location: bool,
pub correlation_ids: bool,
pub custom_fields: HashMap<String, serde_json::Value>,
pub format: LogFormat,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogFormat {
Json,
Pretty,
Compact,
}
impl Default for LogConfig {
fn default() -> Self {
Self {
level: LogLevel::Info,
timestamps: true,
source_location: true,
correlation_ids: true,
custom_fields: HashMap::new(),
format: LogFormat::Pretty,
}
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
pub fn init_logging(config: LogConfig) -> Result<(), Box<dyn std::error::Error>> {
let env_filter = match config.level {
LogLevel::Trace => EnvFilter::new("trace"),
LogLevel::Debug => EnvFilter::new("debug"),
LogLevel::Info => EnvFilter::new("info"),
LogLevel::Warn => EnvFilter::new("warn"),
LogLevel::Error => EnvFilter::new("error"),
};
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(config.source_location)
.with_thread_ids(true)
.with_thread_names(true);
let fmt_layer = match config.format {
LogFormat::Json => fmt_layer.with_ansi(false).boxed(),
LogFormat::Pretty => fmt_layer.pretty().boxed(),
LogFormat::Compact => fmt_layer.compact().boxed(),
};
tracing_subscriber::registry()
.with(env_filter)
.with(fmt_layer)
.with(CorrelationLayer::new(config))
.init();
Ok(())
}
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
pub struct CorrelationLayer {
config: LogConfig,
}
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
impl std::fmt::Debug for CorrelationLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CorrelationLayer")
.field("config", &self.config)
.finish()
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
impl CorrelationLayer {
pub fn new(config: LogConfig) -> Self {
Self { config }
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "logging"))]
impl<S> Layer<S> for CorrelationLayer
where
S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
{
fn on_new_span(
&self,
_attrs: &span::Attributes<'_>,
id: &span::Id,
ctx: tracing_subscriber::layer::Context<'_, S>,
) {
if self.config.correlation_ids {
if let Some(context) = RequestContext::current() {
if let Some(span) = ctx.span(id) {
let mut extensions = span.extensions_mut();
extensions.insert(context);
}
}
}
}
}
pub struct CorrelatedLogger {
span: Span,
}
impl std::fmt::Debug for CorrelatedLogger {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CorrelatedLogger")
.field("span", &"Span")
.finish()
}
}
impl CorrelatedLogger {
pub fn new(operation: &str) -> Self {
let span = if let Some(context) = RequestContext::current() {
span!(
Level::INFO,
"operation",
name = operation,
request_id = %context.request_id,
trace_id = %context.trace_id,
span_id = %context.span_id,
user_id = field::Empty,
session_id = field::Empty,
)
} else {
span!(
Level::INFO,
"operation",
name = operation,
request_id = field::Empty,
trace_id = field::Empty,
span_id = field::Empty,
user_id = field::Empty,
session_id = field::Empty,
)
};
if let Some(context) = RequestContext::current() {
if let Some(user_id) = &context.user_id {
span.record("user_id", field::display(user_id));
}
if let Some(session_id) = &context.session_id {
span.record("session_id", field::display(session_id));
}
}
Self { span }
}
pub fn from_context(operation: &str, context: &RequestContext) -> Self {
let span = span!(
Level::INFO,
"operation",
name = operation,
request_id = %context.request_id,
trace_id = %context.trace_id,
span_id = %context.span_id,
user_id = field::Empty,
session_id = field::Empty,
);
if let Some(user_id) = &context.user_id {
span.record("user_id", field::display(user_id));
}
if let Some(session_id) = &context.session_id {
span.record("session_id", field::display(session_id));
}
Self { span }
}
pub fn enter(&self) -> span::Entered<'_> {
self.span.enter()
}
pub fn in_scope<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
self.span.in_scope(f)
}
}
#[derive(Debug, Clone, Serialize)]
pub struct LogEntry {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub level: String,
pub message: String,
pub request_id: Option<RequestId>,
pub trace_id: Option<String>,
pub span_id: Option<String>,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub fields: HashMap<String, serde_json::Value>,
pub error: Option<ErrorDetails>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ErrorDetails {
pub error_type: String,
pub message: String,
pub stack_trace: Option<Vec<String>>,
pub code: Option<String>,
}
impl LogEntry {
pub fn new(level: impl Into<String>, message: impl Into<String>) -> Self {
let mut entry = Self {
timestamp: chrono::Utc::now(),
level: level.into(),
message: message.into(),
request_id: None,
trace_id: None,
span_id: None,
user_id: None,
session_id: None,
fields: HashMap::new(),
error: None,
};
if let Some(context) = RequestContext::current() {
entry.request_id = Some(context.request_id.clone());
entry.trace_id = Some(context.trace_id.clone());
entry.span_id = Some(context.span_id.clone());
entry.user_id.clone_from(&context.user_id);
entry.session_id.clone_from(&context.session_id);
}
entry
}
pub fn with_field(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.fields.insert(key.into(), value);
self
}
pub fn with_error(mut self, error: ErrorDetails) -> Self {
self.error = Some(error);
self
}
#[allow(clippy::cognitive_complexity)]
pub fn log(self) {
let json = serde_json::to_string(&self).unwrap_or_else(|_| self.message.clone());
match self.level.to_lowercase().as_str() {
"trace" => tracing::trace!("{}", json),
"debug" => tracing::debug!("{}", json),
"info" => tracing::info!("{}", json),
"warn" => tracing::warn!("{}", json),
"error" => tracing::error!("{}", json),
_ => tracing::info!("{}", json),
}
}
}
#[macro_export]
macro_rules! log_correlated {
($level:expr, $msg:expr) => {
$crate::shared::logging::LogEntry::new($level, $msg).log()
};
($level:expr, $msg:expr, $($key:expr => $value:expr),* $(,)?) => {
{
let mut entry = $crate::shared::logging::LogEntry::new($level, $msg);
$(
entry = entry.with_field($key, serde_json::json!($value));
)*
entry.log()
}
};
}
#[macro_export]
macro_rules! info_correlated {
($msg:expr) => {
$crate::log_correlated!("info", $msg)
};
($msg:expr, $($key:expr => $value:expr),* $(,)?) => {
$crate::log_correlated!("info", $msg, $($key => $value),*)
};
}
#[macro_export]
macro_rules! error_correlated {
($msg:expr) => {
$crate::log_correlated!("error", $msg)
};
($msg:expr, $($key:expr => $value:expr),* $(,)?) => {
$crate::log_correlated!("error", $msg, $($key => $value),*)
};
}
#[macro_export]
macro_rules! warn_correlated {
($msg:expr) => {
$crate::log_correlated!("warn", $msg)
};
($msg:expr, $($key:expr => $value:expr),* $(,)?) => {
$crate::log_correlated!("warn", $msg, $($key => $value),*)
};
}
#[macro_export]
macro_rules! debug_correlated {
($msg:expr) => {
$crate::log_correlated!("debug", $msg)
};
($msg:expr, $($key:expr => $value:expr),* $(,)?) => {
$crate::log_correlated!("debug", $msg, $($key => $value),*)
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_entry_creation() {
let entry = LogEntry::new("info", "Test message")
.with_field("key1", serde_json::json!("value1"))
.with_field("count", serde_json::json!(42));
assert_eq!(entry.level, "info");
assert_eq!(entry.message, "Test message");
assert_eq!(entry.fields.get("key1"), Some(&serde_json::json!("value1")));
assert_eq!(entry.fields.get("count"), Some(&serde_json::json!(42)));
}
#[tokio::test]
async fn test_correlated_logger() {
let context =
RequestContext::new(RequestId::from(123i64)).with_user_id("user123".to_string());
context
.run(async {
let logger = CorrelatedLogger::new("test_operation");
logger.in_scope(|| {
tracing::info!("Test log message");
});
})
.await;
}
#[test]
fn test_error_details() {
let error = ErrorDetails {
error_type: "ValidationError".to_string(),
message: "Invalid input".to_string(),
stack_trace: Some(vec!["line1".to_string(), "line2".to_string()]),
code: Some("E001".to_string()),
};
let entry = LogEntry::new("error", "Validation failed").with_error(error);
assert!(entry.error.is_some());
let err = entry.error.unwrap();
assert_eq!(err.error_type, "ValidationError");
assert_eq!(err.code, Some("E001".to_string()));
}
}