1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, BTreeSet},
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex},
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use axum::{Router, routing::get};
12use http::{Method, Request, Response, StatusCode};
13use tower::{Layer, Service};
14
15pub fn metrics_layer<H>(hook: H) -> MetricsLayer<H>
20where
21 H: HttpMetricsHook,
22{
23 MetricsLayer::new(hook)
24}
25
26pub fn route_metrics_layer<H>(route: impl Into<Cow<'static, str>>, hook: H) -> MetricsLayer<H>
31where
32 H: HttpMetricsHook,
33{
34 MetricsLayer::new(hook).route(route)
35}
36
37pub trait HttpMetricsHook: Clone + Send + Sync + 'static {
44 fn on_request(&self, method: &Method, route: Option<&str>);
46
47 fn on_response(
49 &self,
50 method: &Method,
51 route: Option<&str>,
52 status: StatusCode,
53 latency: Duration,
54 );
55
56 fn on_error(&self, _method: &Method, _route: Option<&str>, _latency: Duration) {}
58}
59
60#[derive(Clone, Debug)]
89pub struct PrometheusMetrics {
90 state: Arc<Mutex<PrometheusState>>,
91 excluded_routes: Arc<BTreeSet<String>>,
92 max_series: Option<usize>,
93}
94
95impl PrometheusMetrics {
96 pub fn new() -> Self {
101 Self {
102 state: Arc::new(Mutex::new(PrometheusState::default())),
103 excluded_routes: Arc::new(BTreeSet::from([
104 "/health/live".to_owned(),
105 "/health/ready".to_owned(),
106 "/metrics".to_owned(),
107 ])),
108 max_series: None,
109 }
110 }
111
112 pub fn exclude_route(mut self, route: impl Into<String>) -> Self {
117 Arc::make_mut(&mut self.excluded_routes).insert(route.into());
118 self
119 }
120
121 pub fn with_max_series(mut self, max_series: usize) -> Self {
130 self.max_series = Some(max_series);
131 self
132 }
133
134 pub fn layer(&self) -> MetricsLayer<Self> {
140 MetricsLayer::new(self.clone())
141 }
142
143 pub fn routes(&self) -> Router {
145 self.routes_at("/metrics")
146 }
147
148 pub fn routes_at(&self, path: &'static str) -> Router {
150 let metrics = self.clone();
151 Router::new().route(path, get(move || async move { metrics.render() }))
152 }
153
154 pub fn render(&self) -> String {
161 let state = self.snapshot();
162 render_prometheus(&state)
163 }
164
165 fn snapshot(&self) -> PrometheusState {
166 self.state
167 .lock()
168 .unwrap_or_else(|poisoned| poisoned.into_inner())
169 .clone()
170 }
171
172 fn should_record(&self, route: Option<&str>) -> bool {
173 route
174 .map(|route| !self.excluded_routes.contains(route))
175 .unwrap_or(true)
176 }
177}
178
179fn render_prometheus(state: &PrometheusState) -> String {
180 let mut output = String::new();
181 output.push_str("# TYPE nidus_http_requests_total counter\n");
182 for ((method, route, status), count) in &state.requests_total {
183 output.push_str(&format!(
184 "nidus_http_requests_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
185 escape_label(method),
186 escape_label(route),
187 status,
188 count
189 ));
190 }
191 output.push_str("# TYPE nidus_http_request_duration_seconds histogram\n");
192 for ((method, route, status), histogram) in &state.durations {
193 for (bucket, count) in HTTP_DURATION_BUCKETS
194 .iter()
195 .zip(histogram.bucket_counts.iter())
196 {
197 output.push_str(&format!(
198 "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"{}\"}} {}\n",
199 escape_label(method),
200 escape_label(route),
201 status,
202 format_bucket(*bucket),
203 count
204 ));
205 }
206 output.push_str(&format!(
207 "nidus_http_request_duration_seconds_bucket{{method=\"{}\",route=\"{}\",status=\"{}\",le=\"+Inf\"}} {}\n",
208 escape_label(method),
209 escape_label(route),
210 status,
211 histogram.count
212 ));
213 output.push_str(&format!(
214 "nidus_http_request_duration_seconds_count{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
215 escape_label(method),
216 escape_label(route),
217 status,
218 histogram.count
219 ));
220 output.push_str(&format!(
221 "nidus_http_request_duration_seconds_sum{{method=\"{}\",route=\"{}\",status=\"{}\"}} {:.6}\n",
222 escape_label(method),
223 escape_label(route),
224 status,
225 histogram.sum
226 ));
227 }
228 output.push_str("# TYPE nidus_http_in_flight_requests gauge\n");
229 for ((method, route), count) in &state.in_flight {
230 output.push_str(&format!(
231 "nidus_http_in_flight_requests{{method=\"{}\",route=\"{}\"}} {}\n",
232 escape_label(method),
233 escape_label(route),
234 count
235 ));
236 }
237 output.push_str("# TYPE nidus_http_errors_total counter\n");
238 for ((method, route, status), count) in &state.errors_total {
239 output.push_str(&format!(
240 "nidus_http_errors_total{{method=\"{}\",route=\"{}\",status=\"{}\"}} {}\n",
241 escape_label(method),
242 escape_label(route),
243 status,
244 count
245 ));
246 }
247 output
248}
249
250impl Default for PrometheusMetrics {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256impl HttpMetricsHook for PrometheusMetrics {
257 fn on_request(&self, method: &Method, route: Option<&str>) {
258 if !self.should_record(route) {
259 return;
260 }
261 let route = route.unwrap_or("<unknown>").to_owned();
262 let mut state = self
263 .state
264 .lock()
265 .unwrap_or_else(|poisoned| poisoned.into_inner());
266 let route = match self.max_series {
267 Some(max) => state.admit_route(route, max),
268 None => route,
269 };
270 *state
271 .in_flight
272 .entry((method.as_str().to_owned(), route))
273 .or_default() += 1;
274 }
275
276 fn on_response(
277 &self,
278 method: &Method,
279 route: Option<&str>,
280 status: StatusCode,
281 latency: Duration,
282 ) {
283 if !self.should_record(route) {
284 return;
285 }
286 let method = method.as_str().to_owned();
287 let route = route.unwrap_or("<unknown>").to_owned();
288 let status = status.as_u16();
289 let mut state = self
290 .state
291 .lock()
292 .unwrap_or_else(|poisoned| poisoned.into_inner());
293 let route = match self.max_series {
294 Some(max) => state.admit_route(route, max),
295 None => route,
296 };
297 *state
298 .requests_total
299 .entry((method.clone(), route.clone(), status))
300 .or_default() += 1;
301 state
302 .durations
303 .entry((method.clone(), route.clone(), status))
304 .or_default()
305 .observe(latency);
306 if StatusCode::from_u16(status)
307 .is_ok_and(|status| status.is_client_error() || status.is_server_error())
308 {
309 *state
310 .errors_total
311 .entry((method.clone(), route.clone(), status))
312 .or_default() += 1;
313 }
314 let key = (method, route);
315 if let Some(count) = state.in_flight.get_mut(&key) {
316 *count = count.saturating_sub(1);
317 }
318 }
319
320 fn on_error(&self, method: &Method, route: Option<&str>, latency: Duration) {
321 if !self.should_record(route) {
322 return;
323 }
324 let method = method.as_str().to_owned();
325 let route = route.unwrap_or("<unknown>").to_owned();
326 let mut state = self
327 .state
328 .lock()
329 .unwrap_or_else(|poisoned| poisoned.into_inner());
330 let route = match self.max_series {
331 Some(max) => state.admit_route(route, max),
332 None => route,
333 };
334 let status = StatusCode::INTERNAL_SERVER_ERROR.as_u16();
335 *state
336 .requests_total
337 .entry((method.clone(), route.clone(), status))
338 .or_default() += 1;
339 state
340 .durations
341 .entry((method.clone(), route.clone(), status))
342 .or_default()
343 .observe(latency);
344 *state
345 .errors_total
346 .entry((method.clone(), route.clone(), status))
347 .or_default() += 1;
348 let key = (method, route);
349 if let Some(count) = state.in_flight.get_mut(&key) {
350 *count = count.saturating_sub(1);
351 }
352 }
353}
354
355#[derive(Clone, Debug, Default)]
356struct PrometheusState {
357 requests_total: BTreeMap<(String, String, u16), u64>,
358 durations: BTreeMap<(String, String, u16), DurationHistogram>,
359 in_flight: BTreeMap<(String, String), u64>,
360 errors_total: BTreeMap<(String, String, u16), u64>,
361 known_routes: BTreeSet<String>,
362}
363
364impl PrometheusState {
365 fn admit_route(&mut self, route: String, max_series: usize) -> String {
370 if self.known_routes.contains(&route) {
371 route
372 } else if self.known_routes.len() < max_series {
373 self.known_routes.insert(route.clone());
374 route
375 } else {
376 "<overflow>".to_owned()
377 }
378 }
379}
380
381const HTTP_DURATION_BUCKETS: [f64; 11] = [
382 0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.000, 2.500, 5.000, 10.000,
383];
384
385#[derive(Clone, Debug, Default)]
386struct DurationHistogram {
387 count: u64,
388 sum: f64,
389 bucket_counts: [u64; HTTP_DURATION_BUCKETS.len()],
390}
391
392impl DurationHistogram {
393 fn observe(&mut self, latency: Duration) {
394 let seconds = latency.as_secs_f64();
395 self.count += 1;
396 self.sum += seconds;
397 for (bucket, count) in HTTP_DURATION_BUCKETS
398 .iter()
399 .zip(self.bucket_counts.iter_mut())
400 {
401 if seconds <= *bucket {
402 *count += 1;
403 }
404 }
405 }
406}
407
408#[derive(Clone, Debug)]
413pub struct MetricsLayer<H> {
414 hook: H,
415 route: Option<Cow<'static, str>>,
416}
417
418impl<H> MetricsLayer<H>
419where
420 H: HttpMetricsHook,
421{
422 pub fn new(hook: H) -> Self {
424 Self { hook, route: None }
425 }
426
427 pub fn route(mut self, route: impl Into<Cow<'static, str>>) -> Self {
432 self.route = Some(route.into());
433 self
434 }
435}
436
437impl<S, H> Layer<S> for MetricsLayer<H>
438where
439 H: HttpMetricsHook,
440{
441 type Service = MetricsService<S, H>;
442
443 fn layer(&self, inner: S) -> Self::Service {
444 MetricsService {
445 inner,
446 hook: self.hook.clone(),
447 route: self.route.clone(),
448 }
449 }
450}
451
452#[derive(Clone, Debug)]
454pub struct MetricsService<S, H> {
455 inner: S,
456 hook: H,
457 route: Option<Cow<'static, str>>,
458}
459
460impl<S, H, RequestBody, ResponseBody> Service<Request<RequestBody>> for MetricsService<S, H>
461where
462 S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
463 S::Future: Send + 'static,
464 S::Error: Send + 'static,
465 H: HttpMetricsHook,
466 RequestBody: Send + 'static,
467 ResponseBody: Send + 'static,
468{
469 type Response = Response<ResponseBody>;
470 type Error = S::Error;
471 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
472
473 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
474 self.inner.poll_ready(cx)
475 }
476
477 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
478 let method = request.method().clone();
479 let hook = self.hook.clone();
480 let route = self.route.clone().or_else(|| {
481 request
482 .extensions()
483 .get::<axum::extract::MatchedPath>()
484 .map(|path| Cow::Owned(path.as_str().to_owned()))
485 });
486 hook.on_request(&method, route.as_deref());
487 let started_at = Instant::now();
488 let future = self.inner.call(request);
489
490 Box::pin(async move {
491 match future.await {
492 Ok(response) => {
493 hook.on_response(
494 &method,
495 route.as_deref(),
496 response.status(),
497 started_at.elapsed(),
498 );
499 Ok(response)
500 }
501 Err(error) => {
502 hook.on_error(&method, route.as_deref(), started_at.elapsed());
503 Err(error)
504 }
505 }
506 })
507 }
508}
509
510fn escape_label(value: &str) -> String {
511 value
512 .replace('\\', r"\\")
513 .replace('\n', r"\n")
514 .replace('"', r#"\""#)
515}
516
517fn format_bucket(bucket: f64) -> String {
518 if bucket.fract() == 0.0 {
519 format!("{bucket:.0}")
520 } else {
521 let formatted = format!("{bucket:.3}");
522 formatted.trim_end_matches('0').to_owned()
523 }
524}