mm1_core/
tracing.rs

1use std::fmt;
2
3use rand::RngCore;
4use tokio::task_local;
5use tracing::{Instrument, Level, Span, span};
6
7task_local! {
8    static TRACE_ID: TraceId;
9}
10
11#[derive(
12    Default,
13    Debug,
14    Clone,
15    Copy,
16    PartialEq,
17    Eq,
18    PartialOrd,
19    Ord,
20    Hash,
21    serde::Serialize,
22    serde::Deserialize,
23)]
24pub struct TraceId(u64);
25
26pub trait WithTraceIdExt: Future + Sized {
27    fn with_trace_id(self, trace_id: TraceId) -> impl Future<Output = Self::Output> {
28        trace_id.scope_async(self)
29    }
30}
31
32impl<F> WithTraceIdExt for F where F: Future + Sized {}
33
34impl TraceId {
35    pub fn random() -> Self {
36        rand::rng().next_u64().into()
37    }
38
39    pub fn current() -> Self {
40        TRACE_ID.try_get().ok().unwrap_or_default()
41    }
42
43    pub fn scope_async<F>(self, fut: F) -> impl Future<Output = F::Output>
44    where
45        F: Future,
46    {
47        TRACE_ID.scope(self, fut.instrument(self.span()))
48    }
49
50    pub fn scope_sync<F, R>(self, func: F) -> R
51    where
52        F: FnOnce() -> R,
53    {
54        TRACE_ID.sync_scope(self, || self.span().in_scope(func))
55    }
56
57    pub fn span(self) -> Span {
58        span!(
59            Level::TRACE,
60            "trace-id-async-scope",
61            trace_id = tracing::field::display(self)
62        )
63    }
64}
65
66impl From<u64> for TraceId {
67    fn from(value: u64) -> Self {
68        Self(value)
69    }
70}
71impl From<TraceId> for u64 {
72    fn from(value: TraceId) -> Self {
73        let TraceId(value) = value;
74        value
75    }
76}
77
78impl fmt::Display for TraceId {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        write!(f, "T#{}", self.0)
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use crate::tracing::TraceId;
87
88    #[tokio::test]
89    async fn scopes() {
90        let t_0 = TraceId::current();
91        let t_1 = TraceId::random();
92
93        assert_ne!(t_0, t_1);
94
95        assert_eq!(t_0, t_0.scope_async(async { TraceId::current() }).await);
96        assert_eq!(t_0, t_0.scope_sync(TraceId::current));
97
98        assert_eq!(t_1, t_1.scope_async(async { TraceId::current() }).await);
99        assert_eq!(t_1, t_1.scope_sync(TraceId::current));
100    }
101}