1use crate::{models::*, Result};
4use sqlx::{sqlite::SqlitePool, Pool, Sqlite};
5use std::{collections::HashMap, path::Path};
6use tracing::{debug, info};
7
8#[derive(Clone)]
10pub struct RecorderDatabase {
11 pool: Pool<Sqlite>,
12}
13
14impl RecorderDatabase {
15 pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
17 let db_url = format!("sqlite:{}?mode=rwc", path.as_ref().display());
18 let pool = SqlitePool::connect(&db_url).await?;
19
20 let db = Self { pool };
21 db.initialize_schema().await?;
22
23 info!("Recorder database initialized at {:?}", path.as_ref());
24 Ok(db)
25 }
26
27 pub async fn new_in_memory() -> Result<Self> {
29 let pool = SqlitePool::connect("sqlite::memory:").await?;
30
31 let db = Self { pool };
32 db.initialize_schema().await?;
33
34 debug!("In-memory recorder database initialized");
35 Ok(db)
36 }
37
38 async fn initialize_schema(&self) -> Result<()> {
40 sqlx::query(
42 r#"
43 CREATE TABLE IF NOT EXISTS requests (
44 id TEXT PRIMARY KEY,
45 protocol TEXT NOT NULL,
46 timestamp TEXT NOT NULL,
47 method TEXT NOT NULL,
48 path TEXT NOT NULL,
49 query_params TEXT,
50 headers TEXT NOT NULL,
51 body TEXT,
52 body_encoding TEXT NOT NULL DEFAULT 'utf8',
53 client_ip TEXT,
54 trace_id TEXT,
55 span_id TEXT,
56 duration_ms INTEGER,
57 status_code INTEGER,
58 tags TEXT,
59 created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
60 )
61 "#,
62 )
63 .execute(&self.pool)
64 .await?;
65
66 sqlx::query(
68 r#"
69 CREATE TABLE IF NOT EXISTS responses (
70 request_id TEXT PRIMARY KEY,
71 status_code INTEGER NOT NULL,
72 headers TEXT NOT NULL,
73 body TEXT,
74 body_encoding TEXT NOT NULL DEFAULT 'utf8',
75 size_bytes INTEGER NOT NULL,
76 timestamp TEXT NOT NULL,
77 created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
78 FOREIGN KEY (request_id) REFERENCES requests(id) ON DELETE CASCADE
79 )
80 "#,
81 )
82 .execute(&self.pool)
83 .await?;
84
85 sqlx::query(
87 "CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp DESC)",
88 )
89 .execute(&self.pool)
90 .await?;
91
92 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_protocol ON requests(protocol)")
93 .execute(&self.pool)
94 .await?;
95
96 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_method ON requests(method)")
97 .execute(&self.pool)
98 .await?;
99
100 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_path ON requests(path)")
101 .execute(&self.pool)
102 .await?;
103
104 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_trace_id ON requests(trace_id)")
105 .execute(&self.pool)
106 .await?;
107
108 sqlx::query("CREATE INDEX IF NOT EXISTS idx_requests_status_code ON requests(status_code)")
109 .execute(&self.pool)
110 .await?;
111
112 debug!("Database schema initialized");
113 Ok(())
114 }
115
116 pub async fn insert_request(&self, request: &RecordedRequest) -> Result<()> {
118 sqlx::query(
119 r#"
120 INSERT INTO requests (
121 id, protocol, timestamp, method, path, query_params,
122 headers, body, body_encoding, client_ip, trace_id, span_id,
123 duration_ms, status_code, tags
124 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
125 "#,
126 )
127 .bind(&request.id)
128 .bind(request.protocol)
129 .bind(request.timestamp)
130 .bind(&request.method)
131 .bind(&request.path)
132 .bind(&request.query_params)
133 .bind(&request.headers)
134 .bind(&request.body)
135 .bind(&request.body_encoding)
136 .bind(&request.client_ip)
137 .bind(&request.trace_id)
138 .bind(&request.span_id)
139 .bind(request.duration_ms)
140 .bind(request.status_code)
141 .bind(&request.tags)
142 .execute(&self.pool)
143 .await?;
144
145 debug!("Recorded request: {} {} {}", request.protocol, request.method, request.path);
146 Ok(())
147 }
148
149 pub async fn insert_response(&self, response: &RecordedResponse) -> Result<()> {
151 sqlx::query(
152 r#"
153 INSERT INTO responses (
154 request_id, status_code, headers, body, body_encoding,
155 size_bytes, timestamp
156 ) VALUES (?, ?, ?, ?, ?, ?, ?)
157 "#,
158 )
159 .bind(&response.request_id)
160 .bind(response.status_code)
161 .bind(&response.headers)
162 .bind(&response.body)
163 .bind(&response.body_encoding)
164 .bind(response.size_bytes)
165 .bind(response.timestamp)
166 .execute(&self.pool)
167 .await?;
168
169 debug!("Recorded response for request: {}", response.request_id);
170 Ok(())
171 }
172
173 pub async fn get_request(&self, id: &str) -> Result<Option<RecordedRequest>> {
175 let request = sqlx::query_as::<_, RecordedRequest>(
176 r#"
177 SELECT id, protocol, timestamp, method, path, query_params,
178 headers, body, body_encoding, client_ip, trace_id, span_id,
179 duration_ms, status_code, tags
180 FROM requests WHERE id = ?
181 "#,
182 )
183 .bind(id)
184 .fetch_optional(&self.pool)
185 .await?;
186
187 Ok(request)
188 }
189
190 pub async fn get_response(&self, request_id: &str) -> Result<Option<RecordedResponse>> {
192 let response = sqlx::query_as::<_, RecordedResponse>(
193 r#"
194 SELECT request_id, status_code, headers, body, body_encoding,
195 size_bytes, timestamp
196 FROM responses WHERE request_id = ?
197 "#,
198 )
199 .bind(request_id)
200 .fetch_optional(&self.pool)
201 .await?;
202
203 Ok(response)
204 }
205
206 pub async fn get_exchange(&self, id: &str) -> Result<Option<RecordedExchange>> {
208 let request = self.get_request(id).await?;
209 if let Some(request) = request {
210 let response = self.get_response(id).await?;
211 Ok(Some(RecordedExchange { request, response }))
212 } else {
213 Ok(None)
214 }
215 }
216
217 pub async fn list_recent(&self, limit: i32) -> Result<Vec<RecordedRequest>> {
219 let requests = sqlx::query_as::<_, RecordedRequest>(
220 r#"
221 SELECT id, protocol, timestamp, method, path, query_params,
222 headers, body, body_encoding, client_ip, trace_id, span_id,
223 duration_ms, status_code, tags
224 FROM requests
225 ORDER BY timestamp DESC
226 LIMIT ?
227 "#,
228 )
229 .bind(limit)
230 .fetch_all(&self.pool)
231 .await?;
232
233 Ok(requests)
234 }
235
236 pub async fn delete_older_than(&self, days: i64) -> Result<u64> {
238 let result = sqlx::query(
239 r#"
240 DELETE FROM requests
241 WHERE timestamp < datetime('now', ? || ' days')
242 "#,
243 )
244 .bind(format!("-{}", days))
245 .execute(&self.pool)
246 .await?;
247
248 info!("Deleted {} old requests", result.rows_affected());
249 Ok(result.rows_affected())
250 }
251
252 pub async fn get_stats(&self) -> Result<DatabaseStats> {
254 let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
255 .fetch_one(&self.pool)
256 .await?;
257
258 let total_responses: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM responses")
259 .fetch_one(&self.pool)
260 .await?;
261
262 let total_size: i64 =
263 sqlx::query_scalar("SELECT COALESCE(SUM(size_bytes), 0) FROM responses")
264 .fetch_one(&self.pool)
265 .await?;
266
267 Ok(DatabaseStats {
268 total_requests,
269 total_responses,
270 total_size_bytes: total_size,
271 })
272 }
273
274 pub async fn get_statistics(&self) -> Result<DetailedStats> {
276 let total_requests: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM requests")
277 .fetch_one(&self.pool)
278 .await?;
279
280 let protocol_rows: Vec<(String, i64)> =
282 sqlx::query_as("SELECT protocol, COUNT(*) as count FROM requests GROUP BY protocol")
283 .fetch_all(&self.pool)
284 .await?;
285
286 let by_protocol: HashMap<String, i64> = protocol_rows.into_iter().collect();
287
288 let status_rows: Vec<(i32, i64)> = sqlx::query_as(
290 "SELECT status_code, COUNT(*) as count FROM requests WHERE status_code IS NOT NULL GROUP BY status_code"
291 )
292 .fetch_all(&self.pool)
293 .await?;
294
295 let by_status_code: HashMap<i32, i64> = status_rows.into_iter().collect();
296
297 let avg_duration: Option<f64> = sqlx::query_scalar(
299 "SELECT AVG(duration_ms) FROM requests WHERE duration_ms IS NOT NULL",
300 )
301 .fetch_one(&self.pool)
302 .await?;
303
304 Ok(DetailedStats {
305 total_requests,
306 by_protocol,
307 by_status_code,
308 avg_duration_ms: avg_duration,
309 })
310 }
311
312 pub async fn clear_all(&self) -> Result<()> {
314 sqlx::query("DELETE FROM responses").execute(&self.pool).await?;
315 sqlx::query("DELETE FROM requests").execute(&self.pool).await?;
316 info!("Cleared all recordings");
317 Ok(())
318 }
319
320 pub async fn close(self) {
322 self.pool.close().await;
323 debug!("Recorder database connection closed");
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct DatabaseStats {
330 pub total_requests: i64,
331 pub total_responses: i64,
332 pub total_size_bytes: i64,
333}
334
335#[derive(Debug, Clone)]
337pub struct DetailedStats {
338 pub total_requests: i64,
339 pub by_protocol: HashMap<String, i64>,
340 pub by_status_code: HashMap<i32, i64>,
341 pub avg_duration_ms: Option<f64>,
342}
343
344impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedRequest {
346 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
347 use sqlx::Row;
348
349 Ok(RecordedRequest {
350 id: row.try_get("id")?,
351 protocol: row.try_get("protocol")?,
352 timestamp: row.try_get("timestamp")?,
353 method: row.try_get("method")?,
354 path: row.try_get("path")?,
355 query_params: row.try_get("query_params")?,
356 headers: row.try_get("headers")?,
357 body: row.try_get("body")?,
358 body_encoding: row.try_get("body_encoding")?,
359 client_ip: row.try_get("client_ip")?,
360 trace_id: row.try_get("trace_id")?,
361 span_id: row.try_get("span_id")?,
362 duration_ms: row.try_get("duration_ms")?,
363 status_code: row.try_get("status_code")?,
364 tags: row.try_get("tags")?,
365 })
366 }
367}
368
369impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for RecordedResponse {
371 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> sqlx::Result<Self> {
372 use sqlx::Row;
373
374 Ok(RecordedResponse {
375 request_id: row.try_get("request_id")?,
376 status_code: row.try_get("status_code")?,
377 headers: row.try_get("headers")?,
378 body: row.try_get("body")?,
379 body_encoding: row.try_get("body_encoding")?,
380 size_bytes: row.try_get("size_bytes")?,
381 timestamp: row.try_get("timestamp")?,
382 })
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use chrono::Utc;
390
391 #[tokio::test]
392 async fn test_database_creation() {
393 let db = RecorderDatabase::new_in_memory().await.unwrap();
394 let stats = db.get_stats().await.unwrap();
395 assert_eq!(stats.total_requests, 0);
396 }
397
398 #[tokio::test]
399 async fn test_insert_and_get_request() {
400 let db = RecorderDatabase::new_in_memory().await.unwrap();
401
402 let request = RecordedRequest {
403 id: "test-123".to_string(),
404 protocol: Protocol::Http,
405 timestamp: Utc::now(),
406 method: "GET".to_string(),
407 path: "/api/test".to_string(),
408 query_params: None,
409 headers: "{}".to_string(),
410 body: None,
411 body_encoding: "utf8".to_string(),
412 client_ip: Some("127.0.0.1".to_string()),
413 trace_id: None,
414 span_id: None,
415 duration_ms: Some(42),
416 status_code: Some(200),
417 tags: None,
418 };
419
420 db.insert_request(&request).await.unwrap();
421
422 let retrieved = db.get_request("test-123").await.unwrap();
423 assert!(retrieved.is_some());
424 assert_eq!(retrieved.unwrap().path, "/api/test");
425 }
426}