1use std::fmt;
2
3use rand::RngCore;
4use tokio::task_local;
5use tracing::{Instrument, Level, Span, span};
6
7task_local! {
8 static TRACE_ID: TraceId;
9}
10
11#[derive(
12 Default,
13 Debug,
14 Clone,
15 Copy,
16 PartialEq,
17 Eq,
18 PartialOrd,
19 Ord,
20 Hash,
21 serde::Serialize,
22 serde::Deserialize,
23)]
24pub struct TraceId(u64);
25
26pub trait WithTraceIdExt: Future + Sized {
27 fn with_trace_id(self, trace_id: TraceId) -> impl Future<Output = Self::Output> {
28 trace_id.scope_async(self)
29 }
30}
31
32impl<F> WithTraceIdExt for F where F: Future + Sized {}
33
34impl TraceId {
35 pub fn random() -> Self {
36 rand::rng().next_u64().into()
37 }
38
39 pub fn current() -> Self {
40 TRACE_ID.try_get().ok().unwrap_or_default()
41 }
42
43 pub fn scope_async<F>(self, fut: F) -> impl Future<Output = F::Output>
44 where
45 F: Future,
46 {
47 TRACE_ID.scope(self, fut.instrument(self.span()))
48 }
49
50 pub fn scope_sync<F, R>(self, func: F) -> R
51 where
52 F: FnOnce() -> R,
53 {
54 TRACE_ID.sync_scope(self, || self.span().in_scope(func))
55 }
56
57 pub fn span(self) -> Span {
58 span!(
59 Level::TRACE,
60 "trace-id-async-scope",
61 trace_id = tracing::field::display(self)
62 )
63 }
64}
65
66impl From<u64> for TraceId {
67 fn from(value: u64) -> Self {
68 Self(value)
69 }
70}
71impl From<TraceId> for u64 {
72 fn from(value: TraceId) -> Self {
73 let TraceId(value) = value;
74 value
75 }
76}
77
78impl fmt::Display for TraceId {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "T#{}", self.0)
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use crate::tracing::TraceId;
87
88 #[tokio::test]
89 async fn scopes() {
90 let t_0 = TraceId::current();
91 let t_1 = TraceId::random();
92
93 assert_ne!(t_0, t_1);
94
95 assert_eq!(t_0, t_0.scope_async(async { TraceId::current() }).await);
96 assert_eq!(t_0, t_0.scope_sync(TraceId::current));
97
98 assert_eq!(t_1, t_1.scope_async(async { TraceId::current() }).await);
99 assert_eq!(t_1, t_1.scope_sync(TraceId::current));
100 }
101}