Skip to main content

ai_agents_observability/
manager.rs

1use crate::aggregator::{
2    AggregatedMetrics, MetricsAggregator, aggregate_events, enrich_dimensions,
3};
4use crate::config::{AggregationDimension, ExportFormat, ObservabilityConfig, UnknownPricePolicy};
5use crate::context::{SpanContext, current_observation_context};
6use crate::cost::CostEstimator;
7use crate::event::{
8    CostEstimate, EventStatus, EventType, ObservationEvent, ObservationPurpose,
9    ObservationTokenUsage,
10};
11use crate::export::{ExportResult, export_observability};
12use crate::redaction::Redactor;
13use crate::report::{ObservabilityReport, generate_report};
14use crate::span::SpanGuard;
15use crate::{ObservabilityError, Result};
16use chrono::Utc;
17use parking_lot::{Mutex, RwLock};
18use serde_json::Value;
19use std::collections::{HashMap, VecDeque};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::Duration;
23use tokio::sync::mpsc;
24use uuid::Uuid;
25
26/// Central collector that receives events, applies privacy rules, aggregates metrics, and exports reports.
27pub struct ObservabilityManager {
28    config: ObservabilityConfig,
29    sender: mpsc::Sender<ObservationEvent>,
30    receiver: Mutex<mpsc::Receiver<ObservationEvent>>,
31    raw_events: RwLock<VecDeque<ObservationEvent>>,
32    pending_branch_events: RwLock<HashMap<String, Vec<ObservationEvent>>>,
33    aggregator: MetricsAggregator,
34    cost_estimator: CostEstimator,
35    redactor: Redactor,
36    dropped_events: AtomicU64,
37}
38
39impl ObservabilityManager {
40    /// Creates a shared manager with bounded event buffering.
41    pub fn new(config: ObservabilityConfig) -> Arc<Self> {
42        let _ = config.validate();
43        let (sender, receiver) = mpsc::channel(config.buffer.event_buffer.max(1));
44        Arc::new(Self {
45            cost_estimator: CostEstimator::new(config.cost.clone()),
46            redactor: Redactor::new(config.privacy.clone()),
47            aggregator: MetricsAggregator::new(config.aggregation.clone()),
48            sender,
49            receiver: Mutex::new(receiver),
50            raw_events: RwLock::new(VecDeque::new()),
51            pending_branch_events: RwLock::new(HashMap::new()),
52            dropped_events: AtomicU64::new(0),
53            config,
54        })
55    }
56
57    /// Returns the immutable configuration used by this manager.
58    pub fn config(&self) -> &ObservabilityConfig {
59        &self.config
60    }
61
62    /// Starts a measured span for an LLM or tool wrapper.
63    pub fn start_span(
64        self: &Arc<Self>,
65        event_type: EventType,
66        purpose: ObservationPurpose,
67    ) -> SpanGuard {
68        let mut context = current_observation_context()
69            .map(|ctx| ctx.child())
70            .unwrap_or_else(|| SpanContext::new_root("unknown"));
71        context.purpose = purpose;
72        SpanGuard::new(Arc::clone(self), context, event_type)
73    }
74
75    /// Records hook-style lifecycle events that are not LLM or tool wrapper calls.
76    pub fn record_lifecycle_event(
77        &self,
78        event_type: EventType,
79        purpose: ObservationPurpose,
80        status: EventStatus,
81        duration_ms: u64,
82        tags: HashMap<String, String>,
83        payload: Option<Value>,
84    ) {
85        let context = current_observation_context()
86            .map(|ctx| ctx.child())
87            .unwrap_or_else(|| SpanContext::new_root("unknown"));
88        let mut dimensions = context_dimension_map(&context);
89        for (key, value) in &tags {
90            if key.starts_with("runtime.") {
91                dimensions.insert(key.clone(), value.clone());
92                if let Some(short_key) = key.strip_prefix("runtime.") {
93                    dimensions.insert(short_key.to_string(), value.clone());
94                }
95            }
96        }
97        let event = ObservationEvent {
98            trace_id: context.trace_id,
99            span_id: context.span_id,
100            parent_span_id: context.parent_span_id,
101            turn_id: context.turn_id,
102            agent_id: context.agent_id,
103            actor_id: context.actor_id,
104            session_id: context.session_id,
105            event_type,
106            purpose,
107            status,
108            timestamp: Utc::now(),
109            duration_ms,
110            tokens: None,
111            cost: None,
112            error: None,
113            dimensions,
114            tags,
115            payload,
116        };
117        self.record_event(event);
118    }
119
120    /// Records an event that should be finalized when its runtime branch resolves.
121    pub fn record_pending_event(&self, branch_id: impl Into<String>, event: ObservationEvent) {
122        if !self.config.enabled {
123            return;
124        }
125        let mut pending = self.pending_branch_events.write();
126        let pending_count: usize = pending.values().map(Vec::len).sum();
127        if pending_count >= self.config.buffer.pending_branch_event_limit {
128            self.dropped_events.fetch_add(1, Ordering::Relaxed);
129            return;
130        }
131        pending.entry(branch_id.into()).or_default().push(event);
132    }
133
134    /// Finalizes all pending events for a runtime branch and ingests them normally.
135    pub fn finalize_pending_branch(
136        &self,
137        branch_id: &str,
138        branch_status: impl Into<String>,
139        winner: bool,
140        extra_tags: HashMap<String, String>,
141    ) -> usize {
142        let mut events = self
143            .pending_branch_events
144            .write()
145            .remove(branch_id)
146            .unwrap_or_default();
147        let status = branch_status.into();
148        let count = events.len();
149        for event in &mut events {
150            event
151                .tags
152                .insert("runtime.branch_status".to_string(), status.clone());
153            event
154                .tags
155                .insert("runtime.winner".to_string(), winner.to_string());
156            event.tags.insert("winner".to_string(), winner.to_string());
157            event
158                .dimensions
159                .insert("branch_status".to_string(), status.clone());
160            event
161                .dimensions
162                .insert("runtime.branch_status".to_string(), status.clone());
163            event
164                .dimensions
165                .insert("runtime.winner".to_string(), winner.to_string());
166            event
167                .dimensions
168                .insert("winner".to_string(), winner.to_string());
169            for (key, value) in &extra_tags {
170                event.tags.insert(key.clone(), value.clone());
171                event.dimensions.insert(key.clone(), value.clone());
172            }
173            self.record_event(event.clone());
174        }
175        count
176    }
177
178    /// Queues a completed event without blocking the observed call path.
179    pub fn record_event(&self, event: ObservationEvent) {
180        if !self.config.enabled {
181            return;
182        }
183        match self.sender.try_send(event) {
184            Ok(()) => {}
185            Err(mpsc::error::TrySendError::Full(event)) => {
186                if self.config.buffer.drop_on_full {
187                    self.dropped_events.fetch_add(1, Ordering::Relaxed);
188                } else {
189                    self.ingest_event(event);
190                }
191            }
192            Err(mpsc::error::TrySendError::Closed(event)) => {
193                self.ingest_event(event);
194            }
195        }
196    }
197
198    /// Drains pending queued events into aggregation and raw buffers.
199    pub async fn flush(&self) -> Result<()> {
200        self.drain_pending();
201        Ok(())
202    }
203
204    /// Returns configured aggregate metrics after draining pending events.
205    pub fn get_metrics(&self) -> Vec<AggregatedMetrics> {
206        self.drain_pending();
207        self.aggregator.aggregate_configured()
208    }
209
210    /// Returns retained raw events after redaction and queue draining.
211    pub fn raw_events(&self) -> Vec<ObservationEvent> {
212        self.drain_pending();
213        self.raw_events.read().iter().cloned().collect()
214    }
215
216    /// Builds the user-facing report from the current rolling event window.
217    pub fn generate_report(&self) -> ObservabilityReport {
218        self.drain_pending();
219        let events = self.aggregator.events();
220        generate_report(
221            &events,
222            self.aggregator.aggregate_configured(),
223            self.dropped_events(),
224        )
225    }
226
227    /// Writes configured report, aggregate, raw event, and Prometheus files.
228    pub async fn export(&self) -> Result<ExportResult> {
229        export_observability(self).map_err(ObservabilityError::Io)
230    }
231
232    /// Returns the total number of events dropped by bounded buffers.
233    pub fn dropped_events(&self) -> u64 {
234        self.dropped_events.load(Ordering::Relaxed)
235    }
236
237    /// Returns the redactor used by wrappers for safe payload summaries.
238    pub fn redactor(&self) -> &Redactor {
239        &self.redactor
240    }
241
242    /// Converts a completed SpanGuard into an ObservationEvent.
243    pub fn build_event_from_span(
244        &self,
245        context: SpanContext,
246        event_type: EventType,
247        duration: Duration,
248        status: EventStatus,
249        tokens: Option<crate::event::ObservationTokenUsage>,
250        error: Option<crate::event::ObservationError>,
251        tags: HashMap<String, String>,
252        payload: Option<Value>,
253    ) -> ObservationEvent {
254        let dimensions = context_dimension_map(&context);
255        ObservationEvent {
256            trace_id: context.trace_id,
257            span_id: context.span_id,
258            parent_span_id: context.parent_span_id,
259            turn_id: context.turn_id,
260            agent_id: context.agent_id,
261            actor_id: context.actor_id,
262            session_id: context.session_id,
263            event_type,
264            purpose: context.purpose,
265            status,
266            timestamp: Utc::now(),
267            duration_ms: duration.as_millis() as u64,
268            tokens,
269            cost: None::<CostEstimate>,
270            error,
271            dimensions,
272            tags,
273            payload,
274        }
275    }
276
277    /// Drains queued events into the synchronous aggregation path.
278    fn drain_pending(&self) {
279        let mut receiver = self.receiver.lock();
280        loop {
281            match receiver.try_recv() {
282                Ok(event) => self.ingest_event(event),
283                Err(mpsc::error::TryRecvError::Empty)
284                | Err(mpsc::error::TryRecvError::Disconnected) => break,
285            }
286        }
287    }
288
289    /// Enriches, costs, redacts, aggregates, and optionally stores one event.
290    fn ingest_event(&self, mut event: ObservationEvent) {
291        enrich_dimensions(&mut event);
292        event.tokens = event
293            .tokens
294            .take()
295            .map(|tokens| self.apply_token_config(tokens));
296        if event.cost.is_none() {
297            let (provider, model) = match &event.event_type {
298                EventType::LlmCall {
299                    provider, model, ..
300                } => (Some(provider.as_str()), Some(model.as_str())),
301                _ => (None, None),
302            };
303            event.cost = self
304                .cost_estimator
305                .estimate(provider, model, event.tokens.as_ref());
306            if matches!(
307                self.config.cost.unknown_price_policy,
308                UnknownPricePolicy::Error
309            ) && event.tokens.is_some()
310                && event.cost.is_none()
311                && matches!(&event.event_type, EventType::LlmCall { .. })
312            {
313                event
314                    .tags
315                    .insert("cost_error".to_string(), "unknown_price".to_string());
316            }
317        }
318        let event = self.redactor.redact_event(event);
319        self.aggregator.record(event.clone());
320        self.store_raw_event(event);
321    }
322
323    /// Applies token count switches before reports and cost estimates read usage.
324    fn apply_token_config(&self, mut tokens: ObservationTokenUsage) -> ObservationTokenUsage {
325        if !self.config.tokens.count_input {
326            tokens.input_tokens = 0;
327        }
328        if !self.config.tokens.count_output {
329            tokens.output_tokens = 0;
330        }
331        tokens.total_tokens = tokens.input_tokens + tokens.output_tokens;
332        tokens
333    }
334
335    /// Retains a redacted raw event when raw event export is enabled.
336    fn store_raw_event(&self, event: ObservationEvent) {
337        if !self.config.export.write_raw_events {
338            return;
339        }
340        if self.config.buffer.raw_event_limit == 0 {
341            self.dropped_events.fetch_add(1, Ordering::Relaxed);
342            return;
343        }
344        let mut raw_events = self.raw_events.write();
345        if raw_events.len() >= self.config.buffer.raw_event_limit {
346            if self.config.buffer.drop_on_full {
347                self.dropped_events.fetch_add(1, Ordering::Relaxed);
348                return;
349            }
350            raw_events.pop_front();
351        }
352        raw_events.push_back(event);
353    }
354
355    /// Renders current aggregate metrics in Prometheus text exposition format.
356    pub fn render_prometheus(&self) -> String {
357        let report = self.generate_report();
358        let events = self.aggregator.events();
359        let llm_events: Vec<_> = events
360            .iter()
361            .filter(|event| matches!(&event.event_type, EventType::LlmCall { .. }))
362            .cloned()
363            .collect();
364        let tool_events: Vec<_> = events
365            .iter()
366            .filter(|event| matches!(&event.event_type, EventType::ToolCall { .. }))
367            .cloned()
368            .collect();
369        let by_model_purpose = aggregate_events(
370            &llm_events,
371            &[AggregationDimension::Model, AggregationDimension::Purpose],
372        );
373        let by_tool = aggregate_events(&tool_events, &[AggregationDimension::Tool]);
374        let mut output = String::new();
375        output.push_str(
376            "# HELP ai_agents_observation_events_total Total recorded observation events\n",
377        );
378        output.push_str("# TYPE ai_agents_observation_events_total counter\n");
379        output.push_str(&format!(
380            "ai_agents_observation_events_total {}\n",
381            report.summary.total_events
382        ));
383        output.push_str("# HELP ai_agents_observation_errors_total Total observation events with error status\n");
384        output.push_str("# TYPE ai_agents_observation_errors_total counter\n");
385        output.push_str(&format!(
386            "ai_agents_observation_errors_total {}\n",
387            report.summary.total_errors
388        ));
389        output.push_str(
390            "# HELP ai_agents_observation_cost_usd_total Estimated total LLM cost in USD\n",
391        );
392        output.push_str("# TYPE ai_agents_observation_cost_usd_total counter\n");
393        output.push_str(&format!(
394            "ai_agents_observation_cost_usd_total {:.8}\n",
395            report.summary.total_cost_usd
396        ));
397        output.push_str("# HELP ai_agents_observation_tokens_total Total observed LLM tokens\n");
398        output.push_str("# TYPE ai_agents_observation_tokens_total counter\n");
399        output.push_str(&format!(
400            "ai_agents_observation_tokens_total {}\n",
401            report.summary.total_tokens
402        ));
403        output.push_str("# HELP ai_agents_llm_calls_total LLM calls grouped by safe labels\n");
404        output.push_str("# TYPE ai_agents_llm_calls_total counter\n");
405        for metric in by_model_purpose {
406            let model = metric
407                .dimensions
408                .get("model")
409                .map(String::as_str)
410                .unwrap_or("unknown");
411            let purpose = metric
412                .dimensions
413                .get("purpose")
414                .map(String::as_str)
415                .unwrap_or("unknown");
416            output.push_str(&format!(
417                "ai_agents_llm_calls_total{{model=\"{}\",purpose=\"{}\"}} {}\n",
418                prometheus_label(model),
419                prometheus_label(purpose),
420                metric.count
421            ));
422        }
423        output.push_str("# HELP ai_agents_tool_calls_total Tool calls grouped by tool ID\n");
424        output.push_str("# TYPE ai_agents_tool_calls_total counter\n");
425        for metric in by_tool {
426            let tool = metric
427                .dimensions
428                .get("tool")
429                .map(String::as_str)
430                .unwrap_or("unknown");
431            if tool != "unknown" {
432                output.push_str(&format!(
433                    "ai_agents_tool_calls_total{{tool=\"{}\"}} {}\n",
434                    prometheus_label(tool),
435                    metric.count
436                ));
437            }
438        }
439        output
440    }
441
442    /// Returns true when a format is enabled in export.formats.
443    pub fn wants_format(&self, format: ExportFormat) -> bool {
444        self.config.export.formats.contains(&format)
445    }
446}
447
448/// Escapes label values for Prometheus text output.
449fn prometheus_label(value: &str) -> String {
450    value
451        .chars()
452        .flat_map(|ch| match ch {
453            '\\' => "\\\\".chars().collect::<Vec<_>>(),
454            '"' => "\\\"".chars().collect::<Vec<_>>(),
455            '\n' | '\r' | '\t' => "_".chars().collect::<Vec<_>>(),
456            _ => vec![ch],
457        })
458        .collect()
459}
460
461/// Builds the base event dimensions from the current span context.
462fn context_dimension_map(context: &SpanContext) -> HashMap<String, String> {
463    let mut dimensions = HashMap::new();
464    dimensions.insert("agent".to_string(), context.agent_id.clone());
465    dimensions.insert("purpose".to_string(), context.purpose.as_label());
466    if let Some(actor) = &context.actor_id {
467        dimensions.insert("actor".to_string(), actor.clone());
468    }
469    if let Some(state) = &context.state {
470        dimensions.insert("state".to_string(), state.clone());
471    }
472    if let Some(language) = &context.language {
473        dimensions.insert("language".to_string(), language.clone());
474    }
475    dimensions.extend(context.tags.clone());
476    dimensions
477}
478
479/// Resolves the language dimension by checking configured context paths in order.
480pub fn resolve_language_from_context(
481    config: &ObservabilityConfig,
482    context: &HashMap<String, Value>,
483) -> String {
484    for path in &config.language.paths {
485        if let Some(value) = get_dotted(context, path) {
486            if let Some(language) = value.as_str() {
487                if !language.trim().is_empty() {
488                    return language.to_string();
489                }
490            }
491        }
492    }
493    config.language.fallback.clone()
494}
495
496/// Looks up a top-level or dotted path in a JSON context map.
497fn get_dotted<'a>(context: &'a HashMap<String, Value>, path: &str) -> Option<&'a Value> {
498    if let Some(value) = context.get(path) {
499        return Some(value);
500    }
501    let mut parts = path.split('.');
502    let first = parts.next()?;
503    let mut current = context.get(first)?;
504    for part in parts {
505        current = current.get(part)?;
506    }
507    Some(current)
508}
509
510/// Generates a session ID for observed runtime sessions that do not have one yet.
511pub fn new_session_id() -> String {
512    Uuid::new_v4().to_string()
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::event::{ObservationTokenUsage, TokenUsageSource};
519
520    fn test_event() -> ObservationEvent {
521        ObservationEvent {
522            trace_id: "trace".to_string(),
523            span_id: Uuid::new_v4().to_string(),
524            parent_span_id: None,
525            turn_id: "turn".to_string(),
526            agent_id: "agent".to_string(),
527            actor_id: None,
528            session_id: None,
529            event_type: EventType::LlmCall {
530                provider: "openai".to_string(),
531                model: "test".to_string(),
532                alias: Some("default".to_string()),
533                streaming: false,
534            },
535            purpose: ObservationPurpose::MainResponse,
536            status: EventStatus::Success,
537            timestamp: Utc::now(),
538            duration_ms: 10,
539            tokens: Some(ObservationTokenUsage::new(
540                100,
541                25,
542                TokenUsageSource::Provider,
543            )),
544            cost: None,
545            error: None,
546            dimensions: HashMap::new(),
547            tags: HashMap::new(),
548            payload: None,
549        }
550    }
551
552    #[test]
553    fn token_count_flags_are_applied_before_report() {
554        let mut config = ObservabilityConfig::default();
555        config.enabled = true;
556        config.tokens.count_input = false;
557        config.tokens.count_output = true;
558        config.cost.enabled = false;
559        let manager = ObservabilityManager::new(config);
560        manager.record_event(test_event());
561
562        let report = manager.generate_report();
563        assert_eq!(report.token_breakdown.total_input, 0);
564        assert_eq!(report.token_breakdown.total_output, 25);
565        assert_eq!(report.token_breakdown.total_tokens, 25);
566    }
567
568    #[test]
569    fn pending_branch_event_is_hidden_until_finalized() {
570        let mut config = ObservabilityConfig::default();
571        config.enabled = true;
572        config.export.write_raw_events = true;
573        let manager = ObservabilityManager::new(config);
574        manager.record_pending_event("branch", test_event());
575
576        let mut tags = HashMap::new();
577        tags.insert("runtime.speculative".to_string(), "true".to_string());
578        tags.insert("speculative".to_string(), "true".to_string());
579
580        assert_eq!(manager.generate_report().summary.total_events, 0);
581        manager.finalize_pending_branch("branch", "discarded", false, tags);
582        let report = manager.generate_report();
583        assert_eq!(report.summary.total_events, 1);
584        assert_eq!(
585            manager.raw_events()[0].dimensions.get("branch_status"),
586            Some(&"discarded".to_string())
587        );
588        assert_eq!(
589            manager.raw_events()[0].dimensions.get("runtime.winner"),
590            Some(&"false".to_string())
591        );
592        assert_eq!(
593            manager.raw_events()[0].dimensions.get("speculative"),
594            Some(&"true".to_string())
595        );
596        assert_eq!(
597            manager.raw_events()[0]
598                .dimensions
599                .get("runtime.speculative"),
600            Some(&"true".to_string())
601        );
602    }
603
604    #[test]
605    fn pending_branch_events_are_bounded() {
606        let mut config = ObservabilityConfig::default();
607        config.enabled = true;
608        config.buffer.pending_branch_event_limit = 1;
609        let manager = ObservabilityManager::new(config);
610        manager.record_pending_event("branch-a", test_event());
611        manager.record_pending_event("branch-b", test_event());
612
613        manager.finalize_pending_branch("branch-a", "committed", true, HashMap::new());
614        manager.finalize_pending_branch("branch-b", "committed", true, HashMap::new());
615        let report = manager.generate_report();
616        assert_eq!(report.summary.total_events, 1);
617        assert_eq!(report.dropped_events, 1);
618    }
619}