1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, BTreeSet},
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex},
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use axum::{Router, routing::get};
12use http::{Method, Request, Response, StatusCode};
13use tower::{Layer, Service};
14
15pub fn metrics_layer<H>(hook: H) -> MetricsLayer<H>
20where
21 H: HttpMetricsHook,
22{
23 MetricsLayer::new(hook)
24}
25
26pub fn route_metrics_layer<H>(route: impl Into<Cow<'static, str>>, hook: H) -> MetricsLayer<H>
31where
32 H: HttpMetricsHook,
33{
34 MetricsLayer::new(hook).route(route)
35}
36
37pub trait HttpMetricsHook: Clone + Send + Sync + 'static {
44 fn on_request(&self, method: &Method, route: Option<&str>);
46
47 fn on_response(
49 &self,
50 method: &Method,
51 route: Option<&str>,
52 status: StatusCode,
53 latency: Duration,
54 );
55
56 fn on_error(&self, _method: &Method, _route: Option<&str>, _latency: Duration) {}
58}
59
60#[derive(Clone, Debug)]
91pub struct PrometheusMetrics {
92 state: Arc<Mutex<PrometheusState>>,
93 excluded_routes: Arc<BTreeSet<String>>,
94 max_series: Option<usize>,
95}
96
97impl PrometheusMetrics {
98 pub fn new() -> Self {
103 Self {
104 state: Arc::new(Mutex::new(PrometheusState::default())),
105 excluded_routes: Arc::new(BTreeSet::from([
106 "/health/live".to_owned(),
107 "/health/ready".to_owned(),
108 "/metrics".to_owned(),
109 ])),
110 max_series: None,
111 }
112 }
113
114 pub fn exclude_route(mut self, route: impl Into<String>) -> Self {
119 Arc::make_mut(&mut self.excluded_routes).insert(route.into());
120 self
121 }
122
123 pub fn with_max_series(mut self, max_series: usize) -> Self {
132 self.max_series = Some(max_series);
133 self
134 }
135
136 pub fn layer(&self) -> MetricsLayer<Self> {
142 MetricsLayer::new(self.clone())
143 }
144
145 pub fn routes(&self) -> Router {
147 self.routes_at("/metrics")
148 }
149
150 pub fn routes_at(&self, path: &'static str) -> Router {
152 let metrics = self.clone();
153 Router::new().route(path, get(move || async move { metrics.render() }))
154 }
155
156 pub fn render(&self) -> String {
163 let state = self.snapshot();
164 render_prometheus(&state)
165 }
166
167 fn snapshot(&self) -> PrometheusState {
168 self.state
169 .lock()
170 .unwrap_or_else(|poisoned| poisoned.into_inner())
171 .clone()
172 }
173
174 fn should_record(&self, route: Option<&str>) -> bool {
175 route
176 .map(|route| !self.excluded_routes.contains(route))
177 .unwrap_or(true)
178 }
179}
180
181fn render_prometheus(state: &PrometheusState) -> String {
182 let mut output = String::new();
183 output.push_str("# TYPE nidus_http_requests_total counter\n");
184 for ((method, route, status), count) in &state.requests_total {
185 output.push_str(&format!(
186 "nidus_http_requests_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
187 escape_label(method),
188 escape_label(route),
189 status,
190 count
191 ));
192 }
193 output.push_str("# TYPE nidus_http_request_duration_seconds histogram\n");
194 for ((method, route, status), histogram) in &state.durations {
195 for (bucket, count) in HTTP_DURATION_BUCKETS
196 .iter()
197 .zip(histogram.bucket_counts.iter())
198 {
199 output.push_str(&format!(
200 "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"{}\"}} {}\n",
201 escape_label(method),
202 escape_label(route),
203 status,
204 format_bucket(*bucket),
205 count
206 ));
207 }
208 output.push_str(&format!(
209 "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"+Inf\"}} {}\n",
210 escape_label(method),
211 escape_label(route),
212 status,
213 histogram.count
214 ));
215 output.push_str(&format!(
216 "nidus_http_request_duration_seconds_count{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
217 escape_label(method),
218 escape_label(route),
219 status,
220 histogram.count
221 ));
222 output.push_str(&format!(
223 "nidus_http_request_duration_seconds_sum{{method=\"{}\",route=\"{}\",status=\"{}\"}} {:.6}\n",
224 escape_label(method),
225 escape_label(route),
226 status,
227 histogram.sum
228 ));
229 }
230 output.push_str("# TYPE nidus_http_in_flight_requests gauge\n");
231 for ((method, route), count) in &state.in_flight {
232 output.push_str(&format!(
233 "nidus_http_in_flight_requests{{method=\"{}\",route=\"{}\"}} {}\n",
234 escape_label(method),
235 escape_label(route),
236 count
237 ));
238 }
239 output.push_str("# TYPE nidus_http_errors_total counter\n");
240 for ((method, route, status), count) in &state.errors_total {
241 output.push_str(&format!(
242 "nidus_http_errors_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
243 escape_label(method),
244 escape_label(route),
245 status,
246 count
247 ));
248 }
249 output
250}
251
252impl Default for PrometheusMetrics {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl HttpMetricsHook for PrometheusMetrics {
259 fn on_request(&self, method: &Method, route: Option<&str>) {
260 if !self.should_record(route) {
261 return;
262 }
263 let route = route.unwrap_or("<unknown>").to_owned();
264 let mut state = self
265 .state
266 .lock()
267 .unwrap_or_else(|poisoned| poisoned.into_inner());
268 let route = match self.max_series {
269 Some(max) => state.admit_route(route, max),
270 None => route,
271 };
272 *state
273 .in_flight
274 .entry((method.as_str().to_owned(), route))
275 .or_default() += 1;
276 }
277
278 fn on_response(
279 &self,
280 method: &Method,
281 route: Option<&str>,
282 status: StatusCode,
283 latency: Duration,
284 ) {
285 if !self.should_record(route) {
286 return;
287 }
288 let method = method.as_str().to_owned();
289 let route = route.unwrap_or("<unknown>").to_owned();
290 let status = status.as_u16();
291 let mut state = self
292 .state
293 .lock()
294 .unwrap_or_else(|poisoned| poisoned.into_inner());
295 let route = match self.max_series {
296 Some(max) => state.admit_route(route, max),
297 None => route,
298 };
299 *state
300 .requests_total
301 .entry((method.clone(), route.clone(), status))
302 .or_default() += 1;
303 state
304 .durations
305 .entry((method.clone(), route.clone(), status))
306 .or_default()
307 .observe(latency);
308 if StatusCode::from_u16(status)
309 .is_ok_and(|status| status.is_client_error() || status.is_server_error())
310 {
311 *state
312 .errors_total
313 .entry((method.clone(), route.clone(), status))
314 .or_default() += 1;
315 }
316 let key = (method, route);
317 if let Some(count) = state.in_flight.get_mut(&key) {
318 *count = count.saturating_sub(1);
319 }
320 }
321
322 fn on_error(&self, method: &Method, route: Option<&str>, latency: Duration) {
323 if !self.should_record(route) {
324 return;
325 }
326 let method = method.as_str().to_owned();
327 let route = route.unwrap_or("<unknown>").to_owned();
328 let mut state = self
329 .state
330 .lock()
331 .unwrap_or_else(|poisoned| poisoned.into_inner());
332 let route = match self.max_series {
333 Some(max) => state.admit_route(route, max),
334 None => route,
335 };
336 let status = StatusCode::INTERNAL_SERVER_ERROR.as_u16();
337 *state
338 .requests_total
339 .entry((method.clone(), route.clone(), status))
340 .or_default() += 1;
341 state
342 .durations
343 .entry((method.clone(), route.clone(), status))
344 .or_default()
345 .observe(latency);
346 *state
347 .errors_total
348 .entry((method.clone(), route.clone(), status))
349 .or_default() += 1;
350 let key = (method, route);
351 if let Some(count) = state.in_flight.get_mut(&key) {
352 *count = count.saturating_sub(1);
353 }
354 }
355}
356
357#[derive(Clone, Debug, Default)]
358struct PrometheusState {
359 requests_total: BTreeMap<(String, String, u16), u64>,
360 durations: BTreeMap<(String, String, u16), DurationHistogram>,
361 in_flight: BTreeMap<(String, String), u64>,
362 errors_total: BTreeMap<(String, String, u16), u64>,
363 known_routes: BTreeSet<String>,
364}
365
366impl PrometheusState {
367 fn admit_route(&mut self, route: String, max_series: usize) -> String {
372 if self.known_routes.contains(&route) {
373 route
374 } else if self.known_routes.len() < max_series {
375 self.known_routes.insert(route.clone());
376 route
377 } else {
378 "<overflow>".to_owned()
379 }
380 }
381}
382
383const HTTP_DURATION_BUCKETS: [f64; 11] = [
384 0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.000, 2.500, 5.000, 10.000,
385];
386
387#[derive(Clone, Debug, Default)]
388struct DurationHistogram {
389 count: u64,
390 sum: f64,
391 bucket_counts: [u64; HTTP_DURATION_BUCKETS.len()],
392}
393
394impl DurationHistogram {
395 fn observe(&mut self, latency: Duration) {
396 let seconds = latency.as_secs_f64();
397 self.count += 1;
398 self.sum += seconds;
399 for (bucket, count) in HTTP_DURATION_BUCKETS
400 .iter()
401 .zip(self.bucket_counts.iter_mut())
402 {
403 if seconds <= *bucket {
404 *count += 1;
405 }
406 }
407 }
408}
409
410#[derive(Clone, Debug)]
415pub struct MetricsLayer<H> {
416 hook: H,
417 route: Option<Cow<'static, str>>,
418}
419
420impl<H> MetricsLayer<H>
421where
422 H: HttpMetricsHook,
423{
424 pub fn new(hook: H) -> Self {
426 Self { hook, route: None }
427 }
428
429 pub fn route(mut self, route: impl Into<Cow<'static, str>>) -> Self {
434 self.route = Some(route.into());
435 self
436 }
437}
438
439impl<S, H> Layer<S> for MetricsLayer<H>
440where
441 H: HttpMetricsHook,
442{
443 type Service = MetricsService<S, H>;
444
445 fn layer(&self, inner: S) -> Self::Service {
446 MetricsService {
447 inner,
448 hook: self.hook.clone(),
449 route: self.route.clone(),
450 }
451 }
452}
453
454#[derive(Clone, Debug)]
456pub struct MetricsService<S, H> {
457 inner: S,
458 hook: H,
459 route: Option<Cow<'static, str>>,
460}
461
462impl<S, H, RequestBody, ResponseBody> Service<Request<RequestBody>> for MetricsService<S, H>
463where
464 S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
465 S::Future: Send + 'static,
466 S::Error: Send + 'static,
467 H: HttpMetricsHook,
468 RequestBody: Send + 'static,
469 ResponseBody: Send + 'static,
470{
471 type Response = Response<ResponseBody>;
472 type Error = S::Error;
473 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
474
475 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
476 self.inner.poll_ready(cx)
477 }
478
479 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
480 let method = request.method().clone();
481 let hook = self.hook.clone();
482 let route = self.route.clone().or_else(|| {
483 request
484 .extensions()
485 .get::<axum::extract::MatchedPath>()
486 .map(|path| Cow::Owned(path.as_str().to_owned()))
487 });
488 hook.on_request(&method, route.as_deref());
489 let started_at = Instant::now();
490 let future = self.inner.call(request);
491
492 Box::pin(async move {
493 match future.await {
494 Ok(response) => {
495 hook.on_response(
496 &method,
497 route.as_deref(),
498 response.status(),
499 started_at.elapsed(),
500 );
501 Ok(response)
502 }
503 Err(error) => {
504 hook.on_error(&method, route.as_deref(), started_at.elapsed());
505 Err(error)
506 }
507 }
508 })
509 }
510}
511
512fn escape_label(value: &str) -> String {
513 value
514 .replace('\\', r"\\")
515 .replace('\n', r"\n")
516 .replace('"', r#"\""#)
517}
518
519fn format_bucket(bucket: f64) -> String {
520 if bucket.fract() == 0.0 {
521 format!("{bucket:.0}")
522 } else {
523 let formatted = format!("{bucket:.3}");
524 formatted.trim_end_matches('0').to_owned()
525 }
526}