use std::sync::Arc;
use std::time::Instant;
use tracing::Subscriber;
use tracing_subscriber::Layer;
use tracing_subscriber::layer::Context;
use tracing_subscriber::registry::LookupSpan;
use crate::metrics::MetricsCollector;
const WATCHED_SPANS: &[(&str, TimingField)] = &[
("agent.prepare_context", TimingField::PrepareContext),
("llm.chat", TimingField::LlmChat),
("agent.tool_loop", TimingField::ToolExec),
("agent.persist_message", TimingField::PersistMessage),
];
#[derive(Clone, Copy)]
enum TimingField {
PrepareContext,
LlmChat,
ToolExec,
PersistMessage,
}
struct WatchedSpan;
struct SpanEntry(Instant);
struct SpanTiming(u64);
pub struct MetricsBridge {
collector: Arc<MetricsCollector>,
}
impl MetricsBridge {
#[must_use]
pub fn new(collector: Arc<MetricsCollector>) -> Self {
Self { collector }
}
}
impl<S> Layer<S> for MetricsBridge
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_new_span(
&self,
attrs: &tracing::span::Attributes<'_>,
id: &tracing::span::Id,
ctx: Context<'_, S>,
) {
let name = attrs.metadata().name();
if WATCHED_SPANS.iter().any(|(n, _)| *n == name)
&& let Some(span) = ctx.span(id)
{
span.extensions_mut().insert(WatchedSpan);
}
}
fn on_enter(&self, id: &tracing::span::Id, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
if span.extensions().get::<WatchedSpan>().is_some() {
span.extensions_mut().replace(SpanEntry(Instant::now()));
}
}
}
fn on_exit(&self, id: &tracing::span::Id, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
let elapsed_ms = span
.extensions()
.get::<SpanEntry>()
.map(|e| u64::try_from(e.0.elapsed().as_millis()).unwrap_or(u64::MAX));
if let Some(elapsed_ms) = elapsed_ms {
let mut exts = span.extensions_mut();
if let Some(timing) = exts.get_mut::<SpanTiming>() {
timing.0 = timing.0.saturating_add(elapsed_ms);
} else {
exts.insert(SpanTiming(elapsed_ms));
}
}
}
}
fn on_close(&self, id: tracing::span::Id, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(&id) {
let exts = span.extensions();
if let Some(timing) = exts.get::<SpanTiming>() {
let duration_ms = timing.0;
let name = span.name();
if let Some((_, field)) = WATCHED_SPANS.iter().find(|(n, _)| *n == name) {
let field = *field;
self.collector.update(|m| match field {
TimingField::PrepareContext => {
m.last_turn_timings.prepare_context_ms = duration_ms;
}
TimingField::LlmChat => {
m.last_turn_timings.llm_chat_ms = duration_ms;
}
TimingField::ToolExec => {
m.last_turn_timings.tool_exec_ms = duration_ms;
}
TimingField::PersistMessage => {
m.last_turn_timings.persist_message_ms = duration_ms;
}
});
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tracing_subscriber::Registry;
use tracing_subscriber::layer::SubscriberExt;
use super::MetricsBridge;
use crate::metrics::MetricsCollector;
fn make_bridge() -> (
MetricsBridge,
Arc<MetricsCollector>,
tokio::sync::watch::Receiver<crate::metrics::MetricsSnapshot>,
) {
let (collector, rx) = MetricsCollector::new();
let arc = Arc::new(collector);
(MetricsBridge::new(Arc::clone(&arc)), arc, rx)
}
#[test]
fn watched_span_updates_correct_field() {
let (bridge, _collector, rx) = make_bridge();
let subscriber = Registry::default().with(bridge);
tracing::subscriber::with_default(subscriber, || {
let span = tracing::span!(tracing::Level::INFO, "llm.chat");
let guard = span.enter();
drop(guard);
});
let snapshot = rx.borrow().clone();
assert_eq!(snapshot.last_turn_timings.prepare_context_ms, 0);
assert_eq!(snapshot.last_turn_timings.tool_exec_ms, 0);
assert_eq!(snapshot.last_turn_timings.persist_message_ms, 0);
let _ = snapshot.last_turn_timings.llm_chat_ms;
}
#[test]
fn non_watched_span_produces_no_update() {
let (bridge, _collector, rx) = make_bridge();
let subscriber = Registry::default().with(bridge);
tracing::subscriber::with_default(subscriber, || {
let span = tracing::span!(tracing::Level::INFO, "some.other.span");
let guard = span.enter();
drop(guard);
});
let snapshot = rx.borrow().clone();
assert_eq!(snapshot.last_turn_timings.prepare_context_ms, 0);
assert_eq!(snapshot.last_turn_timings.llm_chat_ms, 0);
assert_eq!(snapshot.last_turn_timings.tool_exec_ms, 0);
assert_eq!(snapshot.last_turn_timings.persist_message_ms, 0);
}
#[test]
fn all_watched_span_names_registered() {
let expected = [
"agent.prepare_context",
"llm.chat",
"agent.tool_loop",
"agent.persist_message",
];
for span_name in expected {
assert!(
super::WATCHED_SPANS.iter().any(|(n, _)| *n == span_name),
"span '{span_name}' not in WATCHED_SPANS",
);
}
assert_eq!(
super::WATCHED_SPANS.len(),
expected.len(),
"unexpected extra spans in WATCHED_SPANS"
);
}
}