use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
pub const DEFAULT_BUCKETS_S: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
];
#[derive(Clone, Default)]
pub struct MetricsRegistry {
inner: Arc<RwLock<Inner>>,
}
#[derive(Default)]
struct Inner {
counters: BTreeMap<MetricKey, Arc<CounterInner>>,
histograms: BTreeMap<MetricKey, Arc<HistogramInner>>,
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
struct MetricKey {
name: String,
labels: Vec<(String, String)>,
}
impl MetricKey {
fn new(name: &str, labels: &[(&str, &str)]) -> Self {
let mut sorted: Vec<(String, String)> = labels
.iter()
.map(|(k, v)| ((*k).to_owned(), (*v).to_owned()))
.collect();
sorted.sort();
Self {
name: name.to_owned(),
labels: sorted,
}
}
fn render(&self) -> String {
if self.labels.is_empty() {
self.name.clone()
} else {
let mut s = String::with_capacity(self.name.len() + self.labels.len() * 16);
s.push_str(&self.name);
s.push('{');
let mut first = true;
for (k, v) in &self.labels {
if !first {
s.push(',');
}
first = false;
s.push_str(k);
s.push('=');
s.push('"');
s.push_str(&escape_label_value(v));
s.push('"');
}
s.push('}');
s
}
}
}
fn escape_label_value(v: &str) -> String {
let mut out = String::with_capacity(v.len());
for c in v.chars() {
match c {
'\\' => out.push_str(r"\\"),
'"' => out.push_str("\\\""),
'\n' => out.push_str("\\n"),
other => out.push(other),
}
}
out
}
impl MetricsRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn counter(&self, name: &str, labels: &[(&str, &str)]) -> Counter {
let key = MetricKey::new(name, labels);
let inner = {
let read = self.inner.read().expect("metrics poisoned");
read.counters.get(&key).cloned()
};
if let Some(i) = inner {
return Counter { inner: i };
}
let mut write = self.inner.write().expect("metrics poisoned");
let i = write
.counters
.entry(key)
.or_insert_with(|| Arc::new(CounterInner::default()))
.clone();
Counter { inner: i }
}
pub fn histogram(&self, name: &str, labels: &[(&str, &str)]) -> Histogram {
self.histogram_with_buckets(name, labels, DEFAULT_BUCKETS_S)
}
pub fn histogram_with_buckets(
&self,
name: &str,
labels: &[(&str, &str)],
buckets: &[f64],
) -> Histogram {
let key = MetricKey::new(name, labels);
let inner = {
let read = self.inner.read().expect("metrics poisoned");
read.histograms.get(&key).cloned()
};
if let Some(i) = inner {
return Histogram { inner: i };
}
let mut write = self.inner.write().expect("metrics poisoned");
let i = write
.histograms
.entry(key)
.or_insert_with(|| Arc::new(HistogramInner::new(buckets)))
.clone();
Histogram { inner: i }
}
#[must_use]
pub fn render(&self) -> String {
let r = self.inner.read().expect("metrics poisoned");
let mut out = String::new();
let mut emitted_types = std::collections::HashSet::new();
for (key, c) in &r.counters {
if emitted_types.insert(("counter", key.name.clone())) {
out.push_str(&format!("# TYPE {} counter\n", key.name));
}
out.push_str(&format!("{} {}\n", key.render(), c.value()));
}
for (key, h) in &r.histograms {
if emitted_types.insert(("histogram", key.name.clone())) {
out.push_str(&format!("# TYPE {} histogram\n", key.name));
}
let buckets = &h.buckets;
let counts = h.bucket_counts();
for (i, le) in buckets.iter().enumerate() {
let mut bucket_labels = key.labels.clone();
bucket_labels.push(("le".into(), format!("{le}")));
bucket_labels.sort();
let bucket_key = MetricKey {
name: format!("{}_bucket", key.name),
labels: bucket_labels,
};
out.push_str(&format!("{} {}\n", bucket_key.render(), counts[i]));
}
let mut inf_labels = key.labels.clone();
inf_labels.push(("le".into(), "+Inf".into()));
inf_labels.sort();
let inf_key = MetricKey {
name: format!("{}_bucket", key.name),
labels: inf_labels,
};
out.push_str(&format!("{} {}\n", inf_key.render(), h.total_count()));
let sum_key = MetricKey {
name: format!("{}_sum", key.name),
labels: key.labels.clone(),
};
let count_key = MetricKey {
name: format!("{}_count", key.name),
labels: key.labels.clone(),
};
out.push_str(&format!("{} {}\n", sum_key.render(), h.sum()));
out.push_str(&format!("{} {}\n", count_key.render(), h.total_count()));
}
out
}
}
#[derive(Default)]
struct CounterInner {
value: AtomicU64,
}
impl CounterInner {
fn value(&self) -> u64 {
self.value.load(Ordering::Relaxed)
}
}
#[derive(Clone)]
pub struct Counter {
inner: Arc<CounterInner>,
}
impl Counter {
pub fn inc(&self) {
self.inner.value.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_by(&self, n: u64) {
self.inner.value.fetch_add(n, Ordering::Relaxed);
}
#[must_use]
pub fn value(&self) -> u64 {
self.inner.value()
}
}
struct HistogramInner {
buckets: Vec<f64>,
counts: Vec<AtomicU64>,
total: AtomicU64,
sum_micro: AtomicU64,
}
impl HistogramInner {
fn new(buckets: &[f64]) -> Self {
let mut sorted: Vec<f64> = buckets.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let counts = (0..sorted.len()).map(|_| AtomicU64::new(0)).collect();
Self {
buckets: sorted,
counts,
total: AtomicU64::new(0),
sum_micro: AtomicU64::new(0),
}
}
fn observe(&self, v: f64) {
for (i, &le) in self.buckets.iter().enumerate() {
if v <= le {
self.counts[i].fetch_add(1, Ordering::Relaxed);
}
}
self.total.fetch_add(1, Ordering::Relaxed);
let micro = (v * 1_000_000.0).max(0.0) as u64;
self.sum_micro.fetch_add(micro, Ordering::Relaxed);
}
fn bucket_counts(&self) -> Vec<u64> {
self.counts
.iter()
.map(|a| a.load(Ordering::Relaxed))
.collect()
}
fn total_count(&self) -> u64 {
self.total.load(Ordering::Relaxed)
}
fn sum(&self) -> f64 {
self.sum_micro.load(Ordering::Relaxed) as f64 / 1_000_000.0
}
}
#[derive(Clone)]
pub struct Histogram {
inner: Arc<HistogramInner>,
}
impl Histogram {
pub fn observe(&self, v: f64) {
self.inner.observe(v);
}
pub fn time<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
let start = std::time::Instant::now();
let r = f();
self.observe(start.elapsed().as_secs_f64());
r
}
}
#[cfg(feature = "admin")]
pub fn metrics_router(reg: MetricsRegistry) -> axum::Router {
use axum::extract::State;
use axum::http::header;
use axum::response::Response;
use axum::routing::get;
use std::sync::Arc as StdArc;
async fn handler(State(reg): State<StdArc<MetricsRegistry>>) -> Response {
let body = reg.render();
Response::builder()
.status(200)
.header(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)
.body(axum::body::Body::from(body))
.unwrap_or_else(|_| Response::new(axum::body::Body::empty()))
}
axum::Router::new()
.route("/metrics", get(handler))
.with_state(StdArc::new(reg))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn counter_increments() {
let r = MetricsRegistry::new();
let c = r.counter("hits", &[]);
c.inc();
c.inc();
c.inc_by(5);
assert_eq!(c.value(), 7);
}
#[test]
fn same_label_set_returns_same_counter_instance() {
let r = MetricsRegistry::new();
r.counter("hits", &[("path", "/x")]).inc_by(3);
r.counter("hits", &[("path", "/x")]).inc_by(4);
assert_eq!(r.counter("hits", &[("path", "/x")]).value(), 7);
}
#[test]
fn different_label_sets_are_independent() {
let r = MetricsRegistry::new();
r.counter("hits", &[("path", "/a")]).inc();
r.counter("hits", &[("path", "/b")]).inc_by(5);
assert_eq!(r.counter("hits", &[("path", "/a")]).value(), 1);
assert_eq!(r.counter("hits", &[("path", "/b")]).value(), 5);
}
#[test]
fn label_order_does_not_create_separate_metrics() {
let r = MetricsRegistry::new();
r.counter("hits", &[("a", "1"), ("b", "2")]).inc();
r.counter("hits", &[("b", "2"), ("a", "1")]).inc();
assert_eq!(r.counter("hits", &[("a", "1"), ("b", "2")]).value(), 2);
}
#[test]
fn histogram_observe_increments_buckets_and_count() {
let r = MetricsRegistry::new();
let h = r.histogram_with_buckets("dur", &[], &[0.1, 1.0, 10.0]);
h.observe(0.05); h.observe(0.5); h.observe(5.0); h.observe(20.0); assert_eq!(h.inner.total_count(), 4);
let counts = h.inner.bucket_counts();
assert_eq!(counts, vec![1, 2, 3]);
}
#[test]
fn histogram_sum_reflects_observed_values() {
let r = MetricsRegistry::new();
let h = r.histogram_with_buckets("dur", &[], &[0.1, 1.0, 10.0]);
h.observe(0.5);
h.observe(0.75);
assert!((h.inner.sum() - 1.25).abs() < 0.001);
}
#[test]
fn histogram_unsorted_buckets_are_sorted() {
let r = MetricsRegistry::new();
let h = r.histogram_with_buckets("dur", &[], &[10.0, 0.1, 1.0]);
assert_eq!(h.inner.buckets, vec![0.1, 1.0, 10.0]);
}
#[test]
fn histogram_time_observes_elapsed() {
let r = MetricsRegistry::new();
let h = r.histogram("op_duration_seconds", &[]);
h.time(|| {
std::thread::sleep(std::time::Duration::from_millis(5));
});
assert_eq!(h.inner.total_count(), 1);
assert!(h.inner.sum() >= 0.004);
}
#[test]
fn render_emits_counter_lines() {
let r = MetricsRegistry::new();
r.counter("requests_total", &[("status", "200")]).inc_by(3);
r.counter("requests_total", &[("status", "500")]).inc();
let s = r.render();
assert!(s.contains("# TYPE requests_total counter"));
assert!(s.contains(r#"requests_total{status="200"} 3"#));
assert!(s.contains(r#"requests_total{status="500"} 1"#));
assert_eq!(s.matches("# TYPE requests_total").count(), 1);
}
#[test]
fn render_emits_bare_counter_without_braces() {
let r = MetricsRegistry::new();
r.counter("uptime_seconds", &[]).inc_by(42);
let s = r.render();
assert!(s.contains("uptime_seconds 42"));
}
#[test]
fn render_emits_histogram_buckets_sum_count_and_inf() {
let r = MetricsRegistry::new();
let h = r.histogram_with_buckets("dur", &[("op", "ping")], &[0.1, 1.0]);
h.observe(0.05);
h.observe(0.5);
let s = r.render();
assert!(s.contains("# TYPE dur histogram"));
assert!(
s.contains(r#"dur_bucket{le="0.1",op="ping"} 1"#),
"got: {s}"
);
assert!(s.contains(r#"dur_bucket{le="1",op="ping"} 2"#));
assert!(s.contains(r#"dur_bucket{le="+Inf",op="ping"} 2"#));
assert!(s.contains(r#"dur_count{op="ping"} 2"#));
assert!(s.contains("dur_sum"), "got: {s}");
}
#[test]
fn render_escapes_label_values() {
let r = MetricsRegistry::new();
r.counter("custom", &[("path", r#"/a"b\c"#)]).inc();
let s = r.render();
assert!(s.contains(r#"custom{path="/a\"b\\c"} 1"#), "got: {s}");
}
#[cfg(feature = "admin")]
#[tokio::test]
async fn metrics_endpoint_returns_text_format() {
use axum::body::{to_bytes, Body};
use axum::http::Request;
use tower::ServiceExt;
let r = MetricsRegistry::new();
r.counter("hits", &[]).inc_by(7);
let app = metrics_router(r);
let resp = app
.oneshot(
Request::builder()
.uri("/metrics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(
resp.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap(),
"text/plain; version=0.0.4; charset=utf-8"
);
let bytes = to_bytes(resp.into_body(), 1 << 16).await.unwrap();
let body = std::str::from_utf8(&bytes).unwrap();
assert!(body.contains("hits 7"));
}
}