1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{PgPool, QueryBuilder, Row};
7use uuid::Uuid;
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13 pub id: i64,
14 pub instance_id: Uuid,
15 pub correlation_id: i64,
16 pub timestamp: DateTime<Utc>,
17 pub method: String,
18 pub uri: String,
19 pub headers: Value,
20 pub body: Option<Result<TReq, Bytes>>,
21 pub created_at: DateTime<Utc>,
22}
23
24#[derive(Debug)]
25pub struct HttpResponse<TRes> {
26 pub id: i64,
27 pub instance_id: Uuid,
28 pub correlation_id: i64,
29 pub timestamp: DateTime<Utc>,
30 pub status_code: i32,
31 pub headers: Value,
32 pub body: Option<Result<TRes, Bytes>>,
33 pub duration_to_first_byte_ms: i64,
34 pub duration_ms: i64,
35 pub created_at: DateTime<Utc>,
36}
37
38#[derive(Debug)]
39pub struct RequestResponsePair<TReq, TRes> {
40 pub request: HttpRequest<TReq>,
41 pub response: Option<HttpResponse<TRes>>,
42}
43
44#[derive(Debug, Default)]
45pub struct RequestFilter {
46 pub instance_id: Option<Uuid>,
47 pub correlation_id: Option<i64>,
48 pub method: Option<String>,
49 pub uri_pattern: Option<String>,
50 pub status_code: Option<i32>,
51 pub status_code_min: Option<i32>,
52 pub status_code_max: Option<i32>,
53 pub timestamp_after: Option<DateTime<Utc>>,
54 pub timestamp_before: Option<DateTime<Utc>>,
55 pub min_duration_ms: Option<i64>,
56 pub max_duration_ms: Option<i64>,
57 pub min_duration_to_first_byte_ms: Option<i64>,
58 pub max_duration_to_first_byte_ms: Option<i64>,
59 pub body_parsed: Option<bool>,
60 pub limit: Option<i64>,
61 pub offset: Option<i64>,
62 pub order_by_timestamp_desc: bool,
63}
64
65impl RequestFilter {
66 pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
67 let mut query = QueryBuilder::new(
68 r#"
69 SELECT
70 r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
71 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
72 res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
73 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_to_first_byte_ms, res.duration_ms, res.created_at as res_created_at
74 FROM http_requests r
75 LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
76 "#,
77 );
78
79 let mut where_added = false;
80
81 if let Some(instance_id) = self.instance_id {
82 query.push(" WHERE r.instance_id = ");
83 query.push_bind(instance_id);
84 where_added = true;
85 }
86
87 if let Some(correlation_id) = self.correlation_id {
88 if where_added {
89 query.push(" AND ");
90 } else {
91 query.push(" WHERE ");
92 where_added = true;
93 }
94 query.push("r.correlation_id = ");
95 query.push_bind(correlation_id);
96 }
97
98 if let Some(method) = &self.method {
99 if where_added {
100 query.push(" AND ");
101 } else {
102 query.push(" WHERE ");
103 where_added = true;
104 }
105 query.push("r.method = ");
106 query.push_bind(method);
107 }
108
109 if let Some(uri_pattern) = &self.uri_pattern {
110 if where_added {
111 query.push(" AND ");
112 } else {
113 query.push(" WHERE ");
114 where_added = true;
115 }
116 query.push("r.uri ILIKE ");
117 query.push_bind(uri_pattern);
118 }
119
120 if let Some(status_code) = self.status_code {
121 if where_added {
122 query.push(" AND ");
123 } else {
124 query.push(" WHERE ");
125 where_added = true;
126 }
127 query.push("res.status_code = ");
128 query.push_bind(status_code);
129 }
130
131 if let Some(min_status) = self.status_code_min {
132 if where_added {
133 query.push(" AND ");
134 } else {
135 query.push(" WHERE ");
136 where_added = true;
137 }
138 query.push("res.status_code >= ");
139 query.push_bind(min_status);
140 }
141
142 if let Some(max_status) = self.status_code_max {
143 if where_added {
144 query.push(" AND ");
145 } else {
146 query.push(" WHERE ");
147 where_added = true;
148 }
149 query.push("res.status_code <= ");
150 query.push_bind(max_status);
151 }
152
153 if let Some(timestamp_after) = self.timestamp_after {
154 if where_added {
155 query.push(" AND ");
156 } else {
157 query.push(" WHERE ");
158 where_added = true;
159 }
160 query.push("r.timestamp >= ");
161 query.push_bind(timestamp_after);
162 }
163
164 if let Some(timestamp_before) = self.timestamp_before {
165 if where_added {
166 query.push(" AND ");
167 } else {
168 query.push(" WHERE ");
169 where_added = true;
170 }
171 query.push("r.timestamp <= ");
172 query.push_bind(timestamp_before);
173 }
174
175 if let Some(min_duration) = self.min_duration_ms {
176 if where_added {
177 query.push(" AND ");
178 } else {
179 query.push(" WHERE ");
180 where_added = true;
181 }
182 query.push("res.duration_ms >= ");
183 query.push_bind(min_duration);
184 }
185
186 if let Some(max_duration) = self.max_duration_ms {
187 if where_added {
188 query.push(" AND ");
189 } else {
190 query.push(" WHERE ");
191 where_added = true;
192 }
193 query.push("res.duration_ms <= ");
194 query.push_bind(max_duration);
195 }
196
197 if let Some(min_duration_to_first_byte) = self.min_duration_to_first_byte_ms {
198 if where_added {
199 query.push(" AND ");
200 } else {
201 query.push(" WHERE ");
202 where_added = true;
203 }
204 query.push("res.duration_to_first_byte_ms >= ");
205 query.push_bind(min_duration_to_first_byte);
206 }
207
208 if let Some(max_duration_to_first_byte) = self.max_duration_to_first_byte_ms {
209 if where_added {
210 query.push(" AND ");
211 } else {
212 query.push(" WHERE ");
213 }
214 query.push("res.duration_to_first_byte_ms <= ");
215 query.push_bind(max_duration_to_first_byte);
216 }
217
218 if self.order_by_timestamp_desc {
219 query.push(" ORDER BY r.timestamp DESC");
220 } else {
221 query.push(" ORDER BY r.timestamp ASC");
222 }
223
224 if let Some(limit) = self.limit {
225 query.push(" LIMIT ");
226 query.push_bind(limit);
227 }
228
229 if let Some(offset) = self.offset {
230 query.push(" OFFSET ");
231 query.push_bind(offset);
232 }
233
234 query
235 }
236}
237
238#[derive(Clone)]
239pub struct RequestRepository<TReq, TRes> {
240 pool: PgPool,
241 _phantom_req: std::marker::PhantomData<TReq>,
242 _phantom_res: std::marker::PhantomData<TRes>,
243}
244
245impl<TReq, TRes> RequestRepository<TReq, TRes>
246where
247 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
248 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
249{
250 pub fn new(pool: PgPool) -> Self {
251 Self {
252 pool,
253 _phantom_req: std::marker::PhantomData,
254 _phantom_res: std::marker::PhantomData,
255 }
256 }
257
258 pub async fn query(
259 &self,
260 filter: RequestFilter,
261 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
262 let rows = filter
263 .build_query()
264 .build()
265 .fetch_all(&self.pool)
266 .await
267 .map_err(PostgresHandlerError::Query)?;
268
269 let mut pairs = Vec::new();
270 for row in rows {
271 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
272 let req_body_parsed = row
273 .try_get::<Option<bool>, _>("req_body_parsed")
274 .unwrap_or(Some(false));
275
276 let request_body = match req_body {
277 Some(json_value) => {
278 if req_body_parsed == Some(true) {
279 Some(Ok(serde_json::from_value::<TReq>(json_value)
281 .map_err(PostgresHandlerError::Json)?))
282 } else {
283 if let Value::String(utf8_str) = json_value {
285 Some(Err(Bytes::from(utf8_str.into_bytes())))
286 } else {
287 return Err(PostgresHandlerError::Json(Error::custom(
288 "Invalid body format",
289 )));
290 }
291 }
292 }
293 None => None,
294 };
295
296 let request = HttpRequest {
297 id: row.get("req_id"),
298 instance_id: row.get("req_instance_id"),
299 correlation_id: row.get("req_correlation_id"),
300 timestamp: row.get("req_timestamp"),
301 method: row.get("method"),
302 uri: row.get("uri"),
303 headers: row.get("req_headers"),
304 body: request_body,
305 created_at: row.get("req_created_at"),
306 };
307
308 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
309 res_id
310 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
311 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
312 let res_body_parsed = row
313 .try_get::<Option<bool>, _>("res_body_parsed")
314 .unwrap_or(Some(false));
315
316 let response_body = match res_body {
317 Some(json_value) => {
318 if res_body_parsed == Some(true) {
319 Some(Ok(serde_json::from_value::<TRes>(json_value)
321 .map_err(PostgresHandlerError::Json)?))
322 } else {
323 if let Value::String(utf8_str) = json_value {
325 Some(Err(Bytes::from(utf8_str.into_bytes())))
326 } else {
327 return Err(PostgresHandlerError::Json(Error::custom(
328 "Invalid body format",
329 )));
330 }
331 }
332 }
333 None => None,
334 };
335
336 Ok(HttpResponse {
337 id: row.get("res_id"),
338 instance_id: row.get("res_instance_id"),
339 correlation_id: row.get("res_correlation_id"),
340 timestamp: row.get("res_timestamp"),
341 status_code: row.get("status_code"),
342 headers: row.get("res_headers"),
343 body: response_body,
344 duration_to_first_byte_ms: row.get("duration_to_first_byte_ms"),
345 duration_ms: row.get("duration_ms"),
346 created_at: row.get("res_created_at"),
347 })
348 })
349 .transpose()?
350 } else {
351 None
352 };
353
354 pairs.push(RequestResponsePair { request, response });
355 }
356
357 Ok(pairs)
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use chrono::DateTime;
365 use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
366
367 fn validate_sql(sql: &str) -> Result<(), String> {
368 let dialect = PostgreSqlDialect {};
369 Parser::parse_sql(&dialect, sql)
370 .map_err(|e| format!("SQL parse error: {e}"))
371 .map(|_| ())
372 }
373
374 #[test]
375 fn test_default_filter_generates_valid_sql() {
376 let filter = RequestFilter::default();
377 let query = filter.build_query();
378 let sql = query.sql();
379
380 validate_sql(sql).unwrap();
381 assert!(sql.contains("ORDER BY r.timestamp ASC"));
382 assert!(!sql.contains("WHERE"));
383 }
384
385 #[test]
386 fn test_correlation_id_filter() {
387 let filter = RequestFilter {
388 correlation_id: Some(123),
389 ..Default::default()
390 };
391 let query = filter.build_query();
392 let sql = query.sql();
393
394 validate_sql(sql).unwrap();
395 assert!(sql.contains("WHERE r.correlation_id = $1"));
396 }
397
398 #[test]
399 fn test_method_filter() {
400 let filter = RequestFilter {
401 method: Some("POST".to_string()),
402 ..Default::default()
403 };
404 let query = filter.build_query();
405 let sql = query.sql();
406
407 validate_sql(sql).unwrap();
408 assert!(sql.contains("WHERE r.method = $1"));
409 }
410
411 #[test]
412 fn test_uri_pattern_filter() {
413 let filter = RequestFilter {
414 uri_pattern: Some("/api/%".to_string()),
415 ..Default::default()
416 };
417 let query = filter.build_query();
418 let sql = query.sql();
419
420 validate_sql(sql).unwrap();
421 assert!(sql.contains("WHERE r.uri ILIKE $1"));
422 }
423
424 #[test]
425 fn test_status_code_exact_filter() {
426 let filter = RequestFilter {
427 status_code: Some(404),
428 ..Default::default()
429 };
430 let query = filter.build_query();
431 let sql = query.sql();
432
433 validate_sql(sql).unwrap();
434 assert!(sql.contains("WHERE res.status_code = $1"));
435 }
436
437 #[test]
438 fn test_status_code_range_filters() {
439 let filter = RequestFilter {
440 status_code_min: Some(400),
441 status_code_max: Some(499),
442 ..Default::default()
443 };
444 let query = filter.build_query();
445 let sql = query.sql();
446
447 validate_sql(sql).unwrap();
448 assert!(sql.contains("WHERE res.status_code >= $1"));
449 assert!(sql.contains("AND res.status_code <= $2"));
450 }
451
452 #[test]
453 fn test_timestamp_filters() {
454 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
455 .unwrap()
456 .with_timezone(&Utc);
457 let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
458 .unwrap()
459 .with_timezone(&Utc);
460
461 let filter = RequestFilter {
462 timestamp_after: Some(after),
463 timestamp_before: Some(before),
464 ..Default::default()
465 };
466 let query = filter.build_query();
467 let sql = query.sql();
468
469 validate_sql(sql).unwrap();
470 assert!(sql.contains("WHERE r.timestamp >= $1"));
471 assert!(sql.contains("AND r.timestamp <= $2"));
472 }
473
474 #[test]
475 fn test_duration_filters() {
476 let filter = RequestFilter {
477 min_duration_ms: Some(100),
478 max_duration_ms: Some(5000),
479 ..Default::default()
480 };
481 let query = filter.build_query();
482 let sql = query.sql();
483
484 validate_sql(sql).unwrap();
485 assert!(sql.contains("WHERE res.duration_ms >= $1"));
486 assert!(sql.contains("AND res.duration_ms <= $2"));
487 }
488
489 #[test]
490 fn test_ordering_desc() {
491 let filter = RequestFilter {
492 order_by_timestamp_desc: true,
493 ..Default::default()
494 };
495 let query = filter.build_query();
496 let sql = query.sql();
497
498 validate_sql(sql).unwrap();
499 assert!(sql.contains("ORDER BY r.timestamp DESC"));
500 }
501
502 #[test]
503 fn test_ordering_asc() {
504 let filter = RequestFilter {
505 order_by_timestamp_desc: false,
506 ..Default::default()
507 };
508 let query = filter.build_query();
509 let sql = query.sql();
510
511 validate_sql(sql).unwrap();
512 assert!(sql.contains("ORDER BY r.timestamp ASC"));
513 }
514
515 #[test]
516 fn test_pagination() {
517 let filter = RequestFilter {
518 limit: Some(10),
519 offset: Some(20),
520 ..Default::default()
521 };
522 let query = filter.build_query();
523 let sql = query.sql();
524
525 validate_sql(sql).unwrap();
526 assert!(sql.contains("LIMIT $1"));
527 assert!(sql.contains("OFFSET $2"));
528 }
529
530 #[test]
531 fn test_multiple_filters_use_and() {
532 let filter = RequestFilter {
533 correlation_id: Some(123),
534 method: Some("POST".to_string()),
535 status_code: Some(200),
536 ..Default::default()
537 };
538 let query = filter.build_query();
539 let sql = query.sql();
540
541 validate_sql(sql).unwrap();
542 assert!(sql.contains("WHERE r.correlation_id = $1"));
543 assert!(sql.contains("AND r.method = $2"));
544 assert!(sql.contains("AND res.status_code = $3"));
545
546 assert_eq!(sql.matches("WHERE").count(), 1);
548 assert!(sql.matches("AND").count() >= 2);
549 }
550
551 #[test]
552 fn test_complex_filter_combination() {
553 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
554 .unwrap()
555 .with_timezone(&Utc);
556
557 let filter = RequestFilter {
558 correlation_id: Some(456),
559 method: Some("GET".to_string()),
560 uri_pattern: Some("/api/users%".to_string()),
561 status_code_min: Some(200),
562 status_code_max: Some(299),
563 timestamp_after: Some(after),
564 min_duration_ms: Some(50),
565 max_duration_ms: Some(1000),
566 limit: Some(100),
567 offset: Some(0),
568 order_by_timestamp_desc: true,
569 ..Default::default()
570 };
571 let query = filter.build_query();
572 let sql = query.sql();
573
574 validate_sql(sql).unwrap();
575
576 assert!(sql.contains("WHERE r.correlation_id = $1"));
578 assert!(sql.contains("AND r.method = $2"));
579 assert!(sql.contains("AND r.uri ILIKE $3"));
580 assert!(sql.contains("AND res.status_code >= $4"));
581 assert!(sql.contains("AND res.status_code <= $5"));
582 assert!(sql.contains("AND r.timestamp >= $6"));
583 assert!(sql.contains("AND res.duration_ms >= $7"));
584 assert!(sql.contains("AND res.duration_ms <= $8"));
585 assert!(sql.contains("ORDER BY r.timestamp DESC"));
586 assert!(sql.contains("LIMIT $9"));
587 assert!(sql.contains("OFFSET $10"));
588
589 assert_eq!(sql.matches("WHERE").count(), 1);
591 }
592
593 #[test]
594 fn test_no_filters_only_has_base_query() {
595 let filter = RequestFilter::default();
596 let query = filter.build_query();
597 let sql = query.sql();
598
599 validate_sql(sql).unwrap();
600
601 assert!(sql.contains("SELECT"));
603 assert!(sql.contains("FROM http_requests r"));
604 assert!(
605 sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
606 );
607
608 assert!(!sql.contains("WHERE"));
610
611 assert!(sql.contains("ORDER BY r.timestamp ASC"));
613 }
614}