mockforge_recorder/
database.rs

1//! SQLite database for storing recorded requests and responses
2
3use crate::{models::*, Result};
4use sqlx::{sqlite::SqlitePool, Pool, Sqlite};
5use std::{collections::HashMap, path::Path};
6use tracing::{debug, info};
7
8/// SQLite database for recorder
9#[derive(Clone)]
10pub struct RecorderDatabase {
11    pool: Pool<Sqlite>,
12}
13
14impl RecorderDatabase {
15    /// Create a new database connection
16    pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
17        let db_url = format!("sqlite:{}?mode=rwc", path.as_ref().display());
18        let pool = SqlitePool::connect(&db_url).await?;
19
20        let db = Self { pool };
21        db.initialize_schema().await?;
22
23        info!("Recorder database initialized at {:?}", path.as_ref());
24        Ok(db)
25    }
26
27    /// Create an in-memory database (for testing)
28    pub async fn new_in_memory() -> Result<Self> {
29        let pool = SqlitePool::connect("sqlite::memory:").await?;
30
31        let db = Self { pool };
32        db.initialize_schema().await?;
33
34        debug!("In-memory recorder database initialized");
35        Ok(db)
36    }
37
38    /// Initialize database schema
39    async fn initialize_schema(&self) -> Result<()> {
40        // Create requests table
41        sqlx::query(
42            r#"
43            CREATE TABLE IF NOT EXISTS requests (
44                id TEXT PRIMARY KEY,
45                protocol TEXT NOT NULL,
46                timestamp TEXT NOT NULL,
47                method TEXT NOT NULL,
48                path TEXT NOT NULL,
49                query_params TEXT,
50                headers TEXT NOT NULL,
51                body TEXT,
52                body_encoding TEXT NOT NULL DEFAULT 'utf8',
53                client_ip TEXT,
54                trace_id TEXT,
55                span_id TEXT,
56                duration_ms INTEGER,
57                status_code INTEGER,
58                tags TEXT,
59                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
60            )
61            "#,
62        )
63        .execute(&self.pool)
64        .await?;
65
66        // Create responses table
67        sqlx::query(
68            r#"
69            CREATE TABLE IF NOT EXISTS responses (
70                request_id TEXT PRIMARY KEY,
71                status_code INTEGER NOT NULL,
72                headers TEXT NOT NULL,
73                body TEXT,
74                body_encoding TEXT NOT NULL DEFAULT 'utf8',
75                size_bytes INTEGER NOT NULL,
76                timestamp TEXT NOT NULL,
77                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
78                FOREIGN KEY (request_id) REFERENCES requests(id) ON DELETE CASCADE
79            )
80            "#,
81        )
82        .execute(&self.pool)
83        .await?;
84
85        // Create indexes for common queries
86        sqlx::query(
87            "CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC)",
88        )
89        .execute(&self.pool)
90        .await?;
91
92        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_protocol ON requests(protocol)")
93            .execute(&self.pool)
94            .await?;
95
96        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_method ON requests(method)")
97            .execute(&self.pool)
98            .await?;
99
100        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_path ON requests(path)")
101            .execute(&self.pool)
102            .await?;
103
104        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_trace_id ON requests(trace_id)")
105            .execute(&self.pool)
106            .await?;
107
108        sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_status_code ON requests(status_code)")
109            .execute(&self.pool)
110            .await?;
111
112        debug!("Database schema initialized");
113        Ok(())
114    }
115
116    /// Insert a new request
117    pub async fn insert_request(&self, request: &RecordedRequest) -> Result<()> {
118        sqlx::query(
119            r#"
120            INSERT INTO requests (
121                id, protocol, timestamp, method, path, query_params,
122                headers, body, body_encoding, client_ip, trace_id, span_id,
123                duration_ms, status_code, tags
124            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
125            "#,
126        )
127        .bind(&request.id)
128        .bind(request.protocol)
129        .bind(request.timestamp)
130        .bind(&request.method)
131        .bind(&request.path)
132        .bind(&request.query_params)
133        .bind(&request.headers)
134        .bind(&request.body)
135        .bind(&request.body_encoding)
136        .bind(&request.client_ip)
137        .bind(&request.trace_id)
138        .bind(&request.span_id)
139        .bind(request.duration_ms)
140        .bind(request.status_code)
141        .bind(&request.tags)
142        .execute(&self.pool)
143        .await?;
144
145        debug!("Recorded request: {} {} {}", request.protocol, request.method, request.path);
146        Ok(())
147    }
148
149    /// Insert a response
150    pub async fn insert_response(&self, response: &RecordedResponse) -> Result<()> {
151        sqlx::query(
152            r#"
153            INSERT INTO responses (
154                request_id, status_code, headers, body, body_encoding,
155                size_bytes, timestamp
156            ) VALUES (?, ?, ?, ?, ?, ?, ?)
157            "#,
158        )
159        .bind(&response.request_id)
160        .bind(response.status_code)
161        .bind(&response.headers)
162        .bind(&response.body)
163        .bind(&response.body_encoding)
164        .bind(response.size_bytes)
165        .bind(response.timestamp)
166        .execute(&self.pool)
167        .await?;
168
169        debug!("Recorded response for request: {}", response.request_id);
170        Ok(())
171    }
172
173    /// Get a request by ID
174    pub async fn get_request(&self, id: &str) -> Result<Option<RecordedRequest>> {
175        let request = sqlx::query_as::<_, RecordedRequest>(
176            r#"
177            SELECT id, protocol, timestamp, method, path, query_params,
178                   headers, body, body_encoding, client_ip, trace_id, span_id,
179                   duration_ms, status_code, tags
180            FROM requests WHERE id = ?
181            "#,
182        )
183        .bind(id)
184        .fetch_optional(&self.pool)
185        .await?;
186
187        Ok(request)
188    }
189
190    /// Get a response by request ID
191    pub async fn get_response(&self, request_id: &str) -> Result<Option<RecordedResponse>> {
192        let response = sqlx::query_as::<_, RecordedResponse>(
193            r#"
194            SELECT request_id, status_code, headers, body, body_encoding,
195                   size_bytes, timestamp
196            FROM responses WHERE request_id = ?
197            "#,
198        )
199        .bind(request_id)
200        .fetch_optional(&self.pool)
201        .await?;
202
203        Ok(response)
204    }
205
206    /// Get an exchange (request + response) by request ID
207    pub async fn get_exchange(&self, id: &str) -> Result<Option<RecordedExchange>> {
208        let request = self.get_request(id).await?;
209        if let Some(request) = request {
210            let response = self.get_response(id).await?;
211            Ok(Some(RecordedExchange { request, response }))
212        } else {
213            Ok(None)
214        }
215    }
216
217    /// Update an existing response
218    pub async fn update_response(
219        &self,
220        request_id: &str,
221        status_code: i32,
222        headers: &str,
223        body: &str,
224        size_bytes: i64,
225    ) -> Result<()> {
226        sqlx::query(
227            r#"
228            UPDATE responses
229            SET status_code = ?,
230                headers = ?,
231                body = ?,
232                body_encoding = 'base64',
233                size_bytes = ?,
234                timestamp = datetime('now')
235            WHERE request_id = ?
236            "#,
237        )
238        .bind(status_code)
239        .bind(headers)
240        .bind(body)
241        .bind(size_bytes)
242        .bind(request_id)
243        .execute(&self.pool)
244        .await?;
245
246        debug!("Updated response for request {}", request_id);
247        Ok(())
248    }
249
250    /// List recent requests
251    pub async fn list_recent(&self, limit: i32) -> Result<Vec<RecordedRequest>> {
252        let requests = sqlx::query_as::<_, RecordedRequest>(
253            r#"
254            SELECT id, protocol, timestamp, method, path, query_params,
255                   headers, body, body_encoding, client_ip, trace_id, span_id,
256                   duration_ms, status_code, tags
257            FROM requests
258            ORDER BY timestamp DESC
259            LIMIT ?
260            "#,
261        )
262        .bind(limit)
263        .fetch_all(&self.pool)
264        .await?;
265
266        Ok(requests)
267    }
268
269    /// Delete old requests
270    pub async fn delete_older_than(&self, days: i64) -> Result<u64> {
271        let result = sqlx::query(
272            r#"
273            DELETE FROM requests
274            WHERE timestamp < datetime('now', ? || ' days')
275            "#,
276        )
277        .bind(format!("-{}", days))
278        .execute(&self.pool)
279        .await?;
280
281        info!("Deleted {} old requests", result.rows_affected());
282        Ok(result.rows_affected())
283    }
284
285    /// Get database statistics
286    pub async fn get_stats(&self) -> Result<DatabaseStats> {
287        let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
288            .fetch_one(&self.pool)
289            .await?;
290
291        let total_responses: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM responses")
292            .fetch_one(&self.pool)
293            .await?;
294
295        let total_size: i64 =
296            sqlx::query_scalar("SELECT COALESCE(SUM(size_bytes), 0) FROM responses")
297                .fetch_one(&self.pool)
298                .await?;
299
300        Ok(DatabaseStats {
301            total_requests,
302            total_responses,
303            total_size_bytes: total_size,
304        })
305    }
306
307    /// Get detailed statistics for API
308    pub async fn get_statistics(&self) -> Result<DetailedStats> {
309        let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
310            .fetch_one(&self.pool)
311            .await?;
312
313        // Get count by protocol
314        let protocol_rows: Vec<(String, i64)> =
315            sqlx::query_as("SELECT protocol, COUNT(*) as count FROM requests GROUP BY protocol")
316                .fetch_all(&self.pool)
317                .await?;
318
319        let by_protocol: HashMap<String, i64> = protocol_rows.into_iter().collect();
320
321        // Get count by status code
322        let status_rows: Vec<(i32, i64)> = sqlx::query_as(
323            "SELECT status_code, COUNT(*) as count FROM requests WHERE status_code IS NOT NULL GROUP BY status_code"
324        )
325        .fetch_all(&self.pool)
326        .await?;
327
328        let by_status_code: HashMap<i32, i64> = status_rows.into_iter().collect();
329
330        // Get average duration
331        let avg_duration: Option<f64> = sqlx::query_scalar(
332            "SELECT AVG(duration_ms) FROM requests WHERE duration_ms IS NOT NULL",
333        )
334        .fetch_one(&self.pool)
335        .await?;
336
337        Ok(DetailedStats {
338            total_requests,
339            by_protocol,
340            by_status_code,
341            avg_duration_ms: avg_duration,
342        })
343    }
344
345    /// Clear all recordings
346    pub async fn clear_all(&self) -> Result<()> {
347        sqlx::query("DELETE FROM responses").execute(&self.pool).await?;
348        sqlx::query("DELETE FROM requests").execute(&self.pool).await?;
349        info!("Cleared all recordings");
350        Ok(())
351    }
352
353    /// Close the database connection
354    pub async fn close(self) {
355        self.pool.close().await;
356        debug!("Recorder database connection closed");
357    }
358}
359
360/// Database statistics
361#[derive(Debug, Clone)]
362pub struct DatabaseStats {
363    pub total_requests: i64,
364    pub total_responses: i64,
365    pub total_size_bytes: i64,
366}
367
368/// Detailed statistics for API
369#[derive(Debug, Clone)]
370pub struct DetailedStats {
371    pub total_requests: i64,
372    pub by_protocol: HashMap<String, i64>,
373    pub by_status_code: HashMap<i32, i64>,
374    pub avg_duration_ms: Option<f64>,
375}
376
377// Implement FromRow for RecordedRequest
378impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedRequest {
379    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
380        use sqlx::Row;
381
382        Ok(RecordedRequest {
383            id: row.try_get("id")?,
384            protocol: row.try_get("protocol")?,
385            timestamp: row.try_get("timestamp")?,
386            method: row.try_get("method")?,
387            path: row.try_get("path")?,
388            query_params: row.try_get("query_params")?,
389            headers: row.try_get("headers")?,
390            body: row.try_get("body")?,
391            body_encoding: row.try_get("body_encoding")?,
392            client_ip: row.try_get("client_ip")?,
393            trace_id: row.try_get("trace_id")?,
394            span_id: row.try_get("span_id")?,
395            duration_ms: row.try_get("duration_ms")?,
396            status_code: row.try_get("status_code")?,
397            tags: row.try_get("tags")?,
398        })
399    }
400}
401
402// Implement FromRow for RecordedResponse
403impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedResponse {
404    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
405        use sqlx::Row;
406
407        Ok(RecordedResponse {
408            request_id: row.try_get("request_id")?,
409            status_code: row.try_get("status_code")?,
410            headers: row.try_get("headers")?,
411            body: row.try_get("body")?,
412            body_encoding: row.try_get("body_encoding")?,
413            size_bytes: row.try_get("size_bytes")?,
414            timestamp: row.try_get("timestamp")?,
415        })
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use chrono::Utc;
423
424    #[tokio::test]
425    async fn test_database_creation() {
426        let db = RecorderDatabase::new_in_memory().await.unwrap();
427        let stats = db.get_stats().await.unwrap();
428        assert_eq!(stats.total_requests, 0);
429    }
430
431    #[tokio::test]
432    async fn test_insert_and_get_request() {
433        let db = RecorderDatabase::new_in_memory().await.unwrap();
434
435        let request = RecordedRequest {
436            id: "test-123".to_string(),
437            protocol: Protocol::Http,
438            timestamp: Utc::now(),
439            method: "GET".to_string(),
440            path: "/api/test".to_string(),
441            query_params: None,
442            headers: "{}".to_string(),
443            body: None,
444            body_encoding: "utf8".to_string(),
445            client_ip: Some("127.0.0.1".to_string()),
446            trace_id: None,
447            span_id: None,
448            duration_ms: Some(42),
449            status_code: Some(200),
450            tags: None,
451        };
452
453        db.insert_request(&request).await.unwrap();
454
455        let retrieved = db.get_request("test-123").await.unwrap();
456        assert!(retrieved.is_some());
457        assert_eq!(retrieved.unwrap().path, "/api/test");
458    }
459}