1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{QueryBuilder, Row};
7use sqlx_pool_router::PoolProvider;
8use uuid::Uuid;
9
10use crate::error::PostgresHandlerError;
11
12#[derive(Debug)]
13pub struct HttpRequest<TReq> {
14 pub id: i64,
15 pub instance_id: Uuid,
16 pub correlation_id: i64,
17 pub timestamp: DateTime<Utc>,
18 pub method: String,
19 pub uri: String,
20 pub headers: Value,
21 pub body: Option<Result<TReq, Bytes>>,
22 pub created_at: DateTime<Utc>,
23 pub trace_id: Option<String>,
24 pub span_id: Option<String>,
25}
26
27#[derive(Debug)]
28pub struct HttpResponse<TRes> {
29 pub id: i64,
30 pub instance_id: Uuid,
31 pub correlation_id: i64,
32 pub timestamp: DateTime<Utc>,
33 pub status_code: i32,
34 pub headers: Value,
35 pub body: Option<Result<TRes, Bytes>>,
36 pub duration_to_first_byte_ms: i64,
37 pub duration_ms: i64,
38 pub created_at: DateTime<Utc>,
39}
40
41#[derive(Debug)]
42pub struct RequestResponsePair<TReq, TRes> {
43 pub request: HttpRequest<TReq>,
44 pub response: Option<HttpResponse<TRes>>,
45}
46
47#[derive(Debug, Default)]
48pub struct RequestFilter {
49 pub instance_id: Option<Uuid>,
50 pub correlation_id: Option<i64>,
51 pub method: Option<String>,
52 pub uri_pattern: Option<String>,
53 pub status_code: Option<i32>,
54 pub status_code_min: Option<i32>,
55 pub status_code_max: Option<i32>,
56 pub timestamp_after: Option<DateTime<Utc>>,
57 pub timestamp_before: Option<DateTime<Utc>>,
58 pub min_duration_ms: Option<i64>,
59 pub max_duration_ms: Option<i64>,
60 pub min_duration_to_first_byte_ms: Option<i64>,
61 pub max_duration_to_first_byte_ms: Option<i64>,
62 pub body_parsed: Option<bool>,
63 pub limit: Option<i64>,
64 pub offset: Option<i64>,
65 pub order_by_timestamp_desc: bool,
66}
67
68impl RequestFilter {
69 pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
70 let mut query = QueryBuilder::new(
71 r#"
72 SELECT
73 r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
74 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
75 r.trace_id, r.span_id,
76 res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
77 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_to_first_byte_ms, res.duration_ms, res.created_at as res_created_at
78 FROM http_requests r
79 LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
80 "#,
81 );
82
83 let mut where_added = false;
84
85 if let Some(instance_id) = self.instance_id {
86 query.push(" WHERE r.instance_id = ");
87 query.push_bind(instance_id);
88 where_added = true;
89 }
90
91 if let Some(correlation_id) = self.correlation_id {
92 if where_added {
93 query.push(" AND ");
94 } else {
95 query.push(" WHERE ");
96 where_added = true;
97 }
98 query.push("r.correlation_id = ");
99 query.push_bind(correlation_id);
100 }
101
102 if let Some(method) = &self.method {
103 if where_added {
104 query.push(" AND ");
105 } else {
106 query.push(" WHERE ");
107 where_added = true;
108 }
109 query.push("r.method = ");
110 query.push_bind(method);
111 }
112
113 if let Some(uri_pattern) = &self.uri_pattern {
114 if where_added {
115 query.push(" AND ");
116 } else {
117 query.push(" WHERE ");
118 where_added = true;
119 }
120 query.push("r.uri ILIKE ");
121 query.push_bind(uri_pattern);
122 }
123
124 if let Some(status_code) = self.status_code {
125 if where_added {
126 query.push(" AND ");
127 } else {
128 query.push(" WHERE ");
129 where_added = true;
130 }
131 query.push("res.status_code = ");
132 query.push_bind(status_code);
133 }
134
135 if let Some(min_status) = self.status_code_min {
136 if where_added {
137 query.push(" AND ");
138 } else {
139 query.push(" WHERE ");
140 where_added = true;
141 }
142 query.push("res.status_code >= ");
143 query.push_bind(min_status);
144 }
145
146 if let Some(max_status) = self.status_code_max {
147 if where_added {
148 query.push(" AND ");
149 } else {
150 query.push(" WHERE ");
151 where_added = true;
152 }
153 query.push("res.status_code <= ");
154 query.push_bind(max_status);
155 }
156
157 if let Some(timestamp_after) = self.timestamp_after {
158 if where_added {
159 query.push(" AND ");
160 } else {
161 query.push(" WHERE ");
162 where_added = true;
163 }
164 query.push("r.timestamp >= ");
165 query.push_bind(timestamp_after);
166 }
167
168 if let Some(timestamp_before) = self.timestamp_before {
169 if where_added {
170 query.push(" AND ");
171 } else {
172 query.push(" WHERE ");
173 where_added = true;
174 }
175 query.push("r.timestamp <= ");
176 query.push_bind(timestamp_before);
177 }
178
179 if let Some(min_duration) = self.min_duration_ms {
180 if where_added {
181 query.push(" AND ");
182 } else {
183 query.push(" WHERE ");
184 where_added = true;
185 }
186 query.push("res.duration_ms >= ");
187 query.push_bind(min_duration);
188 }
189
190 if let Some(max_duration) = self.max_duration_ms {
191 if where_added {
192 query.push(" AND ");
193 } else {
194 query.push(" WHERE ");
195 where_added = true;
196 }
197 query.push("res.duration_ms <= ");
198 query.push_bind(max_duration);
199 }
200
201 if let Some(min_duration_to_first_byte) = self.min_duration_to_first_byte_ms {
202 if where_added {
203 query.push(" AND ");
204 } else {
205 query.push(" WHERE ");
206 where_added = true;
207 }
208 query.push("res.duration_to_first_byte_ms >= ");
209 query.push_bind(min_duration_to_first_byte);
210 }
211
212 if let Some(max_duration_to_first_byte) = self.max_duration_to_first_byte_ms {
213 if where_added {
214 query.push(" AND ");
215 } else {
216 query.push(" WHERE ");
217 }
218 query.push("res.duration_to_first_byte_ms <= ");
219 query.push_bind(max_duration_to_first_byte);
220 }
221
222 if self.order_by_timestamp_desc {
223 query.push(" ORDER BY r.timestamp DESC");
224 } else {
225 query.push(" ORDER BY r.timestamp ASC");
226 }
227
228 if let Some(limit) = self.limit {
229 query.push(" LIMIT ");
230 query.push_bind(limit);
231 }
232
233 if let Some(offset) = self.offset {
234 query.push(" OFFSET ");
235 query.push_bind(offset);
236 }
237
238 query
239 }
240}
241
242#[derive(Clone)]
243pub struct RequestRepository<P, TReq, TRes>
244where
245 P: PoolProvider,
246{
247 pool: P,
248 _phantom_req: std::marker::PhantomData<TReq>,
249 _phantom_res: std::marker::PhantomData<TRes>,
250}
251
252impl<P, TReq, TRes> RequestRepository<P, TReq, TRes>
253where
254 P: PoolProvider,
255 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
256 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
257{
258 pub fn new(pool: P) -> Self {
259 Self {
260 pool,
261 _phantom_req: std::marker::PhantomData,
262 _phantom_res: std::marker::PhantomData,
263 }
264 }
265
266 pub async fn query(
267 &self,
268 filter: RequestFilter,
269 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
270 let rows = filter
271 .build_query()
272 .build()
273 .fetch_all(self.pool.read())
274 .await
275 .map_err(PostgresHandlerError::Query)?;
276
277 let mut pairs = Vec::new();
278 for row in rows {
279 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
280 let req_body_parsed = row
281 .try_get::<Option<bool>, _>("req_body_parsed")
282 .unwrap_or(Some(false));
283
284 let request_body = match req_body {
285 Some(json_value) => {
286 if req_body_parsed == Some(true) {
287 Some(Ok(serde_json::from_value::<TReq>(json_value)
289 .map_err(PostgresHandlerError::Json)?))
290 } else {
291 if let Value::String(utf8_str) = json_value {
293 Some(Err(Bytes::from(utf8_str.into_bytes())))
294 } else {
295 return Err(PostgresHandlerError::Json(Error::custom(
296 "Invalid body format",
297 )));
298 }
299 }
300 }
301 None => None,
302 };
303
304 let request = HttpRequest {
305 id: row.get("req_id"),
306 instance_id: row.get("req_instance_id"),
307 correlation_id: row.get("req_correlation_id"),
308 timestamp: row.get("req_timestamp"),
309 method: row.get("method"),
310 uri: row.get("uri"),
311 headers: row.get("req_headers"),
312 body: request_body,
313 created_at: row.get("req_created_at"),
314 trace_id: row.try_get("trace_id").unwrap_or(None),
315 span_id: row.try_get("span_id").unwrap_or(None),
316 };
317
318 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
319 res_id
320 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
321 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
322 let res_body_parsed = row
323 .try_get::<Option<bool>, _>("res_body_parsed")
324 .unwrap_or(Some(false));
325
326 let response_body = match res_body {
327 Some(json_value) => {
328 if res_body_parsed == Some(true) {
329 Some(Ok(serde_json::from_value::<TRes>(json_value)
331 .map_err(PostgresHandlerError::Json)?))
332 } else {
333 if let Value::String(utf8_str) = json_value {
335 Some(Err(Bytes::from(utf8_str.into_bytes())))
336 } else {
337 return Err(PostgresHandlerError::Json(Error::custom(
338 "Invalid body format",
339 )));
340 }
341 }
342 }
343 None => None,
344 };
345
346 Ok(HttpResponse {
347 id: row.get("res_id"),
348 instance_id: row.get("res_instance_id"),
349 correlation_id: row.get("res_correlation_id"),
350 timestamp: row.get("res_timestamp"),
351 status_code: row.get("status_code"),
352 headers: row.get("res_headers"),
353 body: response_body,
354 duration_to_first_byte_ms: row.get("duration_to_first_byte_ms"),
355 duration_ms: row.get("duration_ms"),
356 created_at: row.get("res_created_at"),
357 })
358 })
359 .transpose()?
360 } else {
361 None
362 };
363
364 pairs.push(RequestResponsePair { request, response });
365 }
366
367 Ok(pairs)
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use chrono::DateTime;
375 use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
376
377 fn validate_sql(sql: &str) -> Result<(), String> {
378 let dialect = PostgreSqlDialect {};
379 Parser::parse_sql(&dialect, sql)
380 .map_err(|e| format!("SQL parse error: {e}"))
381 .map(|_| ())
382 }
383
384 #[test]
385 fn test_default_filter_generates_valid_sql() {
386 let filter = RequestFilter::default();
387 let query = filter.build_query();
388 let sql = query.sql();
389
390 validate_sql(sql).unwrap();
391 assert!(sql.contains("ORDER BY r.timestamp ASC"));
392 assert!(!sql.contains("WHERE"));
393 }
394
395 #[test]
396 fn test_correlation_id_filter() {
397 let filter = RequestFilter {
398 correlation_id: Some(123),
399 ..Default::default()
400 };
401 let query = filter.build_query();
402 let sql = query.sql();
403
404 validate_sql(sql).unwrap();
405 assert!(sql.contains("WHERE r.correlation_id = $1"));
406 }
407
408 #[test]
409 fn test_method_filter() {
410 let filter = RequestFilter {
411 method: Some("POST".to_string()),
412 ..Default::default()
413 };
414 let query = filter.build_query();
415 let sql = query.sql();
416
417 validate_sql(sql).unwrap();
418 assert!(sql.contains("WHERE r.method = $1"));
419 }
420
421 #[test]
422 fn test_uri_pattern_filter() {
423 let filter = RequestFilter {
424 uri_pattern: Some("/api/%".to_string()),
425 ..Default::default()
426 };
427 let query = filter.build_query();
428 let sql = query.sql();
429
430 validate_sql(sql).unwrap();
431 assert!(sql.contains("WHERE r.uri ILIKE $1"));
432 }
433
434 #[test]
435 fn test_status_code_exact_filter() {
436 let filter = RequestFilter {
437 status_code: Some(404),
438 ..Default::default()
439 };
440 let query = filter.build_query();
441 let sql = query.sql();
442
443 validate_sql(sql).unwrap();
444 assert!(sql.contains("WHERE res.status_code = $1"));
445 }
446
447 #[test]
448 fn test_status_code_range_filters() {
449 let filter = RequestFilter {
450 status_code_min: Some(400),
451 status_code_max: Some(499),
452 ..Default::default()
453 };
454 let query = filter.build_query();
455 let sql = query.sql();
456
457 validate_sql(sql).unwrap();
458 assert!(sql.contains("WHERE res.status_code >= $1"));
459 assert!(sql.contains("AND res.status_code <= $2"));
460 }
461
462 #[test]
463 fn test_timestamp_filters() {
464 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
465 .unwrap()
466 .with_timezone(&Utc);
467 let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
468 .unwrap()
469 .with_timezone(&Utc);
470
471 let filter = RequestFilter {
472 timestamp_after: Some(after),
473 timestamp_before: Some(before),
474 ..Default::default()
475 };
476 let query = filter.build_query();
477 let sql = query.sql();
478
479 validate_sql(sql).unwrap();
480 assert!(sql.contains("WHERE r.timestamp >= $1"));
481 assert!(sql.contains("AND r.timestamp <= $2"));
482 }
483
484 #[test]
485 fn test_duration_filters() {
486 let filter = RequestFilter {
487 min_duration_ms: Some(100),
488 max_duration_ms: Some(5000),
489 ..Default::default()
490 };
491 let query = filter.build_query();
492 let sql = query.sql();
493
494 validate_sql(sql).unwrap();
495 assert!(sql.contains("WHERE res.duration_ms >= $1"));
496 assert!(sql.contains("AND res.duration_ms <= $2"));
497 }
498
499 #[test]
500 fn test_ordering_desc() {
501 let filter = RequestFilter {
502 order_by_timestamp_desc: true,
503 ..Default::default()
504 };
505 let query = filter.build_query();
506 let sql = query.sql();
507
508 validate_sql(sql).unwrap();
509 assert!(sql.contains("ORDER BY r.timestamp DESC"));
510 }
511
512 #[test]
513 fn test_ordering_asc() {
514 let filter = RequestFilter {
515 order_by_timestamp_desc: false,
516 ..Default::default()
517 };
518 let query = filter.build_query();
519 let sql = query.sql();
520
521 validate_sql(sql).unwrap();
522 assert!(sql.contains("ORDER BY r.timestamp ASC"));
523 }
524
525 #[test]
526 fn test_pagination() {
527 let filter = RequestFilter {
528 limit: Some(10),
529 offset: Some(20),
530 ..Default::default()
531 };
532 let query = filter.build_query();
533 let sql = query.sql();
534
535 validate_sql(sql).unwrap();
536 assert!(sql.contains("LIMIT $1"));
537 assert!(sql.contains("OFFSET $2"));
538 }
539
540 #[test]
541 fn test_multiple_filters_use_and() {
542 let filter = RequestFilter {
543 correlation_id: Some(123),
544 method: Some("POST".to_string()),
545 status_code: Some(200),
546 ..Default::default()
547 };
548 let query = filter.build_query();
549 let sql = query.sql();
550
551 validate_sql(sql).unwrap();
552 assert!(sql.contains("WHERE r.correlation_id = $1"));
553 assert!(sql.contains("AND r.method = $2"));
554 assert!(sql.contains("AND res.status_code = $3"));
555
556 assert_eq!(sql.matches("WHERE").count(), 1);
558 assert!(sql.matches("AND").count() >= 2);
559 }
560
561 #[test]
562 fn test_complex_filter_combination() {
563 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
564 .unwrap()
565 .with_timezone(&Utc);
566
567 let filter = RequestFilter {
568 correlation_id: Some(456),
569 method: Some("GET".to_string()),
570 uri_pattern: Some("/api/users%".to_string()),
571 status_code_min: Some(200),
572 status_code_max: Some(299),
573 timestamp_after: Some(after),
574 min_duration_ms: Some(50),
575 max_duration_ms: Some(1000),
576 limit: Some(100),
577 offset: Some(0),
578 order_by_timestamp_desc: true,
579 ..Default::default()
580 };
581 let query = filter.build_query();
582 let sql = query.sql();
583
584 validate_sql(sql).unwrap();
585
586 assert!(sql.contains("WHERE r.correlation_id = $1"));
588 assert!(sql.contains("AND r.method = $2"));
589 assert!(sql.contains("AND r.uri ILIKE $3"));
590 assert!(sql.contains("AND res.status_code >= $4"));
591 assert!(sql.contains("AND res.status_code <= $5"));
592 assert!(sql.contains("AND r.timestamp >= $6"));
593 assert!(sql.contains("AND res.duration_ms >= $7"));
594 assert!(sql.contains("AND res.duration_ms <= $8"));
595 assert!(sql.contains("ORDER BY r.timestamp DESC"));
596 assert!(sql.contains("LIMIT $9"));
597 assert!(sql.contains("OFFSET $10"));
598
599 assert_eq!(sql.matches("WHERE").count(), 1);
601 }
602
603 #[test]
604 fn test_no_filters_only_has_base_query() {
605 let filter = RequestFilter::default();
606 let query = filter.build_query();
607 let sql = query.sql();
608
609 validate_sql(sql).unwrap();
610
611 assert!(sql.contains("SELECT"));
613 assert!(sql.contains("FROM http_requests r"));
614 assert!(
615 sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
616 );
617
618 assert!(!sql.contains("WHERE"));
620
621 assert!(sql.contains("ORDER BY r.timestamp ASC"));
623 }
624}