outlet_postgres/
repository.rs

1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{QueryBuilder, Row};
7use sqlx_pool_router::PoolProvider;
8use uuid::Uuid;
9
10use crate::error::PostgresHandlerError;
11
12#[derive(Debug)]
13pub struct HttpRequest<TReq> {
14    pub id: i64,
15    pub instance_id: Uuid,
16    pub correlation_id: i64,
17    pub timestamp: DateTime<Utc>,
18    pub method: String,
19    pub uri: String,
20    pub headers: Value,
21    pub body: Option<Result<TReq, Bytes>>,
22    pub created_at: DateTime<Utc>,
23}
24
25#[derive(Debug)]
26pub struct HttpResponse<TRes> {
27    pub id: i64,
28    pub instance_id: Uuid,
29    pub correlation_id: i64,
30    pub timestamp: DateTime<Utc>,
31    pub status_code: i32,
32    pub headers: Value,
33    pub body: Option<Result<TRes, Bytes>>,
34    pub duration_to_first_byte_ms: i64,
35    pub duration_ms: i64,
36    pub created_at: DateTime<Utc>,
37}
38
39#[derive(Debug)]
40pub struct RequestResponsePair<TReq, TRes> {
41    pub request: HttpRequest<TReq>,
42    pub response: Option<HttpResponse<TRes>>,
43}
44
45#[derive(Debug, Default)]
46pub struct RequestFilter {
47    pub instance_id: Option<Uuid>,
48    pub correlation_id: Option<i64>,
49    pub method: Option<String>,
50    pub uri_pattern: Option<String>,
51    pub status_code: Option<i32>,
52    pub status_code_min: Option<i32>,
53    pub status_code_max: Option<i32>,
54    pub timestamp_after: Option<DateTime<Utc>>,
55    pub timestamp_before: Option<DateTime<Utc>>,
56    pub min_duration_ms: Option<i64>,
57    pub max_duration_ms: Option<i64>,
58    pub min_duration_to_first_byte_ms: Option<i64>,
59    pub max_duration_to_first_byte_ms: Option<i64>,
60    pub body_parsed: Option<bool>,
61    pub limit: Option<i64>,
62    pub offset: Option<i64>,
63    pub order_by_timestamp_desc: bool,
64}
65
66impl RequestFilter {
67    pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
68        let mut query = QueryBuilder::new(
69            r#"
70            SELECT 
71                r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp, 
72                r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
73                res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
74                res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_to_first_byte_ms, res.duration_ms, res.created_at as res_created_at
75            FROM http_requests r
76            LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
77            "#,
78        );
79
80        let mut where_added = false;
81
82        if let Some(instance_id) = self.instance_id {
83            query.push(" WHERE r.instance_id = ");
84            query.push_bind(instance_id);
85            where_added = true;
86        }
87
88        if let Some(correlation_id) = self.correlation_id {
89            if where_added {
90                query.push(" AND ");
91            } else {
92                query.push(" WHERE ");
93                where_added = true;
94            }
95            query.push("r.correlation_id = ");
96            query.push_bind(correlation_id);
97        }
98
99        if let Some(method) = &self.method {
100            if where_added {
101                query.push(" AND ");
102            } else {
103                query.push(" WHERE ");
104                where_added = true;
105            }
106            query.push("r.method = ");
107            query.push_bind(method);
108        }
109
110        if let Some(uri_pattern) = &self.uri_pattern {
111            if where_added {
112                query.push(" AND ");
113            } else {
114                query.push(" WHERE ");
115                where_added = true;
116            }
117            query.push("r.uri ILIKE ");
118            query.push_bind(uri_pattern);
119        }
120
121        if let Some(status_code) = self.status_code {
122            if where_added {
123                query.push(" AND ");
124            } else {
125                query.push(" WHERE ");
126                where_added = true;
127            }
128            query.push("res.status_code = ");
129            query.push_bind(status_code);
130        }
131
132        if let Some(min_status) = self.status_code_min {
133            if where_added {
134                query.push(" AND ");
135            } else {
136                query.push(" WHERE ");
137                where_added = true;
138            }
139            query.push("res.status_code >= ");
140            query.push_bind(min_status);
141        }
142
143        if let Some(max_status) = self.status_code_max {
144            if where_added {
145                query.push(" AND ");
146            } else {
147                query.push(" WHERE ");
148                where_added = true;
149            }
150            query.push("res.status_code <= ");
151            query.push_bind(max_status);
152        }
153
154        if let Some(timestamp_after) = self.timestamp_after {
155            if where_added {
156                query.push(" AND ");
157            } else {
158                query.push(" WHERE ");
159                where_added = true;
160            }
161            query.push("r.timestamp >= ");
162            query.push_bind(timestamp_after);
163        }
164
165        if let Some(timestamp_before) = self.timestamp_before {
166            if where_added {
167                query.push(" AND ");
168            } else {
169                query.push(" WHERE ");
170                where_added = true;
171            }
172            query.push("r.timestamp <= ");
173            query.push_bind(timestamp_before);
174        }
175
176        if let Some(min_duration) = self.min_duration_ms {
177            if where_added {
178                query.push(" AND ");
179            } else {
180                query.push(" WHERE ");
181                where_added = true;
182            }
183            query.push("res.duration_ms >= ");
184            query.push_bind(min_duration);
185        }
186
187        if let Some(max_duration) = self.max_duration_ms {
188            if where_added {
189                query.push(" AND ");
190            } else {
191                query.push(" WHERE ");
192                where_added = true;
193            }
194            query.push("res.duration_ms <= ");
195            query.push_bind(max_duration);
196        }
197
198        if let Some(min_duration_to_first_byte) = self.min_duration_to_first_byte_ms {
199            if where_added {
200                query.push(" AND ");
201            } else {
202                query.push(" WHERE ");
203                where_added = true;
204            }
205            query.push("res.duration_to_first_byte_ms >= ");
206            query.push_bind(min_duration_to_first_byte);
207        }
208
209        if let Some(max_duration_to_first_byte) = self.max_duration_to_first_byte_ms {
210            if where_added {
211                query.push(" AND ");
212            } else {
213                query.push(" WHERE ");
214            }
215            query.push("res.duration_to_first_byte_ms <= ");
216            query.push_bind(max_duration_to_first_byte);
217        }
218
219        if self.order_by_timestamp_desc {
220            query.push(" ORDER BY r.timestamp DESC");
221        } else {
222            query.push(" ORDER BY r.timestamp ASC");
223        }
224
225        if let Some(limit) = self.limit {
226            query.push(" LIMIT ");
227            query.push_bind(limit);
228        }
229
230        if let Some(offset) = self.offset {
231            query.push(" OFFSET ");
232            query.push_bind(offset);
233        }
234
235        query
236    }
237}
238
239#[derive(Clone)]
240pub struct RequestRepository<P, TReq, TRes>
241where
242    P: PoolProvider,
243{
244    pool: P,
245    _phantom_req: std::marker::PhantomData<TReq>,
246    _phantom_res: std::marker::PhantomData<TRes>,
247}
248
249impl<P, TReq, TRes> RequestRepository<P, TReq, TRes>
250where
251    P: PoolProvider,
252    TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
253    TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
254{
255    pub fn new(pool: P) -> Self {
256        Self {
257            pool,
258            _phantom_req: std::marker::PhantomData,
259            _phantom_res: std::marker::PhantomData,
260        }
261    }
262
263    pub async fn query(
264        &self,
265        filter: RequestFilter,
266    ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
267        let rows = filter
268            .build_query()
269            .build()
270            .fetch_all(self.pool.read())
271            .await
272            .map_err(PostgresHandlerError::Query)?;
273
274        let mut pairs = Vec::new();
275        for row in rows {
276            let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
277            let req_body_parsed = row
278                .try_get::<Option<bool>, _>("req_body_parsed")
279                .unwrap_or(Some(false));
280
281            let request_body = match req_body {
282                Some(json_value) => {
283                    if req_body_parsed == Some(true) {
284                        // Body was successfully parsed as TReq when stored
285                        Some(Ok(serde_json::from_value::<TReq>(json_value)
286                            .map_err(PostgresHandlerError::Json)?))
287                    } else {
288                        // Body is stored as UTF-8 string (raw content that failed to parse)
289                        if let Value::String(utf8_str) = json_value {
290                            Some(Err(Bytes::from(utf8_str.into_bytes())))
291                        } else {
292                            return Err(PostgresHandlerError::Json(Error::custom(
293                                "Invalid body format",
294                            )));
295                        }
296                    }
297                }
298                None => None,
299            };
300
301            let request = HttpRequest {
302                id: row.get("req_id"),
303                instance_id: row.get("req_instance_id"),
304                correlation_id: row.get("req_correlation_id"),
305                timestamp: row.get("req_timestamp"),
306                method: row.get("method"),
307                uri: row.get("uri"),
308                headers: row.get("req_headers"),
309                body: request_body,
310                created_at: row.get("req_created_at"),
311            };
312
313            let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
314                res_id
315                    .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
316                        let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
317                        let res_body_parsed = row
318                            .try_get::<Option<bool>, _>("res_body_parsed")
319                            .unwrap_or(Some(false));
320
321                        let response_body = match res_body {
322                            Some(json_value) => {
323                                if res_body_parsed == Some(true) {
324                                    // Body was successfully parsed as TRes when stored
325                                    Some(Ok(serde_json::from_value::<TRes>(json_value)
326                                        .map_err(PostgresHandlerError::Json)?))
327                                } else {
328                                    // Body is stored as UTF-8 string (raw content that failed to parse)
329                                    if let Value::String(utf8_str) = json_value {
330                                        Some(Err(Bytes::from(utf8_str.into_bytes())))
331                                    } else {
332                                        return Err(PostgresHandlerError::Json(Error::custom(
333                                            "Invalid body format",
334                                        )));
335                                    }
336                                }
337                            }
338                            None => None,
339                        };
340
341                        Ok(HttpResponse {
342                            id: row.get("res_id"),
343                            instance_id: row.get("res_instance_id"),
344                            correlation_id: row.get("res_correlation_id"),
345                            timestamp: row.get("res_timestamp"),
346                            status_code: row.get("status_code"),
347                            headers: row.get("res_headers"),
348                            body: response_body,
349                            duration_to_first_byte_ms: row.get("duration_to_first_byte_ms"),
350                            duration_ms: row.get("duration_ms"),
351                            created_at: row.get("res_created_at"),
352                        })
353                    })
354                    .transpose()?
355            } else {
356                None
357            };
358
359            pairs.push(RequestResponsePair { request, response });
360        }
361
362        Ok(pairs)
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use chrono::DateTime;
370    use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
371
372    fn validate_sql(sql: &str) -> Result<(), String> {
373        let dialect = PostgreSqlDialect {};
374        Parser::parse_sql(&dialect, sql)
375            .map_err(|e| format!("SQL parse error: {e}"))
376            .map(|_| ())
377    }
378
379    #[test]
380    fn test_default_filter_generates_valid_sql() {
381        let filter = RequestFilter::default();
382        let query = filter.build_query();
383        let sql = query.sql();
384
385        validate_sql(sql).unwrap();
386        assert!(sql.contains("ORDER BY r.timestamp ASC"));
387        assert!(!sql.contains("WHERE"));
388    }
389
390    #[test]
391    fn test_correlation_id_filter() {
392        let filter = RequestFilter {
393            correlation_id: Some(123),
394            ..Default::default()
395        };
396        let query = filter.build_query();
397        let sql = query.sql();
398
399        validate_sql(sql).unwrap();
400        assert!(sql.contains("WHERE r.correlation_id = $1"));
401    }
402
403    #[test]
404    fn test_method_filter() {
405        let filter = RequestFilter {
406            method: Some("POST".to_string()),
407            ..Default::default()
408        };
409        let query = filter.build_query();
410        let sql = query.sql();
411
412        validate_sql(sql).unwrap();
413        assert!(sql.contains("WHERE r.method = $1"));
414    }
415
416    #[test]
417    fn test_uri_pattern_filter() {
418        let filter = RequestFilter {
419            uri_pattern: Some("/api/%".to_string()),
420            ..Default::default()
421        };
422        let query = filter.build_query();
423        let sql = query.sql();
424
425        validate_sql(sql).unwrap();
426        assert!(sql.contains("WHERE r.uri ILIKE $1"));
427    }
428
429    #[test]
430    fn test_status_code_exact_filter() {
431        let filter = RequestFilter {
432            status_code: Some(404),
433            ..Default::default()
434        };
435        let query = filter.build_query();
436        let sql = query.sql();
437
438        validate_sql(sql).unwrap();
439        assert!(sql.contains("WHERE res.status_code = $1"));
440    }
441
442    #[test]
443    fn test_status_code_range_filters() {
444        let filter = RequestFilter {
445            status_code_min: Some(400),
446            status_code_max: Some(499),
447            ..Default::default()
448        };
449        let query = filter.build_query();
450        let sql = query.sql();
451
452        validate_sql(sql).unwrap();
453        assert!(sql.contains("WHERE res.status_code >= $1"));
454        assert!(sql.contains("AND res.status_code <= $2"));
455    }
456
457    #[test]
458    fn test_timestamp_filters() {
459        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
460            .unwrap()
461            .with_timezone(&Utc);
462        let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
463            .unwrap()
464            .with_timezone(&Utc);
465
466        let filter = RequestFilter {
467            timestamp_after: Some(after),
468            timestamp_before: Some(before),
469            ..Default::default()
470        };
471        let query = filter.build_query();
472        let sql = query.sql();
473
474        validate_sql(sql).unwrap();
475        assert!(sql.contains("WHERE r.timestamp >= $1"));
476        assert!(sql.contains("AND r.timestamp <= $2"));
477    }
478
479    #[test]
480    fn test_duration_filters() {
481        let filter = RequestFilter {
482            min_duration_ms: Some(100),
483            max_duration_ms: Some(5000),
484            ..Default::default()
485        };
486        let query = filter.build_query();
487        let sql = query.sql();
488
489        validate_sql(sql).unwrap();
490        assert!(sql.contains("WHERE res.duration_ms >= $1"));
491        assert!(sql.contains("AND res.duration_ms <= $2"));
492    }
493
494    #[test]
495    fn test_ordering_desc() {
496        let filter = RequestFilter {
497            order_by_timestamp_desc: true,
498            ..Default::default()
499        };
500        let query = filter.build_query();
501        let sql = query.sql();
502
503        validate_sql(sql).unwrap();
504        assert!(sql.contains("ORDER BY r.timestamp DESC"));
505    }
506
507    #[test]
508    fn test_ordering_asc() {
509        let filter = RequestFilter {
510            order_by_timestamp_desc: false,
511            ..Default::default()
512        };
513        let query = filter.build_query();
514        let sql = query.sql();
515
516        validate_sql(sql).unwrap();
517        assert!(sql.contains("ORDER BY r.timestamp ASC"));
518    }
519
520    #[test]
521    fn test_pagination() {
522        let filter = RequestFilter {
523            limit: Some(10),
524            offset: Some(20),
525            ..Default::default()
526        };
527        let query = filter.build_query();
528        let sql = query.sql();
529
530        validate_sql(sql).unwrap();
531        assert!(sql.contains("LIMIT $1"));
532        assert!(sql.contains("OFFSET $2"));
533    }
534
535    #[test]
536    fn test_multiple_filters_use_and() {
537        let filter = RequestFilter {
538            correlation_id: Some(123),
539            method: Some("POST".to_string()),
540            status_code: Some(200),
541            ..Default::default()
542        };
543        let query = filter.build_query();
544        let sql = query.sql();
545
546        validate_sql(sql).unwrap();
547        assert!(sql.contains("WHERE r.correlation_id = $1"));
548        assert!(sql.contains("AND r.method = $2"));
549        assert!(sql.contains("AND res.status_code = $3"));
550
551        // Should not have multiple WHERE clauses
552        assert_eq!(sql.matches("WHERE").count(), 1);
553        assert!(sql.matches("AND").count() >= 2);
554    }
555
556    #[test]
557    fn test_complex_filter_combination() {
558        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
559            .unwrap()
560            .with_timezone(&Utc);
561
562        let filter = RequestFilter {
563            correlation_id: Some(456),
564            method: Some("GET".to_string()),
565            uri_pattern: Some("/api/users%".to_string()),
566            status_code_min: Some(200),
567            status_code_max: Some(299),
568            timestamp_after: Some(after),
569            min_duration_ms: Some(50),
570            max_duration_ms: Some(1000),
571            limit: Some(100),
572            offset: Some(0),
573            order_by_timestamp_desc: true,
574            ..Default::default()
575        };
576        let query = filter.build_query();
577        let sql = query.sql();
578
579        validate_sql(sql).unwrap();
580
581        // Check all filters are present
582        assert!(sql.contains("WHERE r.correlation_id = $1"));
583        assert!(sql.contains("AND r.method = $2"));
584        assert!(sql.contains("AND r.uri ILIKE $3"));
585        assert!(sql.contains("AND res.status_code >= $4"));
586        assert!(sql.contains("AND res.status_code <= $5"));
587        assert!(sql.contains("AND r.timestamp >= $6"));
588        assert!(sql.contains("AND res.duration_ms >= $7"));
589        assert!(sql.contains("AND res.duration_ms <= $8"));
590        assert!(sql.contains("ORDER BY r.timestamp DESC"));
591        assert!(sql.contains("LIMIT $9"));
592        assert!(sql.contains("OFFSET $10"));
593
594        // Should have exactly one WHERE
595        assert_eq!(sql.matches("WHERE").count(), 1);
596    }
597
598    #[test]
599    fn test_no_filters_only_has_base_query() {
600        let filter = RequestFilter::default();
601        let query = filter.build_query();
602        let sql = query.sql();
603
604        validate_sql(sql).unwrap();
605
606        // Should contain base SELECT and JOIN
607        assert!(sql.contains("SELECT"));
608        assert!(sql.contains("FROM http_requests r"));
609        assert!(
610            sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
611        );
612
613        // Should not contain WHERE clause
614        assert!(!sql.contains("WHERE"));
615
616        // Should have default ordering
617        assert!(sql.contains("ORDER BY r.timestamp ASC"));
618    }
619}