mockforge_recorder/
recorder.rs

1//! Core recording functionality
2
3use crate::{database::RecorderDatabase, models::*, scrubbing::*, Result};
4use chrono::Utc;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tracing::debug;
8use uuid::Uuid;
9
10/// Recorder for capturing API requests and responses
11#[derive(Clone)]
12pub struct Recorder {
13    db: Arc<RecorderDatabase>,
14    enabled: Arc<RwLock<bool>>,
15    scrubber: Arc<Scrubber>,
16    filter: Arc<CaptureFilter>,
17}
18
19impl Recorder {
20    /// Create a new recorder
21    pub fn new(db: RecorderDatabase) -> Self {
22        Self {
23            db: Arc::new(db),
24            enabled: Arc::new(RwLock::new(true)),
25            scrubber: Scrubber::global(),
26            filter: CaptureFilter::global(),
27        }
28    }
29
30    /// Create a new recorder with custom scrubber and filter
31    pub fn with_scrubbing(db: RecorderDatabase, scrubber: Scrubber, filter: CaptureFilter) -> Self {
32        Self {
33            db: Arc::new(db),
34            enabled: Arc::new(RwLock::new(true)),
35            scrubber: Arc::new(scrubber),
36            filter: Arc::new(filter),
37        }
38    }
39
40    /// Get the scrubber
41    pub fn scrubber(&self) -> &Arc<Scrubber> {
42        &self.scrubber
43    }
44
45    /// Get the filter
46    pub fn filter(&self) -> &Arc<CaptureFilter> {
47        &self.filter
48    }
49
50    /// Check if recording is enabled
51    pub async fn is_enabled(&self) -> bool {
52        *self.enabled.read().await
53    }
54
55    /// Enable recording
56    pub async fn enable(&self) {
57        *self.enabled.write().await = true;
58        debug!("Recording enabled");
59    }
60
61    /// Disable recording
62    pub async fn disable(&self) {
63        *self.enabled.write().await = false;
64        debug!("Recording disabled");
65    }
66
67    /// Record a request
68    pub async fn record_request(&self, mut request: RecordedRequest) -> Result<String> {
69        if !self.is_enabled().await {
70            return Ok(request.id);
71        }
72
73        // Apply scrubbing
74        self.scrubber.scrub_request(&mut request);
75
76        let request_id = request.id.clone();
77        self.db.insert_request(&request).await?;
78        Ok(request_id)
79    }
80
81    /// Record a response
82    pub async fn record_response(&self, mut response: RecordedResponse) -> Result<()> {
83        if !self.is_enabled().await {
84            return Ok(());
85        }
86
87        // Apply scrubbing
88        self.scrubber.scrub_response(&mut response);
89
90        self.db.insert_response(&response).await?;
91        Ok(())
92    }
93
94    /// Record an HTTP request
95    pub async fn record_http_request(
96        &self,
97        method: &str,
98        path: &str,
99        query_params: Option<&str>,
100        headers: &std::collections::HashMap<String, String>,
101        body: Option<&[u8]>,
102        context: &crate::models::RequestContext,
103    ) -> Result<String> {
104        let request_id = Uuid::new_v4().to_string();
105
106        let (body_str, body_encoding) = encode_body(body);
107
108        let request = RecordedRequest {
109            id: request_id.clone(),
110            protocol: Protocol::Http,
111            timestamp: Utc::now(),
112            method: method.to_string(),
113            path: path.to_string(),
114            query_params: query_params.map(|q| q.to_string()),
115            headers: serde_json::to_string(&headers)?,
116            body: body_str,
117            body_encoding,
118            client_ip: context.client_ip.clone(),
119            trace_id: context.trace_id.clone(),
120            span_id: context.span_id.clone(),
121            duration_ms: None,
122            status_code: None,
123            tags: None,
124        };
125
126        self.record_request(request).await
127    }
128
129    /// Record an HTTP response
130    pub async fn record_http_response(
131        &self,
132        request_id: &str,
133        status_code: i32,
134        headers: &std::collections::HashMap<String, String>,
135        body: Option<&[u8]>,
136        duration_ms: i64,
137    ) -> Result<()> {
138        // Check filter with status code now that we have it
139        // Get the request to check path and method
140        if let Some(request) = self.db.get_request(request_id).await? {
141            let should_capture = self.filter.should_capture(
142                &request.method,
143                &request.path,
144                Some(status_code as u16),
145            );
146
147            if !should_capture {
148                // Delete the request since it doesn't match the filter
149                // (We don't have a delete method, so we just skip the response)
150                debug!("Skipping response recording due to filter");
151                return Ok(());
152            }
153        }
154
155        let (body_str, body_encoding) = encode_body(body);
156        let size_bytes = body.map(|b| b.len()).unwrap_or(0) as i64;
157
158        let response = RecordedResponse {
159            request_id: request_id.to_string(),
160            status_code,
161            headers: serde_json::to_string(&headers)?,
162            body: body_str,
163            body_encoding,
164            size_bytes,
165            timestamp: Utc::now(),
166        };
167
168        self.record_response(response).await?;
169
170        // Update request with duration and status
171        self.update_request_completion(request_id, status_code, duration_ms).await?;
172
173        Ok(())
174    }
175
176    /// Update request with completion data
177    async fn update_request_completion(
178        &self,
179        _request_id: &str,
180        _status_code: i32,
181        _duration_ms: i64,
182    ) -> Result<()> {
183        // Note: This would need to access the pool through a public method
184        // For now, we'll skip this optimization and rely on separate inserts
185        Ok(())
186    }
187
188    /// Get database reference
189    pub fn database(&self) -> &Arc<RecorderDatabase> {
190        &self.db
191    }
192}
193
194/// Encode body for storage (binary data as base64)
195fn encode_body(body: Option<&[u8]>) -> (Option<String>, String) {
196    match body {
197        None => (None, "utf8".to_string()),
198        Some(bytes) => {
199            // Try to parse as UTF-8 first
200            if let Ok(text) = std::str::from_utf8(bytes) {
201                (Some(text.to_string()), "utf8".to_string())
202            } else {
203                // Binary data, encode as base64
204                let encoded =
205                    base64::Engine::encode(&base64::engine::general_purpose::STANDARD, bytes);
206                (Some(encoded), "base64".to_string())
207            }
208        }
209    }
210}
211
212/// Decode body from storage
213pub fn decode_body(body: Option<&str>, encoding: &str) -> Option<Vec<u8>> {
214    body.map(|b| {
215        if encoding == "base64" {
216            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, b)
217                .unwrap_or_else(|_| b.as_bytes().to_vec())
218        } else {
219            b.as_bytes().to_vec()
220        }
221    })
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::database::RecorderDatabase;
228
229    #[tokio::test]
230    async fn test_recorder_enable_disable() {
231        let db = RecorderDatabase::new_in_memory().await.unwrap();
232        let recorder = Recorder::new(db);
233
234        assert!(recorder.is_enabled().await);
235
236        recorder.disable().await;
237        assert!(!recorder.is_enabled().await);
238
239        recorder.enable().await;
240        assert!(recorder.is_enabled().await);
241    }
242
243    #[tokio::test]
244    async fn test_record_http_exchange() {
245        let db = RecorderDatabase::new_in_memory().await.unwrap();
246        let recorder = Recorder::new(db);
247
248        let headers = std::collections::HashMap::from([(
249            "content-type".to_string(),
250            "application/json".to_string(),
251        )]);
252
253        let context = RequestContext::new(Some("127.0.0.1"), None, None);
254        let request_id = recorder
255            .record_http_request("GET", "/api/test", Some("foo=bar"), &headers, None, &context)
256            .await
257            .unwrap();
258
259        let body = b"{\"result\":\"ok\"}";
260        recorder
261            .record_http_response(&request_id, 200, &headers, Some(body), 42)
262            .await
263            .unwrap();
264
265        // Verify it was recorded
266        let exchange = recorder.database().get_exchange(&request_id).await.unwrap();
267        assert!(exchange.is_some());
268
269        let exchange = exchange.unwrap();
270        assert_eq!(exchange.request.path, "/api/test");
271        assert_eq!(exchange.response.unwrap().status_code, 200);
272    }
273
274    #[test]
275    fn test_body_encoding() {
276        // UTF-8 text
277        let text = b"Hello, World!";
278        let (encoded, encoding) = encode_body(Some(text));
279        assert_eq!(encoding, "utf8");
280        assert_eq!(encoded, Some("Hello, World!".to_string()));
281
282        // Binary data
283        let binary = &[0xFF, 0xFE, 0xFD];
284        let (encoded, encoding) = encode_body(Some(binary));
285        assert_eq!(encoding, "base64");
286        assert!(encoded.is_some());
287
288        // Decode back
289        let decoded = decode_body(encoded.as_deref(), &encoding);
290        assert_eq!(decoded, Some(binary.to_vec()));
291    }
292}