mockforge_recorder/
database.rs

1//! SQLite database for storing recorded requests and responses
2
3use crate::{models::*, Result};
4use sqlx::{sqlite::SqlitePool, Pool, Sqlite};
5use std::{collections::HashMap, path::Path};
6use tracing::{debug, info};
7
8/// SQLite database for recorder
9#[derive(Clone)]
10pub struct RecorderDatabase {
11    pool: Pool<Sqlite>,
12}
13
14impl RecorderDatabase {
15    /// Create a new database connection
16    pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
17        let db_url = format!("sqlite:{}?mode=rwc", path.as_ref().display());
18        let pool = SqlitePool::connect(&db_url).await?;
19
20        let db = Self { pool };
21        db.initialize_schema().await?;
22
23        info!("Recorder database initialized at {:?}", path.as_ref());
24        Ok(db)
25    }
26
27    /// Create an in-memory database (for testing)
28    pub async fn new_in_memory() -> Result<Self> {
29        let pool = SqlitePool::connect("sqlite::memory:").await?;
30
31        let db = Self { pool };
32        db.initialize_schema().await?;
33
34        debug!("In-memory recorder database initialized");
35        Ok(db)
36    }
37
38    /// Initialize database schema
39    async fn initialize_schema(&self) -> Result<()> {
40        // Create requests table
41        sqlx::query(
42            r#"
43            CREATE TABLE IF NOT EXISTS requests (
44                id TEXT PRIMARY KEY,
45                protocol TEXT NOT NULL,
46                timestamp TEXT NOT NULL,
47                method TEXT NOT NULL,
48                path TEXT NOT NULL,
49                query_params TEXT,
50                headers TEXT NOT NULL,
51                body TEXT,
52                body_encoding TEXT NOT NULL DEFAULT 'utf8',
53                client_ip TEXT,
54                trace_id TEXT,
55                span_id TEXT,
56                duration_ms INTEGER,
57                status_code INTEGER,
58                tags TEXT,
59                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
60            )
61            "#,
62        )
63        .execute(&self.pool)
64        .await?;
65
66        // Create responses table
67        sqlx::query(
68            r#"
69            CREATE TABLE IF NOT EXISTS responses (
70                request_id TEXT PRIMARY KEY,
71                status_code INTEGER NOT NULL,
72                headers TEXT NOT NULL,
73                body TEXT,
74                body_encoding TEXT NOT NULL DEFAULT 'utf8',
75                size_bytes INTEGER NOT NULL,
76                timestamp TEXT NOT NULL,
77                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
78                FOREIGN KEY (request_id) REFERENCES requests(id) ON DELETE CASCADE
79            )
80            "#,
81        )
82        .execute(&self.pool)
83        .await?;
84
85        // Create indexes for common queries
86        sqlx::query(
87            "CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC)",
88        )
89        .execute(&self.pool)
90        .await?;
91
92        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_protocol ON requests(protocol)")
93            .execute(&self.pool)
94            .await?;
95
96        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_method ON requests(method)")
97            .execute(&self.pool)
98            .await?;
99
100        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_path ON requests(path)")
101            .execute(&self.pool)
102            .await?;
103
104        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_trace_id ON requests(trace_id)")
105            .execute(&self.pool)
106            .await?;
107
108        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_status_code ON requests(status_code)")
109            .execute(&self.pool)
110            .await?;
111
112        debug!("Database schema initialized");
113        Ok(())
114    }
115
116    /// Insert a new request
117    pub async fn insert_request(&self, request: &RecordedRequest) -> Result<()> {
118        sqlx::query(
119            r#"
120            INSERT INTO requests (
121                id, protocol, timestamp, method, path, query_params,
122                headers, body, body_encoding, client_ip, trace_id, span_id,
123                duration_ms, status_code, tags
124            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
125            "#,
126        )
127        .bind(&request.id)
128        .bind(request.protocol)
129        .bind(request.timestamp)
130        .bind(&request.method)
131        .bind(&request.path)
132        .bind(&request.query_params)
133        .bind(&request.headers)
134        .bind(&request.body)
135        .bind(&request.body_encoding)
136        .bind(&request.client_ip)
137        .bind(&request.trace_id)
138        .bind(&request.span_id)
139        .bind(request.duration_ms)
140        .bind(request.status_code)
141        .bind(&request.tags)
142        .execute(&self.pool)
143        .await?;
144
145        debug!("Recorded request: {} {} {}", request.protocol, request.method, request.path);
146        Ok(())
147    }
148
149    /// Insert a response
150    pub async fn insert_response(&self, response: &RecordedResponse) -> Result<()> {
151        sqlx::query(
152            r#"
153            INSERT INTO responses (
154                request_id, status_code, headers, body, body_encoding,
155                size_bytes, timestamp
156            ) VALUES (?, ?, ?, ?, ?, ?, ?)
157            "#,
158        )
159        .bind(&response.request_id)
160        .bind(response.status_code)
161        .bind(&response.headers)
162        .bind(&response.body)
163        .bind(&response.body_encoding)
164        .bind(response.size_bytes)
165        .bind(response.timestamp)
166        .execute(&self.pool)
167        .await?;
168
169        debug!("Recorded response for request: {}", response.request_id);
170        Ok(())
171    }
172
173    /// Get a request by ID
174    pub async fn get_request(&self, id: &str) -> Result<Option<RecordedRequest>> {
175        let request = sqlx::query_as::<_, RecordedRequest>(
176            r#"
177            SELECT id, protocol, timestamp, method, path, query_params,
178                   headers, body, body_encoding, client_ip, trace_id, span_id,
179                   duration_ms, status_code, tags
180            FROM requests WHERE id = ?
181            "#,
182        )
183        .bind(id)
184        .fetch_optional(&self.pool)
185        .await?;
186
187        Ok(request)
188    }
189
190    /// Get a response by request ID
191    pub async fn get_response(&self, request_id: &str) -> Result<Option<RecordedResponse>> {
192        let response = sqlx::query_as::<_, RecordedResponse>(
193            r#"
194            SELECT request_id, status_code, headers, body, body_encoding,
195                   size_bytes, timestamp
196            FROM responses WHERE request_id = ?
197            "#,
198        )
199        .bind(request_id)
200        .fetch_optional(&self.pool)
201        .await?;
202
203        Ok(response)
204    }
205
206    /// Get an exchange (request + response) by request ID
207    pub async fn get_exchange(&self, id: &str) -> Result<Option<RecordedExchange>> {
208        let request = self.get_request(id).await?;
209        if let Some(request) = request {
210            let response = self.get_response(id).await?;
211            Ok(Some(RecordedExchange { request, response }))
212        } else {
213            Ok(None)
214        }
215    }
216
217    /// List recent requests
218    pub async fn list_recent(&self, limit: i32) -> Result<Vec<RecordedRequest>> {
219        let requests = sqlx::query_as::<_, RecordedRequest>(
220            r#"
221            SELECT id, protocol, timestamp, method, path, query_params,
222                   headers, body, body_encoding, client_ip, trace_id, span_id,
223                   duration_ms, status_code, tags
224            FROM requests
225            ORDER BY timestamp DESC
226            LIMIT ?
227            "#,
228        )
229        .bind(limit)
230        .fetch_all(&self.pool)
231        .await?;
232
233        Ok(requests)
234    }
235
236    /// Delete old requests
237    pub async fn delete_older_than(&self, days: i64) -> Result<u64> {
238        let result = sqlx::query(
239            r#"
240            DELETE FROM requests
241            WHERE timestamp < datetime('now', ? || ' days')
242            "#,
243        )
244        .bind(format!("-{}", days))
245        .execute(&self.pool)
246        .await?;
247
248        info!("Deleted {} old requests", result.rows_affected());
249        Ok(result.rows_affected())
250    }
251
252    /// Get database statistics
253    pub async fn get_stats(&self) -> Result<DatabaseStats> {
254        let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
255            .fetch_one(&self.pool)
256            .await?;
257
258        let total_responses: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM responses")
259            .fetch_one(&self.pool)
260            .await?;
261
262        let total_size: i64 =
263            sqlx::query_scalar("SELECT COALESCE(SUM(size_bytes), 0) FROM responses")
264                .fetch_one(&self.pool)
265                .await?;
266
267        Ok(DatabaseStats {
268            total_requests,
269            total_responses,
270            total_size_bytes: total_size,
271        })
272    }
273
274    /// Get detailed statistics for API
275    pub async fn get_statistics(&self) -> Result<DetailedStats> {
276        let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
277            .fetch_one(&self.pool)
278            .await?;
279
280        // Get count by protocol
281        let protocol_rows: Vec<(String, i64)> =
282            sqlx::query_as("SELECT protocol, COUNT(*) as count FROM requests GROUP BY protocol")
283                .fetch_all(&self.pool)
284                .await?;
285
286        let by_protocol: HashMap<String, i64> = protocol_rows.into_iter().collect();
287
288        // Get count by status code
289        let status_rows: Vec<(i32, i64)> = sqlx::query_as(
290            "SELECT status_code, COUNT(*) as count FROM requests WHERE status_code IS NOT NULL GROUP BY status_code"
291        )
292        .fetch_all(&self.pool)
293        .await?;
294
295        let by_status_code: HashMap<i32, i64> = status_rows.into_iter().collect();
296
297        // Get average duration
298        let avg_duration: Option<f64> = sqlx::query_scalar(
299            "SELECT AVG(duration_ms) FROM requests WHERE duration_ms IS NOT NULL",
300        )
301        .fetch_one(&self.pool)
302        .await?;
303
304        Ok(DetailedStats {
305            total_requests,
306            by_protocol,
307            by_status_code,
308            avg_duration_ms: avg_duration,
309        })
310    }
311
312    /// Clear all recordings
313    pub async fn clear_all(&self) -> Result<()> {
314        sqlx::query("DELETE FROM responses").execute(&self.pool).await?;
315        sqlx::query("DELETE FROM requests").execute(&self.pool).await?;
316        info!("Cleared all recordings");
317        Ok(())
318    }
319
320    /// Close the database connection
321    pub async fn close(self) {
322        self.pool.close().await;
323        debug!("Recorder database connection closed");
324    }
325}
326
327/// Database statistics
328#[derive(Debug, Clone)]
329pub struct DatabaseStats {
330    pub total_requests: i64,
331    pub total_responses: i64,
332    pub total_size_bytes: i64,
333}
334
335/// Detailed statistics for API
336#[derive(Debug, Clone)]
337pub struct DetailedStats {
338    pub total_requests: i64,
339    pub by_protocol: HashMap<String, i64>,
340    pub by_status_code: HashMap<i32, i64>,
341    pub avg_duration_ms: Option<f64>,
342}
343
344// Implement FromRow for RecordedRequest
345impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedRequest {
346    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
347        use sqlx::Row;
348
349        Ok(RecordedRequest {
350            id: row.try_get("id")?,
351            protocol: row.try_get("protocol")?,
352            timestamp: row.try_get("timestamp")?,
353            method: row.try_get("method")?,
354            path: row.try_get("path")?,
355            query_params: row.try_get("query_params")?,
356            headers: row.try_get("headers")?,
357            body: row.try_get("body")?,
358            body_encoding: row.try_get("body_encoding")?,
359            client_ip: row.try_get("client_ip")?,
360            trace_id: row.try_get("trace_id")?,
361            span_id: row.try_get("span_id")?,
362            duration_ms: row.try_get("duration_ms")?,
363            status_code: row.try_get("status_code")?,
364            tags: row.try_get("tags")?,
365        })
366    }
367}
368
369// Implement FromRow for RecordedResponse
370impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedResponse {
371    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
372        use sqlx::Row;
373
374        Ok(RecordedResponse {
375            request_id: row.try_get("request_id")?,
376            status_code: row.try_get("status_code")?,
377            headers: row.try_get("headers")?,
378            body: row.try_get("body")?,
379            body_encoding: row.try_get("body_encoding")?,
380            size_bytes: row.try_get("size_bytes")?,
381            timestamp: row.try_get("timestamp")?,
382        })
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use chrono::Utc;
390
391    #[tokio::test]
392    async fn test_database_creation() {
393        let db = RecorderDatabase::new_in_memory().await.unwrap();
394        let stats = db.get_stats().await.unwrap();
395        assert_eq!(stats.total_requests, 0);
396    }
397
398    #[tokio::test]
399    async fn test_insert_and_get_request() {
400        let db = RecorderDatabase::new_in_memory().await.unwrap();
401
402        let request = RecordedRequest {
403            id: "test-123".to_string(),
404            protocol: Protocol::Http,
405            timestamp: Utc::now(),
406            method: "GET".to_string(),
407            path: "/api/test".to_string(),
408            query_params: None,
409            headers: "{}".to_string(),
410            body: None,
411            body_encoding: "utf8".to_string(),
412            client_ip: Some("127.0.0.1".to_string()),
413            trace_id: None,
414            span_id: None,
415            duration_ms: Some(42),
416            status_code: Some(200),
417            tags: None,
418        };
419
420        db.insert_request(&request).await.unwrap();
421
422        let retrieved = db.get_request("test-123").await.unwrap();
423        assert!(retrieved.is_some());
424        assert_eq!(retrieved.unwrap().path, "/api/test");
425    }
426}