vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Minimal tracing subscriber that records span enter/exit timing.

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;

use tracing::field::{Field, Visit};
use tracing::span::{Attributes, Record};
use tracing::{Event, Id, Level, Metadata, Subscriber};

use super::trace::{TraceRecord, TraceSnapshot};

/// Explicit scoped subscriber for span-tree observation.
pub struct ObservingSubscriber {
    recorder: TraceRecorder,
}

impl ObservingSubscriber {
    /// Create a subscriber and a recorder handle for reading captured spans.
    #[inline]
    pub fn new() -> (Self, TraceRecorder) {
        let recorder = TraceRecorder::new();
        (
            Self {
                recorder: recorder.clone(),
            },
            recorder,
        )
    }
}

impl Subscriber for ObservingSubscriber {
    fn enabled(&self, metadata: &Metadata<'_>) -> bool {
        metadata.level() <= &Level::INFO
    }

    fn new_span(&self, attrs: &Attributes<'_>) -> Id {
        self.recorder.new_span(attrs)
    }

    fn record(&self, span: &Id, values: &Record<'_>) {
        self.recorder.record(span, values);
    }

    fn record_follows_from(&self, _span: &Id, _follows: &Id) {}

    fn event(&self, _event: &Event<'_>) {}

    fn enter(&self, span: &Id) {
        self.recorder.enter(span);
    }

    fn exit(&self, span: &Id) {
        self.recorder.exit(span);
    }
}

/// Shared recorder state for a scoped subscriber.
#[derive(Clone)]
pub struct TraceRecorder {
    state: Arc<Mutex<RecorderState>>,
}

impl TraceRecorder {
    fn new() -> Self {
        Self {
            state: Arc::new(Mutex::new(RecorderState {
                next_id: 1,
                base: Instant::now(),
                spans: HashMap::new(),
                stacks: HashMap::new(),
                thread_numbers: HashMap::new(),
                next_thread_number: 1,
                completed: Vec::new(),
            })),
        }
    }

    /// Return a stable snapshot of completed spans.
    #[inline]
    pub fn snapshot(&self) -> TraceSnapshot {
        match self.state.lock() {
            Ok(state) => TraceSnapshot {
                records: state.completed.clone(),
            },
            Err(_) => TraceSnapshot::default(),
        }
    }

    fn new_span(&self, attrs: &Attributes<'_>) -> Id {
        let mut fields = FieldMap::default();
        attrs.record(&mut fields);
        let Ok(mut state) = self.state.lock() else {
            return Id::from_u64(0);
        };
        let id = state.next_id;
        state.next_id = state.next_id.saturating_add(1);
        state.spans.insert(
            id,
            SpanState {
                name: attrs.metadata().name().to_string(),
                layer_id: fields.layer_id,
                op_id: fields.op_id,
                verdict: fields.verdict,
                start_us: None,
                duration_us: None,
                thread_id: None,
            },
        );
        Id::from_u64(id)
    }

    fn record(&self, span: &Id, values: &Record<'_>) {
        let Ok(mut state) = self.state.lock() else {
            return;
        };
        let Some(span_state) = state.spans.get_mut(&span.into_u64()) else {
            return;
        };
        let mut fields = FieldMap::default();
        values.record(&mut fields);
        if fields.layer_id.is_some() {
            span_state.layer_id = fields.layer_id;
        }
        if fields.op_id.is_some() {
            span_state.op_id = fields.op_id;
        }
        if fields.verdict.is_some() {
            span_state.verdict = fields.verdict;
        }
    }

    fn enter(&self, span: &Id) {
        let thread_id = std::thread::current().id();
        let Ok(mut state) = self.state.lock() else {
            return;
        };
        let thread_number = state.thread_number(thread_id);
        let start_us = state.base.elapsed().as_micros();
        if let Some(span_state) = state.spans.get_mut(&span.into_u64()) {
            if span_state.start_us.is_none() {
                span_state.start_us = Some(start_us);
                span_state.thread_id = Some(thread_number);
            }
        }
        state
            .stacks
            .entry(thread_id)
            .or_default()
            .push(span.into_u64());
    }

