outlet_postgres/
repository.rs

1use base64::Engine;
2use bytes::Bytes;
3use chrono::{DateTime, Utc};
4use serde::de::Error;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use sqlx::{PgPool, QueryBuilder, Row};
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13    pub id: i64,
14    pub correlation_id: i64,
15    pub timestamp: DateTime<Utc>,
16    pub method: String,
17    pub uri: String,
18    pub headers: Value,
19    pub body: Option<Result<TReq, Bytes>>,
20    pub created_at: DateTime<Utc>,
21}
22
23#[derive(Debug)]
24pub struct HttpResponse<TRes> {
25    pub id: i64,
26    pub correlation_id: i64,
27    pub timestamp: DateTime<Utc>,
28    pub status_code: i32,
29    pub headers: Value,
30    pub body: Option<Result<TRes, Bytes>>,
31    pub duration_ms: i64,
32    pub created_at: DateTime<Utc>,
33}
34
35#[derive(Debug)]
36pub struct RequestResponsePair<TReq, TRes> {
37    pub request: HttpRequest<TReq>,
38    pub response: Option<HttpResponse<TRes>>,
39}
40
41#[derive(Debug, Default)]
42pub struct RequestFilter {
43    pub correlation_id: Option<i64>,
44    pub method: Option<String>,
45    pub uri_pattern: Option<String>,
46    pub status_code: Option<i32>,
47    pub status_code_min: Option<i32>,
48    pub status_code_max: Option<i32>,
49    pub timestamp_after: Option<DateTime<Utc>>,
50    pub timestamp_before: Option<DateTime<Utc>>,
51    pub min_duration_ms: Option<i64>,
52    pub max_duration_ms: Option<i64>,
53    pub body_parsed: Option<bool>,
54    pub limit: Option<i64>,
55    pub offset: Option<i64>,
56    pub order_by_timestamp_desc: bool,
57}
58
59impl RequestFilter {
60    pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
61        let mut query = QueryBuilder::new(
62            r#"
63            SELECT 
64                r.id as req_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp, 
65                r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
66                res.id as res_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
67                res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
68            FROM http_requests r
69            LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id
70            "#,
71        );
72
73        let mut where_added = false;
74
75        if let Some(correlation_id) = self.correlation_id {
76            query.push(" WHERE r.correlation_id = ");
77            query.push_bind(correlation_id);
78            where_added = true;
79        }
80
81        if let Some(method) = &self.method {
82            if where_added {
83                query.push(" AND ");
84            } else {
85                query.push(" WHERE ");
86                where_added = true;
87            }
88            query.push("r.method = ");
89            query.push_bind(method);
90        }
91
92        if let Some(uri_pattern) = &self.uri_pattern {
93            if where_added {
94                query.push(" AND ");
95            } else {
96                query.push(" WHERE ");
97                where_added = true;
98            }
99            query.push("r.uri ILIKE ");
100            query.push_bind(uri_pattern);
101        }
102
103        if let Some(status_code) = self.status_code {
104            if where_added {
105                query.push(" AND ");
106            } else {
107                query.push(" WHERE ");
108                where_added = true;
109            }
110            query.push("res.status_code = ");
111            query.push_bind(status_code);
112        }
113
114        if let Some(min_status) = self.status_code_min {
115            if where_added {
116                query.push(" AND ");
117            } else {
118                query.push(" WHERE ");
119                where_added = true;
120            }
121            query.push("res.status_code >= ");
122            query.push_bind(min_status);
123        }
124
125        if let Some(max_status) = self.status_code_max {
126            if where_added {
127                query.push(" AND ");
128            } else {
129                query.push(" WHERE ");
130                where_added = true;
131            }
132            query.push("res.status_code <= ");
133            query.push_bind(max_status);
134        }
135
136        if let Some(timestamp_after) = self.timestamp_after {
137            if where_added {
138                query.push(" AND ");
139            } else {
140                query.push(" WHERE ");
141                where_added = true;
142            }
143            query.push("r.timestamp >= ");
144            query.push_bind(timestamp_after);
145        }
146
147        if let Some(timestamp_before) = self.timestamp_before {
148            if where_added {
149                query.push(" AND ");
150            } else {
151                query.push(" WHERE ");
152                where_added = true;
153            }
154            query.push("r.timestamp <= ");
155            query.push_bind(timestamp_before);
156        }
157
158        if let Some(min_duration) = self.min_duration_ms {
159            if where_added {
160                query.push(" AND ");
161            } else {
162                query.push(" WHERE ");
163                where_added = true;
164            }
165            query.push("res.duration_ms >= ");
166            query.push_bind(min_duration);
167        }
168
169        if let Some(max_duration) = self.max_duration_ms {
170            if where_added {
171                query.push(" AND ");
172            } else {
173                query.push(" WHERE ");
174            }
175            query.push("res.duration_ms <= ");
176            query.push_bind(max_duration);
177        }
178
179        if self.order_by_timestamp_desc {
180            query.push(" ORDER BY r.timestamp DESC");
181        } else {
182            query.push(" ORDER BY r.timestamp ASC");
183        }
184
185        if let Some(limit) = self.limit {
186            query.push(" LIMIT ");
187            query.push_bind(limit);
188        }
189
190        if let Some(offset) = self.offset {
191            query.push(" OFFSET ");
192            query.push_bind(offset);
193        }
194
195        query
196    }
197}
198
199#[derive(Clone)]
200pub struct RequestRepository<TReq, TRes> {
201    pool: PgPool,
202    _phantom_req: std::marker::PhantomData<TReq>,
203    _phantom_res: std::marker::PhantomData<TRes>,
204}
205
206impl<TReq, TRes> RequestRepository<TReq, TRes>
207where
208    TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
209    TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
210{
211    pub fn new(pool: PgPool) -> Self {
212        Self {
213            pool,
214            _phantom_req: std::marker::PhantomData,
215            _phantom_res: std::marker::PhantomData,
216        }
217    }
218
219    pub async fn query(
220        &self,
221        filter: RequestFilter,
222    ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
223        let rows = filter
224            .build_query()
225            .build()
226            .fetch_all(&self.pool)
227            .await
228            .map_err(PostgresHandlerError::Query)?;
229
230        let mut pairs = Vec::new();
231        for row in rows {
232            let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
233            let req_body_parsed = row
234                .try_get::<Option<bool>, _>("req_body_parsed")
235                .unwrap_or(Some(false));
236
237            let request_body = match req_body {
238                Some(json_value) => {
239                    if req_body_parsed == Some(true) {
240                        // Body was successfully parsed as TReq when stored
241                        Some(Ok(serde_json::from_value::<TReq>(json_value)
242                            .map_err(PostgresHandlerError::Json)?))
243                    } else {
244                        // Body is stored as base64 string, decode it
245                        if let Value::String(base64_str) = json_value {
246                            let decoded_bytes = base64::engine::general_purpose::STANDARD
247                                .decode(&base64_str)
248                                .map_err(|_| {
249                                    PostgresHandlerError::Json(Error::custom(
250                                        "Failed to decode base64",
251                                    ))
252                                })?;
253                            Some(Err(Bytes::from(decoded_bytes)))
254                        } else {
255                            return Err(PostgresHandlerError::Json(Error::custom(
256                                "Invalid body format",
257                            )));
258                        }
259                    }
260                }
261                None => None,
262            };
263
264            let request = HttpRequest {
265                id: row.get("req_id"),
266                correlation_id: row.get("req_correlation_id"),
267                timestamp: row.get("req_timestamp"),
268                method: row.get("method"),
269                uri: row.get("uri"),
270                headers: row.get("req_headers"),
271                body: request_body,
272                created_at: row.get("req_created_at"),
273            };
274
275            let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
276                res_id
277                    .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
278                        let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
279                        let res_body_parsed = row
280                            .try_get::<Option<bool>, _>("res_body_parsed")
281                            .unwrap_or(Some(false));
282
283                        let response_body = match res_body {
284                            Some(json_value) => {
285                                if res_body_parsed == Some(true) {
286                                    // Body was successfully parsed as TRes when stored
287                                    Some(Ok(serde_json::from_value::<TRes>(json_value)
288                                        .map_err(PostgresHandlerError::Json)?))
289                                } else {
290                                    // Body is stored as base64 string, decode it
291                                    if let Value::String(base64_str) = json_value {
292                                        let decoded_bytes =
293                                            base64::engine::general_purpose::STANDARD
294                                                .decode(&base64_str)
295                                                .map_err(|_| {
296                                                    PostgresHandlerError::Json(Error::custom(
297                                                        "Failed to decode base64",
298                                                    ))
299                                                })?;
300                                        Some(Err(Bytes::from(decoded_bytes)))
301                                    } else {
302                                        return Err(PostgresHandlerError::Json(Error::custom(
303                                            "Invalid body format",
304                                        )));
305                                    }
306                                }
307                            }
308                            None => None,
309                        };
310
311                        Ok(HttpResponse {
312                            id: row.get("res_id"),
313                            correlation_id: row.get("res_correlation_id"),
314                            timestamp: row.get("res_timestamp"),
315                            status_code: row.get("status_code"),
316                            headers: row.get("res_headers"),
317                            body: response_body,
318                            duration_ms: row.get("duration_ms"),
319                            created_at: row.get("res_created_at"),
320                        })
321                    })
322                    .transpose()?
323            } else {
324                None
325            };
326
327            pairs.push(RequestResponsePair { request, response });
328        }
329
330        Ok(pairs)
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use chrono::DateTime;
338    use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
339
340    fn validate_sql(sql: &str) -> Result<(), String> {
341        let dialect = PostgreSqlDialect {};
342        Parser::parse_sql(&dialect, sql)
343            .map_err(|e| format!("SQL parse error: {}", e))
344            .map(|_| ())
345    }
346
347    #[test]
348    fn test_default_filter_generates_valid_sql() {
349        let filter = RequestFilter::default();
350        let query = filter.build_query();
351        let sql = query.sql();
352
353        validate_sql(sql).unwrap();
354        assert!(sql.contains("ORDER BY r.timestamp ASC"));
355        assert!(!sql.contains("WHERE"));
356    }
357
358    #[test]
359    fn test_correlation_id_filter() {
360        let filter = RequestFilter {
361            correlation_id: Some(123),
362            ..Default::default()
363        };
364        let query = filter.build_query();
365        let sql = query.sql();
366
367        validate_sql(sql).unwrap();
368        assert!(sql.contains("WHERE r.correlation_id = $1"));
369    }
370
371    #[test]
372    fn test_method_filter() {
373        let filter = RequestFilter {
374            method: Some("POST".to_string()),
375            ..Default::default()
376        };
377        let query = filter.build_query();
378        let sql = query.sql();
379
380        validate_sql(sql).unwrap();
381        assert!(sql.contains("WHERE r.method = $1"));
382    }
383
384    #[test]
385    fn test_uri_pattern_filter() {
386        let filter = RequestFilter {
387            uri_pattern: Some("/api/%".to_string()),
388            ..Default::default()
389        };
390        let query = filter.build_query();
391        let sql = query.sql();
392
393        validate_sql(sql).unwrap();
394        assert!(sql.contains("WHERE r.uri ILIKE $1"));
395    }
396
397    #[test]
398    fn test_status_code_exact_filter() {
399        let filter = RequestFilter {
400            status_code: Some(404),
401            ..Default::default()
402        };
403        let query = filter.build_query();
404        let sql = query.sql();
405
406        validate_sql(sql).unwrap();
407        assert!(sql.contains("WHERE res.status_code = $1"));
408    }
409
410    #[test]
411    fn test_status_code_range_filters() {
412        let filter = RequestFilter {
413            status_code_min: Some(400),
414            status_code_max: Some(499),
415            ..Default::default()
416        };
417        let query = filter.build_query();
418        let sql = query.sql();
419
420        validate_sql(sql).unwrap();
421        assert!(sql.contains("WHERE res.status_code >= $1"));
422        assert!(sql.contains("AND res.status_code <= $2"));
423    }
424
425    #[test]
426    fn test_timestamp_filters() {
427        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
428            .unwrap()
429            .with_timezone(&Utc);
430        let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
431            .unwrap()
432            .with_timezone(&Utc);
433
434        let filter = RequestFilter {
435            timestamp_after: Some(after),
436            timestamp_before: Some(before),
437            ..Default::default()
438        };
439        let query = filter.build_query();
440        let sql = query.sql();
441
442        validate_sql(sql).unwrap();
443        assert!(sql.contains("WHERE r.timestamp >= $1"));
444        assert!(sql.contains("AND r.timestamp <= $2"));
445    }
446
447    #[test]
448    fn test_duration_filters() {
449        let filter = RequestFilter {
450            min_duration_ms: Some(100),
451            max_duration_ms: Some(5000),
452            ..Default::default()
453        };
454        let query = filter.build_query();
455        let sql = query.sql();
456
457        validate_sql(sql).unwrap();
458        assert!(sql.contains("WHERE res.duration_ms >= $1"));
459        assert!(sql.contains("AND res.duration_ms <= $2"));
460    }
461
462    #[test]
463    fn test_ordering_desc() {
464        let filter = RequestFilter {
465            order_by_timestamp_desc: true,
466            ..Default::default()
467        };
468        let query = filter.build_query();
469        let sql = query.sql();
470
471        validate_sql(sql).unwrap();
472        assert!(sql.contains("ORDER BY r.timestamp DESC"));
473    }
474
475    #[test]
476    fn test_ordering_asc() {
477        let filter = RequestFilter {
478            order_by_timestamp_desc: false,
479            ..Default::default()
480        };
481        let query = filter.build_query();
482        let sql = query.sql();
483
484        validate_sql(sql).unwrap();
485        assert!(sql.contains("ORDER BY r.timestamp ASC"));
486    }
487
488    #[test]
489    fn test_pagination() {
490        let filter = RequestFilter {
491            limit: Some(10),
492            offset: Some(20),
493            ..Default::default()
494        };
495        let query = filter.build_query();
496        let sql = query.sql();
497
498        validate_sql(sql).unwrap();
499        assert!(sql.contains("LIMIT $1"));
500        assert!(sql.contains("OFFSET $2"));
501    }
502
503    #[test]
504    fn test_multiple_filters_use_and() {
505        let filter = RequestFilter {
506            correlation_id: Some(123),
507            method: Some("POST".to_string()),
508            status_code: Some(200),
509            ..Default::default()
510        };
511        let query = filter.build_query();
512        let sql = query.sql();
513
514        validate_sql(sql).unwrap();
515        assert!(sql.contains("WHERE r.correlation_id = $1"));
516        assert!(sql.contains("AND r.method = $2"));
517        assert!(sql.contains("AND res.status_code = $3"));
518
519        // Should not have multiple WHERE clauses
520        assert_eq!(sql.matches("WHERE").count(), 1);
521        assert!(sql.matches("AND").count() >= 2);
522    }
523
524    #[test]
525    fn test_complex_filter_combination() {
526        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
527            .unwrap()
528            .with_timezone(&Utc);
529
530        let filter = RequestFilter {
531            correlation_id: Some(456),
532            method: Some("GET".to_string()),
533            uri_pattern: Some("/api/users%".to_string()),
534            status_code_min: Some(200),
535            status_code_max: Some(299),
536            timestamp_after: Some(after),
537            min_duration_ms: Some(50),
538            max_duration_ms: Some(1000),
539            limit: Some(100),
540            offset: Some(0),
541            order_by_timestamp_desc: true,
542            ..Default::default()
543        };
544        let query = filter.build_query();
545        let sql = query.sql();
546
547        validate_sql(sql).unwrap();
548
549        // Check all filters are present
550        assert!(sql.contains("WHERE r.correlation_id = $1"));
551        assert!(sql.contains("AND r.method = $2"));
552        assert!(sql.contains("AND r.uri ILIKE $3"));
553        assert!(sql.contains("AND res.status_code >= $4"));
554        assert!(sql.contains("AND res.status_code <= $5"));
555        assert!(sql.contains("AND r.timestamp >= $6"));
556        assert!(sql.contains("AND res.duration_ms >= $7"));
557        assert!(sql.contains("AND res.duration_ms <= $8"));
558        assert!(sql.contains("ORDER BY r.timestamp DESC"));
559        assert!(sql.contains("LIMIT $9"));
560        assert!(sql.contains("OFFSET $10"));
561
562        // Should have exactly one WHERE
563        assert_eq!(sql.matches("WHERE").count(), 1);
564    }
565
566    #[test]
567    fn test_no_filters_only_has_base_query() {
568        let filter = RequestFilter::default();
569        let query = filter.build_query();
570        let sql = query.sql();
571
572        validate_sql(sql).unwrap();
573
574        // Should contain base SELECT and JOIN
575        assert!(sql.contains("SELECT"));
576        assert!(sql.contains("FROM http_requests r"));
577        assert!(
578            sql.contains("LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id")
579        );
580
581        // Should not contain WHERE clause
582        assert!(!sql.contains("WHERE"));
583
584        // Should have default ordering
585        assert!(sql.contains("ORDER BY r.timestamp ASC"));
586    }
587}