Skip to main content

kora_lib/metrics/
middleware.rs

1use crate::rpc_server::middleware_utils::{extract_parts_and_body_bytes, get_jsonrpc_method};
2use http::{Request, Response};
3use jsonrpsee::server::logger::Body;
4use prometheus::{CounterVec, HistogramVec, Opts};
5use std::{sync::OnceLock, time::Instant};
6use tower::Layer;
7
8static HTTP_METRICS: OnceLock<HttpMetrics> = OnceLock::new();
9
10const UNKNOWN_METHOD: &str = "unknown";
11const ERROR_STATUS: &str = "error";
12
13pub struct HttpMetrics {
14    pub requests_total: CounterVec,
15    pub request_duration_seconds: HistogramVec,
16}
17
18impl HttpMetrics {
19    fn new() -> Self {
20        let requests_total = CounterVec::new(
21            Opts::new("http_requests_total", "Total number of HTTP requests").namespace("kora"),
22            &["method", "status"],
23        )
24        .unwrap_or_else(|e| {
25            log::error!("Failed to create http_requests_total metric: {e:?}");
26            panic!("Metrics initialization failed - cannot continue")
27        });
28
29        let request_duration_seconds = HistogramVec::new(
30            prometheus::HistogramOpts::new(
31                "http_request_duration_seconds",
32                "HTTP request duration in seconds",
33            )
34            .namespace("kora")
35            .buckets(vec![0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0]),
36            &["method"],
37        )
38        .unwrap_or_else(|e| {
39            log::error!("Failed to create http_request_duration_seconds metric: {e:?}");
40            panic!("Metrics initialization failed - cannot continue")
41        });
42
43        prometheus::register(Box::new(requests_total.clone())).unwrap_or_else(|e| {
44            log::error!("Failed to register http_requests_total metric: {e:?}");
45            panic!("Metrics initialization failed - cannot continue")
46        });
47        prometheus::register(Box::new(request_duration_seconds.clone())).unwrap_or_else(|e| {
48            log::error!("Failed to register http_request_duration_seconds metric: {e:?}");
49            panic!("Metrics initialization failed - cannot continue")
50        });
51
52        Self { requests_total, request_duration_seconds }
53    }
54
55    pub fn get() -> &'static HttpMetrics {
56        HTTP_METRICS.get_or_init(HttpMetrics::new)
57    }
58}
59/// Tower layer for collecting HTTP metrics
60#[derive(Clone)]
61pub struct HttpMetricsLayer;
62
63impl HttpMetricsLayer {
64    pub fn new() -> Self {
65        Self
66    }
67}
68
69impl Default for HttpMetricsLayer {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl<S> Layer<S> for HttpMetricsLayer {
76    type Service = HttpMetricsService<S>;
77
78    fn layer(&self, service: S) -> Self::Service {
79        HttpMetricsService { inner: service }
80    }
81}
82
83/// Tower service for collecting HTTP metrics
84#[derive(Clone)]
85pub struct HttpMetricsService<S> {
86    inner: S,
87}
88
89impl<S> tower::Service<Request<Body>> for HttpMetricsService<S>
90where
91    S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
92    S::Future: Send + 'static,
93{
94    type Response = S::Response;
95    type Error = S::Error;
96    type Future = std::pin::Pin<
97        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
98    >;
99
100    fn poll_ready(
101        &mut self,
102        cx: &mut std::task::Context<'_>,
103    ) -> std::task::Poll<Result<(), Self::Error>> {
104        self.inner.poll_ready(cx)
105    }
106
107    fn call(&mut self, request: Request<Body>) -> Self::Future {
108        let start = Instant::now();
109        let mut inner = self.inner.clone();
110
111        Box::pin(async move {
112            let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
113            let method = get_jsonrpc_method(&body_bytes).unwrap_or(UNKNOWN_METHOD.to_string());
114
115            // Reconstruct the request with the consumed body
116            let new_body = Body::from(body_bytes);
117            let new_request = Request::from_parts(parts, new_body);
118
119            // Call the inner service
120            let result = inner.call(new_request).await;
121
122            // Record metrics
123            let metrics = HttpMetrics::get();
124            let duration = start.elapsed();
125
126            match &result {
127                Ok(response) => {
128                    let status = response.status().as_u16().to_string();
129                    metrics.requests_total.with_label_values(&[&method, &status]).inc();
130                    metrics
131                        .request_duration_seconds
132                        .with_label_values(&[&method])
133                        .observe(duration.as_secs_f64());
134                }
135                Err(_) => {
136                    metrics
137                        .requests_total
138                        .with_label_values(&[&method, &ERROR_STATUS.to_string()])
139                        .inc();
140                }
141            }
142
143            result
144        })
145    }
146}