1use std::cell::Cell;
2use std::error::Error;
3use std::fmt::{Display, Formatter, Result as FmtResult};
4use std::ptr;
5use std::sync::OnceLock;
6use std::sync::atomic::{AtomicPtr, Ordering};
7use std::time::{Duration, Instant};
8
9use crate::tracing::TraceContext;
10
11static SPAN_COLLECTOR_MARKER: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
12static SPAN_COLLECTOR_TOKEN: () = ();
13static GLOBAL_SPAN_COLLECTOR: OnceLock<Box<dyn SpanCollector>> = OnceLock::new();
14
15thread_local! {
16 static ACTIVE_TRACE_CONTEXT: Cell<Option<TraceContext>> = const { Cell::new(None) };
17}
18
19#[derive(Debug)]
20pub struct Span {
21 name: String,
22 context: TraceContext,
23 parent: Option<TraceContext>,
24 start: Instant,
25}
26
27impl Span {
28 #[must_use]
29 pub fn new(
30 name: impl Into<String>,
31 context: TraceContext,
32 parent: Option<TraceContext>,
33 ) -> Self {
34 Self {
35 name: name.into(),
36 context,
37 parent,
38 start: Instant::now(),
39 }
40 }
41
42 #[must_use]
43 pub fn name(&self) -> &str {
44 &self.name
45 }
46
47 #[must_use]
48 pub const fn context(&self) -> TraceContext {
49 self.context
50 }
51
52 #[must_use]
53 pub const fn parent(&self) -> Option<TraceContext> {
54 self.parent
55 }
56
57 #[must_use]
58 pub const fn start(&self) -> Instant {
59 self.start
60 }
61
62 #[must_use]
63 pub fn finish(self) -> FinishedSpan {
64 let duration = self.start.elapsed();
65 let finished = FinishedSpan::new(self.name, self.context, self.parent, duration);
66
67 if let Some(collector) = global_span_collector() {
68 collector.on_span(finished.clone());
69 }
70
71 finished
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct FinishedSpan {
77 name: String,
78 context: TraceContext,
79 parent: Option<TraceContext>,
80 duration: Duration,
81}
82
83impl FinishedSpan {
84 const fn new(
85 name: String,
86 context: TraceContext,
87 parent: Option<TraceContext>,
88 duration: Duration,
89 ) -> Self {
90 Self {
91 name,
92 context,
93 parent,
94 duration,
95 }
96 }
97
98 #[must_use]
99 pub fn name(&self) -> &str {
100 &self.name
101 }
102
103 #[must_use]
104 pub const fn context(&self) -> TraceContext {
105 self.context
106 }
107
108 #[must_use]
109 pub const fn parent(&self) -> Option<TraceContext> {
110 self.parent
111 }
112
113 #[must_use]
114 pub const fn duration(&self) -> Duration {
115 self.duration
116 }
117}
118
119#[derive(Debug)]
120pub struct ConversationSpan {
121 span: Span,
122}
123
124impl ConversationSpan {
125 #[must_use]
126 pub fn new(conversation_id: impl Into<String>) -> Self {
127 Self::root(conversation_id)
128 }
129
130 #[must_use]
131 pub fn root(conversation_id: impl Into<String>) -> Self {
132 Self {
133 span: Span::new(conversation_id, TraceContext::new_root(), None),
134 }
135 }
136
137 #[must_use]
138 pub fn child(&self, conversation_id: impl Into<String>) -> Self {
139 Self::with_parent(conversation_id, self.context())
140 }
141
142 #[must_use]
143 pub fn with_parent(conversation_id: impl Into<String>, parent: TraceContext) -> Self {
144 Self {
145 span: Span::new(conversation_id, parent.child(), Some(parent)),
146 }
147 }
148
149 #[must_use]
150 pub fn name(&self) -> &str {
151 self.span.name()
152 }
153
154 #[must_use]
155 pub const fn context(&self) -> TraceContext {
156 self.span.context()
157 }
158
159 #[must_use]
160 pub const fn parent(&self) -> Option<TraceContext> {
161 self.span.parent()
162 }
163
164 #[must_use]
165 pub const fn message_context(&self) -> TraceContext {
166 self.context()
167 }
168
169 #[must_use]
170 pub fn finish(self) -> FinishedSpan {
171 self.span.finish()
172 }
173}
174
175#[derive(Debug)]
176pub struct SpanGuard {
177 span: Option<ConversationSpan>,
178 name: String,
179 context: TraceContext,
180 parent: Option<TraceContext>,
181 previous_context: Option<TraceContext>,
182 context_restored: bool,
183}
184
185impl SpanGuard {
186 #[must_use]
187 pub fn start_conversation(conversation_id: impl Into<String>) -> Self {
188 Self::new(conversation_id)
189 }
190
191 #[must_use]
192 pub fn new(conversation_id: impl Into<String>) -> Self {
193 Self::from_conversation(ConversationSpan::root(conversation_id))
194 }
195
196 #[must_use]
197 pub fn child_conversation(&self, conversation_id: impl Into<String>) -> Self {
198 Self::from_conversation(ConversationSpan::with_parent(conversation_id, self.context))
199 }
200
201 #[must_use]
202 pub fn name(&self) -> &str {
203 &self.name
204 }
205
206 #[must_use]
207 pub const fn context(&self) -> TraceContext {
208 self.context
209 }
210
211 #[must_use]
212 pub const fn parent(&self) -> Option<TraceContext> {
213 self.parent
214 }
215
216 #[must_use]
217 pub const fn message_context(&self) -> TraceContext {
218 self.context
219 }
220
221 #[must_use]
222 pub fn finish(mut self) -> FinishedSpan {
223 let finished = self.finish_active_span();
224 self.restore_context();
225 finished
226 }
227
228 fn from_conversation(conversation: ConversationSpan) -> Self {
229 let name = conversation.name().to_owned();
230 let context = conversation.context();
231 let parent = conversation.parent();
232 let previous_context = replace_current_trace_context(Some(context));
233
234 Self {
235 span: Some(conversation),
236 name,
237 context,
238 parent,
239 previous_context,
240 context_restored: false,
241 }
242 }
243
244 fn finish_active_span(&mut self) -> FinishedSpan {
245 match self.span.take() {
246 Some(span) => span.finish(),
247 None => Span::new(self.name.clone(), self.context, self.parent).finish(),
248 }
249 }
250
251 fn restore_context(&mut self) {
252 if !self.context_restored {
253 replace_current_trace_context(self.previous_context);
254 self.context_restored = true;
255 }
256 }
257}
258
259impl Drop for SpanGuard {
260 fn drop(&mut self) {
261 if self.span.is_some() {
262 drop(self.finish_active_span());
263 }
264 self.restore_context();
265 }
266}
267
268pub trait SpanCollector: std::fmt::Debug + Send + Sync + 'static {
269 fn on_span(&self, span: FinishedSpan);
270}
271
272#[derive(Debug, Clone, Copy, Default)]
273pub struct NoopCollector;
274
275impl SpanCollector for NoopCollector {
276 fn on_span(&self, span: FinishedSpan) {
277 drop(span);
278 }
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq)]
282pub enum SpanCollectorInstallError {
283 AlreadyInstalled,
284}
285
286impl Display for SpanCollectorInstallError {
287 fn fmt(&self, formatter: &mut Formatter<'_>) -> FmtResult {
288 match self {
289 Self::AlreadyInstalled => {
290 formatter.write_str("global span collector is already installed")
291 }
292 }
293 }
294}
295
296impl Error for SpanCollectorInstallError {}
297
298pub fn install_span_collector<Collector>(
302 collector: Collector,
303) -> Result<(), SpanCollectorInstallError>
304where
305 Collector: SpanCollector,
306{
307 install_boxed_span_collector(Box::new(collector))
308}
309
310pub fn install_boxed_span_collector(
314 collector: Box<dyn SpanCollector>,
315) -> Result<(), SpanCollectorInstallError> {
316 match GLOBAL_SPAN_COLLECTOR.set(collector) {
317 Ok(()) => {
318 SPAN_COLLECTOR_MARKER.store(
319 ptr::addr_of!(SPAN_COLLECTOR_TOKEN).cast_mut(),
320 Ordering::Release,
321 );
322 Ok(())
323 }
324 Err(_collector) => Err(SpanCollectorInstallError::AlreadyInstalled),
325 }
326}
327
328#[must_use]
329pub fn span_collector_enabled() -> bool {
330 !SPAN_COLLECTOR_MARKER.load(Ordering::Acquire).is_null()
331}
332
333#[must_use]
334pub fn global_span_collector() -> Option<&'static dyn SpanCollector> {
335 if span_collector_enabled() {
336 GLOBAL_SPAN_COLLECTOR.get().map(Box::as_ref)
337 } else {
338 None
339 }
340}
341
342#[must_use]
343pub fn current_trace_context() -> Option<TraceContext> {
344 ACTIVE_TRACE_CONTEXT.with(Cell::get)
345}
346
347fn replace_current_trace_context(context: Option<TraceContext>) -> Option<TraceContext> {
348 ACTIVE_TRACE_CONTEXT.with(|active| active.replace(context))
349}
350
351#[cfg(test)]
352mod tests {
353 use std::time::{Duration, Instant};
354
355 use super::{
356 ConversationSpan, NoopCollector, Span, SpanCollector, SpanGuard, current_trace_context,
357 };
358 use crate::tracing::TraceContext;
359
360 #[test]
361 fn span_finish_returns_finished_span() {
362 let context = TraceContext::new_root();
363 let parent = Some(TraceContext::new_root());
364 let outer_start = Instant::now();
365 let span = Span::new("conversation-1", context, parent);
366
367 let finished = span.finish();
368 let outer_elapsed = outer_start.elapsed();
369
370 assert_eq!(finished.name(), "conversation-1");
371 assert_eq!(finished.context(), context);
372 assert_eq!(finished.parent(), parent);
373 assert!(finished.duration() <= outer_elapsed);
374 }
375
376 #[test]
377 fn conversation_span_creates_root_and_child_contexts() {
378 let parent = ConversationSpan::new("parent");
379 let child = parent.child("child");
380
381 assert_eq!(parent.name(), "parent");
382 assert_eq!(parent.parent(), None);
383 assert_eq!(parent.message_context(), parent.context());
384 assert_eq!(child.parent(), Some(parent.context()));
385 assert_eq!(child.context().trace_id(), parent.context().trace_id());
386 assert_ne!(child.context().span_id(), parent.context().span_id());
387 assert_eq!(child.message_context(), child.context());
388 }
389
390 #[test]
391 fn span_guard_sets_and_restores_current_context() {
392 assert_eq!(current_trace_context(), None);
393
394 let guard = SpanGuard::start_conversation("root");
395 assert_eq!(current_trace_context(), Some(guard.context()));
396
397 {
398 let child = guard.child_conversation("child");
399 assert_eq!(child.parent(), Some(guard.context()));
400 assert_eq!(current_trace_context(), Some(child.context()));
401 }
402
403 assert_eq!(current_trace_context(), Some(guard.context()));
404 drop(guard);
405 assert_eq!(current_trace_context(), None);
406 }
407
408 #[test]
409 fn noop_collector_discards_spans() {
410 let collector = NoopCollector;
411 let span = Span::new("discard", TraceContext::new_root(), None).finish();
412
413 collector.on_span(span);
414 }
415
416 #[test]
417 fn finished_span_is_clone_debug() {
418 fn assert_clone_debug<T: Clone + std::fmt::Debug>() {}
419
420 assert_clone_debug::<super::FinishedSpan>();
421 }
422
423 #[test]
424 fn span_start_is_now() {
425 let before = Instant::now();
426 let span = Span::new("timed", TraceContext::new_root(), None);
427 let after = Instant::now();
428
429 assert!(span.start() >= before);
430 assert!(span.start() <= after + Duration::from_millis(1));
431 }
432}