use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Instant,
};
use axum::body::Body;
use http::{Request, Response};
use prometheus::{HistogramVec, IntCounterVec};
use scion_sdk_observability::metrics::registry::MetricsRegistry;
use tower::{BoxError, Layer, Service};
#[derive(Clone)]
pub struct PrometheusMiddlewareLayer {
metrics: Metrics,
}
impl PrometheusMiddlewareLayer {
pub fn new(metrics: Metrics) -> Self {
Self { metrics }
}
}
impl<S> Layer<S> for PrometheusMiddlewareLayer {
type Service = PrometheusMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
PrometheusMiddleware::new(inner, self.metrics.clone())
}
}
#[derive(Clone)]
pub struct PrometheusMiddleware<S> {
inner: S,
metrics: Metrics,
}
impl<S> PrometheusMiddleware<S> {
pub fn new(inner: S, metrics: Metrics) -> Self {
Self { inner, metrics }
}
}
impl<S> Service<Request<Body>> for PrometheusMiddleware<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
S::Error: Into<BoxError>,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let method = request.uri().path().to_string();
let metrics = self.metrics.clone();
metrics
.control_plane_started_total
.with_label_values(&[&method])
.inc();
let fut = self.inner.call(request);
let start = Instant::now();
Box::pin(async move {
let result = fut.await.map_err(Into::into)?;
let status = result.status().as_str().to_string();
metrics
.control_plane_handled_total
.with_label_values(&[&method, &status])
.inc();
let elapsed = start.elapsed().as_secs_f64();
metrics
.control_plane_latency_seconds
.with_label_values(&[&method, &status])
.observe(elapsed);
Ok(result)
})
}
}
#[derive(Debug, Clone)]
pub struct Metrics {
pub control_plane_started_total: IntCounterVec,
pub control_plane_handled_total: IntCounterVec,
pub control_plane_latency_seconds: HistogramVec,
}
impl Metrics {
pub fn new(metrics_registry: &MetricsRegistry) -> Self {
Metrics {
control_plane_started_total: metrics_registry.int_counter_vec(
"control_plane_requests_started_total",
"Total number of control plane API requests started on the server.",
&["method"],
),
control_plane_handled_total: metrics_registry.int_counter_vec(
"control_plane_requests_handled_total",
"Total number of control plane API requests handled on the server.",
&["method", "status"],
),
control_plane_latency_seconds: metrics_registry.histogram_vec(
"control_plane_requests_latency_seconds",
"Latency of control plane API requests in seconds.",
vec![0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0],
&["method", "status"],
),
}
}
}