mockforge_recorder/
query.rs1use crate::{database::RecorderDatabase, models::*, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8pub struct QueryFilter {
9 pub protocol: Option<Protocol>,
11 pub method: Option<String>,
13 pub path: Option<String>,
15 pub status_code: Option<i32>,
17 pub trace_id: Option<String>,
19 pub min_duration_ms: Option<i64>,
21 pub max_duration_ms: Option<i64>,
23 pub tags: Option<Vec<String>>,
25 pub limit: Option<i32>,
27 pub offset: Option<i32>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct QueryResult {
34 pub total: i64,
35 pub offset: i32,
36 pub limit: i32,
37 pub exchanges: Vec<RecordedExchange>,
38}
39
40pub async fn execute_query(db: &RecorderDatabase, filter: QueryFilter) -> Result<QueryResult> {
42 let limit = filter.limit.unwrap_or(100);
43 let offset = filter.offset.unwrap_or(0);
44
45 let fetch_window = std::cmp::max(limit + offset, 1000);
48 let requests = db.list_recent(fetch_window).await?;
49
50 let mut filtered: Vec<RecordedRequest> = requests
51 .into_iter()
52 .filter(|request| request_matches_filter(request, &filter))
53 .collect();
54
55 let total = filtered.len() as i64;
56 filtered = filtered.into_iter().skip(offset as usize).take(limit as usize).collect();
57
58 let mut exchanges = Vec::new();
60 for request in filtered {
61 let response = db.get_response(&request.id).await?;
62 exchanges.push(RecordedExchange { request, response });
63 }
64
65 Ok(QueryResult {
66 total,
67 offset,
68 limit,
69 exchanges,
70 })
71}
72
73fn request_matches_filter(request: &RecordedRequest, filter: &QueryFilter) -> bool {
74 if let Some(protocol) = &filter.protocol {
75 if &request.protocol != protocol {
76 return false;
77 }
78 }
79
80 if let Some(method) = &filter.method {
81 if request.method != *method {
82 return false;
83 }
84 }
85
86 if let Some(path_filter) = &filter.path {
87 let request_path = request.path.as_str();
88 if path_filter.contains('*') {
89 let pattern = path_filter.replace('*', "");
90 if !request_path.contains(&pattern) {
91 return false;
92 }
93 } else if request_path != path_filter {
94 return false;
95 }
96 }
97
98 if let Some(status_code) = filter.status_code {
99 if request.status_code != Some(status_code) {
100 return false;
101 }
102 }
103
104 if let Some(trace_id) = &filter.trace_id {
105 if request.trace_id.as_deref() != Some(trace_id.as_str()) {
106 return false;
107 }
108 }
109
110 if let Some(min_duration) = filter.min_duration_ms {
111 let duration = request.duration_ms.unwrap_or_default();
112 if duration < min_duration {
113 return false;
114 }
115 }
116
117 if let Some(max_duration) = filter.max_duration_ms {
118 let duration = request.duration_ms.unwrap_or_default();
119 if duration > max_duration {
120 return false;
121 }
122 }
123
124 if let Some(required_tags) = &filter.tags {
125 let request_tags = request.tags_vec();
126 if required_tags
127 .iter()
128 .any(|required| !request_tags.iter().any(|actual| actual == required))
129 {
130 return false;
131 }
132 }
133
134 true
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_query_filter_creation() {
143 let filter = QueryFilter {
144 protocol: Some(Protocol::Http),
145 method: Some("GET".to_string()),
146 path: Some("/api/*".to_string()),
147 ..Default::default()
148 };
149
150 assert_eq!(filter.protocol, Some(Protocol::Http));
151 assert_eq!(filter.method, Some("GET".to_string()));
152 }
153
154 #[test]
155 fn test_request_matches_filter() {
156 let request = RecordedRequest {
157 id: "req-1".to_string(),
158 protocol: Protocol::Http,
159 timestamp: chrono::Utc::now(),
160 method: "GET".to_string(),
161 path: "/api/users/123".to_string(),
162 query_params: None,
163 headers: "{}".to_string(),
164 body: None,
165 body_encoding: "utf8".to_string(),
166 client_ip: None,
167 trace_id: Some("trace-1".to_string()),
168 span_id: None,
169 duration_ms: Some(42),
170 status_code: Some(200),
171 tags: Some(r#"["users","read"]"#.to_string()),
172 };
173
174 let filter = QueryFilter {
175 protocol: Some(Protocol::Http),
176 method: Some("GET".to_string()),
177 path: Some("/api/users/*".to_string()),
178 status_code: Some(200),
179 trace_id: Some("trace-1".to_string()),
180 min_duration_ms: Some(40),
181 max_duration_ms: Some(100),
182 tags: Some(vec!["users".to_string()]),
183 limit: Some(10),
184 offset: Some(0),
185 };
186
187 assert!(request_matches_filter(&request, &filter));
188 }
189}