    fn exit(&self, span: &Id) {
        let thread_id = std::thread::current().id();
        let Ok(mut state) = self.state.lock() else {
            return;
        };
        let end_us = state.base.elapsed().as_micros();
        let raw_id = span.into_u64();
        let stack_ids = state.stacks.get(&thread_id).cloned().unwrap_or_default();
        let stack = stack_ids
            .iter()
            .filter_map(|id| state.spans.get(id).map(span_stack_name))
            .collect::<Vec<_>>();
        let record = if let Some(span_state) = state.spans.get_mut(&raw_id) {
            let start_us = span_state.start_us.unwrap_or(end_us);
            let duration_us = end_us.saturating_sub(start_us);
            span_state.duration_us = Some(duration_us);
            Some(TraceRecord {
                name: span_state.name.clone(),
                layer_id: span_state.layer_id.clone(),
                op_id: span_state.op_id.clone(),
                verdict: span_state.verdict.clone(),
                start_us,
                duration_us,
                stack,
                thread_id: span_state.thread_id.unwrap_or(0),
            })
        } else {
            None
        };
        if let Some(record) = record {
            state.completed.push(record);
        }
        if let Some(stack) = state.stacks.get_mut(&thread_id) {
            if stack.last().copied() == Some(raw_id) {
                stack.pop();
            } else if let Some(position) = stack.iter().rposition(|id| *id == raw_id) {
                stack.remove(position);
            }
        }
    }
}

struct RecorderState {
    next_id: u64,
    base: Instant,
    spans: HashMap<u64, SpanState>,
    stacks: HashMap<std::thread::ThreadId, Vec<u64>>,
    thread_numbers: HashMap<std::thread::ThreadId, u64>,
    next_thread_number: u64,
    completed: Vec<TraceRecord>,
}

impl RecorderState {
    fn thread_number(&mut self, thread_id: std::thread::ThreadId) -> u64 {
        if let Some(number) = self.thread_numbers.get(&thread_id) {
            return *number;
        }
        let number = self.next_thread_number;
        self.next_thread_number = self.next_thread_number.saturating_add(1);
        self.thread_numbers.insert(thread_id, number);
        number
    }
}

struct SpanState {
    name: String,
    layer_id: Option<String>,
    op_id: Option<String>,
    verdict: Option<String>,
    start_us: Option<u128>,
    duration_us: Option<u128>,
    thread_id: Option<u64>,
}

#[derive(Default)]
struct FieldMap {
    layer_id: Option<String>,
    op_id: Option<String>,
    verdict: Option<String>,
}

impl Visit for FieldMap {
    fn record_debug(&mut self, field: &Field, value: &dyn core::fmt::Debug) {
        self.record_value(field, format!("{value:?}"));
    }

    fn record_str(&mut self, field: &Field, value: &str) {
        self.record_value(field, value.to_string());
    }

    fn record_bool(&mut self, field: &Field, value: bool) {
        self.record_value(field, value.to_string());
    }

    fn record_i64(&mut self, field: &Field, value: i64) {
        self.record_value(field, value.to_string());
    }

    fn record_u64(&mut self, field: &Field, value: u64) {
        self.record_value(field, value.to_string());
    }
}

impl FieldMap {
    fn record_value(&mut self, field: &Field, value: String) {
        match field.name() {
            "layer_id" => self.layer_id = Some(value),
            "op_id" => self.op_id = Some(value),
            "verdict" => self.verdict = Some(value),
            _ => {}
        }
    }
}

fn span_stack_name(span: &SpanState) -> String {
    match (&span.layer_id, &span.op_id) {
        (Some(layer), Some(op)) => format!("{layer}:{op}"),
        (Some(layer), None) => layer.clone(),
        (None, Some(op)) => op.clone(),
        (None, None) => span.name.clone(),
    }
}