adk_telemetry/
span_exporter.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use tracing::{Id, Subscriber, debug};
4use tracing_subscriber::{Layer, layer::Context, registry::LookupSpan};
5
6/// ADK-Go style span exporter that stores spans by event_id
7/// Follows the pattern from APIServerSpanExporter in ADK-Go
8#[derive(Debug, Clone, Default)]
9pub struct AdkSpanExporter {
10    /// Map of event_id -> span attributes (following ADK-Go pattern)
11    trace_dict: Arc<RwLock<HashMap<String, HashMap<String, String>>>>,
12}
13
14impl AdkSpanExporter {
15    pub fn new() -> Self {
16        Self { trace_dict: Arc::new(RwLock::new(HashMap::new())) }
17    }
18
19    /// Get trace dict (following ADK-Go GetTraceDict method)
20    pub fn get_trace_dict(&self) -> HashMap<String, HashMap<String, String>> {
21        self.trace_dict.read().unwrap().clone()
22    }
23
24    /// Get trace by event_id (following ADK-Go pattern)
25    pub fn get_trace_by_event_id(&self, event_id: &str) -> Option<HashMap<String, String>> {
26        debug!("AdkSpanExporter::get_trace_by_event_id called with event_id: {}", event_id);
27        let trace_dict = self.trace_dict.read().unwrap();
28        let result = trace_dict.get(event_id).cloned();
29        debug!("get_trace_by_event_id result for event_id '{}': {:?}", event_id, result.is_some());
30        result
31    }
32
33    /// Get all spans for a session (by filtering spans that have matching session_id)
34    pub fn get_session_trace(&self, session_id: &str) -> Vec<HashMap<String, String>> {
35        debug!("AdkSpanExporter::get_session_trace called with session_id: {}", session_id);
36        let trace_dict = self.trace_dict.read().unwrap();
37
38        let mut spans = Vec::new();
39        for (_event_id, attributes) in trace_dict.iter() {
40            // Check if this span belongs to the session
41            if let Some(span_session_id) = attributes.get("gcp.vertex.agent.session_id") {
42                if span_session_id == session_id {
43                    spans.push(attributes.clone());
44                }
45            }
46        }
47
48        debug!("get_session_trace result for session_id '{}': {} spans", session_id, spans.len());
49        spans
50    }
51
52    /// Internal method to store span (following ADK-Go ExportSpans pattern)
53    fn export_span(&self, span_name: &str, attributes: HashMap<String, String>) {
54        // Only capture specific span names (following ADK-Go pattern)
55        if span_name == "agent.execute"
56            || span_name == "call_llm"
57            || span_name == "send_data"
58            || span_name.starts_with("execute_tool")
59        {
60            if let Some(event_id) = attributes.get("gcp.vertex.agent.event_id") {
61                debug!(
62                    "AdkSpanExporter: Storing span '{}' with event_id '{}'",
63                    span_name, event_id
64                );
65                let mut trace_dict = self.trace_dict.write().unwrap();
66                trace_dict.insert(event_id.clone(), attributes);
67                debug!("AdkSpanExporter: Span stored, total event_ids: {}", trace_dict.len());
68            } else {
69                debug!("AdkSpanExporter: Skipping span '{}' - no event_id found", span_name);
70            }
71        } else {
72            debug!("AdkSpanExporter: Skipping span '{}' - not in allowed list", span_name);
73        }
74    }
75}
76
77/// Tracing layer that captures spans and exports them via AdkSpanExporter
78pub struct AdkSpanLayer {
79    exporter: Arc<AdkSpanExporter>,
80}
81
82impl AdkSpanLayer {
83    pub fn new(exporter: Arc<AdkSpanExporter>) -> Self {
84        Self { exporter }
85    }
86}
87
88#[derive(Clone)]
89struct SpanFields(HashMap<String, String>);
90
91#[derive(Clone)]
92struct SpanTiming {
93    start_time: std::time::Instant,
94}
95
96impl<S> Layer<S> for AdkSpanLayer
97where
98    S: Subscriber + for<'a> LookupSpan<'a>,
99{
100    fn on_new_span(&self, attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
101        let span = ctx.span(id).expect("Span not found");
102        let mut extensions = span.extensions_mut();
103
104        // Record start time
105        extensions.insert(SpanTiming { start_time: std::time::Instant::now() });
106
107        // Capture fields
108        let mut visitor = StringVisitor::default();
109        attrs.record(&mut visitor);
110        let mut fields_map = visitor.0;
111
112        // Propagate fields from parent span (for context inheritance)
113        if let Some(parent) = span.parent() {
114            if let Some(parent_fields) = parent.extensions().get::<SpanFields>() {
115                let context_keys = [
116                    "gcp.vertex.agent.session_id",
117                    "gcp.vertex.agent.invocation_id",
118                    "gcp.vertex.agent.event_id",
119                ];
120
121                for key in context_keys {
122                    if !fields_map.contains_key(key) {
123                        if let Some(val) = parent_fields.0.get(key) {
124                            fields_map.insert(key.to_string(), val.clone());
125                        }
126                    }
127                }
128            }
129        }
130
131        extensions.insert(SpanFields(fields_map));
132    }
133
134    fn on_record(&self, id: &Id, values: &tracing::span::Record<'_>, ctx: Context<'_, S>) {
135        let span = ctx.span(id).expect("Span not found");
136        let mut extensions = span.extensions_mut();
137        if let Some(fields) = extensions.get_mut::<SpanFields>() {
138            let mut visitor = StringVisitor::default();
139            values.record(&mut visitor);
140            for (k, v) in visitor.0 {
141                fields.0.insert(k, v);
142            }
143        }
144    }
145
146    fn on_close(&self, id: Id, ctx: Context<'_, S>) {
147        let span = ctx.span(&id).expect("Span not found");
148        let extensions = span.extensions();
149
150        // Calculate actual duration
151        let timing = extensions.get::<SpanTiming>();
152        let end_time = std::time::Instant::now();
153        let duration_nanos =
154            timing.map(|t| end_time.duration_since(t.start_time).as_nanos() as u64).unwrap_or(0);
155
156        // Get captured fields
157        let mut attributes =
158            extensions.get::<SpanFields>().map(|f| f.0.clone()).unwrap_or_default();
159
160        // Get span name - prefer otel.name attribute (for dynamic names), fallback to metadata
161        let metadata = span.metadata();
162        let span_name =
163            attributes.get("otel.name").cloned().unwrap_or_else(|| metadata.name().to_string());
164
165        // Add span metadata and actual timing with unique IDs
166        let now_nanos = std::time::SystemTime::now()
167            .duration_since(std::time::UNIX_EPOCH)
168            .unwrap_or_default()
169            .as_nanos() as u64;
170
171        // Use invocation_id as trace_id (for grouping in UI)
172        // Use event_id as span_id (for uniqueness)
173        let invocation_id = attributes
174            .get("gcp.vertex.agent.invocation_id")
175            .cloned()
176            .unwrap_or_else(|| format!("{:016x}", id.into_u64()));
177        let event_id = attributes
178            .get("gcp.vertex.agent.event_id")
179            .cloned()
180            .unwrap_or_else(|| format!("{:016x}", id.into_u64()));
181
182        attributes.insert("span_name".to_string(), span_name.clone());
183        attributes.insert("trace_id".to_string(), invocation_id); // Group by invocation
184        attributes.insert("span_id".to_string(), event_id); // Unique per span
185        attributes.insert("start_time".to_string(), (now_nanos - duration_nanos).to_string());
186        attributes.insert("end_time".to_string(), now_nanos.to_string());
187
188        // Don't set parent_span_id to keep all spans at same level like ADK-Go
189
190        // Export the span
191        self.exporter.export_span(&span_name, attributes);
192    }
193}
194
195#[derive(Default)]
196struct StringVisitor(HashMap<String, String>);
197
198impl tracing::field::Visit for StringVisitor {
199    fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
200        self.0.insert(field.name().to_string(), format!("{:?}", value));
201    }
202
203    fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
204        self.0.insert(field.name().to_string(), value.to_string());
205    }
206
207    fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
208        self.0.insert(field.name().to_string(), value.to_string());
209    }
210
211    fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
212        self.0.insert(field.name().to_string(), value.to_string());
213    }
214
215    fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
216        self.0.insert(field.name().to_string(), value.to_string());
217    }
218
219    fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
220        self.0.insert(field.name().to_string(), value.to_string());
221    }
222}