1use crate::{database::RecorderDatabase, models::*, scrubbing::*, Result};
4use chrono::Utc;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tracing::debug;
8use uuid::Uuid;
9
10#[derive(Clone)]
12pub struct Recorder {
13 db: Arc<RecorderDatabase>,
14 enabled: Arc<RwLock<bool>>,
15 scrubber: Arc<Scrubber>,
16 filter: Arc<CaptureFilter>,
17}
18
19impl Recorder {
20 pub fn new(db: RecorderDatabase) -> Self {
22 Self {
23 db: Arc::new(db),
24 enabled: Arc::new(RwLock::new(true)),
25 scrubber: Scrubber::global(),
26 filter: CaptureFilter::global(),
27 }
28 }
29
30 pub fn with_scrubbing(db: RecorderDatabase, scrubber: Scrubber, filter: CaptureFilter) -> Self {
32 Self {
33 db: Arc::new(db),
34 enabled: Arc::new(RwLock::new(true)),
35 scrubber: Arc::new(scrubber),
36 filter: Arc::new(filter),
37 }
38 }
39
40 pub fn scrubber(&self) -> &Arc<Scrubber> {
42 &self.scrubber
43 }
44
45 pub fn filter(&self) -> &Arc<CaptureFilter> {
47 &self.filter
48 }
49
50 pub async fn is_enabled(&self) -> bool {
52 *self.enabled.read().await
53 }
54
55 pub async fn enable(&self) {
57 *self.enabled.write().await = true;
58 debug!("Recording enabled");
59 }
60
61 pub async fn disable(&self) {
63 *self.enabled.write().await = false;
64 debug!("Recording disabled");
65 }
66
67 pub async fn record_request(&self, mut request: RecordedRequest) -> Result<String> {
69 if !self.is_enabled().await {
70 return Ok(request.id);
71 }
72
73 self.scrubber.scrub_request(&mut request);
75
76 let request_id = request.id.clone();
77 self.db.insert_request(&request).await?;
78 Ok(request_id)
79 }
80
81 pub async fn record_response(&self, mut response: RecordedResponse) -> Result<()> {
83 if !self.is_enabled().await {
84 return Ok(());
85 }
86
87 self.scrubber.scrub_response(&mut response);
89
90 self.db.insert_response(&response).await?;
91 Ok(())
92 }
93
94 pub async fn record_http_request(
96 &self,
97 method: &str,
98 path: &str,
99 query_params: Option<&str>,
100 headers: &std::collections::HashMap<String, String>,
101 body: Option<&[u8]>,
102 context: &crate::models::RequestContext,
103 ) -> Result<String> {
104 let request_id = Uuid::new_v4().to_string();
105
106 let (body_str, body_encoding) = encode_body(body);
107
108 let request = RecordedRequest {
109 id: request_id.clone(),
110 protocol: Protocol::Http,
111 timestamp: Utc::now(),
112 method: method.to_string(),
113 path: path.to_string(),
114 query_params: query_params.map(|q| q.to_string()),
115 headers: serde_json::to_string(&headers)?,
116 body: body_str,
117 body_encoding,
118 client_ip: context.client_ip.clone(),
119 trace_id: context.trace_id.clone(),
120 span_id: context.span_id.clone(),
121 duration_ms: None,
122 status_code: None,
123 tags: None,
124 };
125
126 self.record_request(request).await
127 }
128
129 pub async fn record_http_response(
131 &self,
132 request_id: &str,
133 status_code: i32,
134 headers: &std::collections::HashMap<String, String>,
135 body: Option<&[u8]>,
136 duration_ms: i64,
137 ) -> Result<()> {
138 if let Some(request) = self.db.get_request(request_id).await? {
141 let should_capture = self.filter.should_capture(
142 &request.method,
143 &request.path,
144 Some(status_code as u16),
145 );
146
147 if !should_capture {
148 debug!("Skipping response recording due to filter");
151 return Ok(());
152 }
153 }
154
155 let (body_str, body_encoding) = encode_body(body);
156 let size_bytes = body.map(|b| b.len()).unwrap_or(0) as i64;
157
158 let response = RecordedResponse {
159 request_id: request_id.to_string(),
160 status_code,
161 headers: serde_json::to_string(&headers)?,
162 body: body_str,
163 body_encoding,
164 size_bytes,
165 timestamp: Utc::now(),
166 };
167
168 self.record_response(response).await?;
169
170 self.update_request_completion(request_id, status_code, duration_ms).await?;
172
173 Ok(())
174 }
175
176 async fn update_request_completion(
178 &self,
179 _request_id: &str,
180 _status_code: i32,
181 _duration_ms: i64,
182 ) -> Result<()> {
183 Ok(())
186 }
187
188 pub fn database(&self) -> &Arc<RecorderDatabase> {
190 &self.db
191 }
192}
193
194fn encode_body(body: Option<&[u8]>) -> (Option<String>, String) {
196 match body {
197 None => (None, "utf8".to_string()),
198 Some(bytes) => {
199 if let Ok(text) = std::str::from_utf8(bytes) {
201 (Some(text.to_string()), "utf8".to_string())
202 } else {
203 let encoded =
205 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, bytes);
206 (Some(encoded), "base64".to_string())
207 }
208 }
209 }
210}
211
212pub fn decode_body(body: Option<&str>, encoding: &str) -> Option<Vec<u8>> {
214 body.map(|b| {
215 if encoding == "base64" {
216 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, b)
217 .unwrap_or_else(|_| b.as_bytes().to_vec())
218 } else {
219 b.as_bytes().to_vec()
220 }
221 })
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::database::RecorderDatabase;
228
229 #[tokio::test]
230 async fn test_recorder_enable_disable() {
231 let db = RecorderDatabase::new_in_memory().await.unwrap();
232 let recorder = Recorder::new(db);
233
234 assert!(recorder.is_enabled().await);
235
236 recorder.disable().await;
237 assert!(!recorder.is_enabled().await);
238
239 recorder.enable().await;
240 assert!(recorder.is_enabled().await);
241 }
242
243 #[tokio::test]
244 async fn test_record_http_exchange() {
245 let db = RecorderDatabase::new_in_memory().await.unwrap();
246 let recorder = Recorder::new(db);
247
248 let headers = std::collections::HashMap::from([(
249 "content-type".to_string(),
250 "application/json".to_string(),
251 )]);
252
253 let context = RequestContext::new(Some("127.0.0.1"), None, None);
254 let request_id = recorder
255 .record_http_request("GET", "/api/test", Some("foo=bar"), &headers, None, &context)
256 .await
257 .unwrap();
258
259 let body = b"{\"result\":\"ok\"}";
260 recorder
261 .record_http_response(&request_id, 200, &headers, Some(body), 42)
262 .await
263 .unwrap();
264
265 let exchange = recorder.database().get_exchange(&request_id).await.unwrap();
267 assert!(exchange.is_some());
268
269 let exchange = exchange.unwrap();
270 assert_eq!(exchange.request.path, "/api/test");
271 assert_eq!(exchange.response.unwrap().status_code, 200);
272 }
273
274 #[test]
275 fn test_body_encoding() {
276 let text = b"Hello, World!";
278 let (encoded, encoding) = encode_body(Some(text));
279 assert_eq!(encoding, "utf8");
280 assert_eq!(encoded, Some("Hello, World!".to_string()));
281
282 let binary = &[0xFF, 0xFE, 0xFD];
284 let (encoded, encoding) = encode_body(Some(binary));
285 assert_eq!(encoding, "base64");
286 assert!(encoded.is_some());
287
288 let decoded = decode_body(encoded.as_deref(), &encoding);
290 assert_eq!(decoded, Some(binary.to_vec()));
291 }
292}