outlet_postgres/
repository.rs

1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{PgPool, QueryBuilder, Row};
7
8use crate::error::PostgresHandlerError;
9
10#[derive(Debug)]
11pub struct HttpRequest<TReq> {
12    pub id: i64,
13    pub correlation_id: i64,
14    pub timestamp: DateTime<Utc>,
15    pub method: String,
16    pub uri: String,
17    pub headers: Value,
18    pub body: Option<Result<TReq, Bytes>>,
19    pub created_at: DateTime<Utc>,
20}
21
22#[derive(Debug)]
23pub struct HttpResponse<TRes> {
24    pub id: i64,
25    pub correlation_id: i64,
26    pub timestamp: DateTime<Utc>,
27    pub status_code: i32,
28    pub headers: Value,
29    pub body: Option<Result<TRes, Bytes>>,
30    pub duration_ms: i64,
31    pub created_at: DateTime<Utc>,
32}
33
34#[derive(Debug)]
35pub struct RequestResponsePair<TReq, TRes> {
36    pub request: HttpRequest<TReq>,
37    pub response: Option<HttpResponse<TRes>>,
38}
39
40#[derive(Debug, Default)]
41pub struct RequestFilter {
42    pub correlation_id: Option<i64>,
43    pub method: Option<String>,
44    pub uri_pattern: Option<String>,
45    pub status_code: Option<i32>,
46    pub status_code_min: Option<i32>,
47    pub status_code_max: Option<i32>,
48    pub timestamp_after: Option<DateTime<Utc>>,
49    pub timestamp_before: Option<DateTime<Utc>>,
50    pub min_duration_ms: Option<i64>,
51    pub max_duration_ms: Option<i64>,
52    pub body_parsed: Option<bool>,
53    pub limit: Option<i64>,
54    pub offset: Option<i64>,
55    pub order_by_timestamp_desc: bool,
56}
57
58impl RequestFilter {
59    pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
60        let mut query = QueryBuilder::new(
61            r#"
62            SELECT 
63                r.id as req_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp, 
64                r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
65                res.id as res_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
66                res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
67            FROM http_requests r
68            LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id
69            "#,
70        );
71
72        let mut where_added = false;
73
74        if let Some(correlation_id) = self.correlation_id {
75            query.push(" WHERE r.correlation_id = ");
76            query.push_bind(correlation_id);
77            where_added = true;
78        }
79
80        if let Some(method) = &self.method {
81            if where_added {
82                query.push(" AND ");
83            } else {
84                query.push(" WHERE ");
85                where_added = true;
86            }
87            query.push("r.method = ");
88            query.push_bind(method);
89        }
90
91        if let Some(uri_pattern) = &self.uri_pattern {
92            if where_added {
93                query.push(" AND ");
94            } else {
95                query.push(" WHERE ");
96                where_added = true;
97            }
98            query.push("r.uri ILIKE ");
99            query.push_bind(uri_pattern);
100        }
101
102        if let Some(status_code) = self.status_code {
103            if where_added {
104                query.push(" AND ");
105            } else {
106                query.push(" WHERE ");
107                where_added = true;
108            }
109            query.push("res.status_code = ");
110            query.push_bind(status_code);
111        }
112
113        if let Some(min_status) = self.status_code_min {
114            if where_added {
115                query.push(" AND ");
116            } else {
117                query.push(" WHERE ");
118                where_added = true;
119            }
120            query.push("res.status_code >= ");
121            query.push_bind(min_status);
122        }
123
124        if let Some(max_status) = self.status_code_max {
125            if where_added {
126                query.push(" AND ");
127            } else {
128                query.push(" WHERE ");
129                where_added = true;
130            }
131            query.push("res.status_code <= ");
132            query.push_bind(max_status);
133        }
134
135        if let Some(timestamp_after) = self.timestamp_after {
136            if where_added {
137                query.push(" AND ");
138            } else {
139                query.push(" WHERE ");
140                where_added = true;
141            }
142            query.push("r.timestamp >= ");
143            query.push_bind(timestamp_after);
144        }
145
146        if let Some(timestamp_before) = self.timestamp_before {
147            if where_added {
148                query.push(" AND ");
149            } else {
150                query.push(" WHERE ");
151                where_added = true;
152            }
153            query.push("r.timestamp <= ");
154            query.push_bind(timestamp_before);
155        }
156
157        if let Some(min_duration) = self.min_duration_ms {
158            if where_added {
159                query.push(" AND ");
160            } else {
161                query.push(" WHERE ");
162                where_added = true;
163            }
164            query.push("res.duration_ms >= ");
165            query.push_bind(min_duration);
166        }
167
168        if let Some(max_duration) = self.max_duration_ms {
169            if where_added {
170                query.push(" AND ");
171            } else {
172                query.push(" WHERE ");
173            }
174            query.push("res.duration_ms <= ");
175            query.push_bind(max_duration);
176        }
177
178        if self.order_by_timestamp_desc {
179            query.push(" ORDER BY r.timestamp DESC");
180        } else {
181            query.push(" ORDER BY r.timestamp ASC");
182        }
183
184        if let Some(limit) = self.limit {
185            query.push(" LIMIT ");
186            query.push_bind(limit);
187        }
188
189        if let Some(offset) = self.offset {
190            query.push(" OFFSET ");
191            query.push_bind(offset);
192        }
193
194        query
195    }
196}
197
198#[derive(Clone)]
199pub struct RequestRepository<TReq, TRes> {
200    pool: PgPool,
201    _phantom_req: std::marker::PhantomData<TReq>,
202    _phantom_res: std::marker::PhantomData<TRes>,
203}
204
205impl<TReq, TRes> RequestRepository<TReq, TRes>
206where
207    TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
208    TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
209{
210    pub fn new(pool: PgPool) -> Self {
211        Self {
212            pool,
213            _phantom_req: std::marker::PhantomData,
214            _phantom_res: std::marker::PhantomData,
215        }
216    }
217
218    pub async fn query(
219        &self,
220        filter: RequestFilter,
221    ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
222        let rows = filter
223            .build_query()
224            .build()
225            .fetch_all(&self.pool)
226            .await
227            .map_err(PostgresHandlerError::Query)?;
228
229        let mut pairs = Vec::new();
230        for row in rows {
231            let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
232            let req_body_parsed = row
233                .try_get::<Option<bool>, _>("req_body_parsed")
234                .unwrap_or(Some(false));
235
236            let request_body = match req_body {
237                Some(json_value) => {
238                    if req_body_parsed == Some(true) {
239                        // Body was successfully parsed as TReq when stored
240                        Some(Ok(serde_json::from_value::<TReq>(json_value)
241                            .map_err(PostgresHandlerError::Json)?))
242                    } else {
243                        // Body is stored as UTF-8 string (raw content that failed to parse)
244                        if let Value::String(utf8_str) = json_value {
245                            Some(Err(Bytes::from(utf8_str.into_bytes())))
246                        } else {
247                            return Err(PostgresHandlerError::Json(Error::custom(
248                                "Invalid body format",
249                            )));
250                        }
251                    }
252                }
253                None => None,
254            };
255
256            let request = HttpRequest {
257                id: row.get("req_id"),
258                correlation_id: row.get("req_correlation_id"),
259                timestamp: row.get("req_timestamp"),
260                method: row.get("method"),
261                uri: row.get("uri"),
262                headers: row.get("req_headers"),
263                body: request_body,
264                created_at: row.get("req_created_at"),
265            };
266
267            let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
268                res_id
269                    .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
270                        let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
271                        let res_body_parsed = row
272                            .try_get::<Option<bool>, _>("res_body_parsed")
273                            .unwrap_or(Some(false));
274
275                        let response_body = match res_body {
276                            Some(json_value) => {
277                                if res_body_parsed == Some(true) {
278                                    // Body was successfully parsed as TRes when stored
279                                    Some(Ok(serde_json::from_value::<TRes>(json_value)
280                                        .map_err(PostgresHandlerError::Json)?))
281                                } else {
282                                    // Body is stored as UTF-8 string (raw content that failed to parse)
283                                    if let Value::String(utf8_str) = json_value {
284                                        Some(Err(Bytes::from(utf8_str.into_bytes())))
285                                    } else {
286                                        return Err(PostgresHandlerError::Json(Error::custom(
287                                            "Invalid body format",
288                                        )));
289                                    }
290                                }
291                            }
292                            None => None,
293                        };
294
295                        Ok(HttpResponse {
296                            id: row.get("res_id"),
297                            correlation_id: row.get("res_correlation_id"),
298                            timestamp: row.get("res_timestamp"),
299                            status_code: row.get("status_code"),
300                            headers: row.get("res_headers"),
301                            body: response_body,
302                            duration_ms: row.get("duration_ms"),
303                            created_at: row.get("res_created_at"),
304                        })
305                    })
306                    .transpose()?
307            } else {
308                None
309            };
310
311            pairs.push(RequestResponsePair { request, response });
312        }
313
314        Ok(pairs)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use chrono::DateTime;
322    use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
323
324    fn validate_sql(sql: &str) -> Result<(), String> {
325        let dialect = PostgreSqlDialect {};
326        Parser::parse_sql(&dialect, sql)
327            .map_err(|e| format!("SQL parse error: {e}"))
328            .map(|_| ())
329    }
330
331    #[test]
332    fn test_default_filter_generates_valid_sql() {
333        let filter = RequestFilter::default();
334        let query = filter.build_query();
335        let sql = query.sql();
336
337        validate_sql(sql).unwrap();
338        assert!(sql.contains("ORDER BY r.timestamp ASC"));
339        assert!(!sql.contains("WHERE"));
340    }
341
342    #[test]
343    fn test_correlation_id_filter() {
344        let filter = RequestFilter {
345            correlation_id: Some(123),
346            ..Default::default()
347        };
348        let query = filter.build_query();
349        let sql = query.sql();
350
351        validate_sql(sql).unwrap();
352        assert!(sql.contains("WHERE r.correlation_id = $1"));
353    }
354
355    #[test]
356    fn test_method_filter() {
357        let filter = RequestFilter {
358            method: Some("POST".to_string()),
359            ..Default::default()
360        };
361        let query = filter.build_query();
362        let sql = query.sql();
363
364        validate_sql(sql).unwrap();
365        assert!(sql.contains("WHERE r.method = $1"));
366    }
367
368    #[test]
369    fn test_uri_pattern_filter() {
370        let filter = RequestFilter {
371            uri_pattern: Some("/api/%".to_string()),
372            ..Default::default()
373        };
374        let query = filter.build_query();
375        let sql = query.sql();
376
377        validate_sql(sql).unwrap();
378        assert!(sql.contains("WHERE r.uri ILIKE $1"));
379    }
380
381    #[test]
382    fn test_status_code_exact_filter() {
383        let filter = RequestFilter {
384            status_code: Some(404),
385            ..Default::default()
386        };
387        let query = filter.build_query();
388        let sql = query.sql();
389
390        validate_sql(sql).unwrap();
391        assert!(sql.contains("WHERE res.status_code = $1"));
392    }
393
394    #[test]
395    fn test_status_code_range_filters() {
396        let filter = RequestFilter {
397            status_code_min: Some(400),
398            status_code_max: Some(499),
399            ..Default::default()
400        };
401        let query = filter.build_query();
402        let sql = query.sql();
403
404        validate_sql(sql).unwrap();
405        assert!(sql.contains("WHERE res.status_code >= $1"));
406        assert!(sql.contains("AND res.status_code <= $2"));
407    }
408
409    #[test]
410    fn test_timestamp_filters() {
411        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
412            .unwrap()
413            .with_timezone(&Utc);
414        let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
415            .unwrap()
416            .with_timezone(&Utc);
417
418        let filter = RequestFilter {
419            timestamp_after: Some(after),
420            timestamp_before: Some(before),
421            ..Default::default()
422        };
423        let query = filter.build_query();
424        let sql = query.sql();
425
426        validate_sql(sql).unwrap();
427        assert!(sql.contains("WHERE r.timestamp >= $1"));
428        assert!(sql.contains("AND r.timestamp <= $2"));
429    }
430
431    #[test]
432    fn test_duration_filters() {
433        let filter = RequestFilter {
434            min_duration_ms: Some(100),
435            max_duration_ms: Some(5000),
436            ..Default::default()
437        };
438        let query = filter.build_query();
439        let sql = query.sql();
440
441        validate_sql(sql).unwrap();
442        assert!(sql.contains("WHERE res.duration_ms >= $1"));
443        assert!(sql.contains("AND res.duration_ms <= $2"));
444    }
445
446    #[test]
447    fn test_ordering_desc() {
448        let filter = RequestFilter {
449            order_by_timestamp_desc: true,
450            ..Default::default()
451        };
452        let query = filter.build_query();
453        let sql = query.sql();
454
455        validate_sql(sql).unwrap();
456        assert!(sql.contains("ORDER BY r.timestamp DESC"));
457    }
458
459    #[test]
460    fn test_ordering_asc() {
461        let filter = RequestFilter {
462            order_by_timestamp_desc: false,
463            ..Default::default()
464        };
465        let query = filter.build_query();
466        let sql = query.sql();
467
468        validate_sql(sql).unwrap();
469        assert!(sql.contains("ORDER BY r.timestamp ASC"));
470    }
471
472    #[test]
473    fn test_pagination() {
474        let filter = RequestFilter {
475            limit: Some(10),
476            offset: Some(20),
477            ..Default::default()
478        };
479        let query = filter.build_query();
480        let sql = query.sql();
481
482        validate_sql(sql).unwrap();
483        assert!(sql.contains("LIMIT $1"));
484        assert!(sql.contains("OFFSET $2"));
485    }
486
487    #[test]
488    fn test_multiple_filters_use_and() {
489        let filter = RequestFilter {
490            correlation_id: Some(123),
491            method: Some("POST".to_string()),
492            status_code: Some(200),
493            ..Default::default()
494        };
495        let query = filter.build_query();
496        let sql = query.sql();
497
498        validate_sql(sql).unwrap();
499        assert!(sql.contains("WHERE r.correlation_id = $1"));
500        assert!(sql.contains("AND r.method = $2"));
501        assert!(sql.contains("AND res.status_code = $3"));
502
503        // Should not have multiple WHERE clauses
504        assert_eq!(sql.matches("WHERE").count(), 1);
505        assert!(sql.matches("AND").count() >= 2);
506    }
507
508    #[test]
509    fn test_complex_filter_combination() {
510        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
511            .unwrap()
512            .with_timezone(&Utc);
513
514        let filter = RequestFilter {
515            correlation_id: Some(456),
516            method: Some("GET".to_string()),
517            uri_pattern: Some("/api/users%".to_string()),
518            status_code_min: Some(200),
519            status_code_max: Some(299),
520            timestamp_after: Some(after),
521            min_duration_ms: Some(50),
522            max_duration_ms: Some(1000),
523            limit: Some(100),
524            offset: Some(0),
525            order_by_timestamp_desc: true,
526            ..Default::default()
527        };
528        let query = filter.build_query();
529        let sql = query.sql();
530
531        validate_sql(sql).unwrap();
532
533        // Check all filters are present
534        assert!(sql.contains("WHERE r.correlation_id = $1"));
535        assert!(sql.contains("AND r.method = $2"));
536        assert!(sql.contains("AND r.uri ILIKE $3"));
537        assert!(sql.contains("AND res.status_code >= $4"));
538        assert!(sql.contains("AND res.status_code <= $5"));
539        assert!(sql.contains("AND r.timestamp >= $6"));
540        assert!(sql.contains("AND res.duration_ms >= $7"));
541        assert!(sql.contains("AND res.duration_ms <= $8"));
542        assert!(sql.contains("ORDER BY r.timestamp DESC"));
543        assert!(sql.contains("LIMIT $9"));
544        assert!(sql.contains("OFFSET $10"));
545
546        // Should have exactly one WHERE
547        assert_eq!(sql.matches("WHERE").count(), 1);
548    }
549
550    #[test]
551    fn test_no_filters_only_has_base_query() {
552        let filter = RequestFilter::default();
553        let query = filter.build_query();
554        let sql = query.sql();
555
556        validate_sql(sql).unwrap();
557
558        // Should contain base SELECT and JOIN
559        assert!(sql.contains("SELECT"));
560        assert!(sql.contains("FROM http_requests r"));
561        assert!(
562            sql.contains("LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id")
563        );
564
565        // Should not contain WHERE clause
566        assert!(!sql.contains("WHERE"));
567
568        // Should have default ordering
569        assert!(sql.contains("ORDER BY r.timestamp ASC"));
570    }
571}