1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{PgPool, QueryBuilder, Row};
7use uuid::Uuid;
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13 pub id: i64,
14 pub instance_id: Uuid,
15 pub correlation_id: i64,
16 pub timestamp: DateTime<Utc>,
17 pub method: String,
18 pub uri: String,
19 pub headers: Value,
20 pub body: Option<Result<TReq, Bytes>>,
21 pub created_at: DateTime<Utc>,
22}
23
24#[derive(Debug)]
25pub struct HttpResponse<TRes> {
26 pub id: i64,
27 pub instance_id: Uuid,
28 pub correlation_id: i64,
29 pub timestamp: DateTime<Utc>,
30 pub status_code: i32,
31 pub headers: Value,
32 pub body: Option<Result<TRes, Bytes>>,
33 pub duration_ms: i64,
34 pub created_at: DateTime<Utc>,
35}
36
37#[derive(Debug)]
38pub struct RequestResponsePair<TReq, TRes> {
39 pub request: HttpRequest<TReq>,
40 pub response: Option<HttpResponse<TRes>>,
41}
42
43#[derive(Debug, Default)]
44pub struct RequestFilter {
45 pub instance_id: Option<Uuid>,
46 pub correlation_id: Option<i64>,
47 pub method: Option<String>,
48 pub uri_pattern: Option<String>,
49 pub status_code: Option<i32>,
50 pub status_code_min: Option<i32>,
51 pub status_code_max: Option<i32>,
52 pub timestamp_after: Option<DateTime<Utc>>,
53 pub timestamp_before: Option<DateTime<Utc>>,
54 pub min_duration_ms: Option<i64>,
55 pub max_duration_ms: Option<i64>,
56 pub body_parsed: Option<bool>,
57 pub limit: Option<i64>,
58 pub offset: Option<i64>,
59 pub order_by_timestamp_desc: bool,
60}
61
62impl RequestFilter {
63 pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
64 let mut query = QueryBuilder::new(
65 r#"
66 SELECT
67 r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
68 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
69 res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
70 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
71 FROM http_requests r
72 LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
73 "#,
74 );
75
76 let mut where_added = false;
77
78 if let Some(instance_id) = self.instance_id {
79 query.push(" WHERE r.instance_id = ");
80 query.push_bind(instance_id);
81 where_added = true;
82 }
83
84 if let Some(correlation_id) = self.correlation_id {
85 if where_added {
86 query.push(" AND ");
87 } else {
88 query.push(" WHERE ");
89 where_added = true;
90 }
91 query.push("r.correlation_id = ");
92 query.push_bind(correlation_id);
93 }
94
95 if let Some(method) = &self.method {
96 if where_added {
97 query.push(" AND ");
98 } else {
99 query.push(" WHERE ");
100 where_added = true;
101 }
102 query.push("r.method = ");
103 query.push_bind(method);
104 }
105
106 if let Some(uri_pattern) = &self.uri_pattern {
107 if where_added {
108 query.push(" AND ");
109 } else {
110 query.push(" WHERE ");
111 where_added = true;
112 }
113 query.push("r.uri ILIKE ");
114 query.push_bind(uri_pattern);
115 }
116
117 if let Some(status_code) = self.status_code {
118 if where_added {
119 query.push(" AND ");
120 } else {
121 query.push(" WHERE ");
122 where_added = true;
123 }
124 query.push("res.status_code = ");
125 query.push_bind(status_code);
126 }
127
128 if let Some(min_status) = self.status_code_min {
129 if where_added {
130 query.push(" AND ");
131 } else {
132 query.push(" WHERE ");
133 where_added = true;
134 }
135 query.push("res.status_code >= ");
136 query.push_bind(min_status);
137 }
138
139 if let Some(max_status) = self.status_code_max {
140 if where_added {
141 query.push(" AND ");
142 } else {
143 query.push(" WHERE ");
144 where_added = true;
145 }
146 query.push("res.status_code <= ");
147 query.push_bind(max_status);
148 }
149
150 if let Some(timestamp_after) = self.timestamp_after {
151 if where_added {
152 query.push(" AND ");
153 } else {
154 query.push(" WHERE ");
155 where_added = true;
156 }
157 query.push("r.timestamp >= ");
158 query.push_bind(timestamp_after);
159 }
160
161 if let Some(timestamp_before) = self.timestamp_before {
162 if where_added {
163 query.push(" AND ");
164 } else {
165 query.push(" WHERE ");
166 where_added = true;
167 }
168 query.push("r.timestamp <= ");
169 query.push_bind(timestamp_before);
170 }
171
172 if let Some(min_duration) = self.min_duration_ms {
173 if where_added {
174 query.push(" AND ");
175 } else {
176 query.push(" WHERE ");
177 where_added = true;
178 }
179 query.push("res.duration_ms >= ");
180 query.push_bind(min_duration);
181 }
182
183 if let Some(max_duration) = self.max_duration_ms {
184 if where_added {
185 query.push(" AND ");
186 } else {
187 query.push(" WHERE ");
188 }
189 query.push("res.duration_ms <= ");
190 query.push_bind(max_duration);
191 }
192
193 if self.order_by_timestamp_desc {
194 query.push(" ORDER BY r.timestamp DESC");
195 } else {
196 query.push(" ORDER BY r.timestamp ASC");
197 }
198
199 if let Some(limit) = self.limit {
200 query.push(" LIMIT ");
201 query.push_bind(limit);
202 }
203
204 if let Some(offset) = self.offset {
205 query.push(" OFFSET ");
206 query.push_bind(offset);
207 }
208
209 query
210 }
211}
212
213#[derive(Clone)]
214pub struct RequestRepository<TReq, TRes> {
215 pool: PgPool,
216 _phantom_req: std::marker::PhantomData<TReq>,
217 _phantom_res: std::marker::PhantomData<TRes>,
218}
219
220impl<TReq, TRes> RequestRepository<TReq, TRes>
221where
222 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
223 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
224{
225 pub fn new(pool: PgPool) -> Self {
226 Self {
227 pool,
228 _phantom_req: std::marker::PhantomData,
229 _phantom_res: std::marker::PhantomData,
230 }
231 }
232
233 pub async fn query(
234 &self,
235 filter: RequestFilter,
236 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
237 let rows = filter
238 .build_query()
239 .build()
240 .fetch_all(&self.pool)
241 .await
242 .map_err(PostgresHandlerError::Query)?;
243
244 let mut pairs = Vec::new();
245 for row in rows {
246 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
247 let req_body_parsed = row
248 .try_get::<Option<bool>, _>("req_body_parsed")
249 .unwrap_or(Some(false));
250
251 let request_body = match req_body {
252 Some(json_value) => {
253 if req_body_parsed == Some(true) {
254 Some(Ok(serde_json::from_value::<TReq>(json_value)
256 .map_err(PostgresHandlerError::Json)?))
257 } else {
258 if let Value::String(utf8_str) = json_value {
260 Some(Err(Bytes::from(utf8_str.into_bytes())))
261 } else {
262 return Err(PostgresHandlerError::Json(Error::custom(
263 "Invalid body format",
264 )));
265 }
266 }
267 }
268 None => None,
269 };
270
271 let request = HttpRequest {
272 id: row.get("req_id"),
273 instance_id: row.get("req_instance_id"),
274 correlation_id: row.get("req_correlation_id"),
275 timestamp: row.get("req_timestamp"),
276 method: row.get("method"),
277 uri: row.get("uri"),
278 headers: row.get("req_headers"),
279 body: request_body,
280 created_at: row.get("req_created_at"),
281 };
282
283 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
284 res_id
285 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
286 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
287 let res_body_parsed = row
288 .try_get::<Option<bool>, _>("res_body_parsed")
289 .unwrap_or(Some(false));
290
291 let response_body = match res_body {
292 Some(json_value) => {
293 if res_body_parsed == Some(true) {
294 Some(Ok(serde_json::from_value::<TRes>(json_value)
296 .map_err(PostgresHandlerError::Json)?))
297 } else {
298 if let Value::String(utf8_str) = json_value {
300 Some(Err(Bytes::from(utf8_str.into_bytes())))
301 } else {
302 return Err(PostgresHandlerError::Json(Error::custom(
303 "Invalid body format",
304 )));
305 }
306 }
307 }
308 None => None,
309 };
310
311 Ok(HttpResponse {
312 id: row.get("res_id"),
313 instance_id: row.get("res_instance_id"),
314 correlation_id: row.get("res_correlation_id"),
315 timestamp: row.get("res_timestamp"),
316 status_code: row.get("status_code"),
317 headers: row.get("res_headers"),
318 body: response_body,
319 duration_ms: row.get("duration_ms"),
320 created_at: row.get("res_created_at"),
321 })
322 })
323 .transpose()?
324 } else {
325 None
326 };
327
328 pairs.push(RequestResponsePair { request, response });
329 }
330
331 Ok(pairs)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use chrono::DateTime;
339 use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
340
341 fn validate_sql(sql: &str) -> Result<(), String> {
342 let dialect = PostgreSqlDialect {};
343 Parser::parse_sql(&dialect, sql)
344 .map_err(|e| format!("SQL parse error: {e}"))
345 .map(|_| ())
346 }
347
348 #[test]
349 fn test_default_filter_generates_valid_sql() {
350 let filter = RequestFilter::default();
351 let query = filter.build_query();
352 let sql = query.sql();
353
354 validate_sql(sql).unwrap();
355 assert!(sql.contains("ORDER BY r.timestamp ASC"));
356 assert!(!sql.contains("WHERE"));
357 }
358
359 #[test]
360 fn test_correlation_id_filter() {
361 let filter = RequestFilter {
362 correlation_id: Some(123),
363 ..Default::default()
364 };
365 let query = filter.build_query();
366 let sql = query.sql();
367
368 validate_sql(sql).unwrap();
369 assert!(sql.contains("WHERE r.correlation_id = $1"));
370 }
371
372 #[test]
373 fn test_method_filter() {
374 let filter = RequestFilter {
375 method: Some("POST".to_string()),
376 ..Default::default()
377 };
378 let query = filter.build_query();
379 let sql = query.sql();
380
381 validate_sql(sql).unwrap();
382 assert!(sql.contains("WHERE r.method = $1"));
383 }
384
385 #[test]
386 fn test_uri_pattern_filter() {
387 let filter = RequestFilter {
388 uri_pattern: Some("/api/%".to_string()),
389 ..Default::default()
390 };
391 let query = filter.build_query();
392 let sql = query.sql();
393
394 validate_sql(sql).unwrap();
395 assert!(sql.contains("WHERE r.uri ILIKE $1"));
396 }
397
398 #[test]
399 fn test_status_code_exact_filter() {
400 let filter = RequestFilter {
401 status_code: Some(404),
402 ..Default::default()
403 };
404 let query = filter.build_query();
405 let sql = query.sql();
406
407 validate_sql(sql).unwrap();
408 assert!(sql.contains("WHERE res.status_code = $1"));
409 }
410
411 #[test]
412 fn test_status_code_range_filters() {
413 let filter = RequestFilter {
414 status_code_min: Some(400),
415 status_code_max: Some(499),
416 ..Default::default()
417 };
418 let query = filter.build_query();
419 let sql = query.sql();
420
421 validate_sql(sql).unwrap();
422 assert!(sql.contains("WHERE res.status_code >= $1"));
423 assert!(sql.contains("AND res.status_code <= $2"));
424 }
425
426 #[test]
427 fn test_timestamp_filters() {
428 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
429 .unwrap()
430 .with_timezone(&Utc);
431 let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
432 .unwrap()
433 .with_timezone(&Utc);
434
435 let filter = RequestFilter {
436 timestamp_after: Some(after),
437 timestamp_before: Some(before),
438 ..Default::default()
439 };
440 let query = filter.build_query();
441 let sql = query.sql();
442
443 validate_sql(sql).unwrap();
444 assert!(sql.contains("WHERE r.timestamp >= $1"));
445 assert!(sql.contains("AND r.timestamp <= $2"));
446 }
447
448 #[test]
449 fn test_duration_filters() {
450 let filter = RequestFilter {
451 min_duration_ms: Some(100),
452 max_duration_ms: Some(5000),
453 ..Default::default()
454 };
455 let query = filter.build_query();
456 let sql = query.sql();
457
458 validate_sql(sql).unwrap();
459 assert!(sql.contains("WHERE res.duration_ms >= $1"));
460 assert!(sql.contains("AND res.duration_ms <= $2"));
461 }
462
463 #[test]
464 fn test_ordering_desc() {
465 let filter = RequestFilter {
466 order_by_timestamp_desc: true,
467 ..Default::default()
468 };
469 let query = filter.build_query();
470 let sql = query.sql();
471
472 validate_sql(sql).unwrap();
473 assert!(sql.contains("ORDER BY r.timestamp DESC"));
474 }
475
476 #[test]
477 fn test_ordering_asc() {
478 let filter = RequestFilter {
479 order_by_timestamp_desc: false,
480 ..Default::default()
481 };
482 let query = filter.build_query();
483 let sql = query.sql();
484
485 validate_sql(sql).unwrap();
486 assert!(sql.contains("ORDER BY r.timestamp ASC"));
487 }
488
489 #[test]
490 fn test_pagination() {
491 let filter = RequestFilter {
492 limit: Some(10),
493 offset: Some(20),
494 ..Default::default()
495 };
496 let query = filter.build_query();
497 let sql = query.sql();
498
499 validate_sql(sql).unwrap();
500 assert!(sql.contains("LIMIT $1"));
501 assert!(sql.contains("OFFSET $2"));
502 }
503
504 #[test]
505 fn test_multiple_filters_use_and() {
506 let filter = RequestFilter {
507 correlation_id: Some(123),
508 method: Some("POST".to_string()),
509 status_code: Some(200),
510 ..Default::default()
511 };
512 let query = filter.build_query();
513 let sql = query.sql();
514
515 validate_sql(sql).unwrap();
516 assert!(sql.contains("WHERE r.correlation_id = $1"));
517 assert!(sql.contains("AND r.method = $2"));
518 assert!(sql.contains("AND res.status_code = $3"));
519
520 assert_eq!(sql.matches("WHERE").count(), 1);
522 assert!(sql.matches("AND").count() >= 2);
523 }
524
525 #[test]
526 fn test_complex_filter_combination() {
527 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
528 .unwrap()
529 .with_timezone(&Utc);
530
531 let filter = RequestFilter {
532 correlation_id: Some(456),
533 method: Some("GET".to_string()),
534 uri_pattern: Some("/api/users%".to_string()),
535 status_code_min: Some(200),
536 status_code_max: Some(299),
537 timestamp_after: Some(after),
538 min_duration_ms: Some(50),
539 max_duration_ms: Some(1000),
540 limit: Some(100),
541 offset: Some(0),
542 order_by_timestamp_desc: true,
543 ..Default::default()
544 };
545 let query = filter.build_query();
546 let sql = query.sql();
547
548 validate_sql(sql).unwrap();
549
550 assert!(sql.contains("WHERE r.correlation_id = $1"));
552 assert!(sql.contains("AND r.method = $2"));
553 assert!(sql.contains("AND r.uri ILIKE $3"));
554 assert!(sql.contains("AND res.status_code >= $4"));
555 assert!(sql.contains("AND res.status_code <= $5"));
556 assert!(sql.contains("AND r.timestamp >= $6"));
557 assert!(sql.contains("AND res.duration_ms >= $7"));
558 assert!(sql.contains("AND res.duration_ms <= $8"));
559 assert!(sql.contains("ORDER BY r.timestamp DESC"));
560 assert!(sql.contains("LIMIT $9"));
561 assert!(sql.contains("OFFSET $10"));
562
563 assert_eq!(sql.matches("WHERE").count(), 1);
565 }
566
567 #[test]
568 fn test_no_filters_only_has_base_query() {
569 let filter = RequestFilter::default();
570 let query = filter.build_query();
571 let sql = query.sql();
572
573 validate_sql(sql).unwrap();
574
575 assert!(sql.contains("SELECT"));
577 assert!(sql.contains("FROM http_requests r"));
578 assert!(
579 sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
580 );
581
582 assert!(!sql.contains("WHERE"));
584
585 assert!(sql.contains("ORDER BY r.timestamp ASC"));
587 }
588}