1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{QueryBuilder, Row};
7use sqlx_pool_router::PoolProvider;
8use uuid::Uuid;
9
10use crate::error::PostgresHandlerError;
11
12#[derive(Debug)]
13pub struct HttpRequest<TReq> {
14 pub id: i64,
15 pub instance_id: Uuid,
16 pub correlation_id: i64,
17 pub timestamp: DateTime<Utc>,
18 pub method: String,
19 pub uri: String,
20 pub headers: Value,
21 pub body: Option<Result<TReq, Bytes>>,
22 pub created_at: DateTime<Utc>,
23}
24
25#[derive(Debug)]
26pub struct HttpResponse<TRes> {
27 pub id: i64,
28 pub instance_id: Uuid,
29 pub correlation_id: i64,
30 pub timestamp: DateTime<Utc>,
31 pub status_code: i32,
32 pub headers: Value,
33 pub body: Option<Result<TRes, Bytes>>,
34 pub duration_to_first_byte_ms: i64,
35 pub duration_ms: i64,
36 pub created_at: DateTime<Utc>,
37}
38
39#[derive(Debug)]
40pub struct RequestResponsePair<TReq, TRes> {
41 pub request: HttpRequest<TReq>,
42 pub response: Option<HttpResponse<TRes>>,
43}
44
45#[derive(Debug, Default)]
46pub struct RequestFilter {
47 pub instance_id: Option<Uuid>,
48 pub correlation_id: Option<i64>,
49 pub method: Option<String>,
50 pub uri_pattern: Option<String>,
51 pub status_code: Option<i32>,
52 pub status_code_min: Option<i32>,
53 pub status_code_max: Option<i32>,
54 pub timestamp_after: Option<DateTime<Utc>>,
55 pub timestamp_before: Option<DateTime<Utc>>,
56 pub min_duration_ms: Option<i64>,
57 pub max_duration_ms: Option<i64>,
58 pub min_duration_to_first_byte_ms: Option<i64>,
59 pub max_duration_to_first_byte_ms: Option<i64>,
60 pub body_parsed: Option<bool>,
61 pub limit: Option<i64>,
62 pub offset: Option<i64>,
63 pub order_by_timestamp_desc: bool,
64}
65
66impl RequestFilter {
67 pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
68 let mut query = QueryBuilder::new(
69 r#"
70 SELECT
71 r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
72 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
73 res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
74 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_to_first_byte_ms, res.duration_ms, res.created_at as res_created_at
75 FROM http_requests r
76 LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
77 "#,
78 );
79
80 let mut where_added = false;
81
82 if let Some(instance_id) = self.instance_id {
83 query.push(" WHERE r.instance_id = ");
84 query.push_bind(instance_id);
85 where_added = true;
86 }
87
88 if let Some(correlation_id) = self.correlation_id {
89 if where_added {
90 query.push(" AND ");
91 } else {
92 query.push(" WHERE ");
93 where_added = true;
94 }
95 query.push("r.correlation_id = ");
96 query.push_bind(correlation_id);
97 }
98
99 if let Some(method) = &self.method {
100 if where_added {
101 query.push(" AND ");
102 } else {
103 query.push(" WHERE ");
104 where_added = true;
105 }
106 query.push("r.method = ");
107 query.push_bind(method);
108 }
109
110 if let Some(uri_pattern) = &self.uri_pattern {
111 if where_added {
112 query.push(" AND ");
113 } else {
114 query.push(" WHERE ");
115 where_added = true;
116 }
117 query.push("r.uri ILIKE ");
118 query.push_bind(uri_pattern);
119 }
120
121 if let Some(status_code) = self.status_code {
122 if where_added {
123 query.push(" AND ");
124 } else {
125 query.push(" WHERE ");
126 where_added = true;
127 }
128 query.push("res.status_code = ");
129 query.push_bind(status_code);
130 }
131
132 if let Some(min_status) = self.status_code_min {
133 if where_added {
134 query.push(" AND ");
135 } else {
136 query.push(" WHERE ");
137 where_added = true;
138 }
139 query.push("res.status_code >= ");
140 query.push_bind(min_status);
141 }
142
143 if let Some(max_status) = self.status_code_max {
144 if where_added {
145 query.push(" AND ");
146 } else {
147 query.push(" WHERE ");
148 where_added = true;
149 }
150 query.push("res.status_code <= ");
151 query.push_bind(max_status);
152 }
153
154 if let Some(timestamp_after) = self.timestamp_after {
155 if where_added {
156 query.push(" AND ");
157 } else {
158 query.push(" WHERE ");
159 where_added = true;
160 }
161 query.push("r.timestamp >= ");
162 query.push_bind(timestamp_after);
163 }
164
165 if let Some(timestamp_before) = self.timestamp_before {
166 if where_added {
167 query.push(" AND ");
168 } else {
169 query.push(" WHERE ");
170 where_added = true;
171 }
172 query.push("r.timestamp <= ");
173 query.push_bind(timestamp_before);
174 }
175
176 if let Some(min_duration) = self.min_duration_ms {
177 if where_added {
178 query.push(" AND ");
179 } else {
180 query.push(" WHERE ");
181 where_added = true;
182 }
183 query.push("res.duration_ms >= ");
184 query.push_bind(min_duration);
185 }
186
187 if let Some(max_duration) = self.max_duration_ms {
188 if where_added {
189 query.push(" AND ");
190 } else {
191 query.push(" WHERE ");
192 where_added = true;
193 }
194 query.push("res.duration_ms <= ");
195 query.push_bind(max_duration);
196 }
197
198 if let Some(min_duration_to_first_byte) = self.min_duration_to_first_byte_ms {
199 if where_added {
200 query.push(" AND ");
201 } else {
202 query.push(" WHERE ");
203 where_added = true;
204 }
205 query.push("res.duration_to_first_byte_ms >= ");
206 query.push_bind(min_duration_to_first_byte);
207 }
208
209 if let Some(max_duration_to_first_byte) = self.max_duration_to_first_byte_ms {
210 if where_added {
211 query.push(" AND ");
212 } else {
213 query.push(" WHERE ");
214 }
215 query.push("res.duration_to_first_byte_ms <= ");
216 query.push_bind(max_duration_to_first_byte);
217 }
218
219 if self.order_by_timestamp_desc {
220 query.push(" ORDER BY r.timestamp DESC");
221 } else {
222 query.push(" ORDER BY r.timestamp ASC");
223 }
224
225 if let Some(limit) = self.limit {
226 query.push(" LIMIT ");
227 query.push_bind(limit);
228 }
229
230 if let Some(offset) = self.offset {
231 query.push(" OFFSET ");
232 query.push_bind(offset);
233 }
234
235 query
236 }
237}
238
239#[derive(Clone)]
240pub struct RequestRepository<P, TReq, TRes>
241where
242 P: PoolProvider,
243{
244 pool: P,
245 _phantom_req: std::marker::PhantomData<TReq>,
246 _phantom_res: std::marker::PhantomData<TRes>,
247}
248
249impl<P, TReq, TRes> RequestRepository<P, TReq, TRes>
250where
251 P: PoolProvider,
252 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
253 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
254{
255 pub fn new(pool: P) -> Self {
256 Self {
257 pool,
258 _phantom_req: std::marker::PhantomData,
259 _phantom_res: std::marker::PhantomData,
260 }
261 }
262
263 pub async fn query(
264 &self,
265 filter: RequestFilter,
266 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
267 let rows = filter
268 .build_query()
269 .build()
270 .fetch_all(self.pool.read())
271 .await
272 .map_err(PostgresHandlerError::Query)?;
273
274 let mut pairs = Vec::new();
275 for row in rows {
276 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
277 let req_body_parsed = row
278 .try_get::<Option<bool>, _>("req_body_parsed")
279 .unwrap_or(Some(false));
280
281 let request_body = match req_body {
282 Some(json_value) => {
283 if req_body_parsed == Some(true) {
284 Some(Ok(serde_json::from_value::<TReq>(json_value)
286 .map_err(PostgresHandlerError::Json)?))
287 } else {
288 if let Value::String(utf8_str) = json_value {
290 Some(Err(Bytes::from(utf8_str.into_bytes())))
291 } else {
292 return Err(PostgresHandlerError::Json(Error::custom(
293 "Invalid body format",
294 )));
295 }
296 }
297 }
298 None => None,
299 };
300
301 let request = HttpRequest {
302 id: row.get("req_id"),
303 instance_id: row.get("req_instance_id"),
304 correlation_id: row.get("req_correlation_id"),
305 timestamp: row.get("req_timestamp"),
306 method: row.get("method"),
307 uri: row.get("uri"),
308 headers: row.get("req_headers"),
309 body: request_body,
310 created_at: row.get("req_created_at"),
311 };
312
313 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
314 res_id
315 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
316 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
317 let res_body_parsed = row
318 .try_get::<Option<bool>, _>("res_body_parsed")
319 .unwrap_or(Some(false));
320
321 let response_body = match res_body {
322 Some(json_value) => {
323 if res_body_parsed == Some(true) {
324 Some(Ok(serde_json::from_value::<TRes>(json_value)
326 .map_err(PostgresHandlerError::Json)?))
327 } else {
328 if let Value::String(utf8_str) = json_value {
330 Some(Err(Bytes::from(utf8_str.into_bytes())))
331 } else {
332 return Err(PostgresHandlerError::Json(Error::custom(
333 "Invalid body format",
334 )));
335 }
336 }
337 }
338 None => None,
339 };
340
341 Ok(HttpResponse {
342 id: row.get("res_id"),
343 instance_id: row.get("res_instance_id"),
344 correlation_id: row.get("res_correlation_id"),
345 timestamp: row.get("res_timestamp"),
346 status_code: row.get("status_code"),
347 headers: row.get("res_headers"),
348 body: response_body,
349 duration_to_first_byte_ms: row.get("duration_to_first_byte_ms"),
350 duration_ms: row.get("duration_ms"),
351 created_at: row.get("res_created_at"),
352 })
353 })
354 .transpose()?
355 } else {
356 None
357 };
358
359 pairs.push(RequestResponsePair { request, response });
360 }
361
362 Ok(pairs)
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use chrono::DateTime;
370 use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
371
372 fn validate_sql(sql: &str) -> Result<(), String> {
373 let dialect = PostgreSqlDialect {};
374 Parser::parse_sql(&dialect, sql)
375 .map_err(|e| format!("SQL parse error: {e}"))
376 .map(|_| ())
377 }
378
379 #[test]
380 fn test_default_filter_generates_valid_sql() {
381 let filter = RequestFilter::default();
382 let query = filter.build_query();
383 let sql = query.sql();
384
385 validate_sql(sql).unwrap();
386 assert!(sql.contains("ORDER BY r.timestamp ASC"));
387 assert!(!sql.contains("WHERE"));
388 }
389
390 #[test]
391 fn test_correlation_id_filter() {
392 let filter = RequestFilter {
393 correlation_id: Some(123),
394 ..Default::default()
395 };
396 let query = filter.build_query();
397 let sql = query.sql();
398
399 validate_sql(sql).unwrap();
400 assert!(sql.contains("WHERE r.correlation_id = $1"));
401 }
402
403 #[test]
404 fn test_method_filter() {
405 let filter = RequestFilter {
406 method: Some("POST".to_string()),
407 ..Default::default()
408 };
409 let query = filter.build_query();
410 let sql = query.sql();
411
412 validate_sql(sql).unwrap();
413 assert!(sql.contains("WHERE r.method = $1"));
414 }
415
416 #[test]
417 fn test_uri_pattern_filter() {
418 let filter = RequestFilter {
419 uri_pattern: Some("/api/%".to_string()),
420 ..Default::default()
421 };
422 let query = filter.build_query();
423 let sql = query.sql();
424
425 validate_sql(sql).unwrap();
426 assert!(sql.contains("WHERE r.uri ILIKE $1"));
427 }
428
429 #[test]
430 fn test_status_code_exact_filter() {
431 let filter = RequestFilter {
432 status_code: Some(404),
433 ..Default::default()
434 };
435 let query = filter.build_query();
436 let sql = query.sql();
437
438 validate_sql(sql).unwrap();
439 assert!(sql.contains("WHERE res.status_code = $1"));
440 }
441
442 #[test]
443 fn test_status_code_range_filters() {
444 let filter = RequestFilter {
445 status_code_min: Some(400),
446 status_code_max: Some(499),
447 ..Default::default()
448 };
449 let query = filter.build_query();
450 let sql = query.sql();
451
452 validate_sql(sql).unwrap();
453 assert!(sql.contains("WHERE res.status_code >= $1"));
454 assert!(sql.contains("AND res.status_code <= $2"));
455 }
456
457 #[test]
458 fn test_timestamp_filters() {
459 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
460 .unwrap()
461 .with_timezone(&Utc);
462 let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
463 .unwrap()
464 .with_timezone(&Utc);
465
466 let filter = RequestFilter {
467 timestamp_after: Some(after),
468 timestamp_before: Some(before),
469 ..Default::default()
470 };
471 let query = filter.build_query();
472 let sql = query.sql();
473
474 validate_sql(sql).unwrap();
475 assert!(sql.contains("WHERE r.timestamp >= $1"));
476 assert!(sql.contains("AND r.timestamp <= $2"));
477 }
478
479 #[test]
480 fn test_duration_filters() {
481 let filter = RequestFilter {
482 min_duration_ms: Some(100),
483 max_duration_ms: Some(5000),
484 ..Default::default()
485 };
486 let query = filter.build_query();
487 let sql = query.sql();
488
489 validate_sql(sql).unwrap();
490 assert!(sql.contains("WHERE res.duration_ms >= $1"));
491 assert!(sql.contains("AND res.duration_ms <= $2"));
492 }
493
494 #[test]
495 fn test_ordering_desc() {
496 let filter = RequestFilter {
497 order_by_timestamp_desc: true,
498 ..Default::default()
499 };
500 let query = filter.build_query();
501 let sql = query.sql();
502
503 validate_sql(sql).unwrap();
504 assert!(sql.contains("ORDER BY r.timestamp DESC"));
505 }
506
507 #[test]
508 fn test_ordering_asc() {
509 let filter = RequestFilter {
510 order_by_timestamp_desc: false,
511 ..Default::default()
512 };
513 let query = filter.build_query();
514 let sql = query.sql();
515
516 validate_sql(sql).unwrap();
517 assert!(sql.contains("ORDER BY r.timestamp ASC"));
518 }
519
520 #[test]
521 fn test_pagination() {
522 let filter = RequestFilter {
523 limit: Some(10),
524 offset: Some(20),
525 ..Default::default()
526 };
527 let query = filter.build_query();
528 let sql = query.sql();
529
530 validate_sql(sql).unwrap();
531 assert!(sql.contains("LIMIT $1"));
532 assert!(sql.contains("OFFSET $2"));
533 }
534
535 #[test]
536 fn test_multiple_filters_use_and() {
537 let filter = RequestFilter {
538 correlation_id: Some(123),
539 method: Some("POST".to_string()),
540 status_code: Some(200),
541 ..Default::default()
542 };
543 let query = filter.build_query();
544 let sql = query.sql();
545
546 validate_sql(sql).unwrap();
547 assert!(sql.contains("WHERE r.correlation_id = $1"));
548 assert!(sql.contains("AND r.method = $2"));
549 assert!(sql.contains("AND res.status_code = $3"));
550
551 assert_eq!(sql.matches("WHERE").count(), 1);
553 assert!(sql.matches("AND").count() >= 2);
554 }
555
556 #[test]
557 fn test_complex_filter_combination() {
558 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
559 .unwrap()
560 .with_timezone(&Utc);
561
562 let filter = RequestFilter {
563 correlation_id: Some(456),
564 method: Some("GET".to_string()),
565 uri_pattern: Some("/api/users%".to_string()),
566 status_code_min: Some(200),
567 status_code_max: Some(299),
568 timestamp_after: Some(after),
569 min_duration_ms: Some(50),
570 max_duration_ms: Some(1000),
571 limit: Some(100),
572 offset: Some(0),
573 order_by_timestamp_desc: true,
574 ..Default::default()
575 };
576 let query = filter.build_query();
577 let sql = query.sql();
578
579 validate_sql(sql).unwrap();
580
581 assert!(sql.contains("WHERE r.correlation_id = $1"));
583 assert!(sql.contains("AND r.method = $2"));
584 assert!(sql.contains("AND r.uri ILIKE $3"));
585 assert!(sql.contains("AND res.status_code >= $4"));
586 assert!(sql.contains("AND res.status_code <= $5"));
587 assert!(sql.contains("AND r.timestamp >= $6"));
588 assert!(sql.contains("AND res.duration_ms >= $7"));
589 assert!(sql.contains("AND res.duration_ms <= $8"));
590 assert!(sql.contains("ORDER BY r.timestamp DESC"));
591 assert!(sql.contains("LIMIT $9"));
592 assert!(sql.contains("OFFSET $10"));
593
594 assert_eq!(sql.matches("WHERE").count(), 1);
596 }
597
598 #[test]
599 fn test_no_filters_only_has_base_query() {
600 let filter = RequestFilter::default();
601 let query = filter.build_query();
602 let sql = query.sql();
603
604 validate_sql(sql).unwrap();
605
606 assert!(sql.contains("SELECT"));
608 assert!(sql.contains("FROM http_requests r"));
609 assert!(
610 sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
611 );
612
613 assert!(!sql.contains("WHERE"));
615
616 assert!(sql.contains("ORDER BY r.timestamp ASC"));
618 }
619}