1use crate::{models::*, Result};
4use sqlx::{sqlite::SqlitePool, Pool, Sqlite};
5use std::{collections::HashMap, path::Path};
6use tracing::{debug, info};
7
8#[derive(Clone)]
10pub struct RecorderDatabase {
11 pool: Pool<Sqlite>,
12}
13
14impl RecorderDatabase {
15 pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
17 let db_url = format!("sqlite:{}?mode=rwc", path.as_ref().display());
18 let pool = SqlitePool::connect(&db_url).await?;
19
20 let db = Self { pool };
21 db.initialize_schema().await?;
22
23 info!("Recorder database initialized at {:?}", path.as_ref());
24 Ok(db)
25 }
26
27 pub async fn new_in_memory() -> Result<Self> {
29 let pool = SqlitePool::connect("sqlite::memory:").await?;
30
31 let db = Self { pool };
32 db.initialize_schema().await?;
33
34 debug!("In-memory recorder database initialized");
35 Ok(db)
36 }
37
38 async fn initialize_schema(&self) -> Result<()> {
40 sqlx::query(
42 r#"
43 CREATE TABLE IF NOT EXISTS requests (
44 id TEXT PRIMARY KEY,
45 protocol TEXT NOT NULL,
46 timestamp TEXT NOT NULL,
47 method TEXT NOT NULL,
48 path TEXT NOT NULL,
49 query_params TEXT,
50 headers TEXT NOT NULL,
51 body TEXT,
52 body_encoding TEXT NOT NULL DEFAULT 'utf8',
53 client_ip TEXT,
54 trace_id TEXT,
55 span_id TEXT,
56 duration_ms INTEGER,
57 status_code INTEGER,
58 tags TEXT,
59 created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
60 )
61 "#,
62 )
63 .execute(&self.pool)
64 .await?;
65
66 sqlx::query(
68 r#"
69 CREATE TABLE IF NOT EXISTS responses (
70 request_id TEXT PRIMARY KEY,
71 status_code INTEGER NOT NULL,
72 headers TEXT NOT NULL,
73 body TEXT,
74 body_encoding TEXT NOT NULL DEFAULT 'utf8',
75 size_bytes INTEGER NOT NULL,
76 timestamp TEXT NOT NULL,
77 created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
78 FOREIGN KEY (request_id) REFERENCES requests(id) ON DELETE CASCADE
79 )
80 "#,
81 )
82 .execute(&self.pool)
83 .await?;
84
85 sqlx::query(
87 "CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC)",
88 )
89 .execute(&self.pool)
90 .await?;
91
92 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_protocol ON requests(protocol)")
93 .execute(&self.pool)
94 .await?;
95
96 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_method ON requests(method)")
97 .execute(&self.pool)
98 .await?;
99
100 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_path ON requests(path)")
101 .execute(&self.pool)
102 .await?;
103
104 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_trace_id ON requests(trace_id)")
105 .execute(&self.pool)
106 .await?;
107
108 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_status_code ON requests(status_code)")
109 .execute(&self.pool)
110 .await?;
111
112 debug!("Database schema initialized");
113 Ok(())
114 }
115
116 pub async fn insert_request(&self, request: &RecordedRequest) -> Result<()> {
118 sqlx::query(
119 r#"
120 INSERT INTO requests (
121 id, protocol, timestamp, method, path, query_params,
122 headers, body, body_encoding, client_ip, trace_id, span_id,
123 duration_ms, status_code, tags
124 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
125 "#,
126 )
127 .bind(&request.id)
128 .bind(request.protocol)
129 .bind(request.timestamp)
130 .bind(&request.method)
131 .bind(&request.path)
132 .bind(&request.query_params)
133 .bind(&request.headers)
134 .bind(&request.body)
135 .bind(&request.body_encoding)
136 .bind(&request.client_ip)
137 .bind(&request.trace_id)
138 .bind(&request.span_id)
139 .bind(request.duration_ms)
140 .bind(request.status_code)
141 .bind(&request.tags)
142 .execute(&self.pool)
143 .await?;
144
145 debug!("Recorded request: {} {} {}", request.protocol, request.method, request.path);
146 Ok(())
147 }
148
149 pub async fn insert_response(&self, response: &RecordedResponse) -> Result<()> {
151 sqlx::query(
152 r#"
153 INSERT INTO responses (
154 request_id, status_code, headers, body, body_encoding,
155 size_bytes, timestamp
156 ) VALUES (?, ?, ?, ?, ?, ?, ?)
157 "#,
158 )
159 .bind(&response.request_id)
160 .bind(response.status_code)
161 .bind(&response.headers)
162 .bind(&response.body)
163 .bind(&response.body_encoding)
164 .bind(response.size_bytes)
165 .bind(response.timestamp)
166 .execute(&self.pool)
167 .await?;
168
169 debug!("Recorded response for request: {}", response.request_id);
170 Ok(())
171 }
172
173 pub async fn get_request(&self, id: &str) -> Result<Option<RecordedRequest>> {
175 let request = sqlx::query_as::<_, RecordedRequest>(
176 r#"
177 SELECT id, protocol, timestamp, method, path, query_params,
178 headers, body, body_encoding, client_ip, trace_id, span_id,
179 duration_ms, status_code, tags
180 FROM requests WHERE id = ?
181 "#,
182 )
183 .bind(id)
184 .fetch_optional(&self.pool)
185 .await?;
186
187 Ok(request)
188 }
189
190 pub async fn get_response(&self, request_id: &str) -> Result<Option<RecordedResponse>> {
192 let response = sqlx::query_as::<_, RecordedResponse>(
193 r#"
194 SELECT request_id, status_code, headers, body, body_encoding,
195 size_bytes, timestamp
196 FROM responses WHERE request_id = ?
197 "#,
198 )
199 .bind(request_id)
200 .fetch_optional(&self.pool)
201 .await?;
202
203 Ok(response)
204 }
205
206 pub async fn get_exchange(&self, id: &str) -> Result<Option<RecordedExchange>> {
208 let request = self.get_request(id).await?;
209 if let Some(request) = request {
210 let response = self.get_response(id).await?;
211 Ok(Some(RecordedExchange { request, response }))
212 } else {
213 Ok(None)
214 }
215 }
216
217 pub async fn update_response(
219 &self,
220 request_id: &str,
221 status_code: i32,
222 headers: &str,
223 body: &str,
224 size_bytes: i64,
225 ) -> Result<()> {
226 sqlx::query(
227 r#"
228 UPDATE responses
229 SET status_code = ?,
230 headers = ?,
231 body = ?,
232 body_encoding = 'base64',
233 size_bytes = ?,
234 timestamp = datetime('now')
235 WHERE request_id = ?
236 "#,
237 )
238 .bind(status_code)
239 .bind(headers)
240 .bind(body)
241 .bind(size_bytes)
242 .bind(request_id)
243 .execute(&self.pool)
244 .await?;
245
246 debug!("Updated response for request {}", request_id);
247 Ok(())
248 }
249
250 pub async fn list_recent(&self, limit: i32) -> Result<Vec<RecordedRequest>> {
252 let requests = sqlx::query_as::<_, RecordedRequest>(
253 r#"
254 SELECT id, protocol, timestamp, method, path, query_params,
255 headers, body, body_encoding, client_ip, trace_id, span_id,
256 duration_ms, status_code, tags
257 FROM requests
258 ORDER BY timestamp DESC
259 LIMIT ?
260 "#,
261 )
262 .bind(limit)
263 .fetch_all(&self.pool)
264 .await?;
265
266 Ok(requests)
267 }
268
269 pub async fn delete_older_than(&self, days: i64) -> Result<u64> {
271 let result = sqlx::query(
272 r#"
273 DELETE FROM requests
274 WHERE timestamp < datetime('now', ? || ' days')
275 "#,
276 )
277 .bind(format!("-{}", days))
278 .execute(&self.pool)
279 .await?;
280
281 info!("Deleted {} old requests", result.rows_affected());
282 Ok(result.rows_affected())
283 }
284
285 pub async fn get_stats(&self) -> Result<DatabaseStats> {
287 let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
288 .fetch_one(&self.pool)
289 .await?;
290
291 let total_responses: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM responses")
292 .fetch_one(&self.pool)
293 .await?;
294
295 let total_size: i64 =
296 sqlx::query_scalar("SELECT COALESCE(SUM(size_bytes), 0) FROM responses")
297 .fetch_one(&self.pool)
298 .await?;
299
300 Ok(DatabaseStats {
301 total_requests,
302 total_responses,
303 total_size_bytes: total_size,
304 })
305 }
306
307 pub async fn get_statistics(&self) -> Result<DetailedStats> {
309 let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
310 .fetch_one(&self.pool)
311 .await?;
312
313 let protocol_rows: Vec<(String, i64)> =
315 sqlx::query_as("SELECT protocol, COUNT(*) as count FROM requests GROUP BY protocol")
316 .fetch_all(&self.pool)
317 .await?;
318
319 let by_protocol: HashMap<String, i64> = protocol_rows.into_iter().collect();
320
321 let status_rows: Vec<(i32, i64)> = sqlx::query_as(
323 "SELECT status_code, COUNT(*) as count FROM requests WHERE status_code IS NOT NULL GROUP BY status_code"
324 )
325 .fetch_all(&self.pool)
326 .await?;
327
328 let by_status_code: HashMap<i32, i64> = status_rows.into_iter().collect();
329
330 let avg_duration: Option<f64> = sqlx::query_scalar(
332 "SELECT AVG(duration_ms) FROM requests WHERE duration_ms IS NOT NULL",
333 )
334 .fetch_one(&self.pool)
335 .await?;
336
337 Ok(DetailedStats {
338 total_requests,
339 by_protocol,
340 by_status_code,
341 avg_duration_ms: avg_duration,
342 })
343 }
344
345 pub async fn clear_all(&self) -> Result<()> {
347 sqlx::query("DELETE FROM responses").execute(&self.pool).await?;
348 sqlx::query("DELETE FROM requests").execute(&self.pool).await?;
349 info!("Cleared all recordings");
350 Ok(())
351 }
352
353 pub async fn close(self) {
355 self.pool.close().await;
356 debug!("Recorder database connection closed");
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct DatabaseStats {
363 pub total_requests: i64,
364 pub total_responses: i64,
365 pub total_size_bytes: i64,
366}
367
368#[derive(Debug, Clone)]
370pub struct DetailedStats {
371 pub total_requests: i64,
372 pub by_protocol: HashMap<String, i64>,
373 pub by_status_code: HashMap<i32, i64>,
374 pub avg_duration_ms: Option<f64>,
375}
376
377impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedRequest {
379 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
380 use sqlx::Row;
381
382 Ok(RecordedRequest {
383 id: row.try_get("id")?,
384 protocol: row.try_get("protocol")?,
385 timestamp: row.try_get("timestamp")?,
386 method: row.try_get("method")?,
387 path: row.try_get("path")?,
388 query_params: row.try_get("query_params")?,
389 headers: row.try_get("headers")?,
390 body: row.try_get("body")?,
391 body_encoding: row.try_get("body_encoding")?,
392 client_ip: row.try_get("client_ip")?,
393 trace_id: row.try_get("trace_id")?,
394 span_id: row.try_get("span_id")?,
395 duration_ms: row.try_get("duration_ms")?,
396 status_code: row.try_get("status_code")?,
397 tags: row.try_get("tags")?,
398 })
399 }
400}
401
402impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedResponse {
404 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
405 use sqlx::Row;
406
407 Ok(RecordedResponse {
408 request_id: row.try_get("request_id")?,
409 status_code: row.try_get("status_code")?,
410 headers: row.try_get("headers")?,
411 body: row.try_get("body")?,
412 body_encoding: row.try_get("body_encoding")?,
413 size_bytes: row.try_get("size_bytes")?,
414 timestamp: row.try_get("timestamp")?,
415 })
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use chrono::Utc;
423
424 #[tokio::test]
425 async fn test_database_creation() {
426 let db = RecorderDatabase::new_in_memory().await.unwrap();
427 let stats = db.get_stats().await.unwrap();
428 assert_eq!(stats.total_requests, 0);
429 }
430
431 #[tokio::test]
432 async fn test_insert_and_get_request() {
433 let db = RecorderDatabase::new_in_memory().await.unwrap();
434
435 let request = RecordedRequest {
436 id: "test-123".to_string(),
437 protocol: Protocol::Http,
438 timestamp: Utc::now(),
439 method: "GET".to_string(),
440 path: "/api/test".to_string(),
441 query_params: None,
442 headers: "{}".to_string(),
443 body: None,
444 body_encoding: "utf8".to_string(),
445 client_ip: Some("127.0.0.1".to_string()),
446 trace_id: None,
447 span_id: None,
448 duration_ms: Some(42),
449 status_code: Some(200),
450 tags: None,
451 };
452
453 db.insert_request(&request).await.unwrap();
454
455 let retrieved = db.get_request("test-123").await.unwrap();
456 assert!(retrieved.is_some());
457 assert_eq!(retrieved.unwrap().path, "/api/test");
458 }
459}