use prometheus::{
Encoder, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, Opts, Registry,
TextEncoder,
};
use std::collections::HashSet;
use std::sync::Mutex;
use thiserror::Error;
const MAX_TOOL_LABEL_CARDINALITY: usize = 1000;
#[derive(Debug, Error)]
pub enum MetricsError {
#[error("prometheus error: {0}")]
Prometheus(#[from] prometheus::Error),
}
pub struct ArbiterMetrics {
registry: Registry,
pub requests_total: IntCounterVec,
pub tool_calls_total: IntCounterVec,
pub anomalies_total: IntCounter,
pub request_duration_seconds: Histogram,
pub upstream_duration_seconds: Histogram,
pub active_sessions: IntGauge,
pub registered_agents: IntGauge,
known_tools: Mutex<HashSet<String>>,
}
impl ArbiterMetrics {
pub fn new() -> Result<Self, MetricsError> {
let registry = Registry::new();
Self::with_registry(registry)
}
pub fn with_registry(registry: Registry) -> Result<Self, MetricsError> {
let requests_total = IntCounterVec::new(
Opts::new("requests_total", "Total requests by authorization decision"),
&["decision"],
)?;
registry.register(Box::new(requests_total.clone()))?;
let tool_calls_total = IntCounterVec::new(
Opts::new("tool_calls_total", "Total tool calls by tool name"),
&["tool"],
)?;
registry.register(Box::new(tool_calls_total.clone()))?;
let anomalies_total =
IntCounter::with_opts(Opts::new("anomalies_total", "Total anomalies detected"))?;
registry.register(Box::new(anomalies_total.clone()))?;
let request_duration_seconds = Histogram::with_opts(
HistogramOpts::new(
"request_duration_seconds",
"End-to-end request duration in seconds",
)
.buckets(vec![
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
]),
)?;
registry.register(Box::new(request_duration_seconds.clone()))?;
let upstream_duration_seconds = Histogram::with_opts(
HistogramOpts::new(
"upstream_duration_seconds",
"Duration of upstream (forwarded) call in seconds",
)
.buckets(vec![
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
]),
)?;
registry.register(Box::new(upstream_duration_seconds.clone()))?;
let active_sessions = IntGauge::with_opts(Opts::new(
"active_sessions",
"Currently active task sessions",
))?;
registry.register(Box::new(active_sessions.clone()))?;
let registered_agents = IntGauge::with_opts(Opts::new(
"registered_agents",
"Currently registered agents",
))?;
registry.register(Box::new(registered_agents.clone()))?;
Ok(Self {
registry,
requests_total,
tool_calls_total,
anomalies_total,
request_duration_seconds,
upstream_duration_seconds,
active_sessions,
registered_agents,
known_tools: Mutex::new(HashSet::new()),
})
}
pub fn record_request(&self, decision: &str) {
let sanitized: String = decision
.chars()
.take(64)
.filter(|c| c.is_ascii_alphanumeric() || *c == '_')
.collect();
self.requests_total.with_label_values(&[&sanitized]).inc();
}
pub fn record_tool_call(&self, tool: &str) {
let sanitized: String = tool
.chars()
.take(128)
.map(|c| {
if c.is_ascii_graphic() || c == ' ' {
c
} else {
'_'
}
})
.collect();
let label = {
let mut known = self.known_tools.lock().unwrap_or_else(|e| e.into_inner());
if known.contains(&sanitized) || known.len() < MAX_TOOL_LABEL_CARDINALITY {
known.insert(sanitized.clone());
sanitized
} else {
"__other__".to_string()
}
};
self.tool_calls_total.with_label_values(&[&label]).inc();
}
pub fn record_anomaly(&self) {
self.anomalies_total.inc();
}
pub fn observe_request_duration(&self, seconds: f64) {
self.request_duration_seconds.observe(seconds);
}
pub fn observe_upstream_duration(&self, seconds: f64) {
self.upstream_duration_seconds.observe(seconds);
}
pub fn render(&self) -> Result<String, MetricsError> {
let encoder = TextEncoder::new();
let metric_families = self.registry.gather();
let mut buffer = Vec::new();
encoder.encode(&metric_families, &mut buffer)?;
Ok(String::from_utf8_lossy(&buffer).into_owned())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn counter_increments() {
let metrics = ArbiterMetrics::new().unwrap();
metrics.record_request("allow");
metrics.record_request("allow");
metrics.record_request("deny");
metrics.record_tool_call("read_file");
metrics.record_tool_call("read_file");
metrics.record_tool_call("write_file");
metrics.record_anomaly();
assert_eq!(
metrics.requests_total.with_label_values(&["allow"]).get(),
2
);
assert_eq!(metrics.requests_total.with_label_values(&["deny"]).get(), 1);
assert_eq!(
metrics
.tool_calls_total
.with_label_values(&["read_file"])
.get(),
2
);
assert_eq!(
metrics
.tool_calls_total
.with_label_values(&["write_file"])
.get(),
1
);
assert_eq!(metrics.anomalies_total.get(), 1);
}
#[test]
fn metrics_endpoint_returns_valid_prometheus_format() {
let metrics = ArbiterMetrics::new().unwrap();
metrics.record_request("allow");
metrics.record_tool_call("list_dir");
metrics.observe_request_duration(0.042);
metrics.observe_upstream_duration(0.035);
metrics.active_sessions.set(3);
metrics.registered_agents.set(5);
let output = metrics.render().unwrap();
assert!(output.contains("requests_total"));
assert!(output.contains("tool_calls_total"));
assert!(output.contains("anomalies_total"));
assert!(output.contains("request_duration_seconds"));
assert!(output.contains("upstream_duration_seconds"));
assert!(output.contains("active_sessions 3"));
assert!(output.contains("registered_agents 5"));
assert!(output.contains("# HELP requests_total"));
assert!(output.contains("# TYPE requests_total counter"));
assert!(output.contains("# HELP request_duration_seconds"));
assert!(output.contains("# TYPE request_duration_seconds histogram"));
}
#[test]
fn histogram_buckets_are_present() {
let metrics = ArbiterMetrics::new().unwrap();
metrics.observe_request_duration(0.05);
let output = metrics.render().unwrap();
assert!(output.contains("request_duration_seconds_bucket"));
assert!(output.contains("request_duration_seconds_sum"));
assert!(output.contains("request_duration_seconds_count"));
}
#[test]
fn gauges_can_increase_and_decrease() {
let metrics = ArbiterMetrics::new().unwrap();
metrics.active_sessions.set(10);
assert_eq!(metrics.active_sessions.get(), 10);
metrics.active_sessions.dec();
assert_eq!(metrics.active_sessions.get(), 9);
metrics.registered_agents.inc();
metrics.registered_agents.inc();
assert_eq!(metrics.registered_agents.get(), 2);
}
#[test]
fn cardinality_limiting_works() {
let metrics = ArbiterMetrics::new().unwrap();
for i in 0..MAX_TOOL_LABEL_CARDINALITY {
metrics.record_tool_call(&format!("tool_{i}"));
}
metrics.record_tool_call("tool_overflow_a");
metrics.record_tool_call("tool_overflow_b");
let other_count = metrics
.tool_calls_total
.with_label_values(&["__other__"])
.get();
assert_eq!(
other_count, 2,
"overflow tool calls should be bucketed under __other__"
);
let first_count = metrics
.tool_calls_total
.with_label_values(&["tool_0"])
.get();
assert_eq!(first_count, 1, "original tool should still have its label");
let known = metrics.known_tools.lock().unwrap();
assert_eq!(
known.len(),
MAX_TOOL_LABEL_CARDINALITY,
"known tools should be capped at MAX_TOOL_LABEL_CARDINALITY"
);
drop(known);
metrics.record_tool_call("tool_0");
let first_count = metrics
.tool_calls_total
.with_label_values(&["tool_0"])
.get();
assert_eq!(
first_count, 2,
"repeated calls to known tools should still use original label"
);
}
}