1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use tracing::{Id, Subscriber, debug};
4use tracing_subscriber::{Layer, layer::Context, registry::LookupSpan};
5
6#[derive(Debug, Clone, Default)]
9pub struct AdkSpanExporter {
10 trace_dict: Arc<RwLock<HashMap<String, HashMap<String, String>>>>,
12}
13
14impl AdkSpanExporter {
15 pub fn new() -> Self {
16 Self { trace_dict: Arc::new(RwLock::new(HashMap::new())) }
17 }
18
19 pub fn get_trace_dict(&self) -> HashMap<String, HashMap<String, String>> {
21 self.trace_dict.read().unwrap_or_else(|e| e.into_inner()).clone()
22 }
23
24 pub fn get_trace_by_event_id(&self, event_id: &str) -> Option<HashMap<String, String>> {
26 debug!("AdkSpanExporter::get_trace_by_event_id called with event_id: {}", event_id);
27 let trace_dict = self.trace_dict.read().unwrap_or_else(|e| e.into_inner());
28 let result = trace_dict.get(event_id).cloned();
29 debug!("get_trace_by_event_id result for event_id '{}': {:?}", event_id, result.is_some());
30 result
31 }
32
33 pub fn get_session_trace(&self, session_id: &str) -> Vec<HashMap<String, String>> {
35 debug!("AdkSpanExporter::get_session_trace called with session_id: {}", session_id);
36 let trace_dict = self.trace_dict.read().unwrap_or_else(|e| e.into_inner());
37
38 let mut spans = Vec::new();
39 for (_event_id, attributes) in trace_dict.iter() {
40 if let Some(span_session_id) = attributes.get("gcp.vertex.agent.session_id") {
42 if span_session_id == session_id {
43 spans.push(attributes.clone());
44 }
45 }
46 }
47
48 debug!("get_session_trace result for session_id '{}': {} spans", session_id, spans.len());
49 spans
50 }
51
52 fn export_span(&self, span_name: &str, attributes: HashMap<String, String>) {
54 if span_name == "agent.execute"
56 || span_name == "call_llm"
57 || span_name == "send_data"
58 || span_name.starts_with("execute_tool")
59 {
60 if let Some(event_id) = attributes.get("gcp.vertex.agent.event_id") {
61 debug!(
62 "AdkSpanExporter: Storing span '{}' with event_id '{}'",
63 span_name, event_id
64 );
65 let mut trace_dict = self.trace_dict.write().unwrap_or_else(|e| e.into_inner());
66 trace_dict.insert(event_id.clone(), attributes);
67 debug!("AdkSpanExporter: Span stored, total event_ids: {}", trace_dict.len());
68 } else {
69 debug!("AdkSpanExporter: Skipping span '{}' - no event_id found", span_name);
70 }
71 } else {
72 debug!("AdkSpanExporter: Skipping span '{}' - not in allowed list", span_name);
73 }
74 }
75}
76
77pub struct AdkSpanLayer {
79 exporter: Arc<AdkSpanExporter>,
80}
81
82impl AdkSpanLayer {
83 pub fn new(exporter: Arc<AdkSpanExporter>) -> Self {
84 Self { exporter }
85 }
86}
87
88#[derive(Clone)]
89struct SpanFields(HashMap<String, String>);
90
91#[derive(Clone)]
92struct SpanTiming {
93 start_time: std::time::Instant,
94}
95
96impl<S> Layer<S> for AdkSpanLayer
97where
98 S: Subscriber + for<'a> LookupSpan<'a>,
99{
100 fn on_new_span(&self, attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
101 let Some(span) = ctx.span(id) else { return };
102 let mut extensions = span.extensions_mut();
103
104 extensions.insert(SpanTiming { start_time: std::time::Instant::now() });
106
107 let mut visitor = StringVisitor::default();
109 attrs.record(&mut visitor);
110 let mut fields_map = visitor.0;
111
112 if let Some(parent) = span.parent() {
114 if let Some(parent_fields) = parent.extensions().get::<SpanFields>() {
115 let context_keys = [
116 "gcp.vertex.agent.session_id",
117 "gcp.vertex.agent.invocation_id",
118 "gcp.vertex.agent.event_id",
119 "gen_ai.conversation.id",
120 #[cfg(feature = "genai-semconv")]
121 "gen_ai.provider.name",
122 #[cfg(feature = "genai-semconv")]
123 "gen_ai.system",
124 ];
125
126 for key in context_keys {
127 if !fields_map.contains_key(key) {
128 if let Some(val) = parent_fields.0.get(key) {
129 fields_map.insert(key.to_string(), val.clone());
130 }
131 }
132 }
133 }
134 }
135
136 extensions.insert(SpanFields(fields_map));
137 }
138
139 fn on_record(&self, id: &Id, values: &tracing::span::Record<'_>, ctx: Context<'_, S>) {
140 let Some(span) = ctx.span(id) else { return };
141 let mut extensions = span.extensions_mut();
142 if let Some(fields) = extensions.get_mut::<SpanFields>() {
143 let mut visitor = StringVisitor::default();
144 values.record(&mut visitor);
145 for (k, v) in visitor.0 {
146 fields.0.insert(k, v);
147 }
148 }
149 }
150
151 fn on_close(&self, id: Id, ctx: Context<'_, S>) {
152 let Some(span) = ctx.span(&id) else { return };
153 let extensions = span.extensions();
154
155 let timing = extensions.get::<SpanTiming>();
157 let end_time = std::time::Instant::now();
158 let duration_nanos =
159 timing.map(|t| end_time.duration_since(t.start_time).as_nanos() as u64).unwrap_or(0);
160
161 let mut attributes =
163 extensions.get::<SpanFields>().map(|f| f.0.clone()).unwrap_or_default();
164
165 let metadata = span.metadata();
167 let span_name =
168 attributes.get("otel.name").cloned().unwrap_or_else(|| metadata.name().to_string());
169
170 let now_nanos = std::time::SystemTime::now()
172 .duration_since(std::time::UNIX_EPOCH)
173 .unwrap_or_default()
174 .as_nanos() as u64;
175
176 let invocation_id = attributes
179 .get("gcp.vertex.agent.invocation_id")
180 .cloned()
181 .unwrap_or_else(|| format!("{:016x}", id.into_u64()));
182 let event_id = attributes
183 .get("gcp.vertex.agent.event_id")
184 .cloned()
185 .unwrap_or_else(|| format!("{:016x}", id.into_u64()));
186
187 attributes.insert("span_name".to_string(), span_name.clone());
188 attributes.insert("trace_id".to_string(), invocation_id); attributes.insert("span_id".to_string(), event_id); attributes.insert("start_time".to_string(), (now_nanos - duration_nanos).to_string());
191 attributes.insert("end_time".to_string(), now_nanos.to_string());
192
193 self.exporter.export_span(&span_name, attributes);
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use std::sync::Arc;
204 use tracing_subscriber::layer::SubscriberExt;
205
206 #[test]
207 fn test_conversation_id_propagates_to_child_spans() {
208 let exporter = Arc::new(AdkSpanExporter::new());
209 let layer = AdkSpanLayer::new(exporter.clone());
210 let subscriber = tracing_subscriber::registry().with(layer);
211
212 tracing::subscriber::with_default(subscriber, || {
213 let parent = tracing::info_span!(
214 "agent.execute",
215 "gcp.vertex.agent.event_id" = "evt-parent",
216 "gcp.vertex.agent.invocation_id" = "inv-1",
217 "gcp.vertex.agent.session_id" = "session-1",
218 "gen_ai.conversation.id" = "session-1",
219 "agent.name" = "test-agent"
220 );
221
222 let _parent_guard = parent.enter();
223
224 let child = tracing::info_span!(
225 "call_llm",
226 "gcp.vertex.agent.event_id" = "evt-child",
227 "gcp.vertex.agent.llm_request" = "{}"
228 );
229 let _child_guard = child.enter();
230 tracing::info!("child span body");
231 });
232
233 let child_trace =
234 exporter.get_trace_by_event_id("evt-child").expect("child span should be exported");
235 assert_eq!(
236 child_trace.get("gen_ai.conversation.id").map(String::as_str),
237 Some("session-1")
238 );
239 }
240}
241
242#[derive(Default)]
243struct StringVisitor(HashMap<String, String>);
244
245impl tracing::field::Visit for StringVisitor {
246 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
247 self.0.insert(field.name().to_string(), format!("{:?}", value));
248 }
249
250 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
251 self.0.insert(field.name().to_string(), value.to_string());
252 }
253
254 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
255 self.0.insert(field.name().to_string(), value.to_string());
256 }
257
258 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
259 self.0.insert(field.name().to_string(), value.to_string());
260 }
261
262 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
263 self.0.insert(field.name().to_string(), value.to_string());
264 }
265
266 fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
267 self.0.insert(field.name().to_string(), value.to_string());
268 }
269}