use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use tracing::field::{Field, Visit};
use tracing::span::{Attributes, Record};
use tracing::{Event, Id, Level, Metadata, Subscriber};
use super::trace::{TraceRecord, TraceSnapshot};
pub struct ObservingSubscriber {
recorder: TraceRecorder,
}
impl ObservingSubscriber {
#[inline]
pub fn new() -> (Self, TraceRecorder) {
let recorder = TraceRecorder::new();
(
Self {
recorder: recorder.clone(),
},
recorder,
)
}
}
impl Subscriber for ObservingSubscriber {
fn enabled(&self, metadata: &Metadata<'_>) -> bool {
metadata.level() <= &Level::INFO
}
fn new_span(&self, attrs: &Attributes<'_>) -> Id {
self.recorder.new_span(attrs)
}
fn record(&self, span: &Id, values: &Record<'_>) {
self.recorder.record(span, values);
}
fn record_follows_from(&self, _span: &Id, _follows: &Id) {}
fn event(&self, _event: &Event<'_>) {}
fn enter(&self, span: &Id) {
self.recorder.enter(span);
}
fn exit(&self, span: &Id) {
self.recorder.exit(span);
}
}
#[derive(Clone)]
pub struct TraceRecorder {
state: Arc<Mutex<RecorderState>>,
}
impl TraceRecorder {
fn new() -> Self {
Self {
state: Arc::new(Mutex::new(RecorderState {
next_id: 1,
base: Instant::now(),
spans: HashMap::new(),
stacks: HashMap::new(),
thread_numbers: HashMap::new(),
next_thread_number: 1,
completed: Vec::new(),
})),
}
}
#[inline]
pub fn snapshot(&self) -> TraceSnapshot {
match self.state.lock() {
Ok(state) => TraceSnapshot {
records: state.completed.clone(),
},
Err(_) => TraceSnapshot::default(),
}
}
fn new_span(&self, attrs: &Attributes<'_>) -> Id {
let mut fields = FieldMap::default();
attrs.record(&mut fields);
let Ok(mut state) = self.state.lock() else {
return Id::from_u64(0);
};
let id = state.next_id;
state.next_id = state.next_id.saturating_add(1);
state.spans.insert(
id,
SpanState {
name: attrs.metadata().name().to_string(),
layer_id: fields.layer_id,
op_id: fields.op_id,
verdict: fields.verdict,
start_us: None,
duration_us: None,
thread_id: None,
},
);
Id::from_u64(id)
}
fn record(&self, span: &Id, values: &Record<'_>) {
let Ok(mut state) = self.state.lock() else {
return;
};
let Some(span_state) = state.spans.get_mut(&span.into_u64()) else {
return;
};
let mut fields = FieldMap::default();
values.record(&mut fields);
if fields.layer_id.is_some() {
span_state.layer_id = fields.layer_id;
}
if fields.op_id.is_some() {
span_state.op_id = fields.op_id;
}
if fields.verdict.is_some() {
span_state.verdict = fields.verdict;
}
}
fn enter(&self, span: &Id) {
let thread_id = std::thread::current().id();
let Ok(mut state) = self.state.lock() else {
return;
};
let thread_number = state.thread_number(thread_id);
let start_us = state.base.elapsed().as_micros();
if let Some(span_state) = state.spans.get_mut(&span.into_u64()) {
if span_state.start_us.is_none() {
span_state.start_us = Some(start_us);
span_state.thread_id = Some(thread_number);
}
}
state
.stacks
.entry(thread_id)
.or_default()
.push(span.into_u64());
}
fn exit(&self, span: &Id) {
let thread_id = std::thread::current().id();
let Ok(mut state) = self.state.lock() else {
return;
};
let end_us = state.base.elapsed().as_micros();
let raw_id = span.into_u64();
let stack_ids = state.stacks.get(&thread_id).cloned().unwrap_or_default();
let stack = stack_ids
.iter()
.filter_map(|id| state.spans.get(id).map(span_stack_name))
.collect::<Vec<_>>();
let record = if let Some(span_state) = state.spans.get_mut(&raw_id) {
let start_us = span_state.start_us.unwrap_or(end_us);
let duration_us = end_us.saturating_sub(start_us);
span_state.duration_us = Some(duration_us);
Some(TraceRecord {
name: span_state.name.clone(),
layer_id: span_state.layer_id.clone(),
op_id: span_state.op_id.clone(),
verdict: span_state.verdict.clone(),
start_us,
duration_us,
stack,
thread_id: span_state.thread_id.unwrap_or(0),
})
} else {
None
};
if let Some(record) = record {
state.completed.push(record);
}
if let Some(stack) = state.stacks.get_mut(&thread_id) {
if stack.last().copied() == Some(raw_id) {
stack.pop();
} else if let Some(position) = stack.iter().rposition(|id| *id == raw_id) {
stack.remove(position);
}
}
}
}
struct RecorderState {
next_id: u64,
base: Instant,
spans: HashMap<u64, SpanState>,
stacks: HashMap<std::thread::ThreadId, Vec<u64>>,
thread_numbers: HashMap<std::thread::ThreadId, u64>,
next_thread_number: u64,
completed: Vec<TraceRecord>,
}
impl RecorderState {
fn thread_number(&mut self, thread_id: std::thread::ThreadId) -> u64 {
if let Some(number) = self.thread_numbers.get(&thread_id) {
return *number;
}
let number = self.next_thread_number;
self.next_thread_number = self.next_thread_number.saturating_add(1);
self.thread_numbers.insert(thread_id, number);
number
}
}
struct SpanState {
name: String,
layer_id: Option<String>,
op_id: Option<String>,
verdict: Option<String>,
start_us: Option<u128>,
duration_us: Option<u128>,
thread_id: Option<u64>,
}
#[derive(Default)]
struct FieldMap {
layer_id: Option<String>,
op_id: Option<String>,
verdict: Option<String>,
}
impl Visit for FieldMap {
fn record_debug(&mut self, field: &Field, value: &dyn core::fmt::Debug) {
self.record_value(field, format!("{value:?}"));
}
fn record_str(&mut self, field: &Field, value: &str) {
self.record_value(field, value.to_string());
}
fn record_bool(&mut self, field: &Field, value: bool) {
self.record_value(field, value.to_string());
}
fn record_i64(&mut self, field: &Field, value: i64) {
self.record_value(field, value.to_string());
}
fn record_u64(&mut self, field: &Field, value: u64) {
self.record_value(field, value.to_string());
}
}
impl FieldMap {
fn record_value(&mut self, field: &Field, value: String) {
match field.name() {
"layer_id" => self.layer_id = Some(value),
"op_id" => self.op_id = Some(value),
"verdict" => self.verdict = Some(value),
_ => {}
}
}
}
fn span_stack_name(span: &SpanState) -> String {
match (&span.layer_id, &span.op_id) {
(Some(layer), Some(op)) => format!("{layer}:{op}"),
(Some(layer), None) => layer.clone(),
(None, Some(op)) => op.clone(),
(None, None) => span.name.clone(),
}
}