outlet_postgres/
repository.rs

1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{PgPool, QueryBuilder, Row};
7use uuid::Uuid;
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13    pub id: i64,
14    pub instance_id: Uuid,
15    pub correlation_id: i64,
16    pub timestamp: DateTime<Utc>,
17    pub method: String,
18    pub uri: String,
19    pub headers: Value,
20    pub body: Option<Result<TReq, Bytes>>,
21    pub created_at: DateTime<Utc>,
22}
23
24#[derive(Debug)]
25pub struct HttpResponse<TRes> {
26    pub id: i64,
27    pub instance_id: Uuid,
28    pub correlation_id: i64,
29    pub timestamp: DateTime<Utc>,
30    pub status_code: i32,
31    pub headers: Value,
32    pub body: Option<Result<TRes, Bytes>>,
33    pub duration_ms: i64,
34    pub created_at: DateTime<Utc>,
35}
36
37#[derive(Debug)]
38pub struct RequestResponsePair<TReq, TRes> {
39    pub request: HttpRequest<TReq>,
40    pub response: Option<HttpResponse<TRes>>,
41}
42
43#[derive(Debug, Default)]
44pub struct RequestFilter {
45    pub instance_id: Option<Uuid>,
46    pub correlation_id: Option<i64>,
47    pub method: Option<String>,
48    pub uri_pattern: Option<String>,
49    pub status_code: Option<i32>,
50    pub status_code_min: Option<i32>,
51    pub status_code_max: Option<i32>,
52    pub timestamp_after: Option<DateTime<Utc>>,
53    pub timestamp_before: Option<DateTime<Utc>>,
54    pub min_duration_ms: Option<i64>,
55    pub max_duration_ms: Option<i64>,
56    pub body_parsed: Option<bool>,
57    pub limit: Option<i64>,
58    pub offset: Option<i64>,
59    pub order_by_timestamp_desc: bool,
60}
61
62impl RequestFilter {
63    pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
64        let mut query = QueryBuilder::new(
65            r#"
66            SELECT 
67                r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp, 
68                r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
69                res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
70                res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
71            FROM http_requests r
72            LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
73            "#,
74        );
75
76        let mut where_added = false;
77
78        if let Some(instance_id) = self.instance_id {
79            query.push(" WHERE r.instance_id = ");
80            query.push_bind(instance_id);
81            where_added = true;
82        }
83
84        if let Some(correlation_id) = self.correlation_id {
85            if where_added {
86                query.push(" AND ");
87            } else {
88                query.push(" WHERE ");
89                where_added = true;
90            }
91            query.push("r.correlation_id = ");
92            query.push_bind(correlation_id);
93        }
94
95        if let Some(method) = &self.method {
96            if where_added {
97                query.push(" AND ");
98            } else {
99                query.push(" WHERE ");
100                where_added = true;
101            }
102            query.push("r.method = ");
103            query.push_bind(method);
104        }
105
106        if let Some(uri_pattern) = &self.uri_pattern {
107            if where_added {
108                query.push(" AND ");
109            } else {
110                query.push(" WHERE ");
111                where_added = true;
112            }
113            query.push("r.uri ILIKE ");
114            query.push_bind(uri_pattern);
115        }
116
117        if let Some(status_code) = self.status_code {
118            if where_added {
119                query.push(" AND ");
120            } else {
121                query.push(" WHERE ");
122                where_added = true;
123            }
124            query.push("res.status_code = ");
125            query.push_bind(status_code);
126        }
127
128        if let Some(min_status) = self.status_code_min {
129            if where_added {
130                query.push(" AND ");
131            } else {
132                query.push(" WHERE ");
133                where_added = true;
134            }
135            query.push("res.status_code >= ");
136            query.push_bind(min_status);
137        }
138
139        if let Some(max_status) = self.status_code_max {
140            if where_added {
141                query.push(" AND ");
142            } else {
143                query.push(" WHERE ");
144                where_added = true;
145            }
146            query.push("res.status_code <= ");
147            query.push_bind(max_status);
148        }
149
150        if let Some(timestamp_after) = self.timestamp_after {
151            if where_added {
152                query.push(" AND ");
153            } else {
154                query.push(" WHERE ");
155                where_added = true;
156            }
157            query.push("r.timestamp >= ");
158            query.push_bind(timestamp_after);
159        }
160
161        if let Some(timestamp_before) = self.timestamp_before {
162            if where_added {
163                query.push(" AND ");
164            } else {
165                query.push(" WHERE ");
166                where_added = true;
167            }
168            query.push("r.timestamp <= ");
169            query.push_bind(timestamp_before);
170        }
171
172        if let Some(min_duration) = self.min_duration_ms {
173            if where_added {
174                query.push(" AND ");
175            } else {
176                query.push(" WHERE ");
177                where_added = true;
178            }
179            query.push("res.duration_ms >= ");
180            query.push_bind(min_duration);
181        }
182
183        if let Some(max_duration) = self.max_duration_ms {
184            if where_added {
185                query.push(" AND ");
186            } else {
187                query.push(" WHERE ");
188            }
189            query.push("res.duration_ms <= ");
190            query.push_bind(max_duration);
191        }
192
193        if self.order_by_timestamp_desc {
194            query.push(" ORDER BY r.timestamp DESC");
195        } else {
196            query.push(" ORDER BY r.timestamp ASC");
197        }
198
199        if let Some(limit) = self.limit {
200            query.push(" LIMIT ");
201            query.push_bind(limit);
202        }
203
204        if let Some(offset) = self.offset {
205            query.push(" OFFSET ");
206            query.push_bind(offset);
207        }
208
209        query
210    }
211}
212
213#[derive(Clone)]
214pub struct RequestRepository<TReq, TRes> {
215    pool: PgPool,
216    _phantom_req: std::marker::PhantomData<TReq>,
217    _phantom_res: std::marker::PhantomData<TRes>,
218}
219
220impl<TReq, TRes> RequestRepository<TReq, TRes>
221where
222    TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
223    TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
224{
225    pub fn new(pool: PgPool) -> Self {
226        Self {
227            pool,
228            _phantom_req: std::marker::PhantomData,
229            _phantom_res: std::marker::PhantomData,
230        }
231    }
232
233    pub async fn query(
234        &self,
235        filter: RequestFilter,
236    ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
237        let rows = filter
238            .build_query()
239            .build()
240            .fetch_all(&self.pool)
241            .await
242            .map_err(PostgresHandlerError::Query)?;
243
244        let mut pairs = Vec::new();
245        for row in rows {
246            let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
247            let req_body_parsed = row
248                .try_get::<Option<bool>, _>("req_body_parsed")
249                .unwrap_or(Some(false));
250
251            let request_body = match req_body {
252                Some(json_value) => {
253                    if req_body_parsed == Some(true) {
254                        // Body was successfully parsed as TReq when stored
255                        Some(Ok(serde_json::from_value::<TReq>(json_value)
256                            .map_err(PostgresHandlerError::Json)?))
257                    } else {
258                        // Body is stored as UTF-8 string (raw content that failed to parse)
259                        if let Value::String(utf8_str) = json_value {
260                            Some(Err(Bytes::from(utf8_str.into_bytes())))
261                        } else {
262                            return Err(PostgresHandlerError::Json(Error::custom(
263                                "Invalid body format",
264                            )));
265                        }
266                    }
267                }
268                None => None,
269            };
270
271            let request = HttpRequest {
272                id: row.get("req_id"),
273                instance_id: row.get("req_instance_id"),
274                correlation_id: row.get("req_correlation_id"),
275                timestamp: row.get("req_timestamp"),
276                method: row.get("method"),
277                uri: row.get("uri"),
278                headers: row.get("req_headers"),
279                body: request_body,
280                created_at: row.get("req_created_at"),
281            };
282
283            let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
284                res_id
285                    .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
286                        let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
287                        let res_body_parsed = row
288                            .try_get::<Option<bool>, _>("res_body_parsed")
289                            .unwrap_or(Some(false));
290
291                        let response_body = match res_body {
292                            Some(json_value) => {
293                                if res_body_parsed == Some(true) {
294                                    // Body was successfully parsed as TRes when stored
295                                    Some(Ok(serde_json::from_value::<TRes>(json_value)
296                                        .map_err(PostgresHandlerError::Json)?))
297                                } else {
298                                    // Body is stored as UTF-8 string (raw content that failed to parse)
299                                    if let Value::String(utf8_str) = json_value {
300                                        Some(Err(Bytes::from(utf8_str.into_bytes())))
301                                    } else {
302                                        return Err(PostgresHandlerError::Json(Error::custom(
303                                            "Invalid body format",
304                                        )));
305                                    }
306                                }
307                            }
308                            None => None,
309                        };
310
311                        Ok(HttpResponse {
312                            id: row.get("res_id"),
313                            instance_id: row.get("res_instance_id"),
314                            correlation_id: row.get("res_correlation_id"),
315                            timestamp: row.get("res_timestamp"),
316                            status_code: row.get("status_code"),
317                            headers: row.get("res_headers"),
318                            body: response_body,
319                            duration_ms: row.get("duration_ms"),
320                            created_at: row.get("res_created_at"),
321                        })
322                    })
323                    .transpose()?
324            } else {
325                None
326            };
327
328            pairs.push(RequestResponsePair { request, response });
329        }
330
331        Ok(pairs)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use chrono::DateTime;
339    use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
340
341    fn validate_sql(sql: &str) -> Result<(), String> {
342        let dialect = PostgreSqlDialect {};
343        Parser::parse_sql(&dialect, sql)
344            .map_err(|e| format!("SQL parse error: {e}"))
345            .map(|_| ())
346    }
347
348    #[test]
349    fn test_default_filter_generates_valid_sql() {
350        let filter = RequestFilter::default();
351        let query = filter.build_query();
352        let sql = query.sql();
353
354        validate_sql(sql).unwrap();
355        assert!(sql.contains("ORDER BY r.timestamp ASC"));
356        assert!(!sql.contains("WHERE"));
357    }
358
359    #[test]
360    fn test_correlation_id_filter() {
361        let filter = RequestFilter {
362            correlation_id: Some(123),
363            ..Default::default()
364        };
365        let query = filter.build_query();
366        let sql = query.sql();
367
368        validate_sql(sql).unwrap();
369        assert!(sql.contains("WHERE r.correlation_id = $1"));
370    }
371
372    #[test]
373    fn test_method_filter() {
374        let filter = RequestFilter {
375            method: Some("POST".to_string()),
376            ..Default::default()
377        };
378        let query = filter.build_query();
379        let sql = query.sql();
380
381        validate_sql(sql).unwrap();
382        assert!(sql.contains("WHERE r.method = $1"));
383    }
384
385    #[test]
386    fn test_uri_pattern_filter() {
387        let filter = RequestFilter {
388            uri_pattern: Some("/api/%".to_string()),
389            ..Default::default()
390        };
391        let query = filter.build_query();
392        let sql = query.sql();
393
394        validate_sql(sql).unwrap();
395        assert!(sql.contains("WHERE r.uri ILIKE $1"));
396    }
397
398    #[test]
399    fn test_status_code_exact_filter() {
400        let filter = RequestFilter {
401            status_code: Some(404),
402            ..Default::default()
403        };
404        let query = filter.build_query();
405        let sql = query.sql();
406
407        validate_sql(sql).unwrap();
408        assert!(sql.contains("WHERE res.status_code = $1"));
409    }
410
411    #[test]
412    fn test_status_code_range_filters() {
413        let filter = RequestFilter {
414            status_code_min: Some(400),
415            status_code_max: Some(499),
416            ..Default::default()
417        };
418        let query = filter.build_query();
419        let sql = query.sql();
420
421        validate_sql(sql).unwrap();
422        assert!(sql.contains("WHERE res.status_code >= $1"));
423        assert!(sql.contains("AND res.status_code <= $2"));
424    }
425
426    #[test]
427    fn test_timestamp_filters() {
428        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
429            .unwrap()
430            .with_timezone(&Utc);
431        let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
432            .unwrap()
433            .with_timezone(&Utc);
434
435        let filter = RequestFilter {
436            timestamp_after: Some(after),
437            timestamp_before: Some(before),
438            ..Default::default()
439        };
440        let query = filter.build_query();
441        let sql = query.sql();
442
443        validate_sql(sql).unwrap();
444        assert!(sql.contains("WHERE r.timestamp >= $1"));
445        assert!(sql.contains("AND r.timestamp <= $2"));
446    }
447
448    #[test]
449    fn test_duration_filters() {
450        let filter = RequestFilter {
451            min_duration_ms: Some(100),
452            max_duration_ms: Some(5000),
453            ..Default::default()
454        };
455        let query = filter.build_query();
456        let sql = query.sql();
457
458        validate_sql(sql).unwrap();
459        assert!(sql.contains("WHERE res.duration_ms >= $1"));
460        assert!(sql.contains("AND res.duration_ms <= $2"));
461    }
462
463    #[test]
464    fn test_ordering_desc() {
465        let filter = RequestFilter {
466            order_by_timestamp_desc: true,
467            ..Default::default()
468        };
469        let query = filter.build_query();
470        let sql = query.sql();
471
472        validate_sql(sql).unwrap();
473        assert!(sql.contains("ORDER BY r.timestamp DESC"));
474    }
475
476    #[test]
477    fn test_ordering_asc() {
478        let filter = RequestFilter {
479            order_by_timestamp_desc: false,
480            ..Default::default()
481        };
482        let query = filter.build_query();
483        let sql = query.sql();
484
485        validate_sql(sql).unwrap();
486        assert!(sql.contains("ORDER BY r.timestamp ASC"));
487    }
488
489    #[test]
490    fn test_pagination() {
491        let filter = RequestFilter {
492            limit: Some(10),
493            offset: Some(20),
494            ..Default::default()
495        };
496        let query = filter.build_query();
497        let sql = query.sql();
498
499        validate_sql(sql).unwrap();
500        assert!(sql.contains("LIMIT $1"));
501        assert!(sql.contains("OFFSET $2"));
502    }
503
504    #[test]
505    fn test_multiple_filters_use_and() {
506        let filter = RequestFilter {
507            correlation_id: Some(123),
508            method: Some("POST".to_string()),
509            status_code: Some(200),
510            ..Default::default()
511        };
512        let query = filter.build_query();
513        let sql = query.sql();
514
515        validate_sql(sql).unwrap();
516        assert!(sql.contains("WHERE r.correlation_id = $1"));
517        assert!(sql.contains("AND r.method = $2"));
518        assert!(sql.contains("AND res.status_code = $3"));
519
520        // Should not have multiple WHERE clauses
521        assert_eq!(sql.matches("WHERE").count(), 1);
522        assert!(sql.matches("AND").count() >= 2);
523    }
524
525    #[test]
526    fn test_complex_filter_combination() {
527        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
528            .unwrap()
529            .with_timezone(&Utc);
530
531        let filter = RequestFilter {
532            correlation_id: Some(456),
533            method: Some("GET".to_string()),
534            uri_pattern: Some("/api/users%".to_string()),
535            status_code_min: Some(200),
536            status_code_max: Some(299),
537            timestamp_after: Some(after),
538            min_duration_ms: Some(50),
539            max_duration_ms: Some(1000),
540            limit: Some(100),
541            offset: Some(0),
542            order_by_timestamp_desc: true,
543            ..Default::default()
544        };
545        let query = filter.build_query();
546        let sql = query.sql();
547
548        validate_sql(sql).unwrap();
549
550        // Check all filters are present
551        assert!(sql.contains("WHERE r.correlation_id = $1"));
552        assert!(sql.contains("AND r.method = $2"));
553        assert!(sql.contains("AND r.uri ILIKE $3"));
554        assert!(sql.contains("AND res.status_code >= $4"));
555        assert!(sql.contains("AND res.status_code <= $5"));
556        assert!(sql.contains("AND r.timestamp >= $6"));
557        assert!(sql.contains("AND res.duration_ms >= $7"));
558        assert!(sql.contains("AND res.duration_ms <= $8"));
559        assert!(sql.contains("ORDER BY r.timestamp DESC"));
560        assert!(sql.contains("LIMIT $9"));
561        assert!(sql.contains("OFFSET $10"));
562
563        // Should have exactly one WHERE
564        assert_eq!(sql.matches("WHERE").count(), 1);
565    }
566
567    #[test]
568    fn test_no_filters_only_has_base_query() {
569        let filter = RequestFilter::default();
570        let query = filter.build_query();
571        let sql = query.sql();
572
573        validate_sql(sql).unwrap();
574
575        // Should contain base SELECT and JOIN
576        assert!(sql.contains("SELECT"));
577        assert!(sql.contains("FROM http_requests r"));
578        assert!(
579            sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
580        );
581
582        // Should not contain WHERE clause
583        assert!(!sql.contains("WHERE"));
584
585        // Should have default ordering
586        assert!(sql.contains("ORDER BY r.timestamp ASC"));
587    }
588}