use {axum::{body::Body,
extract::Request,
response::Response},
futures::future::BoxFuture,
metrics::{counter,
gauge,
histogram},
std::{task::{Context,
Poll},
time::Instant},
tower::{Layer,
Service}};
#[derive(Clone, Default)]
pub struct HttpMetricsLayer;
impl HttpMetricsLayer {
pub fn new() -> Self {
Self
}
}
impl<S> Layer<S> for HttpMetricsLayer {
type Service = HttpMetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
HttpMetricsService { inner }
}
}
#[derive(Clone)]
pub struct HttpMetricsService<S> {
inner: S,
}
impl<S> Service<Request<Body>> for HttpMetricsService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = S::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let method = req.method().to_string();
let endpoint = extract_endpoint(req.uri().path());
let start = Instant::now();
gauge!("http_active_requests", "endpoint" => endpoint.clone()).increment(1.0);
let mut inner = self.inner.clone();
Box::pin(async move {
let response = inner.call(req).await;
let duration = start.elapsed().as_secs_f64();
let status = match &response {
| Ok(r) => r.status().as_u16().to_string(),
| Err(_) => "500".to_string(),
};
counter!(
"http_requests_total",
"method" => method.clone(),
"endpoint" => endpoint.clone(),
"status" => status
)
.increment(1);
histogram!(
"http_request_duration_seconds",
"method" => method,
"endpoint" => endpoint.clone()
)
.record(duration);
gauge!("http_active_requests", "endpoint" => endpoint).decrement(1.0);
response
})
}
}
fn extract_endpoint(path: &str) -> String {
let trimmed = path.trim_matches('/');
if trimmed.is_empty() {
return "root".to_string();
}
let rest = match trimmed.split_once('/') {
| Some((_prefix, rest)) if !rest.is_empty() => rest,
| _ => trimmed,
};
rest
.replace('/', "_")
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_extract_endpoint_with_service_prefix() {
assert_eq!(extract_endpoint("/my-service/foo/bar"), "foo_bar");
assert_eq!(extract_endpoint("/loyalty-api/campaigns/my"), "campaigns_my");
assert_eq!(extract_endpoint("/loyalty-api/campaigns/detail"), "campaigns_detail");
assert_eq!(extract_endpoint("/loyalty-admin/migration"), "migration");
}
#[test]
fn test_extract_endpoint_single_segment() {
assert_eq!(extract_endpoint("/health"), "health");
assert_eq!(extract_endpoint("/metrics"), "metrics");
}
#[test]
fn test_extract_endpoint_root() {
assert_eq!(extract_endpoint("/"), "root");
assert_eq!(extract_endpoint(""), "root");
}
#[test]
fn test_extract_endpoint_deep_path() {
assert_eq!(extract_endpoint("/svc/a/b/c"), "a_b_c");
}
}