mockforge_recorder/
middleware.rs1use crate::recorder::Recorder;
4use axum::{
5 body::{Body, Bytes},
6 extract::{ConnectInfo, Request},
7 middleware::Next,
8 response::Response,
9};
10use http_body_util::BodyExt;
11use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Instant};
12use tracing::{debug, error};
13
14pub async fn recording_middleware(
16 recorder: Arc<Recorder>,
17 ConnectInfo(addr): ConnectInfo<SocketAddr>,
18 req: Request,
19 next: Next,
20) -> Response {
21 let trace_id = req
23 .headers()
24 .get("traceparent")
25 .and_then(|v| v.to_str().ok())
26 .and_then(extract_trace_id);
27
28 let span_id = req
29 .headers()
30 .get("traceparent")
31 .and_then(|v| v.to_str().ok())
32 .and_then(extract_span_id);
33
34 let method = req.method().to_string();
36 let uri = req.uri().clone();
37 let path = uri.path().to_string();
38 let query = uri.query().map(|q| q.to_string());
39
40 let headers: HashMap<String, String> = req
42 .headers()
43 .iter()
44 .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.as_str().to_string(), s.to_string())))
45 .collect();
46
47 let (parts, body) = req.into_parts();
49 let body_bytes = match body.collect().await {
50 Ok(collected) => collected.to_bytes(),
51 Err(e) => {
52 error!("Failed to read request body: {}", e);
53 Bytes::new()
54 }
55 };
56
57 let start = Instant::now();
59 let context = crate::models::RequestContext::new(
60 Some(&addr.ip().to_string()),
61 trace_id.as_deref(),
62 span_id.as_deref(),
63 );
64 let request_id = match recorder
65 .record_http_request(
66 &method,
67 &path,
68 query.as_deref(),
69 &headers,
70 if body_bytes.is_empty() {
71 None
72 } else {
73 Some(&body_bytes)
74 },
75 &context,
76 )
77 .await
78 {
79 Ok(id) => id,
80 Err(e) => {
81 error!("Failed to record request: {}", e);
82 uuid::Uuid::new_v4().to_string()
84 }
85 };
86
87 debug!("Recorded request: {} {} {}", request_id, method, path);
88
89 let req = Request::from_parts(parts, Body::from(body_bytes));
91
92 let response = next.run(req).await;
94
95 let (parts, body) = response.into_parts();
97 let status_code = parts.status.as_u16() as i32;
98
99 let response_headers: HashMap<String, String> = parts
101 .headers
102 .iter()
103 .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.as_str().to_string(), s.to_string())))
104 .collect();
105
106 let response_body_bytes = match body.collect().await {
108 Ok(collected) => collected.to_bytes(),
109 Err(e) => {
110 error!("Failed to read response body: {}", e);
111 Bytes::new()
112 }
113 };
114
115 let duration_ms = start.elapsed().as_millis() as i64;
117
118 if let Err(e) = recorder
120 .record_http_response(
121 &request_id,
122 status_code,
123 &response_headers,
124 if response_body_bytes.is_empty() {
125 None
126 } else {
127 Some(&response_body_bytes)
128 },
129 duration_ms,
130 )
131 .await
132 {
133 error!("Failed to record response: {}", e);
134 }
135
136 debug!(
137 "Recorded response: {} status={} duration={}ms",
138 request_id, status_code, duration_ms
139 );
140
141 Response::from_parts(parts, Body::from(response_body_bytes))
143}
144
145fn extract_trace_id(traceparent: &str) -> Option<String> {
148 let parts: Vec<&str> = traceparent.split('-').collect();
149 if parts.len() >= 2 {
150 Some(parts[1].to_string())
151 } else {
152 None
153 }
154}
155
156fn extract_span_id(traceparent: &str) -> Option<String> {
158 let parts: Vec<&str> = traceparent.split('-').collect();
159 if parts.len() >= 3 {
160 Some(parts[2].to_string())
161 } else {
162 None
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_extract_trace_id() {
172 let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
173 let trace_id = extract_trace_id(traceparent);
174 assert_eq!(trace_id, Some("4bf92f3577b34da6a3ce929d0e0e4736".to_string()));
175 }
176
177 #[test]
178 fn test_extract_span_id() {
179 let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
180 let span_id = extract_span_id(traceparent);
181 assert_eq!(span_id, Some("00f067aa0ba902b7".to_string()));
182 }
183
184 #[test]
185 fn test_invalid_traceparent() {
186 let trace_id = extract_trace_id("invalid");
187 assert_eq!(trace_id, None);
188 }
189}