1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{PgPool, QueryBuilder, Row};
7
8use crate::error::PostgresHandlerError;
9
10#[derive(Debug)]
11pub struct HttpRequest<TReq> {
12 pub id: i64,
13 pub correlation_id: i64,
14 pub timestamp: DateTime<Utc>,
15 pub method: String,
16 pub uri: String,
17 pub headers: Value,
18 pub body: Option<Result<TReq, Bytes>>,
19 pub created_at: DateTime<Utc>,
20}
21
22#[derive(Debug)]
23pub struct HttpResponse<TRes> {
24 pub id: i64,
25 pub correlation_id: i64,
26 pub timestamp: DateTime<Utc>,
27 pub status_code: i32,
28 pub headers: Value,
29 pub body: Option<Result<TRes, Bytes>>,
30 pub duration_ms: i64,
31 pub created_at: DateTime<Utc>,
32}
33
34#[derive(Debug)]
35pub struct RequestResponsePair<TReq, TRes> {
36 pub request: HttpRequest<TReq>,
37 pub response: Option<HttpResponse<TRes>>,
38}
39
40#[derive(Debug, Default)]
41pub struct RequestFilter {
42 pub correlation_id: Option<i64>,
43 pub method: Option<String>,
44 pub uri_pattern: Option<String>,
45 pub status_code: Option<i32>,
46 pub status_code_min: Option<i32>,
47 pub status_code_max: Option<i32>,
48 pub timestamp_after: Option<DateTime<Utc>>,
49 pub timestamp_before: Option<DateTime<Utc>>,
50 pub min_duration_ms: Option<i64>,
51 pub max_duration_ms: Option<i64>,
52 pub body_parsed: Option<bool>,
53 pub limit: Option<i64>,
54 pub offset: Option<i64>,
55 pub order_by_timestamp_desc: bool,
56}
57
58impl RequestFilter {
59 pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
60 let mut query = QueryBuilder::new(
61 r#"
62 SELECT
63 r.id as req_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
64 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
65 res.id as res_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
66 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
67 FROM http_requests r
68 LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id
69 "#,
70 );
71
72 let mut where_added = false;
73
74 if let Some(correlation_id) = self.correlation_id {
75 query.push(" WHERE r.correlation_id = ");
76 query.push_bind(correlation_id);
77 where_added = true;
78 }
79
80 if let Some(method) = &self.method {
81 if where_added {
82 query.push(" AND ");
83 } else {
84 query.push(" WHERE ");
85 where_added = true;
86 }
87 query.push("r.method = ");
88 query.push_bind(method);
89 }
90
91 if let Some(uri_pattern) = &self.uri_pattern {
92 if where_added {
93 query.push(" AND ");
94 } else {
95 query.push(" WHERE ");
96 where_added = true;
97 }
98 query.push("r.uri ILIKE ");
99 query.push_bind(uri_pattern);
100 }
101
102 if let Some(status_code) = self.status_code {
103 if where_added {
104 query.push(" AND ");
105 } else {
106 query.push(" WHERE ");
107 where_added = true;
108 }
109 query.push("res.status_code = ");
110 query.push_bind(status_code);
111 }
112
113 if let Some(min_status) = self.status_code_min {
114 if where_added {
115 query.push(" AND ");
116 } else {
117 query.push(" WHERE ");
118 where_added = true;
119 }
120 query.push("res.status_code >= ");
121 query.push_bind(min_status);
122 }
123
124 if let Some(max_status) = self.status_code_max {
125 if where_added {
126 query.push(" AND ");
127 } else {
128 query.push(" WHERE ");
129 where_added = true;
130 }
131 query.push("res.status_code <= ");
132 query.push_bind(max_status);
133 }
134
135 if let Some(timestamp_after) = self.timestamp_after {
136 if where_added {
137 query.push(" AND ");
138 } else {
139 query.push(" WHERE ");
140 where_added = true;
141 }
142 query.push("r.timestamp >= ");
143 query.push_bind(timestamp_after);
144 }
145
146 if let Some(timestamp_before) = self.timestamp_before {
147 if where_added {
148 query.push(" AND ");
149 } else {
150 query.push(" WHERE ");
151 where_added = true;
152 }
153 query.push("r.timestamp <= ");
154 query.push_bind(timestamp_before);
155 }
156
157 if let Some(min_duration) = self.min_duration_ms {
158 if where_added {
159 query.push(" AND ");
160 } else {
161 query.push(" WHERE ");
162 where_added = true;
163 }
164 query.push("res.duration_ms >= ");
165 query.push_bind(min_duration);
166 }
167
168 if let Some(max_duration) = self.max_duration_ms {
169 if where_added {
170 query.push(" AND ");
171 } else {
172 query.push(" WHERE ");
173 }
174 query.push("res.duration_ms <= ");
175 query.push_bind(max_duration);
176 }
177
178 if self.order_by_timestamp_desc {
179 query.push(" ORDER BY r.timestamp DESC");
180 } else {
181 query.push(" ORDER BY r.timestamp ASC");
182 }
183
184 if let Some(limit) = self.limit {
185 query.push(" LIMIT ");
186 query.push_bind(limit);
187 }
188
189 if let Some(offset) = self.offset {
190 query.push(" OFFSET ");
191 query.push_bind(offset);
192 }
193
194 query
195 }
196}
197
198#[derive(Clone)]
199pub struct RequestRepository<TReq, TRes> {
200 pool: PgPool,
201 _phantom_req: std::marker::PhantomData<TReq>,
202 _phantom_res: std::marker::PhantomData<TRes>,
203}
204
205impl<TReq, TRes> RequestRepository<TReq, TRes>
206where
207 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
208 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
209{
210 pub fn new(pool: PgPool) -> Self {
211 Self {
212 pool,
213 _phantom_req: std::marker::PhantomData,
214 _phantom_res: std::marker::PhantomData,
215 }
216 }
217
218 pub async fn query(
219 &self,
220 filter: RequestFilter,
221 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
222 let rows = filter
223 .build_query()
224 .build()
225 .fetch_all(&self.pool)
226 .await
227 .map_err(PostgresHandlerError::Query)?;
228
229 let mut pairs = Vec::new();
230 for row in rows {
231 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
232 let req_body_parsed = row
233 .try_get::<Option<bool>, _>("req_body_parsed")
234 .unwrap_or(Some(false));
235
236 let request_body = match req_body {
237 Some(json_value) => {
238 if req_body_parsed == Some(true) {
239 Some(Ok(serde_json::from_value::<TReq>(json_value)
241 .map_err(PostgresHandlerError::Json)?))
242 } else {
243 if let Value::String(utf8_str) = json_value {
245 Some(Err(Bytes::from(utf8_str.into_bytes())))
246 } else {
247 return Err(PostgresHandlerError::Json(Error::custom(
248 "Invalid body format",
249 )));
250 }
251 }
252 }
253 None => None,
254 };
255
256 let request = HttpRequest {
257 id: row.get("req_id"),
258 correlation_id: row.get("req_correlation_id"),
259 timestamp: row.get("req_timestamp"),
260 method: row.get("method"),
261 uri: row.get("uri"),
262 headers: row.get("req_headers"),
263 body: request_body,
264 created_at: row.get("req_created_at"),
265 };
266
267 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
268 res_id
269 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
270 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
271 let res_body_parsed = row
272 .try_get::<Option<bool>, _>("res_body_parsed")
273 .unwrap_or(Some(false));
274
275 let response_body = match res_body {
276 Some(json_value) => {
277 if res_body_parsed == Some(true) {
278 Some(Ok(serde_json::from_value::<TRes>(json_value)
280 .map_err(PostgresHandlerError::Json)?))
281 } else {
282 if let Value::String(utf8_str) = json_value {
284 Some(Err(Bytes::from(utf8_str.into_bytes())))
285 } else {
286 return Err(PostgresHandlerError::Json(Error::custom(
287 "Invalid body format",
288 )));
289 }
290 }
291 }
292 None => None,
293 };
294
295 Ok(HttpResponse {
296 id: row.get("res_id"),
297 correlation_id: row.get("res_correlation_id"),
298 timestamp: row.get("res_timestamp"),
299 status_code: row.get("status_code"),
300 headers: row.get("res_headers"),
301 body: response_body,
302 duration_ms: row.get("duration_ms"),
303 created_at: row.get("res_created_at"),
304 })
305 })
306 .transpose()?
307 } else {
308 None
309 };
310
311 pairs.push(RequestResponsePair { request, response });
312 }
313
314 Ok(pairs)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use chrono::DateTime;
322 use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
323
324 fn validate_sql(sql: &str) -> Result<(), String> {
325 let dialect = PostgreSqlDialect {};
326 Parser::parse_sql(&dialect, sql)
327 .map_err(|e| format!("SQL parse error: {e}"))
328 .map(|_| ())
329 }
330
331 #[test]
332 fn test_default_filter_generates_valid_sql() {
333 let filter = RequestFilter::default();
334 let query = filter.build_query();
335 let sql = query.sql();
336
337 validate_sql(sql).unwrap();
338 assert!(sql.contains("ORDER BY r.timestamp ASC"));
339 assert!(!sql.contains("WHERE"));
340 }
341
342 #[test]
343 fn test_correlation_id_filter() {
344 let filter = RequestFilter {
345 correlation_id: Some(123),
346 ..Default::default()
347 };
348 let query = filter.build_query();
349 let sql = query.sql();
350
351 validate_sql(sql).unwrap();
352 assert!(sql.contains("WHERE r.correlation_id = $1"));
353 }
354
355 #[test]
356 fn test_method_filter() {
357 let filter = RequestFilter {
358 method: Some("POST".to_string()),
359 ..Default::default()
360 };
361 let query = filter.build_query();
362 let sql = query.sql();
363
364 validate_sql(sql).unwrap();
365 assert!(sql.contains("WHERE r.method = $1"));
366 }
367
368 #[test]
369 fn test_uri_pattern_filter() {
370 let filter = RequestFilter {
371 uri_pattern: Some("/api/%".to_string()),
372 ..Default::default()
373 };
374 let query = filter.build_query();
375 let sql = query.sql();
376
377 validate_sql(sql).unwrap();
378 assert!(sql.contains("WHERE r.uri ILIKE $1"));
379 }
380
381 #[test]
382 fn test_status_code_exact_filter() {
383 let filter = RequestFilter {
384 status_code: Some(404),
385 ..Default::default()
386 };
387 let query = filter.build_query();
388 let sql = query.sql();
389
390 validate_sql(sql).unwrap();
391 assert!(sql.contains("WHERE res.status_code = $1"));
392 }
393
394 #[test]
395 fn test_status_code_range_filters() {
396 let filter = RequestFilter {
397 status_code_min: Some(400),
398 status_code_max: Some(499),
399 ..Default::default()
400 };
401 let query = filter.build_query();
402 let sql = query.sql();
403
404 validate_sql(sql).unwrap();
405 assert!(sql.contains("WHERE res.status_code >= $1"));
406 assert!(sql.contains("AND res.status_code <= $2"));
407 }
408
409 #[test]
410 fn test_timestamp_filters() {
411 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
412 .unwrap()
413 .with_timezone(&Utc);
414 let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
415 .unwrap()
416 .with_timezone(&Utc);
417
418 let filter = RequestFilter {
419 timestamp_after: Some(after),
420 timestamp_before: Some(before),
421 ..Default::default()
422 };
423 let query = filter.build_query();
424 let sql = query.sql();
425
426 validate_sql(sql).unwrap();
427 assert!(sql.contains("WHERE r.timestamp >= $1"));
428 assert!(sql.contains("AND r.timestamp <= $2"));
429 }
430
431 #[test]
432 fn test_duration_filters() {
433 let filter = RequestFilter {
434 min_duration_ms: Some(100),
435 max_duration_ms: Some(5000),
436 ..Default::default()
437 };
438 let query = filter.build_query();
439 let sql = query.sql();
440
441 validate_sql(sql).unwrap();
442 assert!(sql.contains("WHERE res.duration_ms >= $1"));
443 assert!(sql.contains("AND res.duration_ms <= $2"));
444 }
445
446 #[test]
447 fn test_ordering_desc() {
448 let filter = RequestFilter {
449 order_by_timestamp_desc: true,
450 ..Default::default()
451 };
452 let query = filter.build_query();
453 let sql = query.sql();
454
455 validate_sql(sql).unwrap();
456 assert!(sql.contains("ORDER BY r.timestamp DESC"));
457 }
458
459 #[test]
460 fn test_ordering_asc() {
461 let filter = RequestFilter {
462 order_by_timestamp_desc: false,
463 ..Default::default()
464 };
465 let query = filter.build_query();
466 let sql = query.sql();
467
468 validate_sql(sql).unwrap();
469 assert!(sql.contains("ORDER BY r.timestamp ASC"));
470 }
471
472 #[test]
473 fn test_pagination() {
474 let filter = RequestFilter {
475 limit: Some(10),
476 offset: Some(20),
477 ..Default::default()
478 };
479 let query = filter.build_query();
480 let sql = query.sql();
481
482 validate_sql(sql).unwrap();
483 assert!(sql.contains("LIMIT $1"));
484 assert!(sql.contains("OFFSET $2"));
485 }
486
487 #[test]
488 fn test_multiple_filters_use_and() {
489 let filter = RequestFilter {
490 correlation_id: Some(123),
491 method: Some("POST".to_string()),
492 status_code: Some(200),
493 ..Default::default()
494 };
495 let query = filter.build_query();
496 let sql = query.sql();
497
498 validate_sql(sql).unwrap();
499 assert!(sql.contains("WHERE r.correlation_id = $1"));
500 assert!(sql.contains("AND r.method = $2"));
501 assert!(sql.contains("AND res.status_code = $3"));
502
503 assert_eq!(sql.matches("WHERE").count(), 1);
505 assert!(sql.matches("AND").count() >= 2);
506 }
507
508 #[test]
509 fn test_complex_filter_combination() {
510 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
511 .unwrap()
512 .with_timezone(&Utc);
513
514 let filter = RequestFilter {
515 correlation_id: Some(456),
516 method: Some("GET".to_string()),
517 uri_pattern: Some("/api/users%".to_string()),
518 status_code_min: Some(200),
519 status_code_max: Some(299),
520 timestamp_after: Some(after),
521 min_duration_ms: Some(50),
522 max_duration_ms: Some(1000),
523 limit: Some(100),
524 offset: Some(0),
525 order_by_timestamp_desc: true,
526 ..Default::default()
527 };
528 let query = filter.build_query();
529 let sql = query.sql();
530
531 validate_sql(sql).unwrap();
532
533 assert!(sql.contains("WHERE r.correlation_id = $1"));
535 assert!(sql.contains("AND r.method = $2"));
536 assert!(sql.contains("AND r.uri ILIKE $3"));
537 assert!(sql.contains("AND res.status_code >= $4"));
538 assert!(sql.contains("AND res.status_code <= $5"));
539 assert!(sql.contains("AND r.timestamp >= $6"));
540 assert!(sql.contains("AND res.duration_ms >= $7"));
541 assert!(sql.contains("AND res.duration_ms <= $8"));
542 assert!(sql.contains("ORDER BY r.timestamp DESC"));
543 assert!(sql.contains("LIMIT $9"));
544 assert!(sql.contains("OFFSET $10"));
545
546 assert_eq!(sql.matches("WHERE").count(), 1);
548 }
549
550 #[test]
551 fn test_no_filters_only_has_base_query() {
552 let filter = RequestFilter::default();
553 let query = filter.build_query();
554 let sql = query.sql();
555
556 validate_sql(sql).unwrap();
557
558 assert!(sql.contains("SELECT"));
560 assert!(sql.contains("FROM http_requests r"));
561 assert!(
562 sql.contains("LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id")
563 );
564
565 assert!(!sql.contains("WHERE"));
567
568 assert!(sql.contains("ORDER BY r.timestamp ASC"));
570 }
571}