mockforge_http/
http_tracing_middleware.rs1use axum::{
6    extract::{MatchedPath, Request},
7    middleware::Next,
8    response::Response,
9};
10use mockforge_tracing::{
11    create_request_span, extract_from_axum_headers, inject_into_axum_headers, record_error,
12    record_success, Protocol,
13};
14use opentelemetry::{trace::TraceContextExt, KeyValue};
15use std::time::Instant;
16use tracing::debug;
17
18pub async fn http_tracing_middleware(
27    matched_path: Option<MatchedPath>,
28    req: Request,
29    next: Next,
30) -> Response {
31    let start_time = Instant::now();
32    let method = req.method().to_string();
33    let uri_path = req.uri().path().to_string();
34    let path = matched_path
35        .as_ref()
36        .map(|mp| mp.as_str().to_string())
37        .unwrap_or(uri_path.clone());
38
39    let parent_ctx = extract_from_axum_headers(req.headers());
41
42    let mut span = create_request_span(
44        Protocol::Http,
45        &format!("{} {}", method, path),
46        vec![
47            KeyValue::new("http.method", method.clone()),
48            KeyValue::new("http.route", path.clone()),
49            KeyValue::new("http.url", uri_path.clone()),
50        ],
51    );
52
53    debug!(
54        method = %method,
55        path = %path,
56        "Created trace span for HTTP request"
57    );
58
59    let mut response = next.run(req).await;
61
62    let duration = start_time.elapsed();
64    let status_code = response.status().as_u16();
65
66    let attributes = vec![
68        KeyValue::new("http.status_code", status_code as i64),
69        KeyValue::new("http.duration_ms", duration.as_millis() as i64),
70    ];
71
72    if status_code >= 400 {
74        record_error(
75            &mut span,
76            &format!(
77                "HTTP {}: {}",
78                status_code,
79                response.status().canonical_reason().unwrap_or("Unknown")
80            ),
81        );
82    } else {
83        record_success(&mut span, attributes);
84    }
85
86    let ctx = parent_ctx.with_span(span);
88
89    inject_into_axum_headers(&ctx, response.headers_mut());
91
92    debug!(
93        method = %method,
94        path = %path,
95        status = status_code,
96        duration_ms = duration.as_millis(),
97        "Completed trace span for HTTP request"
98    );
99
100    response
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use axum::{
107        body::Body,
108        http::{Request, StatusCode},
109        middleware,
110        response::IntoResponse,
111        Router,
112    };
113    use tower::ServiceExt;
114
115    async fn test_handler() -> impl IntoResponse {
116        (StatusCode::OK, "test response")
117    }
118
119    #[tokio::test]
120    #[ignore] async fn test_tracing_middleware_creates_span() {
122        use opentelemetry::global;
124        use opentelemetry_sdk::propagation::TraceContextPropagator;
125
126        global::set_text_map_propagator(TraceContextPropagator::new());
127
128        let app = Router::new()
129            .route("/test", axum::routing::get(test_handler))
130            .layer(middleware::from_fn(http_tracing_middleware));
131
132        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
133
134        let response = app.oneshot(request).await.unwrap();
135        assert_eq!(response.status(), StatusCode::OK);
136
137        assert!(response.headers().contains_key("traceparent"));
139    }
140
141    #[tokio::test]
142    #[ignore] async fn test_tracing_middleware_propagates_context() {
144        use opentelemetry::global;
145        use opentelemetry_sdk::propagation::TraceContextPropagator;
146
147        global::set_text_map_propagator(TraceContextPropagator::new());
148
149        let app = Router::new()
150            .route("/test", axum::routing::get(test_handler))
151            .layer(middleware::from_fn(http_tracing_middleware));
152
153        let request = Request::builder()
155            .uri("/test")
156            .header("traceparent", "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01")
157            .body(Body::empty())
158            .unwrap();
159
160        let response = app.oneshot(request).await.unwrap();
161        assert_eq!(response.status(), StatusCode::OK);
162
163        let traceparent = response.headers().get("traceparent").and_then(|v| v.to_str().ok());
165
166        assert!(traceparent.is_some());
167        assert!(traceparent.unwrap().contains("0af7651916cd43dd8448eb211c80319c"));
169    }
170}