1use base64::Engine;
2use bytes::Bytes;
3use chrono::{DateTime, Utc};
4use serde::de::Error;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use sqlx::{PgPool, QueryBuilder, Row};
8
9use crate::error::PostgresHandlerError;
10
11#[derive(Debug)]
12pub struct HttpRequest<TReq> {
13 pub id: i64,
14 pub correlation_id: i64,
15 pub timestamp: DateTime<Utc>,
16 pub method: String,
17 pub uri: String,
18 pub headers: Value,
19 pub body: Option<Result<TReq, Bytes>>,
20 pub created_at: DateTime<Utc>,
21}
22
23#[derive(Debug)]
24pub struct HttpResponse<TRes> {
25 pub id: i64,
26 pub correlation_id: i64,
27 pub timestamp: DateTime<Utc>,
28 pub status_code: i32,
29 pub headers: Value,
30 pub body: Option<Result<TRes, Bytes>>,
31 pub duration_ms: i64,
32 pub created_at: DateTime<Utc>,
33}
34
35#[derive(Debug)]
36pub struct RequestResponsePair<TReq, TRes> {
37 pub request: HttpRequest<TReq>,
38 pub response: Option<HttpResponse<TRes>>,
39}
40
41#[derive(Debug, Default)]
42pub struct RequestFilter {
43 pub correlation_id: Option<i64>,
44 pub method: Option<String>,
45 pub uri_pattern: Option<String>,
46 pub status_code: Option<i32>,
47 pub status_code_min: Option<i32>,
48 pub status_code_max: Option<i32>,
49 pub timestamp_after: Option<DateTime<Utc>>,
50 pub timestamp_before: Option<DateTime<Utc>>,
51 pub min_duration_ms: Option<i64>,
52 pub max_duration_ms: Option<i64>,
53 pub body_parsed: Option<bool>,
54 pub limit: Option<i64>,
55 pub offset: Option<i64>,
56 pub order_by_timestamp_desc: bool,
57}
58
59#[derive(Clone)]
60pub struct RequestRepository<TReq, TRes> {
61 pool: PgPool,
62 _phantom_req: std::marker::PhantomData<TReq>,
63 _phantom_res: std::marker::PhantomData<TRes>,
64}
65
66impl<TReq, TRes> RequestRepository<TReq, TRes>
67where
68 TReq: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
69 TRes: for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
70{
71 pub fn new(pool: PgPool) -> Self {
72 Self {
73 pool,
74 _phantom_req: std::marker::PhantomData,
75 _phantom_res: std::marker::PhantomData,
76 }
77 }
78
79 pub async fn query(
80 &self,
81 filter: RequestFilter,
82 ) -> Result<Vec<RequestResponsePair<TReq, TRes>>, PostgresHandlerError> {
83 let mut query = QueryBuilder::new(
84 r#"
85 SELECT
86 r.id as req_id, r.correlation_id as req_correlation_id, r.timestamp as req_timestamp,
87 r.method, r.uri, r.headers as req_headers, r.body as req_body, r.body_parsed as req_body_parsed, r.created_at as req_created_at,
88 res.id as res_id, res.correlation_id as res_correlation_id, res.timestamp as res_timestamp,
89 res.status_code, res.headers as res_headers, res.body as res_body, res.body_parsed as res_body_parsed, res.duration_ms, res.created_at as res_created_at
90 FROM http_requests r
91 LEFT JOIN http_responses res ON r.correlation_id = res.correlation_id
92 "#,
93 );
94
95 let mut where_added = false;
96
97 if let Some(correlation_id) = filter.correlation_id {
98 query.push(" WHERE r.correlation_id = ");
99 query.push_bind(correlation_id);
100 where_added = true;
101 }
102
103 if let Some(method) = &filter.method {
104 if where_added {
105 query.push(" AND ");
106 } else {
107 query.push(" WHERE ");
108 where_added = true;
109 }
110 query.push("r.method = ");
111 query.push_bind(method);
112 }
113
114 if let Some(uri_pattern) = &filter.uri_pattern {
115 if where_added {
116 query.push(" AND ");
117 } else {
118 query.push(" WHERE ");
119 where_added = true;
120 }
121 query.push("r.uri ILIKE ");
122 query.push_bind(uri_pattern);
123 }
124
125 if let Some(status_code) = filter.status_code {
126 if where_added {
127 query.push(" AND ");
128 } else {
129 query.push(" WHERE ");
130 where_added = true;
131 }
132 query.push("res.status_code = ");
133 query.push_bind(status_code);
134 }
135
136 if let Some(min_status) = filter.status_code_min {
137 if where_added {
138 query.push(" AND ");
139 } else {
140 query.push(" WHERE ");
141 where_added = true;
142 }
143 query.push("res.status_code >= ");
144 query.push_bind(min_status);
145 }
146
147 if let Some(max_status) = filter.status_code_max {
148 if where_added {
149 query.push(" AND ");
150 } else {
151 query.push(" WHERE ");
152 where_added = true;
153 }
154 query.push("res.status_code <= ");
155 query.push_bind(max_status);
156 }
157
158 if let Some(timestamp_after) = filter.timestamp_after {
159 if where_added {
160 query.push(" AND ");
161 } else {
162 query.push(" WHERE ");
163 where_added = true;
164 }
165 query.push("r.timestamp >= ");
166 query.push_bind(timestamp_after);
167 }
168
169 if let Some(timestamp_before) = filter.timestamp_before {
170 if where_added {
171 query.push(" AND ");
172 } else {
173 query.push(" WHERE ");
174 where_added = true;
175 }
176 query.push("r.timestamp <= ");
177 query.push_bind(timestamp_before);
178 }
179
180 if let Some(min_duration) = filter.min_duration_ms {
181 if where_added {
182 query.push(" AND ");
183 } else {
184 query.push(" WHERE ");
185 where_added = true;
186 }
187 query.push("res.duration_ms >= ");
188 query.push_bind(min_duration);
189 }
190
191 if let Some(max_duration) = filter.max_duration_ms {
192 if where_added {
193 query.push(" AND ");
194 } else {
195 query.push(" WHERE ");
196 }
197 query.push("res.duration_ms <= ");
198 query.push_bind(max_duration);
199 }
200
201 if filter.order_by_timestamp_desc {
202 query.push(" ORDER BY r.timestamp DESC");
203 } else {
204 query.push(" ORDER BY r.timestamp ASC");
205 }
206
207 if let Some(limit) = filter.limit {
208 query.push(" LIMIT ");
209 query.push_bind(limit);
210 }
211
212 if let Some(offset) = filter.offset {
213 query.push(" OFFSET ");
214 query.push_bind(offset);
215 }
216
217 let rows = query
218 .build()
219 .fetch_all(&self.pool)
220 .await
221 .map_err(PostgresHandlerError::Query)?;
222
223 let mut pairs = Vec::new();
224 for row in rows {
225 let req_body = row.try_get::<Option<Value>, _>("req_body").unwrap_or(None);
226 let req_body_parsed = row
227 .try_get::<Option<bool>, _>("req_body_parsed")
228 .unwrap_or(Some(false));
229
230 let request_body = match req_body {
231 Some(json_value) => {
232 if req_body_parsed == Some(true) {
233 Some(Ok(serde_json::from_value::<TReq>(json_value)
235 .map_err(PostgresHandlerError::Json)?))
236 } else {
237 if let Value::String(base64_str) = json_value {
239 let decoded_bytes = base64::engine::general_purpose::STANDARD
240 .decode(&base64_str)
241 .map_err(|_| {
242 PostgresHandlerError::Json(Error::custom(
243 "Failed to decode base64",
244 ))
245 })?;
246 Some(Err(Bytes::from(decoded_bytes)))
247 } else {
248 return Err(PostgresHandlerError::Json(Error::custom(
249 "Invalid body format",
250 )));
251 }
252 }
253 }
254 None => None,
255 };
256
257 let request = HttpRequest {
258 id: row.get("req_id"),
259 correlation_id: row.get("req_correlation_id"),
260 timestamp: row.get("req_timestamp"),
261 method: row.get("method"),
262 uri: row.get("uri"),
263 headers: row.get("req_headers"),
264 body: request_body,
265 created_at: row.get("req_created_at"),
266 };
267
268 let response = if let Ok(res_id) = row.try_get::<Option<i64>, _>("res_id") {
269 res_id
270 .map(|_| -> Result<HttpResponse<TRes>, PostgresHandlerError> {
271 let res_body = row.try_get::<Option<Value>, _>("res_body").unwrap_or(None);
272 let res_body_parsed = row
273 .try_get::<Option<bool>, _>("res_body_parsed")
274 .unwrap_or(Some(false));
275
276 let response_body = match res_body {
277 Some(json_value) => {
278 if res_body_parsed == Some(true) {
279 Some(Ok(serde_json::from_value::<TRes>(json_value)
281 .map_err(PostgresHandlerError::Json)?))
282 } else {
283 if let Value::String(base64_str) = json_value {
285 let decoded_bytes =
286 base64::engine::general_purpose::STANDARD
287 .decode(&base64_str)
288 .map_err(|_| {
289 PostgresHandlerError::Json(Error::custom(
290 "Failed to decode base64",
291 ))
292 })?;
293 Some(Err(Bytes::from(decoded_bytes)))
294 } else {
295 return Err(PostgresHandlerError::Json(Error::custom(
296 "Invalid body format",
297 )));
298 }
299 }
300 }
301 None => None,
302 };
303
304 Ok(HttpResponse {
305 id: row.get("res_id"),
306 correlation_id: row.get("res_correlation_id"),
307 timestamp: row.get("res_timestamp"),
308 status_code: row.get("status_code"),
309 headers: row.get("res_headers"),
310 body: response_body,
311 duration_ms: row.get("duration_ms"),
312 created_at: row.get("res_created_at"),
313 })
314 })
315 .transpose()?
316 } else {
317 None
318 };
319
320 pairs.push(RequestResponsePair { request, response });
321 }
322
323 Ok(pairs)
324 }
325}