use super::layer::{BoxedNext, MiddlewareLayer};
use crate::request::Request;
use crate::response::Response;
use bytes::Bytes;
use prometheus::{
Encoder, GaugeVec, HistogramOpts, HistogramVec, IntCounterVec, Opts, Registry, TextEncoder,
};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
const DEFAULT_BUCKETS: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
];
#[derive(Clone)]
pub struct MetricsLayer {
inner: Arc<MetricsInner>,
}
struct MetricsInner {
registry: Registry,
requests_total: IntCounterVec,
request_duration: HistogramVec,
#[allow(dead_code)]
info_gauge: GaugeVec,
}
impl MetricsLayer {
pub fn new() -> Self {
let registry = Registry::new();
Self::with_registry(registry)
}
pub fn with_registry(registry: Registry) -> Self {
let requests_total = IntCounterVec::new(
Opts::new("http_requests_total", "Total number of HTTP requests"),
&["method", "path", "status"],
)
.expect("Failed to create http_requests_total metric");
let request_duration = HistogramVec::new(
HistogramOpts::new(
"http_request_duration_seconds",
"HTTP request duration in seconds",
)
.buckets(DEFAULT_BUCKETS.to_vec()),
&["method", "path"],
)
.expect("Failed to create http_request_duration_seconds metric");
let info_gauge = GaugeVec::new(
Opts::new("rustapi_info", "RustAPI version information"),
&["version"],
)
.expect("Failed to create rustapi_info metric");
registry
.register(Box::new(requests_total.clone()))
.expect("Failed to register http_requests_total");
registry
.register(Box::new(request_duration.clone()))
.expect("Failed to register http_request_duration_seconds");
registry
.register(Box::new(info_gauge.clone()))
.expect("Failed to register rustapi_info");
let version = env!("CARGO_PKG_VERSION");
info_gauge.with_label_values(&[version]).set(1.0);
Self {
inner: Arc::new(MetricsInner {
registry,
requests_total,
request_duration,
info_gauge,
}),
}
}
pub fn registry(&self) -> &Registry {
&self.inner.registry
}
pub fn handler(&self) -> impl Fn() -> MetricsResponse + Clone + Send + Sync + 'static {
let registry = self.inner.registry.clone();
move || {
let encoder = TextEncoder::new();
let metric_families = registry.gather();
let mut buffer = Vec::new();
encoder
.encode(&metric_families, &mut buffer)
.expect("Failed to encode metrics");
MetricsResponse(buffer)
}
}
fn record_request(&self, method: &str, path: &str, status: u16, duration_secs: f64) {
let normalized_path = normalize_path(path);
self.inner
.requests_total
.with_label_values(&[method, &normalized_path, &status.to_string()])
.inc();
self.inner
.request_duration
.with_label_values(&[method, &normalized_path])
.observe(duration_secs);
}
pub fn custom_metrics(&self) -> CustomMetricsBuilder {
CustomMetricsBuilder {
inner: Arc::clone(&self.inner),
}
}
}
pub struct CustomMetricsBuilder {
inner: Arc<MetricsInner>,
}
impl CustomMetricsBuilder {
pub fn counter(&self, name: &str, help: &str) -> prometheus::Counter {
let counter = prometheus::Counter::new(name, help).expect("Failed to create counter");
self.inner
.registry
.register(Box::new(counter.clone()))
.expect("Failed to register counter");
counter
}
pub fn counter_vec(&self, name: &str, help: &str, label_names: &[&str]) -> IntCounterVec {
let counter = IntCounterVec::new(Opts::new(name, help), label_names)
.expect("Failed to create counter vec");
self.inner
.registry
.register(Box::new(counter.clone()))
.expect("Failed to register counter vec");
counter
}
pub fn gauge(&self, name: &str, help: &str) -> prometheus::Gauge {
let gauge = prometheus::Gauge::new(name, help).expect("Failed to create gauge");
self.inner
.registry
.register(Box::new(gauge.clone()))
.expect("Failed to register gauge");
gauge
}
pub fn gauge_vec(&self, name: &str, help: &str, label_names: &[&str]) -> GaugeVec {
let gauge =
GaugeVec::new(Opts::new(name, help), label_names).expect("Failed to create gauge vec");
self.inner
.registry
.register(Box::new(gauge.clone()))
.expect("Failed to register gauge vec");
gauge
}
pub fn histogram(&self, name: &str, help: &str, buckets: Vec<f64>) -> prometheus::Histogram {
let histogram =
prometheus::Histogram::with_opts(HistogramOpts::new(name, help).buckets(buckets))
.expect("Failed to create histogram");
self.inner
.registry
.register(Box::new(histogram.clone()))
.expect("Failed to register histogram");
histogram
}
pub fn histogram_with_default_buckets(&self, name: &str, help: &str) -> prometheus::Histogram {
self.histogram(name, help, DEFAULT_BUCKETS.to_vec())
}
pub fn histogram_vec(
&self,
name: &str,
help: &str,
label_names: &[&str],
buckets: Vec<f64>,
) -> HistogramVec {
let histogram =
HistogramVec::new(HistogramOpts::new(name, help).buckets(buckets), label_names)
.expect("Failed to create histogram vec");
self.inner
.registry
.register(Box::new(histogram.clone()))
.expect("Failed to register histogram vec");
histogram
}
}
impl Default for MetricsLayer {
fn default() -> Self {
Self::new()
}
}
impl MiddlewareLayer for MetricsLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let method = req.method().to_string();
let path = req.uri().path().to_string();
let metrics = self.clone();
Box::pin(async move {
let start = Instant::now();
let response = next(req).await;
let duration = start.elapsed().as_secs_f64();
let status = response.status().as_u16();
metrics.record_request(&method, &path, status, duration);
response
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
pub struct MetricsResponse(Vec<u8>);
impl crate::response::IntoResponse for MetricsResponse {
fn into_response(self) -> Response {
use crate::response::Body;
http::Response::builder()
.status(http::StatusCode::OK)
.header(
http::header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)
.body(Body::Full(http_body_util::Full::new(Bytes::from(self.0))))
.unwrap()
}
}
fn normalize_path(path: &str) -> String {
let segments: Vec<&str> = path.split('/').collect();
let normalized: Vec<String> = segments
.into_iter()
.map(|segment| {
if segment.is_empty() {
String::new()
} else if is_id_like(segment) {
":id".to_string()
} else {
segment.to_string()
}
})
.collect();
normalized.join("/")
}
fn is_id_like(segment: &str) -> bool {
if segment.len() == 36 && segment.chars().filter(|c| *c == '-').count() == 4 {
return true;
}
if segment.chars().all(|c| c.is_ascii_digit()) && !segment.is_empty() {
return true;
}
if segment.len() >= 8 && segment.chars().all(|c| c.is_ascii_hexdigit()) {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::layer::{BoxedNext, LayerStack};
use http::{Extensions, Method, StatusCode};
use proptest::prelude::*;
use proptest::test_runner::TestCaseError;
use std::collections::HashMap;
use std::sync::Arc;
fn create_test_request(method: Method, path: &str) -> crate::request::Request {
let uri: http::Uri = path.parse().unwrap();
let builder = http::Request::builder().method(method).uri(uri);
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
crate::request::Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
HashMap::new().into(),
)
}
#[test]
fn test_metrics_layer_creation() {
let metrics = MetricsLayer::new();
assert!(!metrics.registry().gather().is_empty());
}
#[test]
fn test_metrics_handler_returns_prometheus_format() {
let metrics = MetricsLayer::new();
let handler = metrics.handler();
let response = handler();
let http_response = crate::response::IntoResponse::into_response(response);
assert_eq!(http_response.status(), StatusCode::OK);
let content_type = http_response
.headers()
.get(http::header::CONTENT_TYPE)
.unwrap();
assert!(content_type.to_str().unwrap().contains("text/plain"));
}
#[test]
fn test_normalize_path_with_uuid() {
let path = "/users/550e8400-e29b-41d4-a716-446655440000/posts";
let normalized = normalize_path(path);
assert_eq!(normalized, "/users/:id/posts");
}
#[test]
fn test_normalize_path_with_numeric_id() {
let path = "/users/12345/posts";
let normalized = normalize_path(path);
assert_eq!(normalized, "/users/:id/posts");
}
#[test]
fn test_normalize_path_without_ids() {
let path = "/users/profile/settings";
let normalized = normalize_path(path);
assert_eq!(normalized, "/users/profile/settings");
}
#[test]
fn test_is_id_like() {
assert!(is_id_like("550e8400-e29b-41d4-a716-446655440000"));
assert!(is_id_like("12345"));
assert!(is_id_like("1"));
assert!(is_id_like("deadbeef"));
assert!(is_id_like("abc123def456"));
assert!(!is_id_like("users"));
assert!(!is_id_like("profile"));
assert!(!is_id_like(""));
}
#[test]
fn test_rustapi_info_gauge_set() {
let metrics = MetricsLayer::new();
let handler = metrics.handler();
let response = handler();
let http_response = crate::response::IntoResponse::into_response(response);
let _body = http_response.into_body();
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_request_metrics_recording(
method_idx in 0usize..5usize,
path in "/[a-z]{1,10}",
status_code in 200u16..600u16,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: Result<(), TestCaseError> = rt.block_on(async {
let metrics = MetricsLayer::new();
let mut stack = LayerStack::new();
stack.push(Box::new(metrics.clone()));
let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
let method = methods[method_idx].clone();
let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
let status = response_status;
Box::pin(async move {
http::Response::builder()
.status(status)
.body(crate::response::Body::Full(http_body_util::Full::new(Bytes::from("test"))))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let request = create_test_request(method.clone(), &path);
let response = stack.execute(request, handler).await;
prop_assert_eq!(response.status(), response_status);
let metric_families = metrics.registry().gather();
let requests_total = metric_families
.iter()
.find(|mf| mf.get_name() == "http_requests_total");
prop_assert!(
requests_total.is_some(),
"http_requests_total metric should exist"
);
let requests_total = requests_total.unwrap();
let metrics_vec = requests_total.get_metric();
let matching_metric = metrics_vec.iter().find(|m| {
let labels = m.get_label();
let method_label = labels.iter().find(|l| l.get_name() == "method");
let path_label = labels.iter().find(|l| l.get_name() == "path");
let status_label = labels.iter().find(|l| l.get_name() == "status");
method_label.map(|l| l.get_value()) == Some(method.as_str())
&& path_label.map(|l| l.get_value()) == Some(&path)
&& status_label.map(|l| l.get_value()) == Some(&status_code.to_string())
});
prop_assert!(
matching_metric.is_some(),
"Should have metric with method={}, path={}, status={}. Available metrics: {:?}",
method.as_str(),
path,
status_code,
metrics_vec.iter().map(|m| m.get_label()).collect::<Vec<_>>()
);
let counter_value = matching_metric.unwrap().get_counter().get_value();
prop_assert!(
counter_value >= 1.0,
"Counter should be at least 1, got {}",
counter_value
);
let duration_metric = metric_families
.iter()
.find(|mf| mf.get_name() == "http_request_duration_seconds");
prop_assert!(
duration_metric.is_some(),
"http_request_duration_seconds metric should exist"
);
let duration_metric = duration_metric.unwrap();
let duration_vec = duration_metric.get_metric();
let matching_histogram = duration_vec.iter().find(|m| {
let labels = m.get_label();
let method_label = labels.iter().find(|l| l.get_name() == "method");
let path_label = labels.iter().find(|l| l.get_name() == "path");
method_label.map(|l| l.get_value()) == Some(method.as_str())
&& path_label.map(|l| l.get_value()) == Some(&path)
});
prop_assert!(
matching_histogram.is_some(),
"Should have histogram with method={}, path={}",
method.as_str(),
path
);
let histogram = matching_histogram.unwrap().get_histogram();
prop_assert!(
histogram.get_sample_count() >= 1,
"Histogram should have at least 1 sample, got {}",
histogram.get_sample_count()
);
let sum = histogram.get_sample_sum();
prop_assert!(
sum < 10.0,
"Duration sum should be less than 10 seconds, got {}",
sum
);
Ok(())
});
result?;
}
}
#[test]
fn test_metrics_layer_records_request() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let metrics = MetricsLayer::new();
let mut stack = LayerStack::new();
stack.push(Box::new(metrics.clone()));
let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
Box::pin(async {
http::Response::builder()
.status(StatusCode::OK)
.body(crate::response::Body::Full(http_body_util::Full::new(
Bytes::from("ok"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let request = create_test_request(Method::GET, "/test");
let response = stack.execute(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
let metric_families = metrics.registry().gather();
let requests_total = metric_families
.iter()
.find(|mf| mf.get_name() == "http_requests_total");
assert!(requests_total.is_some());
});
}
#[test]
fn test_metrics_layer_with_multiple_requests() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let metrics = MetricsLayer::new();
let mut stack = LayerStack::new();
stack.push(Box::new(metrics.clone()));
let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
Box::pin(async {
http::Response::builder()
.status(StatusCode::OK)
.body(crate::response::Body::Full(http_body_util::Full::new(
Bytes::from("ok"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
for _ in 0..5 {
let request = create_test_request(Method::GET, "/test");
let _ = stack.execute(request, handler.clone()).await;
}
let metric_families = metrics.registry().gather();
let requests_total = metric_families
.iter()
.find(|mf| mf.get_name() == "http_requests_total")
.unwrap();
let metrics_vec = requests_total.get_metric();
let matching_metric = metrics_vec.iter().find(|m| {
let labels = m.get_label();
labels
.iter()
.any(|l| l.get_name() == "method" && l.get_value() == "GET")
&& labels
.iter()
.any(|l| l.get_name() == "path" && l.get_value() == "/test")
&& labels
.iter()
.any(|l| l.get_name() == "status" && l.get_value() == "200")
});
assert!(matching_metric.is_some());
assert_eq!(matching_metric.unwrap().get_counter().get_value(), 5.0);
});
}
}