1use std::str::FromStr;
2use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum LogFormat {
7 Json,
8 Text,
9}
10
11impl FromStr for LogFormat {
12 type Err = std::convert::Infallible;
13
14 fn from_str(s: &str) -> Result<Self, Self::Err> {
15 match s.to_lowercase().as_str() {
16 "json" => Ok(LogFormat::Json),
17 _ => Ok(LogFormat::Text),
18 }
19 }
20}
21
22pub fn init_logging(format: LogFormat, trace_id: Option<String>) {
24 let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
25
26 match format {
27 LogFormat::Json => {
28 let fmt_layer = fmt::layer()
29 .json()
30 .with_current_span(true)
31 .with_span_list(false)
32 .with_writer(std::io::stderr);
33
34 tracing_subscriber::registry()
35 .with(filter)
36 .with(fmt_layer)
37 .init();
38 }
39 LogFormat::Text => {
40 let fmt_layer = fmt::layer().with_writer(std::io::stderr).with_target(false);
41
42 tracing_subscriber::registry()
43 .with(filter)
44 .with(fmt_layer)
45 .init();
46 }
47 }
48
49 if let Some(tid) = trace_id {
51 tracing::info!(traceId = %tid, "Request started");
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use std::str::FromStr;
59
60 #[test]
61 fn test_log_format_from_str() {
62 assert_eq!(LogFormat::from_str("json").unwrap(), LogFormat::Json);
63 assert_eq!(LogFormat::from_str("JSON").unwrap(), LogFormat::Json);
64 assert_eq!(LogFormat::from_str("text").unwrap(), LogFormat::Text);
65 assert_eq!(LogFormat::from_str("anything").unwrap(), LogFormat::Text);
66 }
67}