1use std::sync::{Arc, Mutex};
2use std::time::{Duration, SystemTime};
3
4use opentelemetry::global;
5use opentelemetry::trace::{Span, SpanKind, Tracer};
6use opentelemetry::KeyValue;
7use opentelemetry_otlp::WithExportConfig;
8use opentelemetry_sdk;
9use serde::{Deserialize, Serialize};
10use serde_json;
11use tokio::time::sleep;
12use tracing::{span, Level};
13use tracing_subscriber::layer::SubscriberExt;
14use tracing_subscriber::util::SubscriberInitExt;
15use tracing_subscriber::{EnvFilter, Registry};
16
17use crate::error::{AgnoError, Result};
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
20pub struct TelemetryLabels {
21 pub tenant: Option<String>,
22 pub tool: Option<String>,
23 pub workflow: Option<String>,
24}
25
26impl TelemetryLabels {
27 pub fn with_tenant(mut self, tenant: impl Into<String>) -> Self {
28 self.tenant = Some(tenant.into());
29 self
30 }
31
32 pub fn with_tool(mut self, tool: impl Into<String>) -> Self {
33 self.tool = Some(tool.into());
34 self
35 }
36
37 pub fn with_workflow(mut self, workflow: impl Into<String>) -> Self {
38 self.workflow = Some(workflow.into());
39 self
40 }
41
42 pub fn as_attributes(&self) -> Vec<KeyValue> {
43 let mut attrs = Vec::new();
44 if let Some(tenant) = &self.tenant {
45 attrs.push(KeyValue::new("tenant", tenant.clone()));
46 }
47 if let Some(tool) = &self.tool {
48 attrs.push(KeyValue::new("tool", tool.clone()));
49 }
50 if let Some(workflow) = &self.workflow {
51 attrs.push(KeyValue::new("workflow", workflow.clone()));
52 }
53 attrs
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TelemetryEvent {
59 pub kind: String,
60 pub timestamp: SystemTime,
61 pub detail: serde_json::Value,
62 pub labels: TelemetryLabels,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FailureRecord {
67 pub context: String,
68 pub error: String,
69 pub attempt: u32,
70 pub labels: TelemetryLabels,
71}
72
73#[derive(Default, Clone)]
74pub struct TelemetryCollector {
75 events: Arc<Mutex<Vec<TelemetryEvent>>>,
76 failures: Arc<Mutex<Vec<FailureRecord>>>,
77}
78
79impl TelemetryCollector {
80 pub fn record(
81 &self,
82 kind: impl Into<String>,
83 detail: serde_json::Value,
84 labels: TelemetryLabels,
85 ) {
86 self.events.lock().unwrap().push(TelemetryEvent {
87 kind: kind.into(),
88 timestamp: SystemTime::now(),
89 detail,
90 labels,
91 });
92 }
93
94 pub fn record_failure(
95 &self,
96 context: impl Into<String>,
97 error: impl Into<String>,
98 attempt: u32,
99 labels: TelemetryLabels,
100 ) {
101 self.failures.lock().unwrap().push(FailureRecord {
102 context: context.into(),
103 error: error.into(),
104 attempt,
105 labels,
106 });
107 }
108
109 pub fn drain(&self) -> (Vec<TelemetryEvent>, Vec<FailureRecord>) {
110 let mut events = self.events.lock().unwrap();
111 let mut failures = self.failures.lock().unwrap();
112 (std::mem::take(&mut *events), std::mem::take(&mut *failures))
113 }
114}
115
116#[derive(Default, Clone)]
117pub struct TelemetrySink {
118 buffer: Arc<Mutex<Vec<TelemetryEvent>>>,
119}
120
121impl TelemetrySink {
122 pub fn push(&self, event: TelemetryEvent) {
123 self.buffer.lock().unwrap().push(event);
124 }
125
126 pub fn flush(&self) -> Vec<TelemetryEvent> {
127 let mut guard = self.buffer.lock().unwrap();
128 std::mem::take(&mut *guard)
129 }
130}
131
132#[derive(Debug, Clone)]
133pub struct RetryPolicy {
134 pub max_retries: u32,
135 pub backoff: Duration,
136}
137
138impl RetryPolicy {
139 pub fn default_external_call() -> Self {
140 Self {
141 max_retries: 3,
142 backoff: Duration::from_millis(200),
143 }
144 }
145
146 pub async fn retry<F, Fut, T>(
147 &self,
148 mut f: F,
149 telemetry: Option<&TelemetryCollector>,
150 labels: TelemetryLabels,
151 ) -> Result<T>
152 where
153 F: FnMut(u32) -> Fut,
154 Fut: std::future::Future<Output = Result<T>>,
155 {
156 for attempt in 0..=self.max_retries {
157 match f(attempt).await {
158 Ok(value) => return Ok(value),
159 Err(err) => {
160 if let Some(t) = telemetry {
161 t.record_failure("retry", format!("{err}"), attempt, labels.clone());
162 }
163 let span = span!(
164 Level::INFO,
165 "retry_failure",
166 attempt,
167 tenant = labels.tenant.as_deref().unwrap_or(""),
168 tool = labels.tool.as_deref().unwrap_or(""),
169 workflow = labels.workflow.as_deref().unwrap_or("")
170 );
171 let _enter = span.enter();
172 tracing::warn!("retry attempt {} failed: {}", attempt, err);
173 if attempt == self.max_retries {
174 return Err(err);
175 }
176 sleep(self.backoff * (attempt + 1)).await;
177 }
178 }
179 }
180 Err(AgnoError::Protocol("retry exhausted".into()))
181 }
182}
183
184#[derive(Clone)]
185pub struct FallbackChain<T> {
186 steps: Vec<(String, Arc<dyn Fn() -> Result<T> + Send + Sync>)>,
187}
188
189impl<T> std::fmt::Debug for FallbackChain<T> {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 let labels: Vec<&str> = self.steps.iter().map(|(label, _)| label.as_str()).collect();
192 f.debug_struct("FallbackChain")
193 .field("steps", &labels)
194 .finish()
195 }
196}
197
198impl<T> FallbackChain<T> {
199 pub fn new() -> Self {
200 Self { steps: Vec::new() }
201 }
202
203 pub fn with_step(
204 mut self,
205 label: impl Into<String>,
206 handler: impl Fn() -> Result<T> + Send + Sync + 'static,
207 ) -> Self {
208 self.steps.push((label.into(), Arc::new(handler)));
209 self
210 }
211
212 pub fn execute(
213 &self,
214 telemetry: Option<&TelemetryCollector>,
215 labels: TelemetryLabels,
216 ) -> Result<T> {
217 let mut last_error: Option<AgnoError> = None;
218 for (label, handler) in self.steps.iter() {
219 let span = span!(
220 Level::DEBUG,
221 "fallback_step",
222 step = label.as_str(),
223 tenant = labels.tenant.as_deref().unwrap_or(""),
224 tool = labels.tool.as_deref().unwrap_or(""),
225 workflow = labels.workflow.as_deref().unwrap_or("")
226 );
227 let _guard = span.enter();
228 match handler() {
229 Ok(value) => {
230 if let Some(t) = telemetry {
231 t.record(
232 "fallback_success",
233 serde_json::json!({ "step": label }),
234 labels.clone(),
235 );
236 }
237 tracing::info!("fallback step succeeded");
238 return Ok(value);
239 }
240 Err(err) => {
241 if let Some(t) = telemetry {
242 t.record_failure(label.clone(), format!("{err}"), 0, labels.clone());
243 }
244 tracing::warn!("fallback step failed: {}", err);
245 last_error = Some(err);
246 }
247 }
248 }
249 Err(last_error.unwrap_or_else(|| AgnoError::Protocol("fallback exhausted".into())))
250 }
251}
252
253pub fn span_with_labels(_name: &str, labels: &TelemetryLabels) -> tracing::Span {
254 span!(
255 Level::INFO,
256 "labeled_span",
257 tenant = labels.tenant.as_deref().unwrap_or(""),
258 tool = labels.tool.as_deref().unwrap_or(""),
259 workflow = labels.workflow.as_deref().unwrap_or("")
260 )
261}
262
263pub fn init_tracing(service_name: &str, otlp_endpoint: Option<&str>) -> Result<()> {
264 let trace_config = opentelemetry_sdk::trace::config().with_resource(
265 opentelemetry_sdk::Resource::new(vec![KeyValue::new(
266 "service.name",
267 service_name.to_owned(),
268 )]),
269 );
270
271 let tracer = if let Some(endpoint) = otlp_endpoint {
272 opentelemetry_otlp::new_pipeline()
273 .tracing()
274 .with_trace_config(trace_config)
275 .with_exporter(
276 opentelemetry_otlp::new_exporter()
277 .tonic()
278 .with_endpoint(endpoint),
279 )
280 .install_batch(opentelemetry_sdk::runtime::Tokio)
281 .map_err(|e| AgnoError::Telemetry(e.to_string()))?
282 } else {
283 opentelemetry_otlp::new_pipeline()
284 .tracing()
285 .with_trace_config(trace_config)
286 .with_exporter(opentelemetry_otlp::new_exporter().tonic())
287 .install_batch(opentelemetry_sdk::runtime::Tokio)
288 .map_err(|e| AgnoError::Telemetry(e.to_string()))?
289 };
290
291 let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
292 let fmt_layer = tracing_subscriber::fmt::layer().json().with_target(true);
293 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
294 Registry::default()
295 .with(env_filter)
296 .with(fmt_layer)
297 .with(telemetry)
298 .try_init()
299 .map_err(|e| AgnoError::Telemetry(format!("failed to init tracing: {e}")))?;
300 Ok(())
301}
302
303pub fn current_span_attributes(labels: &TelemetryLabels) {
304 let tracer = global::tracer("agno-tracer");
305 let mut span = tracer
306 .span_builder("context")
307 .with_kind(SpanKind::Internal)
308 .with_attributes(labels.as_attributes())
309 .start(&tracer);
310 span.add_event("context attached".to_string(), labels.as_attributes());
311 span.end();
312}
313
314pub fn flush_tracer() {
315 global::shutdown_tracer_provider();
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[tokio::test]
324 async fn retries_until_success() {
325 let policy = RetryPolicy {
326 max_retries: 2,
327 backoff: Duration::from_millis(1),
328 };
329 use std::sync::Arc;
330 use tokio::sync::Mutex;
331
332 let calls = Arc::new(Mutex::new(0u32));
333 let telemetry = TelemetryCollector::default();
334 let labels = TelemetryLabels {
335 tenant: Some("tenant-a".into()),
336 tool: Some("retry".into()),
337 workflow: Some("test".into()),
338 };
339 let res = policy
340 .retry(
341 |_: u32| {
342 let calls = calls.clone();
343 async move {
344 let mut guard = calls.lock().await;
345 *guard += 1;
346 if *guard < 2 {
347 Err(AgnoError::Protocol("fail".into()))
348 } else {
349 Ok(42)
350 }
351 }
352 },
353 Some(&telemetry),
354 labels.clone(),
355 )
356 .await;
357 assert_eq!(res.unwrap(), 42);
358 let drained = telemetry.drain();
359 assert_eq!(drained.1.len(), 1);
360 assert_eq!(drained.1[0].labels, labels);
361 }
362
363 #[test]
364 fn runs_fallbacks() {
365 let telemetry = TelemetryCollector::default();
366 let labels = TelemetryLabels {
367 tenant: Some("tenant-a".into()),
368 tool: Some("fallback".into()),
369 workflow: Some("test".into()),
370 };
371 let chain = FallbackChain::new()
372 .with_step("primary", || Err(AgnoError::Protocol("nope".into())))
373 .with_step("secondary", || Ok("ok"));
374 let res = chain.execute(Some(&telemetry), labels.clone()).unwrap();
375 assert_eq!(res, "ok");
376 let drained = telemetry.drain();
377 assert_eq!(drained.1.len(), 1);
378 assert_eq!(drained.1[0].labels, labels);
379 }
380}