use std::collections::BTreeMap;
use std::future::Future;
use std::sync::{Arc, RwLock};
use serde::Serialize;
use tracing::Instrument as _;
use crate::log::filter::{FILTERED_PLACEHOLDER, ParameterFilter};
tokio::task_local! {
static CURRENT: LogContext;
}
const RESERVED_FIELD_KEYS: [&str; 3] = ["request_id", "user_id", "tenant_id"];
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub struct LogFields {
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
#[serde(flatten)]
pub fields: BTreeMap<String, String>,
}
impl LogFields {
#[must_use]
pub fn is_empty(&self) -> bool {
self.request_id.is_none()
&& self.user_id.is_none()
&& self.tenant_id.is_none()
&& self.fields.is_empty()
}
}
#[derive(Clone)]
pub struct LogContext {
inner: Arc<RwLock<Inner>>,
filter: Arc<ParameterFilter>,
span: tracing::Span,
}
#[derive(Default)]
struct Inner {
request_id: Option<String>,
user_id: Option<String>,
tenant_id: Option<String>,
fields: BTreeMap<String, String>,
}
impl LogContext {
#[must_use]
pub fn new(request_id: Option<String>) -> Self {
Self::with_filter(request_id, Arc::new(ParameterFilter::default()))
}
#[must_use]
pub fn with_filter(request_id: Option<String>, filter: Arc<ParameterFilter>) -> Self {
Self {
inner: Arc::new(RwLock::new(Inner {
request_id,
..Inner::default()
})),
filter,
span: tracing::Span::none(),
}
}
#[must_use]
pub fn with_span(mut self, span: tracing::Span) -> Self {
self.span = span;
self
}
pub fn set_user_id(&self, user_id: impl Into<String>) {
let user_id = user_id.into();
self.span
.record("user_id", tracing::field::display(&user_id));
if let Ok(mut guard) = self.inner.write() {
guard.user_id = Some(user_id);
}
}
pub fn set_tenant_id(&self, tenant_id: impl Into<String>) {
let tenant_id = tenant_id.into();
self.span
.record("tenant_id", tracing::field::display(&tenant_id));
if let Ok(mut guard) = self.inner.write() {
guard.tenant_id = Some(tenant_id);
}
}
pub fn insert_field(&self, key: impl Into<String>, value: impl Into<String>) {
let key = key.into();
if RESERVED_FIELD_KEYS.contains(&key.as_str()) {
return;
}
let value = if self.filter.matches_key(&key) {
FILTERED_PLACEHOLDER.to_owned()
} else {
value.into()
};
if let Ok(mut guard) = self.inner.write() {
guard.fields.insert(key, value);
}
}
#[must_use]
pub fn span(&self) -> tracing::Span {
self.span.clone()
}
#[must_use]
pub fn snapshot(&self) -> LogFields {
self.inner.read().map_or_else(
|_| LogFields::default(),
|guard| LogFields {
request_id: guard.request_id.clone(),
user_id: guard.user_id.clone(),
tenant_id: guard.tenant_id.clone(),
fields: guard.fields.clone(),
},
)
}
}
impl std::fmt::Debug for LogContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LogContext")
.field("fields", &self.snapshot())
.finish()
}
}
pub async fn scope<F: Future>(ctx: LogContext, future: F) -> F::Output {
CURRENT.scope(ctx, future).await
}
pub fn scoped<F: Future>(
ctx: LogContext,
future: F,
) -> tokio::task::futures::TaskLocalFuture<LogContext, F> {
CURRENT.scope(ctx, future)
}
pub fn sync_scope<R>(ctx: LogContext, f: impl FnOnce() -> R) -> R {
CURRENT.sync_scope(ctx, f)
}
#[must_use]
pub fn current() -> Option<LogContext> {
CURRENT.try_with(Clone::clone).ok()
}
#[must_use]
pub fn snapshot() -> Option<LogFields> {
current().map(|ctx| ctx.snapshot())
}
pub fn with_log_field(key: impl Into<String>, value: impl Into<String>) {
if let Some(ctx) = current() {
ctx.insert_field(key, value);
}
}
pub fn set_user_id(user_id: impl Into<String>) {
if let Some(ctx) = current() {
ctx.set_user_id(user_id);
}
}
pub fn set_tenant_id(tenant_id: impl Into<String>) {
if let Some(ctx) = current() {
ctx.set_tenant_id(tenant_id);
}
}
pub fn in_current_context<F: Future>(future: F) -> impl Future<Output = F::Output> {
let ctx = current();
async move {
match ctx {
Some(ctx) => {
let span = ctx.span.clone();
scope(ctx, future.instrument(span)).await
}
None => future.await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn seeds_request_id_and_exposes_it_via_current() {
let ctx = LogContext::new(Some("req-123".to_owned()));
scope(ctx, async {
let snap = snapshot().expect("context should be installed");
assert_eq!(snap.request_id.as_deref(), Some("req-123"));
assert_eq!(snap.user_id, None);
})
.await;
}
#[tokio::test]
async fn user_and_tenant_are_added_to_current_context() {
let ctx = LogContext::new(Some("req-1".to_owned()));
scope(ctx, async {
set_user_id("42");
set_tenant_id("acme");
let snap = snapshot().unwrap();
assert_eq!(snap.user_id.as_deref(), Some("42"));
assert_eq!(snap.tenant_id.as_deref(), Some("acme"));
})
.await;
}
#[tokio::test]
async fn custom_fields_appear_on_subsequent_snapshots() {
let ctx = LogContext::new(Some("req-1".to_owned()));
scope(ctx, async {
with_log_field("order_id", "A-1001");
let snap = snapshot().unwrap();
assert_eq!(
snap.fields.get("order_id").map(String::as_str),
Some("A-1001")
);
})
.await;
}
#[tokio::test]
async fn sensitive_custom_fields_are_scrubbed() {
let ctx = LogContext::new(Some("req-1".to_owned()));
scope(ctx, async {
with_log_field("password", "hunter2");
with_log_field("order_id", "ok");
let snap = snapshot().unwrap();
assert_eq!(
snap.fields.get("password").map(String::as_str),
Some(FILTERED_PLACEHOLDER)
);
assert_eq!(snap.fields.get("order_id").map(String::as_str), Some("ok"));
})
.await;
}
#[tokio::test]
async fn custom_fields_cannot_shadow_core_correlation_ids() {
let ctx = LogContext::new(Some("real-req".to_owned()));
scope(ctx, async {
set_user_id("real-user");
with_log_field("request_id", "spoofed");
with_log_field("user_id", "spoofed");
with_log_field("tenant_id", "spoofed");
with_log_field("order_id", "kept");
let snap = snapshot().unwrap();
assert_eq!(snap.request_id.as_deref(), Some("real-req"));
assert_eq!(snap.user_id.as_deref(), Some("real-user"));
assert!(!snap.fields.contains_key("request_id"));
assert!(!snap.fields.contains_key("user_id"));
assert!(!snap.fields.contains_key("tenant_id"));
assert_eq!(
snap.fields.get("order_id").map(String::as_str),
Some("kept")
);
let v = serde_json::to_value(&snap).unwrap();
assert_eq!(v["request_id"], "real-req");
})
.await;
}
#[tokio::test]
async fn no_context_outside_a_request() {
assert!(current().is_none());
assert!(snapshot().is_none());
with_log_field("k", "v"); set_user_id("u"); }
#[tokio::test]
async fn contexts_do_not_leak_between_requests() {
let first = LogContext::new(Some("req-A".to_owned()));
scope(first, async {
with_log_field("k", "from-A");
})
.await;
let second = LogContext::new(Some("req-B".to_owned()));
scope(second, async {
let snap = snapshot().unwrap();
assert_eq!(snap.request_id.as_deref(), Some("req-B"));
assert!(
snap.fields.is_empty(),
"fields from request A leaked into B"
);
})
.await;
}
#[tokio::test]
async fn spawned_task_does_not_inherit_context_unless_propagated() {
let ctx = LogContext::new(Some("req-1".to_owned()));
scope(ctx, async {
let bare = tokio::spawn(async { current().is_some() }).await.unwrap();
assert!(!bare, "spawned task silently inherited request context");
let propagated = tokio::spawn(in_current_context(async {
snapshot().and_then(|s| s.request_id)
}))
.await
.unwrap();
assert_eq!(propagated.as_deref(), Some("req-1"));
})
.await;
}
#[test]
fn log_fields_serialize_flat() {
let mut fields = BTreeMap::new();
fields.insert("order_id".to_owned(), "A-1".to_owned());
let f = LogFields {
request_id: Some("r".to_owned()),
user_id: Some("42".to_owned()),
tenant_id: None,
fields,
};
let v = serde_json::to_value(&f).unwrap();
assert_eq!(v["request_id"], "r");
assert_eq!(v["user_id"], "42");
assert_eq!(v["order_id"], "A-1");
assert!(v.get("tenant_id").is_none());
}
}