mockforge_http/
http_tracing_middleware.rs1use axum::{
6 extract::{MatchedPath, Request},
7 middleware::Next,
8 response::Response,
9};
10use mockforge_tracing::{
11 create_request_span, extract_from_axum_headers, inject_into_axum_headers, record_error,
12 record_success, Protocol,
13};
14use opentelemetry::{trace::TraceContextExt, KeyValue};
15use std::time::Instant;
16use tracing::debug;
17
18pub async fn http_tracing_middleware(
27 matched_path: Option<MatchedPath>,
28 req: Request,
29 next: Next,
30) -> Response {
31 let start_time = Instant::now();
32 let method = req.method().to_string();
33 let uri_path = req.uri().path().to_string();
34 let path = matched_path
35 .as_ref()
36 .map(|mp| mp.as_str().to_string())
37 .unwrap_or(uri_path.clone());
38
39 let parent_ctx = extract_from_axum_headers(req.headers());
41
42 let mut span = create_request_span(
44 Protocol::Http,
45 &format!("{} {}", method, path),
46 vec![
47 KeyValue::new("http.method", method.clone()),
48 KeyValue::new("http.route", path.clone()),
49 KeyValue::new("http.url", uri_path.clone()),
50 ],
51 );
52
53 debug!(
54 method = %method,
55 path = %path,
56 "Created trace span for HTTP request"
57 );
58
59 let mut response = next.run(req).await;
61
62 let duration = start_time.elapsed();
64 let status_code = response.status().as_u16();
65
66 let attributes = vec![
68 KeyValue::new("http.status_code", status_code as i64),
69 KeyValue::new("http.duration_ms", duration.as_millis() as i64),
70 ];
71
72 if status_code >= 400 {
74 record_error(
75 &mut span,
76 &format!(
77 "HTTP {}: {}",
78 status_code,
79 response.status().canonical_reason().unwrap_or("Unknown")
80 ),
81 );
82 } else {
83 record_success(&mut span, attributes);
84 }
85
86 let ctx = parent_ctx.with_span(span);
88
89 inject_into_axum_headers(&ctx, response.headers_mut());
91
92 debug!(
93 method = %method,
94 path = %path,
95 status = status_code,
96 duration_ms = duration.as_millis(),
97 "Completed trace span for HTTP request"
98 );
99
100 response
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use axum::{
107 body::Body,
108 http::{Request, StatusCode},
109 middleware,
110 response::IntoResponse,
111 Router,
112 };
113 use tower::ServiceExt;
114
115 async fn test_handler() -> impl IntoResponse {
116 (StatusCode::OK, "test response")
117 }
118
119 #[tokio::test]
120 #[ignore] async fn test_tracing_middleware_creates_span() {
122 use opentelemetry::global;
124 use opentelemetry_sdk::propagation::TraceContextPropagator;
125
126 global::set_text_map_propagator(TraceContextPropagator::new());
127
128 let app = Router::new()
129 .route("/test", axum::routing::get(test_handler))
130 .layer(middleware::from_fn(http_tracing_middleware));
131
132 let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
133
134 let response = app.oneshot(request).await.unwrap();
135 assert_eq!(response.status(), StatusCode::OK);
136
137 assert!(response.headers().contains_key("traceparent"));
139 }
140
141 #[tokio::test]
142 #[ignore] async fn test_tracing_middleware_propagates_context() {
144 use opentelemetry::global;
145 use opentelemetry_sdk::propagation::TraceContextPropagator;
146
147 global::set_text_map_propagator(TraceContextPropagator::new());
148
149 let app = Router::new()
150 .route("/test", axum::routing::get(test_handler))
151 .layer(middleware::from_fn(http_tracing_middleware));
152
153 let request = Request::builder()
155 .uri("/test")
156 .header("traceparent", "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01")
157 .body(Body::empty())
158 .unwrap();
159
160 let response = app.oneshot(request).await.unwrap();
161 assert_eq!(response.status(), StatusCode::OK);
162
163 let traceparent = response.headers().get("traceparent").and_then(|v| v.to_str().ok());
165
166 assert!(traceparent.is_some());
167 assert!(traceparent.unwrap().contains("0af7651916cd43dd8448eb211c80319c"));
169 }
170}