Skip to main content

outlet_postgres/
repository.rs

1use bytes::Bytes;
2use chrono::{DateTime, Utc};
3use serde::de::Error;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{QueryBuilder, Row};
7use sqlx_pool_router::PoolProvider;
8use uuid::Uuid;
9
10use crate::error::PostgresHandlerError;
11
12#[derive(Debug)]
13pub struct HttpRequest<TReq> {
14    pub id: i64,
15    pub instance_id: Uuid,
16    pub correlation_id: i64,
17    pub timestamp: DateTime<Utc>,
18    pub method: String,
19    pub uri: String,
20    pub headers: Value,
21    pub body: Option<Result<TReq, Bytes>>,
22    pub created_at: DateTime<Utc>,
23    pub trace_id: Option<String>,
24    pub span_id: Option<String>,
25}
26
27#[derive(Debug)]
28pub struct HttpResponse<TRes> {
29    pub id: i64,
30    pub instance_id: Uuid,
31    pub correlation_id: i64,
32    pub timestamp: DateTime<Utc>,
33    pub status_code: i32,
34    pub headers: Value,
35    pub body: Option<Result<TRes, Bytes>>,
36    pub duration_to_first_byte_ms: i64,
37    pub duration_ms: i64,
38    pub created_at: DateTime<Utc>,
39}
40
41#[derive(Debug)]
42pub struct RequestResponsePair<TReq, TRes> {
43    pub request: HttpRequest<TReq>,
44    pub response: Option<HttpResponse<TRes>>,
45}
46
47#[derive(Debug, Default)]
48pub struct RequestFilter {
49    pub instance_id: Option<Uuid>,
50    pub correlation_id: Option<i64>,
51    pub method: Option<String>,
52    pub uri_pattern: Option<String>,
53    pub status_code: Option<i32>,
54    pub status_code_min: Option<i32>,
55    pub status_code_max: Option<i32>,
56    pub timestamp_after: Option<DateTime<Utc>>,
57    pub timestamp_before: Option<DateTime<Utc>>,
58    pub min_duration_ms: Option<i64>,
59    pub max_duration_ms: Option<i64>,
60    pub min_duration_to_first_byte_ms: Option<i64>,
61    pub max_duration_to_first_byte_ms: Option<i64>,
62    pub body_parsed: Option<bool>,
63    pub limit: Option<i64>,
64    pub offset: Option<i64>,
65    pub order_by_timestamp_desc: bool,
66}
67
68impl RequestFilter {
69    pub fn build_query(&self) -> QueryBuilder<'_, sqlx::Postgres> {
70        let mut query = QueryBuilder::new(
71            r#"
72            SELECT 
73                r.id as req_id, r.instance_id as req_instance_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp, 
74                r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
75                r.trace_id, r.span_id,
76                res.id as res_id, res.instance_id as res_instance_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
77                res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_to_first_byte_ms, res.duration_ms, res.created_at as res_created_at
78            FROM http_requests r
79            LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)
80            "#,
81        );
82
83        let mut where_added = false;
84
85        if let Some(instance_id) = self.instance_id {
86            query.push(" WHERE r.instance_id = ");
87            query.push_bind(instance_id);
88            where_added = true;
89        }
90
91        if let Some(correlation_id) = self.correlation_id {
92            if where_added {
93                query.push(" AND ");
94            } else {
95                query.push(" WHERE ");
96                where_added = true;
97            }
98            query.push("r.correlation_id = ");
99            query.push_bind(correlation_id);
100        }
101
102        if let Some(method) = &self.method {
103            if where_added {
104                query.push(" AND ");
105            } else {
106                query.push(" WHERE ");
107                where_added = true;
108            }
109            query.push("r.method = ");
110            query.push_bind(method);
111        }
112
113        if let Some(uri_pattern) = &self.uri_pattern {
114            if where_added {
115                query.push(" AND ");
116            } else {
117                query.push(" WHERE ");
118                where_added = true;
119            }
120            query.push("r.uri ILIKE ");
121            query.push_bind(uri_pattern);
122        }
123
124        if let Some(status_code) = self.status_code {
125            if where_added {
126                query.push(" AND ");
127            } else {
128                query.push(" WHERE ");
129                where_added = true;
130            }
131            query.push("res.status_code = ");
132            query.push_bind(status_code);
133        }
134
135        if let Some(min_status) = self.status_code_min {
136            if where_added {
137                query.push(" AND ");
138            } else {
139                query.push(" WHERE ");
140                where_added = true;
141            }
142            query.push("res.status_code >= ");
143            query.push_bind(min_status);
144        }
145
146        if let Some(max_status) = self.status_code_max {
147            if where_added {
148                query.push(" AND ");
149            } else {
150                query.push(" WHERE ");
151                where_added = true;
152            }
153            query.push("res.status_code <= ");
154            query.push_bind(max_status);
155        }
156
157        if let Some(timestamp_after) = self.timestamp_after {
158            if where_added {
159                query.push(" AND ");
160            } else {
161                query.push(" WHERE ");
162                where_added = true;
163            }
164            query.push("r.timestamp >= ");
165            query.push_bind(timestamp_after);
166        }
167
168        if let Some(timestamp_before) = self.timestamp_before {
169            if where_added {
170                query.push(" AND ");
171            } else {
172                query.push(" WHERE ");
173                where_added = true;
174            }
175            query.push("r.timestamp <= ");
176            query.push_bind(timestamp_before);
177        }
178
179        if let Some(min_duration) = self.min_duration_ms {
180            if where_added {
181                query.push(" AND ");
182            } else {
183                query.push(" WHERE ");
184                where_added = true;
185            }
186            query.push("res.duration_ms >= ");
187            query.push_bind(min_duration);
188        }
189
190        if let Some(max_duration) = self.max_duration_ms {
191            if where_added {
192                query.push(" AND ");
193            } else {
194                query.push(" WHERE ");
195                where_added = true;
196            }
197            query.push("res.duration_ms <= ");
198            query.push_bind(max_duration);
199        }
200
201        if let Some(min_duration_to_first_byte) = self.min_duration_to_first_byte_ms {
202            if where_added {
203                query.push(" AND ");
204            } else {
205                query.push(" WHERE ");
206                where_added = true;
207            }
208            query.push("res.duration_to_first_byte_ms >= ");
209            query.push_bind(min_duration_to_first_byte);
210        }
211
212        if let Some(max_duration_to_first_byte) = self.max_duration_to_first_byte_ms {
213            if where_added {
214                query.push(" AND ");
215            } else {
216                query.push(" WHERE ");
217            }
218            query.push("res.duration_to_first_byte_ms <= ");
219            query.push_bind(max_duration_to_first_byte);
220        }
221
222        if self.order_by_timestamp_desc {
223            query.push(" ORDER BY r.timestamp DESC");
224        } else {
225            query.push(" ORDER BY r.timestamp ASC");
226        }
227
228        if let Some(limit) = self.limit {
229            query.push(" LIMIT ");
230            query.push_bind(limit);
231        }
232
233        if let Some(offset) = self.offset {
234            query.push(" OFFSET ");
235            query.push_bind(offset);
236        }
237
238        query
239    }
240}
241
242#[derive(Clone)]
243pub struct RequestRepository<P, TReq, TRes>
244where
245    P: PoolProvider,
246{
247    pool: P,
248    _phantom_req: std::marker::PhantomData<TReq>,
249    _phantom_res: std::marker::PhantomData<TRes>,
250}
251
252impl<P, TReq, TRes> RequestRepository<P, TReq, TRes>
253where
254    P: PoolProvider,
255    TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
256    TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
257{
258    pub fn new(pool: P) -> Self {
259        Self {
260            pool,
261            _phantom_req: std::marker::PhantomData,
262            _phantom_res: std::marker::PhantomData,
263        }
264    }
265
266    pub async fn query(
267        &self,
268        filter: RequestFilter,
269    ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
270        let rows = filter
271            .build_query()
272            .build()
273            .fetch_all(self.pool.read())
274            .await
275            .map_err(PostgresHandlerError::Query)?;
276
277        let mut pairs = Vec::new();
278        for row in rows {
279            let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
280            let req_body_parsed = row
281                .try_get::<Option<bool>, _>("req_body_parsed")
282                .unwrap_or(Some(false));
283
284            let request_body = match req_body {
285                Some(json_value) => {
286                    if req_body_parsed == Some(true) {
287                        // Body was successfully parsed as TReq when stored
288                        Some(Ok(serde_json::from_value::<TReq>(json_value)
289                            .map_err(PostgresHandlerError::Json)?))
290                    } else {
291                        // Body is stored as UTF-8 string (raw content that failed to parse)
292                        if let Value::String(utf8_str) = json_value {
293                            Some(Err(Bytes::from(utf8_str.into_bytes())))
294                        } else {
295                            return Err(PostgresHandlerError::Json(Error::custom(
296                                "Invalid body format",
297                            )));
298                        }
299                    }
300                }
301                None => None,
302            };
303
304            let request = HttpRequest {
305                id: row.get("req_id"),
306                instance_id: row.get("req_instance_id"),
307                correlation_id: row.get("req_correlation_id"),
308                timestamp: row.get("req_timestamp"),
309                method: row.get("method"),
310                uri: row.get("uri"),
311                headers: row.get("req_headers"),
312                body: request_body,
313                created_at: row.get("req_created_at"),
314                trace_id: row.try_get("trace_id").unwrap_or(None),
315                span_id: row.try_get("span_id").unwrap_or(None),
316            };
317
318            let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
319                res_id
320                    .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
321                        let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
322                        let res_body_parsed = row
323                            .try_get::<Option<bool>, _>("res_body_parsed")
324                            .unwrap_or(Some(false));
325
326                        let response_body = match res_body {
327                            Some(json_value) => {
328                                if res_body_parsed == Some(true) {
329                                    // Body was successfully parsed as TRes when stored
330                                    Some(Ok(serde_json::from_value::<TRes>(json_value)
331                                        .map_err(PostgresHandlerError::Json)?))
332                                } else {
333                                    // Body is stored as UTF-8 string (raw content that failed to parse)
334                                    if let Value::String(utf8_str) = json_value {
335                                        Some(Err(Bytes::from(utf8_str.into_bytes())))
336                                    } else {
337                                        return Err(PostgresHandlerError::Json(Error::custom(
338                                            "Invalid body format",
339                                        )));
340                                    }
341                                }
342                            }
343                            None => None,
344                        };
345
346                        Ok(HttpResponse {
347                            id: row.get("res_id"),
348                            instance_id: row.get("res_instance_id"),
349                            correlation_id: row.get("res_correlation_id"),
350                            timestamp: row.get("res_timestamp"),
351                            status_code: row.get("status_code"),
352                            headers: row.get("res_headers"),
353                            body: response_body,
354                            duration_to_first_byte_ms: row.get("duration_to_first_byte_ms"),
355                            duration_ms: row.get("duration_ms"),
356                            created_at: row.get("res_created_at"),
357                        })
358                    })
359                    .transpose()?
360            } else {
361                None
362            };
363
364            pairs.push(RequestResponsePair { request, response });
365        }
366
367        Ok(pairs)
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use chrono::DateTime;
375    use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
376
377    fn validate_sql(sql: &str) -> Result<(), String> {
378        let dialect = PostgreSqlDialect {};
379        Parser::parse_sql(&dialect, sql)
380            .map_err(|e| format!("SQL parse error: {e}"))
381            .map(|_| ())
382    }
383
384    #[test]
385    fn test_default_filter_generates_valid_sql() {
386        let filter = RequestFilter::default();
387        let query = filter.build_query();
388        let sql = query.sql();
389
390        validate_sql(sql).unwrap();
391        assert!(sql.contains("ORDER BY r.timestamp ASC"));
392        assert!(!sql.contains("WHERE"));
393    }
394
395    #[test]
396    fn test_correlation_id_filter() {
397        let filter = RequestFilter {
398            correlation_id: Some(123),
399            ..Default::default()
400        };
401        let query = filter.build_query();
402        let sql = query.sql();
403
404        validate_sql(sql).unwrap();
405        assert!(sql.contains("WHERE r.correlation_id = $1"));
406    }
407
408    #[test]
409    fn test_method_filter() {
410        let filter = RequestFilter {
411            method: Some("POST".to_string()),
412            ..Default::default()
413        };
414        let query = filter.build_query();
415        let sql = query.sql();
416
417        validate_sql(sql).unwrap();
418        assert!(sql.contains("WHERE r.method = $1"));
419    }
420
421    #[test]
422    fn test_uri_pattern_filter() {
423        let filter = RequestFilter {
424            uri_pattern: Some("/api/%".to_string()),
425            ..Default::default()
426        };
427        let query = filter.build_query();
428        let sql = query.sql();
429
430        validate_sql(sql).unwrap();
431        assert!(sql.contains("WHERE r.uri ILIKE $1"));
432    }
433
434    #[test]
435    fn test_status_code_exact_filter() {
436        let filter = RequestFilter {
437            status_code: Some(404),
438            ..Default::default()
439        };
440        let query = filter.build_query();
441        let sql = query.sql();
442
443        validate_sql(sql).unwrap();
444        assert!(sql.contains("WHERE res.status_code = $1"));
445    }
446
447    #[test]
448    fn test_status_code_range_filters() {
449        let filter = RequestFilter {
450            status_code_min: Some(400),
451            status_code_max: Some(499),
452            ..Default::default()
453        };
454        let query = filter.build_query();
455        let sql = query.sql();
456
457        validate_sql(sql).unwrap();
458        assert!(sql.contains("WHERE res.status_code >= $1"));
459        assert!(sql.contains("AND res.status_code <= $2"));
460    }
461
462    #[test]
463    fn test_timestamp_filters() {
464        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
465            .unwrap()
466            .with_timezone(&Utc);
467        let before = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
468            .unwrap()
469            .with_timezone(&Utc);
470
471        let filter = RequestFilter {
472            timestamp_after: Some(after),
473            timestamp_before: Some(before),
474            ..Default::default()
475        };
476        let query = filter.build_query();
477        let sql = query.sql();
478
479        validate_sql(sql).unwrap();
480        assert!(sql.contains("WHERE r.timestamp >= $1"));
481        assert!(sql.contains("AND r.timestamp <= $2"));
482    }
483
484    #[test]
485    fn test_duration_filters() {
486        let filter = RequestFilter {
487            min_duration_ms: Some(100),
488            max_duration_ms: Some(5000),
489            ..Default::default()
490        };
491        let query = filter.build_query();
492        let sql = query.sql();
493
494        validate_sql(sql).unwrap();
495        assert!(sql.contains("WHERE res.duration_ms >= $1"));
496        assert!(sql.contains("AND res.duration_ms <= $2"));
497    }
498
499    #[test]
500    fn test_ordering_desc() {
501        let filter = RequestFilter {
502            order_by_timestamp_desc: true,
503            ..Default::default()
504        };
505        let query = filter.build_query();
506        let sql = query.sql();
507
508        validate_sql(sql).unwrap();
509        assert!(sql.contains("ORDER BY r.timestamp DESC"));
510    }
511
512    #[test]
513    fn test_ordering_asc() {
514        let filter = RequestFilter {
515            order_by_timestamp_desc: false,
516            ..Default::default()
517        };
518        let query = filter.build_query();
519        let sql = query.sql();
520
521        validate_sql(sql).unwrap();
522        assert!(sql.contains("ORDER BY r.timestamp ASC"));
523    }
524
525    #[test]
526    fn test_pagination() {
527        let filter = RequestFilter {
528            limit: Some(10),
529            offset: Some(20),
530            ..Default::default()
531        };
532        let query = filter.build_query();
533        let sql = query.sql();
534
535        validate_sql(sql).unwrap();
536        assert!(sql.contains("LIMIT $1"));
537        assert!(sql.contains("OFFSET $2"));
538    }
539
540    #[test]
541    fn test_multiple_filters_use_and() {
542        let filter = RequestFilter {
543            correlation_id: Some(123),
544            method: Some("POST".to_string()),
545            status_code: Some(200),
546            ..Default::default()
547        };
548        let query = filter.build_query();
549        let sql = query.sql();
550
551        validate_sql(sql).unwrap();
552        assert!(sql.contains("WHERE r.correlation_id = $1"));
553        assert!(sql.contains("AND r.method = $2"));
554        assert!(sql.contains("AND res.status_code = $3"));
555
556        // Should not have multiple WHERE clauses
557        assert_eq!(sql.matches("WHERE").count(), 1);
558        assert!(sql.matches("AND").count() >= 2);
559    }
560
561    #[test]
562    fn test_complex_filter_combination() {
563        let after = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
564            .unwrap()
565            .with_timezone(&Utc);
566
567        let filter = RequestFilter {
568            correlation_id: Some(456),
569            method: Some("GET".to_string()),
570            uri_pattern: Some("/api/users%".to_string()),
571            status_code_min: Some(200),
572            status_code_max: Some(299),
573            timestamp_after: Some(after),
574            min_duration_ms: Some(50),
575            max_duration_ms: Some(1000),
576            limit: Some(100),
577            offset: Some(0),
578            order_by_timestamp_desc: true,
579            ..Default::default()
580        };
581        let query = filter.build_query();
582        let sql = query.sql();
583
584        validate_sql(sql).unwrap();
585
586        // Check all filters are present
587        assert!(sql.contains("WHERE r.correlation_id = $1"));
588        assert!(sql.contains("AND r.method = $2"));
589        assert!(sql.contains("AND r.uri ILIKE $3"));
590        assert!(sql.contains("AND res.status_code >= $4"));
591        assert!(sql.contains("AND res.status_code <= $5"));
592        assert!(sql.contains("AND r.timestamp >= $6"));
593        assert!(sql.contains("AND res.duration_ms >= $7"));
594        assert!(sql.contains("AND res.duration_ms <= $8"));
595        assert!(sql.contains("ORDER BY r.timestamp DESC"));
596        assert!(sql.contains("LIMIT $9"));
597        assert!(sql.contains("OFFSET $10"));
598
599        // Should have exactly one WHERE
600        assert_eq!(sql.matches("WHERE").count(), 1);
601    }
602
603    #[test]
604    fn test_no_filters_only_has_base_query() {
605        let filter = RequestFilter::default();
606        let query = filter.build_query();
607        let sql = query.sql();
608
609        validate_sql(sql).unwrap();
610
611        // Should contain base SELECT and JOIN
612        assert!(sql.contains("SELECT"));
613        assert!(sql.contains("FROM http_requests r"));
614        assert!(
615            sql.contains("LEFT JOIN http_responses res ON (r.instance_id = res.instance_id AND r.correlation_id = res.correlation_id)")
616        );
617
618        // Should not contain WHERE clause
619        assert!(!sql.contains("WHERE"));
620
621        // Should have default ordering
622        assert!(sql.contains("ORDER BY r.timestamp ASC"));
623    }
624}