use std::{env, io, str::FromStr};
use anyhow::anyhow;
#[cfg(doc)]
use lexe_api::trace::TraceId;
use lexe_api::{define_trace_id_fns, trace};
use tracing::{Level, level_filters::LevelFilter};
use tracing_subscriber::{
Registry,
filter::{Filtered, Targets},
fmt::{
Layer as FmtLayer,
format::{Compact, DefaultFields, Format},
},
layer::{Layer as LayerTrait, Layered, SubscriberExt},
util::SubscriberInitExt,
};
pub fn init() {
try_init().expect("Failed to setup logger");
}
pub fn init_with_default(rust_log_default: &str) {
try_init_with_default(rust_log_default).expect("Failed to set up logger")
}
#[cfg(any(test, feature = "test-utils"))]
pub fn init_for_testing() {
let _ = try_init();
}
pub fn try_init() -> anyhow::Result<()> {
try_init_with_default("off")
}
pub fn try_init_with_default(rust_log_default: &str) -> anyhow::Result<()> {
let rust_log_env = env::var("RUST_LOG");
let rust_log = rust_log_env.as_deref().unwrap_or(rust_log_default);
subscriber(rust_log)
.try_init()
.context("Logger already initialized")?;
define_trace_id_fns!(SubscriberType);
trace::GET_TRACE_ID_FN
.set(get_trace_id_from_span)
.map_err(|_| anyhow!("GET_TRACE_ID_FN already set"))?;
trace::INSERT_TRACE_ID_FN
.set(insert_trace_id_into_span)
.map_err(|_| anyhow!("INSERT_TRACE_ID_FN already set"))?;
Ok(())
}
type SubscriberType = Layered<
Filtered<
FmtLayer<Registry, DefaultFields, Format<Compact>, fn() -> io::Stderr>,
Targets,
Registry,
>,
Registry,
>;
fn subscriber(rust_log: &str) -> SubscriberType {
let targets = Targets::from_str(rust_log)
.inspect_err(|e| eprintln!("Invalid RUST_LOG; using INFO: {e}"))
.unwrap_or_else(|_| Targets::new().with_default(Level::INFO));
let clamped_targets =
if cfg!(any(test, debug_assertions, feature = "test-utils")) {
targets
} else {
enforce_log_policy(targets)
};
let stderr_log = tracing_subscriber::fmt::Layer::default()
.compact()
.with_level(true)
.with_target(true)
.with_writer(io::stderr as fn() -> io::Stderr)
.with_ansi(true)
.with_filter(clamped_targets);
tracing_subscriber::registry().with(stderr_log)
}
fn enforce_log_policy(targets: Targets) -> Targets {
fn clamp_level(level: LevelFilter) -> LevelFilter {
if level == LevelFilter::TRACE {
LevelFilter::DEBUG
} else {
level
}
}
let clamped_default = match targets.default_level() {
Some(level) => clamp_level(level),
None => LevelFilter::INFO,
};
let targets = targets
.into_iter()
.map(|(target, level)| (target, clamp_level(level)))
.collect::<Targets>();
targets.with_default(clamped_default)
}
#[cfg(test)]
mod test {
use std::env;
use lexe_api::trace::TraceId;
use super::*;
#[test]
fn get_and_insert_trace_ids() {
match env::var("RUST_LOG").ok() {
Some(v) if v.starts_with("off") => return,
Some(_) => (),
None => return,
}
init_for_testing();
TraceId::get_and_insert_test_impl();
}
}