use std::{
collections::HashMap,
sync::{Arc, Mutex, MutexGuard},
thread::{self, ThreadId},
};
#[derive(Clone, Debug)]
pub struct TaggedLogging {
subscriber: tracing::Dispatch,
tag_stacks: Arc<Mutex<HashMap<ThreadId, Vec<String>>>>,
}
impl TaggedLogging {
#[must_use]
pub fn new(subscriber: tracing::Dispatch) -> Self {
Self {
subscriber,
tag_stacks: Arc::new(Mutex::new(HashMap::new())),
}
}
#[must_use]
pub fn subscriber(&self) -> &tracing::Dispatch {
&self.subscriber
}
pub fn push_tags<I, S>(&self, tags: I)
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let tags = normalize_tags(tags);
if tags.is_empty() {
return;
}
self.with_current_stack(|stack| stack.extend(tags));
}
pub fn pop_tags(&self) -> Option<String> {
self.with_current_stack(Vec::pop)
}
pub fn clear_tags(&self) {
self.tag_stacks().remove(&thread::current().id());
}
pub fn tagged<F, R, I, S>(&self, tags: I, f: F) -> R
where
F: FnOnce() -> R,
I: IntoIterator<Item = S>,
S: Into<String>,
{
let tags = normalize_tags(tags);
let thread_id = thread::current().id();
let previous_len = self.with_current_stack(|stack| {
let previous_len = stack.len();
stack.extend(tags);
previous_len
});
let _guard = TagScope {
tag_stacks: Arc::clone(&self.tag_stacks),
thread_id,
previous_len,
};
f()
}
#[must_use]
pub fn format_message(&self, message: &str) -> String {
let prefix = self
.tag_stacks()
.get(&thread::current().id())
.into_iter()
.flat_map(|stack| stack.iter())
.map(|tag| format!("[{tag}]"))
.collect::<Vec<_>>()
.join(" ");
if prefix.is_empty() {
message.to_owned()
} else {
format!("{prefix} {message}")
}
}
fn with_current_stack<R>(&self, f: impl FnOnce(&mut Vec<String>) -> R) -> R {
let thread_id = thread::current().id();
let mut tag_stacks = self.tag_stacks();
let result = {
let stack = tag_stacks.entry(thread_id).or_default();
f(stack)
};
if tag_stacks.get(&thread_id).is_some_and(Vec::is_empty) {
tag_stacks.remove(&thread_id);
}
result
}
fn tag_stacks(&self) -> MutexGuard<'_, HashMap<ThreadId, Vec<String>>> {
lock_tag_stacks(&self.tag_stacks)
}
}
struct TagScope {
tag_stacks: Arc<Mutex<HashMap<ThreadId, Vec<String>>>>,
thread_id: ThreadId,
previous_len: usize,
}
impl Drop for TagScope {
fn drop(&mut self) {
let mut tag_stacks = lock_tag_stacks(&self.tag_stacks);
if let Some(stack) = tag_stacks.get_mut(&self.thread_id) {
stack.truncate(self.previous_len);
if stack.is_empty() {
tag_stacks.remove(&self.thread_id);
}
}
}
}
fn normalize_tags<I, S>(tags: I) -> Vec<String>
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
tags.into_iter()
.map(Into::into)
.filter(|tag| !tag.trim().is_empty())
.collect()
}
fn lock_tag_stacks(
tag_stacks: &Mutex<HashMap<ThreadId, Vec<String>>>,
) -> MutexGuard<'_, HashMap<ThreadId, Vec<String>>> {
tag_stacks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[cfg(test)]
mod tests {
use super::TaggedLogging;
use std::thread;
fn logger() -> TaggedLogging {
TaggedLogging::new(tracing::Dispatch::none())
}
fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
where
R: Send + 'static,
{
match thread::spawn(test).join() {
Ok(result) => result,
Err(payload) => std::panic::resume_unwind(payload),
}
}
#[test]
fn format_message_returns_plain_message_without_tags() {
run_isolated(|| {
let logger = logger();
assert_eq!(logger.format_message("hello"), "hello");
});
}
#[test]
fn push_tags_prepends_tags_in_order() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Request", "User"]);
assert_eq!(logger.format_message("hello"), "[Request] [User] hello");
});
}
#[test]
fn push_tags_ignores_empty_and_whitespace_tags() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Request", "", " ", "User"]);
assert_eq!(logger.format_message("hello"), "[Request] [User] hello");
assert_eq!(logger.pop_tags(), Some(String::from("User")));
assert_eq!(logger.pop_tags(), Some(String::from("Request")));
assert_eq!(logger.pop_tags(), None);
});
}
#[test]
fn pop_tags_removes_the_last_tag() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Request", "User"]);
assert_eq!(logger.pop_tags(), Some(String::from("User")));
assert_eq!(logger.format_message("hello"), "[Request] hello");
});
}
#[test]
fn clear_tags_removes_all_tags() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Request", "User"]);
logger.clear_tags();
assert_eq!(logger.format_message("hello"), "hello");
});
}
#[test]
fn tagged_scopes_tags_temporarily() {
run_isolated(|| {
let logger = logger();
let formatted = logger.tagged(["Request"], || logger.format_message("hello"));
assert_eq!(formatted, "[Request] hello");
assert_eq!(logger.format_message("after"), "after");
});
}
#[test]
fn tagged_ignores_empty_and_whitespace_tags() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Outer"]);
logger.tagged(["", " ", "Inner"], || {
assert_eq!(logger.format_message("hello"), "[Outer] [Inner] hello");
});
assert_eq!(logger.format_message("after"), "[Outer] after");
});
}
#[test]
fn tagged_with_only_empty_tags_is_a_no_op() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Outer"]);
logger.tagged(["", " "], || {
assert_eq!(logger.format_message("hello"), "[Outer] hello");
});
assert_eq!(logger.format_message("after"), "[Outer] after");
});
}
#[test]
fn tagged_restores_previous_tags_after_scope() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Outer"]);
logger.tagged(["Inner"], || {
assert_eq!(logger.format_message("hello"), "[Outer] [Inner] hello");
});
assert_eq!(logger.format_message("after"), "[Outer] after");
});
}
#[test]
fn tagged_can_be_nested() {
run_isolated(|| {
let logger = logger();
logger.tagged(["Outer"], || {
logger.tagged(["Inner"], || {
assert_eq!(logger.format_message("hello"), "[Outer] [Inner] hello");
});
assert_eq!(logger.format_message("middle"), "[Outer] middle");
});
assert_eq!(logger.format_message("after"), "after");
});
}
#[test]
fn tag_stack_is_isolated_per_thread() {
let tagged_logger = logger();
tagged_logger.push_tags(["Main"]);
let child_message = run_isolated(|| {
let logger = logger();
logger.push_tags(["Child"]);
logger.format_message("hello")
});
assert_eq!(child_message, "[Child] hello");
assert_eq!(tagged_logger.format_message("hello"), "[Main] hello");
}
#[test]
fn same_logger_keeps_tags_isolated_per_thread() {
let logger = logger();
logger.push_tags(["Main"]);
let child_message = {
let logger = logger.clone();
run_isolated(move || {
assert_eq!(logger.format_message("before"), "before");
logger.push_tags(["Child"]);
let during = logger.format_message("hello");
logger.clear_tags();
let after = logger.format_message("after");
(during, after)
})
};
assert_eq!(
child_message,
(String::from("[Child] hello"), String::from("after"))
);
assert_eq!(logger.format_message("hello"), "[Main] hello");
}
#[test]
fn tags_are_isolated_per_instance_on_the_same_thread() {
run_isolated(|| {
let first = logger();
let second = logger();
first.push_tags(["First"]);
assert_eq!(first.format_message("hello"), "[First] hello");
assert_eq!(second.format_message("hello"), "hello");
});
}
#[test]
fn clear_tags_only_affects_the_current_instance() {
run_isolated(|| {
let first = logger();
let second = logger();
first.push_tags(["First"]);
second.push_tags(["Second"]);
first.clear_tags();
assert_eq!(first.format_message("hello"), "hello");
assert_eq!(second.format_message("hello"), "[Second] hello");
});
}
#[test]
fn nested_tagged_scopes_restore_outer_tags_in_lifo_order() {
run_isolated(|| {
let logger = logger();
logger.tagged(["Outer"], || {
assert_eq!(logger.format_message("start"), "[Outer] start");
logger.tagged(["Middle"], || {
assert_eq!(logger.format_message("middle"), "[Outer] [Middle] middle");
logger.tagged(["Inner"], || {
assert_eq!(
logger.format_message("inner"),
"[Outer] [Middle] [Inner] inner"
);
});
assert_eq!(
logger.format_message("after inner"),
"[Outer] [Middle] after inner"
);
});
assert_eq!(
logger.format_message("after middle"),
"[Outer] after middle"
);
});
assert_eq!(logger.format_message("after all"), "after all");
});
}
#[test]
fn pop_tags_returns_none_when_stack_is_empty() {
run_isolated(|| {
let logger = logger();
assert_eq!(logger.pop_tags(), None);
});
}
#[test]
fn tagged_restores_state_after_panic() {
run_isolated(|| {
let logger = logger();
logger.push_tags(["Outer"]);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
logger.tagged(["Inner"], || panic!("boom"));
}));
assert!(result.is_err());
assert_eq!(logger.format_message("hello"), "[Outer] hello");
});
}
}