use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::provider::types::Usage;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "event_type", content = "payload")]
pub enum ObserverEvent {
AgentStart {
agent_name: String,
provider: String,
model: String,
input_preview: String,
},
AgentEnd {
agent_name: String,
steps: usize,
usage: Option<Usage>,
cost_usd: Option<f64>,
},
LlmRequestStart {
provider: String,
model: String,
messages_count: usize,
tools_count: usize,
},
LlmResponse {
provider: String,
model: String,
duration: Duration,
success: bool,
usage: Option<Usage>,
},
LlmStreamChunk {
provider: String,
model: String,
chunk_index: usize,
delta_len: usize,
},
ToolCallStart {
tool: String,
tool_call_id: String,
arguments_preview: String,
},
ToolCallEnd {
tool: String,
tool_call_id: String,
duration: Duration,
success: bool,
output_preview: String,
},
ToolCacheHit {
tool: String,
tokens_saved: u64,
},
ToolCacheMiss {
tool: String,
},
CircuitBreakerOpen {
tool: String,
consecutive_failures: u32,
},
CircuitBreakerClose {
tool: String,
},
ToolRetry {
tool: String,
attempt: u32,
max_retries: u32,
},
ToolTimeout {
tool: String,
timeout: Duration,
},
PolicyDenied {
tool: String,
rule_id: String,
reason: String,
},
StepComplete {
step: usize,
max_steps: usize,
decision_type: String,
},
LoopDetected {
tool: String,
repeat_count: usize,
level: String,
},
ContextOverflowRecovered {
strategy: String,
messages_affected: usize,
},
SessionStart {
session_id: String,
channel: String,
},
SessionEnd {
session_id: String,
duration: Duration,
},
Error {
component: String,
error_type: String,
message: String,
},
Heartbeat {
active_sessions: u64,
queue_depth: u64,
},
Raw {
name: String,
data: serde_json::Value,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "metric_type", content = "value")]
pub enum ObserverMetric {
RequestLatencyMs(f64),
TokensUsed {
prompt: u64,
completion: u64,
total: u64,
},
ActiveSessions(u64),
QueueDepth(u64),
ToolCallCount {
tool: String,
count: u64,
},
ToolLatencyMs {
tool: String,
avg_latency: f64,
},
ErrorRate {
error_type: String,
rate: f64,
},
CacheHitRate(f64),
CircuitBreakerState {
tool: String,
state: u8,
},
MemoryUsageBytes(u64),
CpuUsagePercent(f64),
Counter {
name: String,
value: u64,
labels: Option<HashMap<String, String>>,
},
Gauge {
name: String,
value: f64,
labels: Option<HashMap<String, String>>,
},
Histogram {
name: String,
value: f64,
buckets: Vec<f64>,
},
}
pub type MetricLabels = HashMap<String, String>;
#[async_trait::async_trait]
pub trait DualTrackObserver: Send + Sync {
fn observe_event(&self, event: ObserverEvent);
fn record_metric(&self, metric: ObserverMetric);
fn record_metric_with_labels(&self, metric: ObserverMetric, labels: MetricLabels) {
let _ = labels;
self.record_metric(metric);
}
fn flush(&self) {}
}
pub struct MultiObserver {
event_observers: Vec<Box<dyn DualTrackObserver>>,
}
impl MultiObserver {
pub fn new() -> Self {
Self {
event_observers: Vec::new(),
}
}
pub fn add_observer(mut self, observer: Box<dyn DualTrackObserver>) -> Self {
self.event_observers.push(observer);
self
}
}
impl Default for MultiObserver {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl DualTrackObserver for MultiObserver {
fn observe_event(&self, event: ObserverEvent) {
for observer in &self.event_observers {
observer.observe_event(event.clone());
}
}
fn record_metric(&self, metric: ObserverMetric) {
for observer in &self.event_observers {
observer.record_metric(metric.clone());
}
}
fn flush(&self) {
for observer in &self.event_observers {
observer.flush();
}
}
}
pub struct LoggingObserver {
level: LogLevel,
}
impl LoggingObserver {
pub fn new() -> Self {
Self {
level: LogLevel::Info,
}
}
pub fn with_level(mut self, level: LogLevel) -> Self {
self.level = level;
self
}
}
impl Default for LoggingObserver {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl DualTrackObserver for LoggingObserver {
#[allow(clippy::cognitive_complexity)]
fn observe_event(&self, event: ObserverEvent) {
match event {
ObserverEvent::AgentStart { ref agent_name, .. } => {
tracing::info!(agent_name, "agent.start");
}
ObserverEvent::AgentEnd {
ref agent_name,
steps,
..
} => {
tracing::info!(agent_name, steps, "agent.end");
}
ObserverEvent::LlmRequestStart {
ref provider,
ref model,
..
} => {
tracing::debug!(provider, model, "llm.request.start");
}
ObserverEvent::LlmResponse {
ref provider,
success,
..
} => {
if success {
tracing::debug!(provider, "llm.response.success");
} else {
tracing::warn!(provider, "llm.response.failure");
}
}
ObserverEvent::ToolCallStart { ref tool, .. } => {
tracing::debug!(tool, "tool.call.start");
}
ObserverEvent::ToolCallEnd {
ref tool, success, ..
} => {
if success {
tracing::debug!(tool, "tool.call.end");
} else {
tracing::warn!(tool, "tool.call.failure");
}
}
ObserverEvent::ToolCacheHit { ref tool, .. } => {
tracing::debug!(tool, "tool.cache.hit");
}
ObserverEvent::ToolCacheMiss { ref tool } => {
tracing::debug!(tool, "tool.cache.miss");
}
ObserverEvent::Error {
ref component,
ref message,
..
} => {
tracing::error!(component, message, "observer.error");
}
ObserverEvent::LoopDetected {
ref tool,
repeat_count,
..
} => {
tracing::warn!(tool, repeat_count, "loop.detected");
}
_ => {
if self.level <= LogLevel::Debug {
tracing::debug!(event = ?event, "observer.event");
}
}
}
}
#[allow(clippy::cognitive_complexity)]
fn record_metric(&self, metric: ObserverMetric) {
match metric {
ObserverMetric::RequestLatencyMs(latency) => {
tracing::debug!(latency_ms = latency, "metric.request_latency");
}
ObserverMetric::TokensUsed {
prompt,
completion,
total,
} => {
tracing::debug!(prompt, completion, total, "metric.tokens_used");
}
ObserverMetric::ActiveSessions(count) => {
tracing::debug!(count, "metric.active_sessions");
}
ObserverMetric::ToolCallCount { ref tool, count } => {
tracing::debug!(tool, count, "metric.tool_call_count");
}
ObserverMetric::CacheHitRate(rate) => {
tracing::debug!(rate, "metric.cache_hit_rate");
}
_ => {
tracing::debug!(metric = ?metric, "metric.recorded");
}
}
}
}
pub struct VerboseObserver;
impl VerboseObserver {
pub fn new() -> Self {
Self
}
}
impl Default for VerboseObserver {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl DualTrackObserver for VerboseObserver {
fn observe_event(&self, event: ObserverEvent) {
tracing::trace!(event = ?event, "verbose.event");
}
fn record_metric(&self, metric: ObserverMetric) {
tracing::trace!(metric = ?metric, "verbose.metric");
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
pub struct MetricAggregator {
values: Vec<f64>,
}
impl MetricAggregator {
pub fn new() -> Self {
Self { values: Vec::new() }
}
pub fn add(&mut self, value: f64) {
self.values.push(value);
}
pub fn average(&self) -> f64 {
if self.values.is_empty() {
0.0
} else {
self.values.iter().sum::<f64>() / self.values.len() as f64
}
}
pub fn min(&self) -> f64 {
self.values.iter().cloned().fold(f64::INFINITY, f64::min)
}
pub fn max(&self) -> f64 {
self.values
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max)
}
pub fn percentile(&self, p: f64) -> f64 {
assert!(
(0.0..=100.0).contains(&p),
"percentile p must be between 0.0 and 100.0, got {}",
p
);
if self.values.is_empty() {
return 0.0;
}
let mut sorted = self.values.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let index = ((p / 100.0) * (sorted.len() - 1) as f64).round() as usize;
sorted[index.min(sorted.len() - 1)]
}
pub fn count(&self) -> usize {
self.values.len()
}
pub fn clear(&mut self) {
self.values.clear();
}
}
impl Default for MetricAggregator {
fn default() -> Self {
Self::new()
}
}
pub struct EventBuilder {
event: Option<ObserverEvent>,
}
impl EventBuilder {
pub fn new() -> Self {
Self { event: None }
}
pub fn agent_start(
mut self,
agent_name: impl Into<String>,
provider: impl Into<String>,
model: impl Into<String>,
) -> Self {
self.event = Some(ObserverEvent::AgentStart {
agent_name: agent_name.into(),
provider: provider.into(),
model: model.into(),
input_preview: String::new(),
});
self
}
pub fn tool_call_start(
mut self,
tool: impl Into<String>,
tool_call_id: impl Into<String>,
) -> Self {
self.event = Some(ObserverEvent::ToolCallStart {
tool: tool.into(),
tool_call_id: tool_call_id.into(),
arguments_preview: String::new(),
});
self
}
pub fn build(self) -> Option<ObserverEvent> {
self.event
}
}
impl Default for EventBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_observer_event_serialization() {
let event = ObserverEvent::AgentStart {
agent_name: "test".to_string(),
provider: "openai".to_string(),
model: "gpt-4".to_string(),
input_preview: "hello".to_string(),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("AgentStart"));
assert!(json.contains("test"));
}
#[test]
fn test_observer_metric_serialization() {
let metric = ObserverMetric::TokensUsed {
prompt: 100,
completion: 50,
total: 150,
};
let json = serde_json::to_string(&metric).unwrap();
assert!(json.contains("TokensUsed"));
assert!(json.contains("100"));
}
#[test]
fn test_metric_aggregator() {
let mut agg = MetricAggregator::new();
agg.add(1.0);
agg.add(2.0);
agg.add(3.0);
agg.add(4.0);
agg.add(5.0);
assert_eq!(agg.average(), 3.0);
assert_eq!(agg.min(), 1.0);
assert_eq!(agg.max(), 5.0);
assert_eq!(agg.count(), 5);
assert_eq!(agg.percentile(50.0), 3.0);
}
#[test]
fn test_event_builder() {
let event = EventBuilder::new()
.agent_start("my_agent", "openai", "gpt-4")
.build();
assert!(matches!(event, Some(ObserverEvent::AgentStart { .. })));
}
}