use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use axum::body::Body;
use axum::http::{Request, Response};
use tower::{Layer, Service};
use super::metrics::{HttpLabels, HttpLabelsWithStatus, HttpMethod, Metrics};
struct InFlightGuard {
metrics: Arc<Metrics>,
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
self.metrics.http_requests_in_flight.dec();
}
}
#[derive(Clone)]
pub struct MetricsLayer {
metrics: Arc<Metrics>,
}
impl MetricsLayer {
pub fn new(metrics: Arc<Metrics>) -> Self {
Self { metrics }
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
MetricsService {
inner,
metrics: self.metrics.clone(),
}
}
}
#[derive(Clone)]
pub struct MetricsService<S> {
inner: S,
metrics: Arc<Metrics>,
}
impl<S, ResBody> Service<Request<Body>> for MetricsService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
ResBody: Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let method = HttpMethod::from(request.method());
let endpoint = normalize_endpoint(request.uri().path());
let metrics = self.metrics.clone();
metrics.http_requests_in_flight.inc();
let _guard = InFlightGuard {
metrics: metrics.clone(),
};
let start = Instant::now();
let future = self.inner.call(request);
Box::pin(async move {
let _guard = _guard; let response = future.await?;
let status = response.status().as_u16();
let duration = start.elapsed().as_secs_f64();
metrics
.http_requests_total
.get_or_create(&HttpLabelsWithStatus {
method: method.clone(),
endpoint: endpoint.clone(),
status,
})
.inc();
metrics
.http_request_duration_seconds
.get_or_create(&HttpLabels { method, endpoint })
.observe(duration);
Ok(response)
})
}
}
fn normalize_endpoint(path: &str) -> String {
path.to_string()
}
#[derive(Clone)]
pub struct TracingLayer;
impl TracingLayer {
pub fn new() -> Self {
Self
}
}
impl Default for TracingLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for TracingLayer {
type Service = TracingService<S>;
fn layer(&self, inner: S) -> Self::Service {
TracingService { inner }
}
}
#[derive(Clone)]
pub struct TracingService<S> {
inner: S,
}
impl<S, ResBody> Service<Request<Body>> for TracingService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
ResBody: Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let http_method = request.method().clone();
let uri = request.uri().clone();
let user_agent = request
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
let content_length = request
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
tracing::debug!(
method = %http_method,
uri = %uri,
user_agent = %user_agent,
content_length = %content_length,
"HTTP request received"
);
let start_time = Instant::now();
let future = self.inner.call(request);
Box::pin(async move {
let response = future.await?;
let status = response.status().as_u16();
let elapsed = start_time.elapsed();
tracing::debug!(
method = %http_method,
uri = %uri,
status = %status,
duration_ms = %elapsed.as_millis(),
"HTTP request completed"
);
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Method;
use tower::service_fn;
fn test_request(uri: &str) -> Request<Body> {
Request::builder()
.method(Method::GET)
.uri(uri)
.body(Body::empty())
.unwrap()
}
#[tokio::test]
async fn should_decrement_in_flight_gauge_after_successful_request() {
let metrics = Arc::new(Metrics::new());
let mut service = MetricsService {
inner: service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder().status(200).body(Body::empty()).unwrap(),
)
}),
metrics: metrics.clone(),
};
assert_eq!(metrics.http_requests_in_flight.get(), 0);
let future = service.call(test_request("/test"));
assert_eq!(metrics.http_requests_in_flight.get(), 1);
let response = future.await.unwrap();
assert_eq!(response.status().as_u16(), 200);
assert_eq!(metrics.http_requests_in_flight.get(), 0);
}
#[tokio::test]
async fn should_decrement_in_flight_gauge_when_inner_service_errors() {
let metrics = Arc::new(Metrics::new());
let mut service = MetricsService {
inner: service_fn(|_req: Request<Body>| async {
Err::<Response<Body>, _>(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"connection reset",
))
}),
metrics: metrics.clone(),
};
assert_eq!(metrics.http_requests_in_flight.get(), 0);
let future = service.call(test_request("/test"));
assert_eq!(metrics.http_requests_in_flight.get(), 1);
let result = future.await;
assert!(result.is_err());
assert_eq!(metrics.http_requests_in_flight.get(), 0);
}
#[tokio::test]
async fn should_decrement_in_flight_gauge_when_future_is_cancelled() {
let metrics = Arc::new(Metrics::new());
let mut service = MetricsService {
inner: service_fn(|_req: Request<Body>| {
std::future::pending::<Result<Response<Body>, std::convert::Infallible>>()
}),
metrics: metrics.clone(),
};
assert_eq!(metrics.http_requests_in_flight.get(), 0);
let future = service.call(test_request("/api/v1/log/scan"));
assert_eq!(metrics.http_requests_in_flight.get(), 1);
drop(future);
assert_eq!(metrics.http_requests_in_flight.get(), 0);
}
#[test]
fn should_preserve_api_endpoints() {
let path = "/api/v1/log/scan";
let normalized = normalize_endpoint(path);
assert_eq!(normalized, "/api/v1/log/scan");
}
#[test]
fn should_preserve_metrics_endpoint() {
let path = "/metrics";
let normalized = normalize_endpoint(path);
assert_eq!(normalized, "/metrics");
}
#[tokio::test]
async fn should_log_request_and_response_with_tracing_middleware() {
use tower::service_fn;
let test_service = service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder().status(200).body(Body::empty()).unwrap(),
)
});
let mut service = TracingService {
inner: test_service,
};
let request = Request::builder()
.method(Method::GET)
.uri("/api/v1/log/scan?key=test")
.header("user-agent", "test-client")
.header("content-length", "0")
.body(Body::empty())
.unwrap();
let response = service.call(request).await.unwrap();
assert_eq!(response.status().as_u16(), 200);
}
}