use crate::middleware::access_log::AvisoRootSpanBuilder;
use crate::middleware::request_id::RequestIdHeader;
use actix_web::{App, HttpResponse, HttpServer, dev::Server, web};
use prometheus::{
Encoder, IntCounterVec, IntGaugeVec, Registry, TextEncoder, opts,
register_int_counter_vec_with_registry, register_int_gauge_vec_with_registry,
};
#[cfg(feature = "ecpds")]
use prometheus::{
IntCounter, IntGauge, register_int_counter_with_registry, register_int_gauge_with_registry,
};
use std::collections::HashMap;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use tracing_actix_web::TracingLogger;
#[cfg(feature = "ecpds")]
#[derive(Clone, Debug)]
pub struct EcpdsMetrics {
pub cache_hits_total: IntCounter,
pub cache_misses_total: IntCounter,
pub cache_size: IntGauge,
pub access_decisions_total: IntCounterVec,
pub fetch_total: IntCounterVec,
}
#[derive(Clone, Debug)]
pub struct AppMetrics {
pub registry: Registry,
pub notifications_total: IntCounterVec,
pub sse_connections_active: IntGaugeVec,
pub sse_connections_total: IntCounterVec,
pub sse_unique_users_active: IntGaugeVec,
pub auth_requests_total: IntCounterVec,
#[cfg(feature = "ecpds")]
pub ecpds: EcpdsMetrics,
unique_users: Arc<Mutex<HashMap<String, HashMap<String, usize>>>>,
}
impl Default for AppMetrics {
fn default() -> Self {
Self::new()
}
}
impl AppMetrics {
pub fn new() -> Self {
let registry = Registry::new();
let notifications_total = register_int_counter_vec_with_registry!(
opts!(
"aviso_notifications_total",
"Total notification requests by event type and outcome"
),
&["event_type", "status"],
registry
)
.expect("metric must register");
let sse_connections_active = register_int_gauge_vec_with_registry!(
opts!(
"aviso_sse_connections_active",
"Currently active SSE connections"
),
&["endpoint", "event_type"],
registry
)
.expect("metric must register");
let sse_connections_total = register_int_counter_vec_with_registry!(
opts!(
"aviso_sse_connections_total",
"Total SSE connections opened"
),
&["endpoint", "event_type"],
registry
)
.expect("metric must register");
let sse_unique_users_active = register_int_gauge_vec_with_registry!(
opts!(
"aviso_sse_unique_users_active",
"Distinct users with active SSE connections"
),
&["endpoint"],
registry
)
.expect("metric must register");
let auth_requests_total = register_int_counter_vec_with_registry!(
opts!(
"aviso_auth_requests_total",
"Authentication attempts by mode and outcome"
),
&["mode", "outcome"],
registry
)
.expect("metric must register");
#[cfg(feature = "ecpds")]
let ecpds = {
let metrics = EcpdsMetrics {
cache_hits_total: register_int_counter_with_registry!(
opts!(
"aviso_ecpds_cache_hits_total",
"ECPDS destination cache hits"
),
registry
)
.expect("metric must register"),
cache_misses_total: register_int_counter_with_registry!(
opts!(
"aviso_ecpds_cache_misses_total",
"ECPDS destination cache misses (request not served from cache; an upstream fetch ran for this caller or a concurrent caller via single-flight)"
),
registry
)
.expect("metric must register"),
cache_size: register_int_gauge_with_registry!(
opts!(
"aviso_ecpds_cache_size",
"Number of usernames held in the ECPDS destination cache (sampled from moka after eviction passes; may include not-yet-pruned expired entries until the next pending-tasks run)"
),
registry
)
.expect("metric must register"),
access_decisions_total: register_int_counter_vec_with_registry!(
opts!(
"aviso_ecpds_access_decisions_total",
"ECPDS access check outcomes"
),
&["outcome"],
registry
)
.expect("metric must register"),
fetch_total: register_int_counter_vec_with_registry!(
opts!(
"aviso_ecpds_fetch_total",
"ECPDS upstream fetch outcomes; incremented exactly once per upstream call (the request whose check actually ran the fetch). Coalesced waiters that joined an in-flight fetch are NOT counted, so this counter measures actual upstream call volume rather than per-request fetch attempts."
),
&["outcome"],
registry
)
.expect("metric must register"),
};
for outcome in [
"allow",
"deny_destination",
"deny_match_key_missing",
"unavailable",
"admin_bypass",
"error",
] {
let _ = metrics.access_decisions_total.with_label_values(&[outcome]);
}
for outcome in [
"success",
"http_401",
"http_403",
"http_4xx",
"http_5xx",
"invalid_response",
"unreachable",
] {
let _ = metrics.fetch_total.with_label_values(&[outcome]);
}
metrics
};
Self {
registry,
notifications_total,
sse_connections_active,
sse_connections_total,
sse_unique_users_active,
auth_requests_total,
#[cfg(feature = "ecpds")]
ecpds,
unique_users: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn track_sse_connection(
&self,
endpoint: &str,
event_type: &str,
username: Option<&str>,
) -> SseConnectionGuard {
self.sse_connections_active
.with_label_values(&[endpoint, event_type])
.inc();
self.sse_connections_total
.with_label_values(&[endpoint, event_type])
.inc();
if let Some(u) = username {
let mut users = self.unique_users.lock().expect("metrics lock poisoned");
let count = users
.entry(endpoint.to_string())
.or_default()
.entry(u.to_string())
.or_insert(0);
*count += 1;
if *count == 1 {
self.sse_unique_users_active
.with_label_values(&[endpoint])
.inc();
}
}
SseConnectionGuard {
metrics: self.clone(),
endpoint: endpoint.to_string(),
event_type: event_type.to_string(),
username: username.map(str::to_string),
}
}
}
pub struct SseConnectionGuard {
metrics: AppMetrics,
endpoint: String,
event_type: String,
username: Option<String>,
}
impl Drop for SseConnectionGuard {
fn drop(&mut self) {
self.metrics
.sse_connections_active
.with_label_values(&[&self.endpoint, &self.event_type])
.dec();
if let Some(username) = &self.username {
let mut users = self
.metrics
.unique_users
.lock()
.expect("metrics lock poisoned");
if let Some(endpoint_users) = users.get_mut(&self.endpoint)
&& let Some(count) = endpoint_users.get_mut(username)
{
*count = count.saturating_sub(1);
if *count == 0 {
endpoint_users.remove(username);
self.metrics
.sse_unique_users_active
.with_label_values(&[&self.endpoint])
.dec();
}
}
}
}
}
pub struct GuardedSseStream<S> {
#[allow(dead_code)]
guard: SseConnectionGuard,
inner: std::pin::Pin<Box<S>>,
}
impl<S> GuardedSseStream<S> {
pub fn new(inner: std::pin::Pin<Box<S>>, guard: SseConnectionGuard) -> Self {
Self { guard, inner }
}
}
impl<S> futures_util::Stream for GuardedSseStream<S>
where
S: futures_util::Stream,
{
type Item = S::Item;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
pub fn run_metrics_server(
listener: TcpListener,
registry: Registry,
) -> Result<Server, std::io::Error> {
let registry = web::Data::new(registry);
let server = HttpServer::new(move || {
App::new()
.wrap(RequestIdHeader)
.wrap(TracingLogger::<AvisoRootSpanBuilder>::new())
.app_data(registry.clone())
.route("/metrics", web::get().to(metrics_handler))
})
.listen(listener)?
.shutdown_timeout(5)
.disable_signals()
.run();
Ok(server)
}
async fn metrics_handler(registry: web::Data<Registry>) -> HttpResponse {
let encoder = TextEncoder::new();
let metric_families = registry.gather();
let mut buffer = Vec::new();
if encoder.encode(&metric_families, &mut buffer).is_err() {
return HttpResponse::InternalServerError().finish();
}
HttpResponse::Ok()
.content_type(encoder.format_type())
.body(buffer)
}
pub fn register_process_metrics(registry: &Registry) {
#[cfg(target_os = "linux")]
{
let pc =
prometheus::process_collector::ProcessCollector::new(std::process::id() as i32, "");
let _ = registry.register(Box::new(pc));
}
#[cfg(not(target_os = "linux"))]
{
let _ = registry;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn gauge_value(metrics: &AppMetrics, name: &str, labels: &[&str]) -> i64 {
match name {
"sse_connections_active" => metrics
.sse_connections_active
.with_label_values(labels)
.get(),
"sse_unique_users_active" => metrics
.sse_unique_users_active
.with_label_values(labels)
.get(),
_ => panic!("unknown gauge: {name}"),
}
}
fn counter_value(metrics: &AppMetrics, name: &str, labels: &[&str]) -> u64 {
match name {
"sse_connections_total" => metrics
.sse_connections_total
.with_label_values(labels)
.get(),
"notifications_total" => metrics.notifications_total.with_label_values(labels).get(),
"auth_requests_total" => metrics.auth_requests_total.with_label_values(labels).get(),
_ => panic!("unknown counter: {name}"),
}
}
#[test]
fn new_metrics_start_at_zero() {
let m = AppMetrics::new();
assert_eq!(
counter_value(&m, "sse_connections_total", &["watch", "mars"]),
0
);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
0
);
}
#[test]
fn track_sse_connection_increments_and_guard_drop_decrements() {
let m = AppMetrics::new();
let guard = m.track_sse_connection("watch", "mars", None);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
1
);
assert_eq!(
counter_value(&m, "sse_connections_total", &["watch", "mars"]),
1
);
drop(guard);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
0
);
assert_eq!(
counter_value(&m, "sse_connections_total", &["watch", "mars"]),
1
);
}
#[test]
fn multiple_connections_stack_on_active_gauge() {
let m = AppMetrics::new();
let g1 = m.track_sse_connection("watch", "mars", None);
let g2 = m.track_sse_connection("watch", "mars", None);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
2
);
drop(g1);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
1
);
drop(g2);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
0
);
}
#[test]
fn unique_users_gauge_tracks_distinct_users() {
let m = AppMetrics::new();
let g1 = m.track_sse_connection("watch", "mars", Some("alice"));
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 1);
let g2 = m.track_sse_connection("watch", "mars", Some("alice"));
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 1);
let g3 = m.track_sse_connection("watch", "mars", Some("bob"));
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 2);
drop(g1);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 2);
drop(g2);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 1);
drop(g3);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 0);
}
#[test]
fn anonymous_connections_do_not_affect_unique_users_gauge() {
let m = AppMetrics::new();
let guard = m.track_sse_connection("watch", "mars", None);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 0);
drop(guard);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 0);
}
#[test]
fn separate_endpoints_track_independently() {
let m = AppMetrics::new();
let g1 = m.track_sse_connection("watch", "mars", Some("alice"));
let g2 = m.track_sse_connection("replay", "mars", Some("alice"));
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 1);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["replay"]), 1);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["watch", "mars"]),
1
);
assert_eq!(
gauge_value(&m, "sse_connections_active", &["replay", "mars"]),
1
);
drop(g1);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["watch"]), 0);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["replay"]), 1);
drop(g2);
assert_eq!(gauge_value(&m, "sse_unique_users_active", &["replay"]), 0);
}
#[test]
fn metrics_handler_returns_prometheus_text() {
let m = AppMetrics::new();
m.notifications_total
.with_label_values(&["mars", "success"])
.inc();
let encoder = TextEncoder::new();
let families = m.registry.gather();
let mut buf = Vec::new();
encoder.encode(&families, &mut buf).expect("encode ok");
let output = String::from_utf8(buf).expect("valid utf8");
assert!(
output.contains("aviso_notifications_total"),
"output should contain metric name"
);
assert!(
output.contains(r#"event_type="mars""#),
"output should contain label"
);
}
#[test]
fn register_process_metrics_does_not_panic() {
let registry = Registry::new();
register_process_metrics(®istry);
#[cfg(target_os = "linux")]
{
let families = registry.gather();
assert!(
!families.is_empty(),
"process metrics should register at least one family"
);
}
}
#[cfg(feature = "ecpds")]
#[test]
fn ecpds_metrics_register_and_publish() {
let m = AppMetrics::new();
m.ecpds.cache_hits_total.inc();
m.ecpds.cache_misses_total.inc();
m.ecpds.cache_size.set(7);
m.ecpds
.access_decisions_total
.with_label_values(&["allow"])
.inc();
m.ecpds
.access_decisions_total
.with_label_values(&["deny_destination"])
.inc();
let encoder = TextEncoder::new();
let mut buf = Vec::new();
encoder
.encode(&m.registry.gather(), &mut buf)
.expect("encode ok");
let output = String::from_utf8(buf).expect("valid utf8");
assert!(output.contains("aviso_ecpds_cache_hits_total"));
assert!(output.contains("aviso_ecpds_cache_misses_total"));
assert!(output.contains("aviso_ecpds_cache_size"));
assert!(output.contains("aviso_ecpds_access_decisions_total"));
assert!(output.contains(r#"outcome="allow""#));
assert!(output.contains(r#"outcome="deny_destination""#));
}
#[actix_web::test]
async fn metrics_response_carries_x_request_id_header() {
use actix_web::test::{TestRequest, call_service, init_service};
let registry = Registry::new();
let registry_data = web::Data::new(registry);
let app = init_service(
App::new()
.wrap(RequestIdHeader)
.wrap(TracingLogger::<AvisoRootSpanBuilder>::new())
.app_data(registry_data)
.route("/metrics", web::get().to(metrics_handler)),
)
.await;
let res = call_service(&app, TestRequest::get().uri("/metrics").to_request()).await;
assert_eq!(res.status(), actix_web::http::StatusCode::OK);
let value = res
.headers()
.get("x-request-id")
.expect("metrics response should carry X-Request-ID")
.to_str()
.expect("header should be ascii");
let uuid_re =
regex::Regex::new(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$")
.expect("valid uuid regex");
assert!(
uuid_re.is_match(value),
"metrics X-Request-ID should be a canonical UUID, got: {value}"
);
}
}