outlet_postgres/
repository.rs

1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{PgPool, QueryBuilder, Row};
7use uuid::Uuid;
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13    pub id: i64,
14    pub instance_id: Uuid,
15    pub correlation_id: i64,
16    pub timestamp: DateTime<Utc>,
17    pub method: String,
18    pub uri: String,
19    pub headers: Value,
20    pub body: Option<Result<TReq, Bytes>>,
21    pub created_at: DateTime<Utc>,
22}
23
24#[derive(Debug)]
25pub struct HttpResponse<TRes> {
26    pub id: i64,
27    pub instance_id: Uuid,
28    pub correlation_id: i64,
29    pub timestamp: DateTime<Utc>,
30    pub status_code: i32,
31    pub headers: Value,
32    pub body: Option<Result<TRes, Bytes>>,
33    pub duration_to_first_byte_ms: i64,
34    pub duration_ms: i64,
35    pub created_at: DateTime<Utc>,
36}
37
38#[derive(Debug)]
39pub struct RequestResponsePair<TReq, TRes> {
40    pub request: HttpRequest<TReq>,
41    pub response: Option<HttpResponse<TRes>>,
42}
43
44#[derive(Debug, Default)]
45pub struct RequestFilter {
46    pub instance_id: Option<Uuid>,
47    pub correlation_id: Option<i64>,
48    pub method: Option<String>,
49    pub uri_pattern: Option<String>,
50    pub status_code: Option<i32>,
51    pub status_code_min: Option<i32>,
52    pub status_code_max: Option<i32>,
53    pub timestamp_after: Option<DateTime<Utc>>,
54    pub timestamp_before: Option<DateTime<Utc>>,
55    pub min_duration_ms: Option<i64>,
56    pub max_duration_ms: Option<i64>,
57    pub min_duration_to_first_byte_ms: Option<i64>,
58    pub max_duration_to_first_byte_ms: Option<i64>,
59    pub body_parsed: Option<bool>,
60    pub limit: Option<i64>,
61    pub offset: Option<i64>,
62    pub order_by_timestamp_desc: bool,
63}
64
65impl RequestFilter {
66    pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
67        let mut query = QueryBuilder::new(
68            r#"
69            SELECT 
70                r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp, 
71                r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
72                res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
73                res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_to_first_byte_ms, res.duration_ms, res.created_at as res_created_at
74            FROM http_requests r
75            LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
76            "#,
77        );
78
79        let mut where_added = false;
80
81        if let Some(instance_id) = self.instance_id {
82            query.push(" WHERE r.instance_id = ");
83            query.push_bind(instance_id);
84            where_added = true;
85        }
86
87        if let Some(correlation_id) = self.correlation_id {
88            if where_added {
89                query.push(" AND ");
90            } else {
91                query.push(" WHERE ");
92                where_added = true;
93            }
94            query.push("r.correlation_id = ");
95            query.push_bind(correlation_id);
96        }
97
98        if let Some(method) = &self.method {
99            if where_added {
100                query.push(" AND ");
101            } else {
102                query.push(" WHERE ");
103                where_added = true;
104            }
105            query.push("r.method = ");
106            query.push_bind(method);
107        }
108
109        if let Some(uri_pattern) = &self.uri_pattern {
110            if where_added {
111                query.push(" AND ");
112            } else {
113                query.push(" WHERE ");
114                where_added = true;
115            }
116            query.push("r.uri ILIKE ");
117            query.push_bind(uri_pattern);
118        }
119
120        if let Some(status_code) = self.status_code {
121            if where_added {
122                query.push(" AND ");
123            } else {
124                query.push(" WHERE ");
125                where_added = true;
126            }
127            query.push("res.status_code = ");
128            query.push_bind(status_code);
129        }
130
131        if let Some(min_status) = self.status_code_min {
132            if where_added {
133                query.push(" AND ");
134            } else {
135                query.push(" WHERE ");
136                where_added = true;
137            }
138            query.push("res.status_code >= ");
139            query.push_bind(min_status);
140        }
141
142        if let Some(max_status) = self.status_code_max {
143            if where_added {
144                query.push(" AND ");
145            } else {
146                query.push(" WHERE ");
147                where_added = true;
148            }
149            query.push("res.status_code <= ");
150            query.push_bind(max_status);
151        }
152
153        if let Some(timestamp_after) = self.timestamp_after {
154            if where_added {
155                query.push(" AND ");
156            } else {
157                query.push(" WHERE ");
158                where_added = true;
159            }
160            query.push("r.timestamp >= ");
161            query.push_bind(timestamp_after);
162        }
163
164        if let Some(timestamp_before) = self.timestamp_before {
165            if where_added {
166                query.push(" AND ");
167            } else {
168                query.push(" WHERE ");
169                where_added = true;
170            }
171            query.push("r.timestamp <= ");
172            query.push_bind(timestamp_before);
173        }
174
175        if let Some(min_duration) = self.min_duration_ms {
176            if where_added {
177                query.push(" AND ");
178            } else {
179                query.push(" WHERE ");
180                where_added = true;
181            }
182            query.push("res.duration_ms >= ");
183            query.push_bind(min_duration);
184        }
185
186        if let Some(max_duration) = self.max_duration_ms {
187            if where_added {
188                query.push(" AND ");
189            } else {
190                query.push(" WHERE ");
191                where_added = true;
192            }
193            query.push("res.duration_ms <= ");
194            query.push_bind(max_duration);
195        }
196
197        if let Some(min_duration_to_first_byte) = self.min_duration_to_first_byte_ms {
198            if where_added {
199                query.push(" AND ");
200            } else {
201                query.push(" WHERE ");
202                where_added = true;
203            }
204            query.push("res.duration_to_first_byte_ms >= ");
205            query.push_bind(min_duration_to_first_byte);
206        }
207
208        if let Some(max_duration_to_first_byte) = self.max_duration_to_first_byte_ms {
209            if where_added {
210                query.push(" AND ");
211            } else {
212                query.push(" WHERE ");
213            }
214            query.push("res.duration_to_first_byte_ms <= ");
215            query.push_bind(max_duration_to_first_byte);
216        }
217
218        if self.order_by_timestamp_desc {
219            query.push(" ORDER BY r.timestamp DESC");
220        } else {
221            query.push(" ORDER BY r.timestamp ASC");
222        }
223
224        if let Some(limit) = self.limit {
225            query.push(" LIMIT ");
226            query.push_bind(limit);
227        }
228
229        if let Some(offset) = self.offset {
230            query.push(" OFFSET ");
231            query.push_bind(offset);
232        }
233
234        query
235    }
236}
237
238#[derive(Clone)]
239pub struct RequestRepository<TReq, TRes> {
240    pool: PgPool,
241    _phantom_req: std::marker::PhantomData<TReq>,
242    _phantom_res: std::marker::PhantomData<TRes>,
243}
244
245impl<TReq, TRes> RequestRepository<TReq, TRes>
246where
247    TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
248    TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
249{
250    pub fn new(pool: PgPool) -> Self {
251        Self {
252            pool,
253            _phantom_req: std::marker::PhantomData,
254            _phantom_res: std::marker::PhantomData,
255        }
256    }
257
258    pub async fn query(
259        &self,
260        filter: RequestFilter,
261    ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
262        let rows = filter
263            .build_query()
264            .build()
265            .fetch_all(&self.pool)
266            .await
267            .map_err(PostgresHandlerError::Query)?;
268
269        let mut pairs = Vec::new();
270        for row in rows {
271            let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
272            let req_body_parsed = row
273                .try_get::<Option<bool>, _>("req_body_parsed")
274                .unwrap_or(Some(false));
275
276            let request_body = match req_body {
277                Some(json_value) => {
278                    if req_body_parsed == Some(true) {
279                        // Body was successfully parsed as TReq when stored
280                        Some(Ok(serde_json::from_value::<TReq>(json_value)
281                            .map_err(PostgresHandlerError::Json)?))
282                    } else {
283                        // Body is stored as UTF-8 string (raw content that failed to parse)
284                        if let Value::String(utf8_str) = json_value {
285                            Some(Err(Bytes::from(utf8_str.into_bytes())))
286                        } else {
287                            return Err(PostgresHandlerError::Json(Error::custom(
288                                "Invalid body format",
289                            )));
290                        }
291                    }
292                }
293                None => None,
294            };
295
296            let request = HttpRequest {
297                id: row.get("req_id"),
298                instance_id: row.get("req_instance_id"),
299                correlation_id: row.get("req_correlation_id"),
300                timestamp: row.get("req_timestamp"),
301                method: row.get("method"),
302                uri: row.get("uri"),
303                headers: row.get("req_headers"),
304                body: request_body,
305                created_at: row.get("req_created_at"),
306            };
307
308            let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
309                res_id
310                    .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
311                        let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
312                        let res_body_parsed = row
313                            .try_get::<Option<bool>, _>("res_body_parsed")
314                            .unwrap_or(Some(false));
315
316                        let response_body = match res_body {
317                            Some(json_value) => {
318                                if res_body_parsed == Some(true) {
319                                    // Body was successfully parsed as TRes when stored
320                                    Some(Ok(serde_json::from_value::<TRes>(json_value)
321                                        .map_err(PostgresHandlerError::Json)?))
322                                } else {
323                                    // Body is stored as UTF-8 string (raw content that failed to parse)
324                                    if let Value::String(utf8_str) = json_value {
325                                        Some(Err(Bytes::from(utf8_str.into_bytes())))
326                                    } else {
327                                        return Err(PostgresHandlerError::Json(Error::custom(
328                                            "Invalid body format",
329                                        )));
330                                    }
331                                }
332                            }
333                            None => None,
334                        };
335
336                        Ok(HttpResponse {
337                            id: row.get("res_id"),
338                            instance_id: row.get("res_instance_id"),
339                            correlation_id: row.get("res_correlation_id"),
340                            timestamp: row.get("res_timestamp"),
341                            status_code: row.get("status_code"),
342                            headers: row.get("res_headers"),
343                            body: response_body,
344                            duration_to_first_byte_ms: row.get("duration_to_first_byte_ms"),
345                            duration_ms: row.get("duration_ms"),
346                            created_at: row.get("res_created_at"),
347                        })
348                    })
349                    .transpose()?
350            } else {
351                None
352            };
353
354            pairs.push(RequestResponsePair { request, response });
355        }
356
357        Ok(pairs)
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use chrono::DateTime;
365    use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
366
367    fn validate_sql(sql: &str) -> Result<(), String> {
368        let dialect = PostgreSqlDialect {};
369        Parser::parse_sql(&dialect, sql)
370            .map_err(|e| format!("SQL parse error: {e}"))
371            .map(|_| ())
372    }
373
374    #[test]
375    fn test_default_filter_generates_valid_sql() {
376        let filter = RequestFilter::default();
377        let query = filter.build_query();
378        let sql = query.sql();
379
380        validate_sql(sql).unwrap();
381        assert!(sql.contains("ORDER BY r.timestamp ASC"));
382        assert!(!sql.contains("WHERE"));
383    }
384
385    #[test]
386    fn test_correlation_id_filter() {
387        let filter = RequestFilter {
388            correlation_id: Some(123),
389            ..Default::default()
390        };
391        let query = filter.build_query();
392        let sql = query.sql();
393
394        validate_sql(sql).unwrap();
395        assert!(sql.contains("WHERE r.correlation_id = $1"));
396    }
397
398    #[test]
399    fn test_method_filter() {
400        let filter = RequestFilter {
401            method: Some("POST".to_string()),
402            ..Default::default()
403        };
404        let query = filter.build_query();
405        let sql = query.sql();
406
407        validate_sql(sql).unwrap();
408        assert!(sql.contains("WHERE r.method = $1"));
409    }
410
411    #[test]
412    fn test_uri_pattern_filter() {
413        let filter = RequestFilter {
414            uri_pattern: Some("/api/%".to_string()),
415            ..Default::default()
416        };
417        let query = filter.build_query();
418        let sql = query.sql();
419
420        validate_sql(sql).unwrap();
421        assert!(sql.contains("WHERE r.uri ILIKE $1"));
422    }
423
424    #[test]
425    fn test_status_code_exact_filter() {
426        let filter = RequestFilter {
427            status_code: Some(404),
428            ..Default::default()
429        };
430        let query = filter.build_query();
431        let sql = query.sql();
432
433        validate_sql(sql).unwrap();
434        assert!(sql.contains("WHERE res.status_code = $1"));
435    }
436
437    #[test]
438    fn test_status_code_range_filters() {
439        let filter = RequestFilter {
440            status_code_min: Some(400),
441            status_code_max: Some(499),
442            ..Default::default()
443        };
444        let query = filter.build_query();
445        let sql = query.sql();
446
447        validate_sql(sql).unwrap();
448        assert!(sql.contains("WHERE res.status_code >= $1"));
449        assert!(sql.contains("AND res.status_code <= $2"));
450    }
451
452    #[test]
453    fn test_timestamp_filters() {
454        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
455            .unwrap()
456            .with_timezone(&Utc);
457        let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
458            .unwrap()
459            .with_timezone(&Utc);
460
461        let filter = RequestFilter {
462            timestamp_after: Some(after),
463            timestamp_before: Some(before),
464            ..Default::default()
465        };
466        let query = filter.build_query();
467        let sql = query.sql();
468
469        validate_sql(sql).unwrap();
470        assert!(sql.contains("WHERE r.timestamp >= $1"));
471        assert!(sql.contains("AND r.timestamp <= $2"));
472    }
473
474    #[test]
475    fn test_duration_filters() {
476        let filter = RequestFilter {
477            min_duration_ms: Some(100),
478            max_duration_ms: Some(5000),
479            ..Default::default()
480        };
481        let query = filter.build_query();
482        let sql = query.sql();
483
484        validate_sql(sql).unwrap();
485        assert!(sql.contains("WHERE res.duration_ms >= $1"));
486        assert!(sql.contains("AND res.duration_ms <= $2"));
487    }
488
489    #[test]
490    fn test_ordering_desc() {
491        let filter = RequestFilter {
492            order_by_timestamp_desc: true,
493            ..Default::default()
494        };
495        let query = filter.build_query();
496        let sql = query.sql();
497
498        validate_sql(sql).unwrap();
499        assert!(sql.contains("ORDER BY r.timestamp DESC"));
500    }
501
502    #[test]
503    fn test_ordering_asc() {
504        let filter = RequestFilter {
505            order_by_timestamp_desc: false,
506            ..Default::default()
507        };
508        let query = filter.build_query();
509        let sql = query.sql();
510
511        validate_sql(sql).unwrap();
512        assert!(sql.contains("ORDER BY r.timestamp ASC"));
513    }
514
515    #[test]
516    fn test_pagination() {
517        let filter = RequestFilter {
518            limit: Some(10),
519            offset: Some(20),
520            ..Default::default()
521        };
522        let query = filter.build_query();
523        let sql = query.sql();
524
525        validate_sql(sql).unwrap();
526        assert!(sql.contains("LIMIT $1"));
527        assert!(sql.contains("OFFSET $2"));
528    }
529
530    #[test]
531    fn test_multiple_filters_use_and() {
532        let filter = RequestFilter {
533            correlation_id: Some(123),
534            method: Some("POST".to_string()),
535            status_code: Some(200),
536            ..Default::default()
537        };
538        let query = filter.build_query();
539        let sql = query.sql();
540
541        validate_sql(sql).unwrap();
542        assert!(sql.contains("WHERE r.correlation_id = $1"));
543        assert!(sql.contains("AND r.method = $2"));
544        assert!(sql.contains("AND res.status_code = $3"));
545
546        // Should not have multiple WHERE clauses
547        assert_eq!(sql.matches("WHERE").count(), 1);
548        assert!(sql.matches("AND").count() >= 2);
549    }
550
551    #[test]
552    fn test_complex_filter_combination() {
553        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
554            .unwrap()
555            .with_timezone(&Utc);
556
557        let filter = RequestFilter {
558            correlation_id: Some(456),
559            method: Some("GET".to_string()),
560            uri_pattern: Some("/api/users%".to_string()),
561            status_code_min: Some(200),
562            status_code_max: Some(299),
563            timestamp_after: Some(after),
564            min_duration_ms: Some(50),
565            max_duration_ms: Some(1000),
566            limit: Some(100),
567            offset: Some(0),
568            order_by_timestamp_desc: true,
569            ..Default::default()
570        };
571        let query = filter.build_query();
572        let sql = query.sql();
573
574        validate_sql(sql).unwrap();
575
576        // Check all filters are present
577        assert!(sql.contains("WHERE r.correlation_id = $1"));
578        assert!(sql.contains("AND r.method = $2"));
579        assert!(sql.contains("AND r.uri ILIKE $3"));
580        assert!(sql.contains("AND res.status_code >= $4"));
581        assert!(sql.contains("AND res.status_code <= $5"));
582        assert!(sql.contains("AND r.timestamp >= $6"));
583        assert!(sql.contains("AND res.duration_ms >= $7"));
584        assert!(sql.contains("AND res.duration_ms <= $8"));
585        assert!(sql.contains("ORDER BY r.timestamp DESC"));
586        assert!(sql.contains("LIMIT $9"));
587        assert!(sql.contains("OFFSET $10"));
588
589        // Should have exactly one WHERE
590        assert_eq!(sql.matches("WHERE").count(), 1);
591    }
592
593    #[test]
594    fn test_no_filters_only_has_base_query() {
595        let filter = RequestFilter::default();
596        let query = filter.build_query();
597        let sql = query.sql();
598
599        validate_sql(sql).unwrap();
600
601        // Should contain base SELECT and JOIN
602        assert!(sql.contains("SELECT"));
603        assert!(sql.contains("FROM http_requests r"));
604        assert!(
605            sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
606        );
607
608        // Should not contain WHERE clause
609        assert!(!sql.contains("WHERE"));
610
611        // Should have default ordering
612        assert!(sql.contains("ORDER BY r.timestamp ASC"));
613    }
614}