mockforge_recorder/
middleware.rs

1//! Recording middleware for HTTP requests
2
3use crate::recorder::Recorder;
4use axum::{
5    body::{Body, Bytes},
6    extract::{ConnectInfo, Request},
7    middleware::Next,
8    response::Response,
9};
10use http_body_util::BodyExt;
11use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Instant};
12use tracing::{debug, error};
13
14/// Middleware layer for recording HTTP requests and responses
15pub async fn recording_middleware(
16    recorder: Arc<Recorder>,
17    ConnectInfo(addr): ConnectInfo<SocketAddr>,
18    req: Request,
19    next: Next,
20) -> Response {
21    // Extract trace context if available
22    let trace_id = req
23        .headers()
24        .get("traceparent")
25        .and_then(|v| v.to_str().ok())
26        .and_then(extract_trace_id);
27
28    let span_id = req
29        .headers()
30        .get("traceparent")
31        .and_then(|v| v.to_str().ok())
32        .and_then(extract_span_id);
33
34    // Extract request details
35    let method = req.method().to_string();
36    let uri = req.uri().clone();
37    let path = uri.path().to_string();
38    let query = uri.query().map(|q| q.to_string());
39
40    // Clone headers
41    let headers: HashMap<String, String> = req
42        .headers()
43        .iter()
44        .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.as_str().to_string(), s.to_string())))
45        .collect();
46
47    // Extract body (need to consume and recreate the request)
48    let (parts, body) = req.into_parts();
49    let body_bytes = match body.collect().await {
50        Ok(collected) => collected.to_bytes(),
51        Err(e) => {
52            error!("Failed to read request body: {}", e);
53            Bytes::new()
54        }
55    };
56
57    // Record the request
58    let start = Instant::now();
59    let context = crate::models::RequestContext::new(
60        Some(&addr.ip().to_string()),
61        trace_id.as_deref(),
62        span_id.as_deref(),
63    );
64    let request_id = match recorder
65        .record_http_request(
66            &method,
67            &path,
68            query.as_deref(),
69            &headers,
70            if body_bytes.is_empty() {
71                None
72            } else {
73                Some(&body_bytes)
74            },
75            &context,
76        )
77        .await
78    {
79        Ok(id) => id,
80        Err(e) => {
81            error!("Failed to record request: {}", e);
82            // Continue processing even if recording fails
83            uuid::Uuid::new_v4().to_string()
84        }
85    };
86
87    debug!("Recorded request: {} {} {}", request_id, method, path);
88
89    // Reconstruct request with body
90    let req = Request::from_parts(parts, Body::from(body_bytes));
91
92    // Pass to next handler
93    let response = next.run(req).await;
94
95    // Extract response details
96    let (parts, body) = response.into_parts();
97    let status_code = parts.status.as_u16() as i32;
98
99    // Extract response headers
100    let response_headers: HashMap<String, String> = parts
101        .headers
102        .iter()
103        .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.as_str().to_string(), s.to_string())))
104        .collect();
105
106    // Extract response body
107    let response_body_bytes = match body.collect().await {
108        Ok(collected) => collected.to_bytes(),
109        Err(e) => {
110            error!("Failed to read response body: {}", e);
111            Bytes::new()
112        }
113    };
114
115    // Calculate duration
116    let duration_ms = start.elapsed().as_millis() as i64;
117
118    // Record the response
119    if let Err(e) = recorder
120        .record_http_response(
121            &request_id,
122            status_code,
123            &response_headers,
124            if response_body_bytes.is_empty() {
125                None
126            } else {
127                Some(&response_body_bytes)
128            },
129            duration_ms,
130        )
131        .await
132    {
133        error!("Failed to record response: {}", e);
134    }
135
136    debug!(
137        "Recorded response: {} status={} duration={}ms",
138        request_id, status_code, duration_ms
139    );
140
141    // Reconstruct response with body
142    Response::from_parts(parts, Body::from(response_body_bytes))
143}
144
145/// Extract trace ID from W3C traceparent header
146/// Format: 00-{trace_id}-{parent_id}-{flags}
147fn extract_trace_id(traceparent: &str) -> Option<String> {
148    let parts: Vec<&str> = traceparent.split('-').collect();
149    if parts.len() >= 2 {
150        Some(parts[1].to_string())
151    } else {
152        None
153    }
154}
155
156/// Extract span ID from W3C traceparent header
157fn extract_span_id(traceparent: &str) -> Option<String> {
158    let parts: Vec<&str> = traceparent.split('-').collect();
159    if parts.len() >= 3 {
160        Some(parts[2].to_string())
161    } else {
162        None
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_extract_trace_id() {
172        let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
173        let trace_id = extract_trace_id(traceparent);
174        assert_eq!(trace_id, Some("4bf92f3577b34da6a3ce929d0e0e4736".to_string()));
175    }
176
177    #[test]
178    fn test_extract_span_id() {
179        let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
180        let span_id = extract_span_id(traceparent);
181        assert_eq!(span_id, Some("00f067aa0ba902b7".to_string()));
182    }
183
184    #[test]
185    fn test_invalid_traceparent() {
186        let trace_id = extract_trace_id("invalid");
187        assert_eq!(trace_id, None);
188    }
189}