use std::sync::Arc;
use std::time::Instant;
use axum::extract::{MatchedPath, Request, State};
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use prometheus_client::encoding::{EncodeLabelSet, text::encode};
use prometheus_client::metrics::counter::Counter;
use prometheus_client::metrics::family::Family;
use prometheus_client::metrics::gauge::Gauge;
use prometheus_client::metrics::histogram::{Histogram, exponential_buckets};
use prometheus_client::registry::Registry;
use crate::state::AppState;
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct HttpRequestLabels {
pub method: String,
pub route: String,
pub status: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct AdvanceHeadLabels {
pub result: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct LeidenModeLabels {
pub mode: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct PprSizeGateLabels {
pub reason: String,
}
#[derive(Clone)]
pub struct Metrics {
registry: Arc<Registry>,
pub http_requests: Family<HttpRequestLabels, Counter>,
pub http_duration: Histogram,
pub retrieve_latency: Histogram,
pub commit_duration: Histogram,
pub ingest_duration: Histogram,
pub ingest_chunks: Counter,
pub remote_fetch_blocks: Counter,
pub remote_push_blocks: Counter,
pub remote_advance_head: Family<AdvanceHeadLabels, Counter>,
pub leiden_mode: Family<LeidenModeLabels, Counter>,
pub leiden_debounce_effective: Gauge,
pub leiden_storm_cap_effective: Gauge,
pub leiden_delta_ratio_effective: Gauge,
pub leiden_mode_current: Gauge,
pub traverse_answer_hard_wall_ms_effective: Gauge,
pub traverse_answer_max_hops_effective: Gauge,
pub traverse_answer_hard_wall_exceeded: Counter,
pub ppr_size_gate_skipped: Family<PprSizeGateLabels, Counter>,
pub ppr_size_gate_threshold: Gauge,
}
impl Metrics {
#[must_use]
pub fn new() -> Self {
let mut registry = Registry::default();
let http_requests = Family::<HttpRequestLabels, Counter>::default();
registry.register(
"mnem_http_requests_total",
"Total HTTP requests handled by mnem-http, bucketed by method, route, and status.",
http_requests.clone(),
);
let http_duration = Histogram::new(exponential_buckets(0.001, 2.0, 14));
registry.register(
"mnem_http_request_duration_seconds",
"HTTP request duration in seconds, from axum route match to response body sent.",
http_duration.clone(),
);
let retrieve_latency = Histogram::new(exponential_buckets(0.0001, 2.0, 17));
registry.register(
"mnem_retrieve_latency_seconds",
"Retrieval pipeline latency in seconds, measured around the `Retriever::execute` call.",
retrieve_latency.clone(),
);
let commit_duration = Histogram::new(exponential_buckets(0.0001, 2.0, 17));
registry.register(
"mnem_commit_duration_seconds",
"Transaction commit duration in seconds, measured around Transaction::commit_opts.",
commit_duration.clone(),
);
let ingest_duration = Histogram::new(exponential_buckets(0.001, 2.0, 14));
registry.register(
"mnem_ingest_duration_seconds",
"End-to-end ingest duration in seconds, measured around the full POST /v1/ingest run.",
ingest_duration.clone(),
);
let ingest_chunks = Counter::default();
registry.register(
"mnem_ingest_chunks_total",
"Total chunks produced across every successful POST /v1/ingest call.",
ingest_chunks.clone(),
);
let remote_fetch_blocks = Counter::default();
registry.register(
"mnem_remote_fetch_blocks_total",
"Total `/remote/v1/fetch-blocks` invocations that produced a CAR response.",
remote_fetch_blocks.clone(),
);
let remote_push_blocks = Counter::default();
registry.register(
"mnem_remote_push_blocks_total",
"Total `/remote/v1/push-blocks` invocations that completed an import.",
remote_push_blocks.clone(),
);
let remote_advance_head = Family::<AdvanceHeadLabels, Counter>::default();
registry.register(
"mnem_remote_advance_head_total",
"Total `/remote/v1/advance-head` invocations bucketed by result (success, cas_mismatch, auth_fail).",
remote_advance_head.clone(),
);
let leiden_mode = Family::<LeidenModeLabels, Counter>::default();
registry.register(
"mnem_leiden_mode_total",
"Total Leiden community-cache serves bucketed by mode (full, full_debounced, fallback_stale).",
leiden_mode.clone(),
);
let leiden_debounce_effective = Gauge::default();
registry.register(
"mnem_leiden_debounce_effective",
"Effective Leiden debounce window in ms (max(1000, rolling p75 commit latency)).",
leiden_debounce_effective.clone(),
);
let leiden_storm_cap_effective = Gauge::default();
registry.register(
"mnem_leiden_storm_cap_effective",
"Effective commit-storm cap per minute (floor-c tunable; default 60).",
leiden_storm_cap_effective.clone(),
);
let leiden_delta_ratio_effective = Gauge::default();
registry.register(
"mnem_leiden_delta_ratio_effective",
"Effective delta_ratio_force_full rendered as parts-per-ten-thousand.",
leiden_delta_ratio_effective.clone(),
);
let leiden_mode_current = Gauge::default();
registry.register(
"mnem_leiden_mode_current",
"Current Leiden mode: 0=full, 1=full_debounced, 2=fallback_stale.",
leiden_mode_current.clone(),
);
let traverse_answer_hard_wall_ms_effective = Gauge::default();
registry.register(
"mnem_traverse_answer_hard_wall_ms_effective",
"Effective hard-wall latency budget for /v1/traverse_answer in ms.",
traverse_answer_hard_wall_ms_effective.clone(),
);
let traverse_answer_max_hops_effective = Gauge::default();
registry.register(
"mnem_traverse_answer_max_hops_effective",
"Effective max-hops for /v1/traverse_answer.",
traverse_answer_max_hops_effective.clone(),
);
let traverse_answer_hard_wall_exceeded = Counter::default();
registry.register(
"mnem_traverse_answer_hard_wall_exceeded_total",
"Total /v1/traverse_answer requests that breached the hard-wall budget.",
traverse_answer_hard_wall_exceeded.clone(),
);
let ppr_size_gate_skipped = Family::<PprSizeGateLabels, Counter>::default();
registry.register(
"mnem_ppr_size_gate_skipped_total",
"Total PPR requests skipped by the default-on size gate, bucketed by reason (above_threshold, opted_out).",
ppr_size_gate_skipped.clone(),
);
let ppr_size_gate_threshold = Gauge::default();
registry.register(
"mnem_ppr_size_gate_threshold",
"Effective PPR size-gate node threshold (mirrors PPR_DEFAULT_MAX_NODES).",
ppr_size_gate_threshold.clone(),
);
#[allow(clippy::cast_possible_wrap)]
ppr_size_gate_threshold.set(mnem_core::ppr::PPR_DEFAULT_MAX_NODES as i64);
Self {
registry: Arc::new(registry),
http_requests,
http_duration,
retrieve_latency,
commit_duration,
ingest_duration,
ingest_chunks,
remote_fetch_blocks,
remote_push_blocks,
remote_advance_head,
leiden_mode,
leiden_debounce_effective,
leiden_storm_cap_effective,
leiden_delta_ratio_effective,
leiden_mode_current,
traverse_answer_hard_wall_ms_effective,
traverse_answer_max_hops_effective,
traverse_answer_hard_wall_exceeded,
ppr_size_gate_skipped,
ppr_size_gate_threshold,
}
}
pub fn encode(&self) -> Result<String, std::fmt::Error> {
let mut buf = String::new();
encode(&mut buf, &self.registry)?;
Ok(buf)
}
}
impl Default for Metrics {
fn default() -> Self {
Self::new()
}
}
pub(crate) async fn track_metrics(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Response {
let method = req.method().as_str().to_string();
let route = req
.extensions()
.get::<MatchedPath>()
.map_or_else(|| req.uri().path().to_string(), |m| m.as_str().to_string());
if route == "/metrics" {
return next.run(req).await;
}
let start = Instant::now();
let response = next.run(req).await;
let elapsed = start.elapsed().as_secs_f64();
let status = response.status().as_u16().to_string();
state
.metrics
.http_requests
.get_or_create(&HttpRequestLabels {
method,
route,
status,
})
.inc();
state.metrics.http_duration.observe(elapsed);
response
}
pub(crate) async fn metrics_handler(State(state): State<AppState>) -> Response {
match state.metrics.encode() {
Ok(body) => (
StatusCode::OK,
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
body,
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("metrics encoding failure: {e}"),
)
.into_response(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn metrics_encode_is_well_formed() {
let m = Metrics::new();
m.http_requests
.get_or_create(&HttpRequestLabels {
method: "GET".into(),
route: "/v1/healthz".into(),
status: "200".into(),
})
.inc();
m.http_duration.observe(0.002);
m.retrieve_latency.observe(0.015);
m.commit_duration.observe(0.050);
let text = m.encode().expect("encode");
assert!(
text.contains("# HELP mnem_http_requests_total"),
"missing HELP for mnem_http_requests_total in:\n{text}"
);
assert!(
text.contains("# TYPE mnem_http_requests_total counter"),
"missing TYPE for mnem_http_requests_total"
);
assert!(
text.contains("# HELP mnem_http_request_duration_seconds"),
"missing HELP for mnem_http_request_duration_seconds"
);
assert!(
text.contains("# HELP mnem_retrieve_latency_seconds"),
"missing HELP for mnem_retrieve_latency_seconds"
);
assert!(
text.contains("# HELP mnem_commit_duration_seconds"),
"missing HELP for mnem_commit_duration_seconds"
);
assert!(
text.contains("method=\"GET\""),
"counter label `method=GET` missing in:\n{text}"
);
assert!(
text.contains("route=\"/v1/healthz\""),
"counter label `route=/v1/healthz` missing"
);
assert!(
text.contains("status=\"200\""),
"counter label `status=200` missing"
);
}
#[test]
fn metrics_new_registers_all_four_families() {
let m = Metrics::new();
let text = m.encode().unwrap();
for family in [
"mnem_http_requests_total",
"mnem_http_request_duration_seconds",
"mnem_retrieve_latency_seconds",
"mnem_commit_duration_seconds",
"mnem_ingest_duration_seconds",
"mnem_ingest_chunks_total",
] {
assert!(
text.contains(family),
"expected metric family `{family}` in output:\n{text}"
);
}
}
#[test]
fn metrics_new_registers_all_remote_families() {
let m = Metrics::new();
let text = m.encode().unwrap();
for family in [
"mnem_remote_fetch_blocks_total",
"mnem_remote_push_blocks_total",
"mnem_remote_advance_head_total",
] {
assert!(
text.contains(family),
"expected metric family `{family}` in output:\n{text}"
);
}
}
#[test]
fn remote_counters_increment_and_render() {
let m = Metrics::new();
m.remote_fetch_blocks.inc();
m.remote_push_blocks.inc();
m.remote_advance_head
.get_or_create(&AdvanceHeadLabels {
result: "success".into(),
})
.inc();
m.remote_advance_head
.get_or_create(&AdvanceHeadLabels {
result: "cas_mismatch".into(),
})
.inc();
m.remote_advance_head
.get_or_create(&AdvanceHeadLabels {
result: "auth_fail".into(),
})
.inc();
let text = m.encode().unwrap();
assert!(text.contains("mnem_remote_fetch_blocks_total"));
assert!(text.contains("mnem_remote_push_blocks_total"));
for r in ["success", "cas_mismatch", "auth_fail"] {
assert!(
text.contains(&format!("result=\"{r}\"")),
"missing advance-head result `{r}` in:\n{text}"
);
}
}
}