1use base64::Engine;
2use bytes::Bytes;
3use chrono::{DateTime, Utc};
4use serde::de::Error;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use sqlx::{PgPool, QueryBuilder, Row};
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13 pub id: i64,
14 pub correlation_id: i64,
15 pub timestamp: DateTime<Utc>,
16 pub method: String,
17 pub uri: String,
18 pub headers: Value,
19 pub body: Option<Result<TReq, Bytes>>,
20 pub created_at: DateTime<Utc>,
21}
22
23#[derive(Debug)]
24pub struct HttpResponse<TRes> {
25 pub id: i64,
26 pub correlation_id: i64,
27 pub timestamp: DateTime<Utc>,
28 pub status_code: i32,
29 pub headers: Value,
30 pub body: Option<Result<TRes, Bytes>>,
31 pub duration_ms: i64,
32 pub created_at: DateTime<Utc>,
33}
34
35#[derive(Debug)]
36pub struct RequestResponsePair<TReq, TRes> {
37 pub request: HttpRequest<TReq>,
38 pub response: Option<HttpResponse<TRes>>,
39}
40
41#[derive(Debug, Default)]
42pub struct RequestFilter {
43 pub correlation_id: Option<i64>,
44 pub method: Option<String>,
45 pub uri_pattern: Option<String>,
46 pub status_code: Option<i32>,
47 pub status_code_min: Option<i32>,
48 pub status_code_max: Option<i32>,
49 pub timestamp_after: Option<DateTime<Utc>>,
50 pub timestamp_before: Option<DateTime<Utc>>,
51 pub min_duration_ms: Option<i64>,
52 pub max_duration_ms: Option<i64>,
53 pub body_parsed: Option<bool>,
54 pub limit: Option<i64>,
55 pub offset: Option<i64>,
56 pub order_by_timestamp_desc: bool,
57}
58
59impl RequestFilter {
60 pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
61 let mut query = QueryBuilder::new(
62 r#"
63 SELECT
64 r.id as req_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
65 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
66 res.id as res_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
67 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
68 FROM http_requests r
69 LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id
70 "#,
71 );
72
73 let mut where_added = false;
74
75 if let Some(correlation_id) = self.correlation_id {
76 query.push(" WHERE r.correlation_id = ");
77 query.push_bind(correlation_id);
78 where_added = true;
79 }
80
81 if let Some(method) = &self.method {
82 if where_added {
83 query.push(" AND ");
84 } else {
85 query.push(" WHERE ");
86 where_added = true;
87 }
88 query.push("r.method = ");
89 query.push_bind(method);
90 }
91
92 if let Some(uri_pattern) = &self.uri_pattern {
93 if where_added {
94 query.push(" AND ");
95 } else {
96 query.push(" WHERE ");
97 where_added = true;
98 }
99 query.push("r.uri ILIKE ");
100 query.push_bind(uri_pattern);
101 }
102
103 if let Some(status_code) = self.status_code {
104 if where_added {
105 query.push(" AND ");
106 } else {
107 query.push(" WHERE ");
108 where_added = true;
109 }
110 query.push("res.status_code = ");
111 query.push_bind(status_code);
112 }
113
114 if let Some(min_status) = self.status_code_min {
115 if where_added {
116 query.push(" AND ");
117 } else {
118 query.push(" WHERE ");
119 where_added = true;
120 }
121 query.push("res.status_code >= ");
122 query.push_bind(min_status);
123 }
124
125 if let Some(max_status) = self.status_code_max {
126 if where_added {
127 query.push(" AND ");
128 } else {
129 query.push(" WHERE ");
130 where_added = true;
131 }
132 query.push("res.status_code <= ");
133 query.push_bind(max_status);
134 }
135
136 if let Some(timestamp_after) = self.timestamp_after {
137 if where_added {
138 query.push(" AND ");
139 } else {
140 query.push(" WHERE ");
141 where_added = true;
142 }
143 query.push("r.timestamp >= ");
144 query.push_bind(timestamp_after);
145 }
146
147 if let Some(timestamp_before) = self.timestamp_before {
148 if where_added {
149 query.push(" AND ");
150 } else {
151 query.push(" WHERE ");
152 where_added = true;
153 }
154 query.push("r.timestamp <= ");
155 query.push_bind(timestamp_before);
156 }
157
158 if let Some(min_duration) = self.min_duration_ms {
159 if where_added {
160 query.push(" AND ");
161 } else {
162 query.push(" WHERE ");
163 where_added = true;
164 }
165 query.push("res.duration_ms >= ");
166 query.push_bind(min_duration);
167 }
168
169 if let Some(max_duration) = self.max_duration_ms {
170 if where_added {
171 query.push(" AND ");
172 } else {
173 query.push(" WHERE ");
174 }
175 query.push("res.duration_ms <= ");
176 query.push_bind(max_duration);
177 }
178
179 if self.order_by_timestamp_desc {
180 query.push(" ORDER BY r.timestamp DESC");
181 } else {
182 query.push(" ORDER BY r.timestamp ASC");
183 }
184
185 if let Some(limit) = self.limit {
186 query.push(" LIMIT ");
187 query.push_bind(limit);
188 }
189
190 if let Some(offset) = self.offset {
191 query.push(" OFFSET ");
192 query.push_bind(offset);
193 }
194
195 query
196 }
197}
198
199#[derive(Clone)]
200pub struct RequestRepository<TReq, TRes> {
201 pool: PgPool,
202 _phantom_req: std::marker::PhantomData<TReq>,
203 _phantom_res: std::marker::PhantomData<TRes>,
204}
205
206impl<TReq, TRes> RequestRepository<TReq, TRes>
207where
208 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
209 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
210{
211 pub fn new(pool: PgPool) -> Self {
212 Self {
213 pool,
214 _phantom_req: std::marker::PhantomData,
215 _phantom_res: std::marker::PhantomData,
216 }
217 }
218
219 pub async fn query(
220 &self,
221 filter: RequestFilter,
222 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
223 let rows = filter
224 .build_query()
225 .build()
226 .fetch_all(&self.pool)
227 .await
228 .map_err(PostgresHandlerError::Query)?;
229
230 let mut pairs = Vec::new();
231 for row in rows {
232 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
233 let req_body_parsed = row
234 .try_get::<Option<bool>, _>("req_body_parsed")
235 .unwrap_or(Some(false));
236
237 let request_body = match req_body {
238 Some(json_value) => {
239 if req_body_parsed == Some(true) {
240 Some(Ok(serde_json::from_value::<TReq>(json_value)
242 .map_err(PostgresHandlerError::Json)?))
243 } else {
244 if let Value::String(base64_str) = json_value {
246 let decoded_bytes = base64::engine::general_purpose::STANDARD
247 .decode(&base64_str)
248 .map_err(|_| {
249 PostgresHandlerError::Json(Error::custom(
250 "Failed to decode base64",
251 ))
252 })?;
253 Some(Err(Bytes::from(decoded_bytes)))
254 } else {
255 return Err(PostgresHandlerError::Json(Error::custom(
256 "Invalid body format",
257 )));
258 }
259 }
260 }
261 None => None,
262 };
263
264 let request = HttpRequest {
265 id: row.get("req_id"),
266 correlation_id: row.get("req_correlation_id"),
267 timestamp: row.get("req_timestamp"),
268 method: row.get("method"),
269 uri: row.get("uri"),
270 headers: row.get("req_headers"),
271 body: request_body,
272 created_at: row.get("req_created_at"),
273 };
274
275 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
276 res_id
277 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
278 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
279 let res_body_parsed = row
280 .try_get::<Option<bool>, _>("res_body_parsed")
281 .unwrap_or(Some(false));
282
283 let response_body = match res_body {
284 Some(json_value) => {
285 if res_body_parsed == Some(true) {
286 Some(Ok(serde_json::from_value::<TRes>(json_value)
288 .map_err(PostgresHandlerError::Json)?))
289 } else {
290 if let Value::String(base64_str) = json_value {
292 let decoded_bytes =
293 base64::engine::general_purpose::STANDARD
294 .decode(&base64_str)
295 .map_err(|_| {
296 PostgresHandlerError::Json(Error::custom(
297 "Failed to decode base64",
298 ))
299 })?;
300 Some(Err(Bytes::from(decoded_bytes)))
301 } else {
302 return Err(PostgresHandlerError::Json(Error::custom(
303 "Invalid body format",
304 )));
305 }
306 }
307 }
308 None => None,
309 };
310
311 Ok(HttpResponse {
312 id: row.get("res_id"),
313 correlation_id: row.get("res_correlation_id"),
314 timestamp: row.get("res_timestamp"),
315 status_code: row.get("status_code"),
316 headers: row.get("res_headers"),
317 body: response_body,
318 duration_ms: row.get("duration_ms"),
319 created_at: row.get("res_created_at"),
320 })
321 })
322 .transpose()?
323 } else {
324 None
325 };
326
327 pairs.push(RequestResponsePair { request, response });
328 }
329
330 Ok(pairs)
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use chrono::DateTime;
338 use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
339
340 fn validate_sql(sql: &str) -> Result<(), String> {
341 let dialect = PostgreSqlDialect {};
342 Parser::parse_sql(&dialect, sql)
343 .map_err(|e| format!("SQL parse error: {}", e))
344 .map(|_| ())
345 }
346
347 #[test]
348 fn test_default_filter_generates_valid_sql() {
349 let filter = RequestFilter::default();
350 let query = filter.build_query();
351 let sql = query.sql();
352
353 validate_sql(sql).unwrap();
354 assert!(sql.contains("ORDER BY r.timestamp ASC"));
355 assert!(!sql.contains("WHERE"));
356 }
357
358 #[test]
359 fn test_correlation_id_filter() {
360 let filter = RequestFilter {
361 correlation_id: Some(123),
362 ..Default::default()
363 };
364 let query = filter.build_query();
365 let sql = query.sql();
366
367 validate_sql(sql).unwrap();
368 assert!(sql.contains("WHERE r.correlation_id = $1"));
369 }
370
371 #[test]
372 fn test_method_filter() {
373 let filter = RequestFilter {
374 method: Some("POST".to_string()),
375 ..Default::default()
376 };
377 let query = filter.build_query();
378 let sql = query.sql();
379
380 validate_sql(sql).unwrap();
381 assert!(sql.contains("WHERE r.method = $1"));
382 }
383
384 #[test]
385 fn test_uri_pattern_filter() {
386 let filter = RequestFilter {
387 uri_pattern: Some("/api/%".to_string()),
388 ..Default::default()
389 };
390 let query = filter.build_query();
391 let sql = query.sql();
392
393 validate_sql(sql).unwrap();
394 assert!(sql.contains("WHERE r.uri ILIKE $1"));
395 }
396
397 #[test]
398 fn test_status_code_exact_filter() {
399 let filter = RequestFilter {
400 status_code: Some(404),
401 ..Default::default()
402 };
403 let query = filter.build_query();
404 let sql = query.sql();
405
406 validate_sql(sql).unwrap();
407 assert!(sql.contains("WHERE res.status_code = $1"));
408 }
409
410 #[test]
411 fn test_status_code_range_filters() {
412 let filter = RequestFilter {
413 status_code_min: Some(400),
414 status_code_max: Some(499),
415 ..Default::default()
416 };
417 let query = filter.build_query();
418 let sql = query.sql();
419
420 validate_sql(sql).unwrap();
421 assert!(sql.contains("WHERE res.status_code >= $1"));
422 assert!(sql.contains("AND res.status_code <= $2"));
423 }
424
425 #[test]
426 fn test_timestamp_filters() {
427 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
428 .unwrap()
429 .with_timezone(&Utc);
430 let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
431 .unwrap()
432 .with_timezone(&Utc);
433
434 let filter = RequestFilter {
435 timestamp_after: Some(after),
436 timestamp_before: Some(before),
437 ..Default::default()
438 };
439 let query = filter.build_query();
440 let sql = query.sql();
441
442 validate_sql(sql).unwrap();
443 assert!(sql.contains("WHERE r.timestamp >= $1"));
444 assert!(sql.contains("AND r.timestamp <= $2"));
445 }
446
447 #[test]
448 fn test_duration_filters() {
449 let filter = RequestFilter {
450 min_duration_ms: Some(100),
451 max_duration_ms: Some(5000),
452 ..Default::default()
453 };
454 let query = filter.build_query();
455 let sql = query.sql();
456
457 validate_sql(sql).unwrap();
458 assert!(sql.contains("WHERE res.duration_ms >= $1"));
459 assert!(sql.contains("AND res.duration_ms <= $2"));
460 }
461
462 #[test]
463 fn test_ordering_desc() {
464 let filter = RequestFilter {
465 order_by_timestamp_desc: true,
466 ..Default::default()
467 };
468 let query = filter.build_query();
469 let sql = query.sql();
470
471 validate_sql(sql).unwrap();
472 assert!(sql.contains("ORDER BY r.timestamp DESC"));
473 }
474
475 #[test]
476 fn test_ordering_asc() {
477 let filter = RequestFilter {
478 order_by_timestamp_desc: false,
479 ..Default::default()
480 };
481 let query = filter.build_query();
482 let sql = query.sql();
483
484 validate_sql(sql).unwrap();
485 assert!(sql.contains("ORDER BY r.timestamp ASC"));
486 }
487
488 #[test]
489 fn test_pagination() {
490 let filter = RequestFilter {
491 limit: Some(10),
492 offset: Some(20),
493 ..Default::default()
494 };
495 let query = filter.build_query();
496 let sql = query.sql();
497
498 validate_sql(sql).unwrap();
499 assert!(sql.contains("LIMIT $1"));
500 assert!(sql.contains("OFFSET $2"));
501 }
502
503 #[test]
504 fn test_multiple_filters_use_and() {
505 let filter = RequestFilter {
506 correlation_id: Some(123),
507 method: Some("POST".to_string()),
508 status_code: Some(200),
509 ..Default::default()
510 };
511 let query = filter.build_query();
512 let sql = query.sql();
513
514 validate_sql(sql).unwrap();
515 assert!(sql.contains("WHERE r.correlation_id = $1"));
516 assert!(sql.contains("AND r.method = $2"));
517 assert!(sql.contains("AND res.status_code = $3"));
518
519 assert_eq!(sql.matches("WHERE").count(), 1);
521 assert!(sql.matches("AND").count() >= 2);
522 }
523
524 #[test]
525 fn test_complex_filter_combination() {
526 let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
527 .unwrap()
528 .with_timezone(&Utc);
529
530 let filter = RequestFilter {
531 correlation_id: Some(456),
532 method: Some("GET".to_string()),
533 uri_pattern: Some("/api/users%".to_string()),
534 status_code_min: Some(200),
535 status_code_max: Some(299),
536 timestamp_after: Some(after),
537 min_duration_ms: Some(50),
538 max_duration_ms: Some(1000),
539 limit: Some(100),
540 offset: Some(0),
541 order_by_timestamp_desc: true,
542 ..Default::default()
543 };
544 let query = filter.build_query();
545 let sql = query.sql();
546
547 validate_sql(sql).unwrap();
548
549 assert!(sql.contains("WHERE r.correlation_id = $1"));
551 assert!(sql.contains("AND r.method = $2"));
552 assert!(sql.contains("AND r.uri ILIKE $3"));
553 assert!(sql.contains("AND res.status_code >= $4"));
554 assert!(sql.contains("AND res.status_code <= $5"));
555 assert!(sql.contains("AND r.timestamp >= $6"));
556 assert!(sql.contains("AND res.duration_ms >= $7"));
557 assert!(sql.contains("AND res.duration_ms <= $8"));
558 assert!(sql.contains("ORDER BY r.timestamp DESC"));
559 assert!(sql.contains("LIMIT $9"));
560 assert!(sql.contains("OFFSET $10"));
561
562 assert_eq!(sql.matches("WHERE").count(), 1);
564 }
565
566 #[test]
567 fn test_no_filters_only_has_base_query() {
568 let filter = RequestFilter::default();
569 let query = filter.build_query();
570 let sql = query.sql();
571
572 validate_sql(sql).unwrap();
573
574 assert!(sql.contains("SELECT"));
576 assert!(sql.contains("FROM http_requests r"));
577 assert!(
578 sql.contains("LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id")
579 );
580
581 assert!(!sql.contains("WHERE"));
583
584 assert!(sql.contains("ORDER BY r.timestamp ASC"));
586 }
587}