Skip to main content

mockforge_recorder/
query.rs

1//! Query API for recorded requests
2
3use crate::{database::RecorderDatabase, models::*, Result};
4use serde::{Deserialize, Serialize};
5
6/// Query filter for searching recorded requests
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8pub struct QueryFilter {
9    /// Filter by protocol
10    pub protocol: Option<Protocol>,
11    /// Filter by HTTP method or gRPC method
12    pub method: Option<String>,
13    /// Filter by path (supports wildcards)
14    pub path: Option<String>,
15    /// Filter by status code
16    pub status_code: Option<i32>,
17    /// Filter by trace ID
18    pub trace_id: Option<String>,
19    /// Filter by minimum duration (ms)
20    pub min_duration_ms: Option<i64>,
21    /// Filter by maximum duration (ms)
22    pub max_duration_ms: Option<i64>,
23    /// Filter by tags
24    pub tags: Option<Vec<String>>,
25    /// Limit number of results
26    pub limit: Option<i32>,
27    /// Offset for pagination
28    pub offset: Option<i32>,
29}
30
31/// Query result
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct QueryResult {
34    pub total: i64,
35    pub offset: i32,
36    pub limit: i32,
37    pub exchanges: Vec<RecordedExchange>,
38}
39
40/// Execute a query against the database
41pub async fn execute_query(db: &RecorderDatabase, filter: QueryFilter) -> Result<QueryResult> {
42    let limit = filter.limit.unwrap_or(100);
43    let offset = filter.offset.unwrap_or(0);
44
45    // Fetch a sufficiently large recent window and apply filters in memory.
46    // This avoids the previous placeholder behavior where filters were ignored.
47    let fetch_window = std::cmp::max(limit + offset, 1000);
48    let requests = db.list_recent(fetch_window).await?;
49
50    let mut filtered: Vec<RecordedRequest> = requests
51        .into_iter()
52        .filter(|request| request_matches_filter(request, &filter))
53        .collect();
54
55    let total = filtered.len() as i64;
56    filtered = filtered.into_iter().skip(offset as usize).take(limit as usize).collect();
57
58    // Fetch responses for each request
59    let mut exchanges = Vec::new();
60    for request in filtered {
61        let response = db.get_response(&request.id).await?;
62        exchanges.push(RecordedExchange { request, response });
63    }
64
65    Ok(QueryResult {
66        total,
67        offset,
68        limit,
69        exchanges,
70    })
71}
72
73fn request_matches_filter(request: &RecordedRequest, filter: &QueryFilter) -> bool {
74    if let Some(protocol) = &filter.protocol {
75        if &request.protocol != protocol {
76            return false;
77        }
78    }
79
80    if let Some(method) = &filter.method {
81        if request.method != *method {
82            return false;
83        }
84    }
85
86    if let Some(path_filter) = &filter.path {
87        let request_path = request.path.as_str();
88        if path_filter.contains('*') {
89            let pattern = path_filter.replace('*', "");
90            if !request_path.contains(&pattern) {
91                return false;
92            }
93        } else if request_path != path_filter {
94            return false;
95        }
96    }
97
98    if let Some(status_code) = filter.status_code {
99        if request.status_code != Some(status_code) {
100            return false;
101        }
102    }
103
104    if let Some(trace_id) = &filter.trace_id {
105        if request.trace_id.as_deref() != Some(trace_id.as_str()) {
106            return false;
107        }
108    }
109
110    if let Some(min_duration) = filter.min_duration_ms {
111        let duration = request.duration_ms.unwrap_or_default();
112        if duration < min_duration {
113            return false;
114        }
115    }
116
117    if let Some(max_duration) = filter.max_duration_ms {
118        let duration = request.duration_ms.unwrap_or_default();
119        if duration > max_duration {
120            return false;
121        }
122    }
123
124    if let Some(required_tags) = &filter.tags {
125        let request_tags = request.tags_vec();
126        if required_tags
127            .iter()
128            .any(|required| !request_tags.iter().any(|actual| actual == required))
129        {
130            return false;
131        }
132    }
133
134    true
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_query_filter_creation() {
143        let filter = QueryFilter {
144            protocol: Some(Protocol::Http),
145            method: Some("GET".to_string()),
146            path: Some("/api/*".to_string()),
147            ..Default::default()
148        };
149
150        assert_eq!(filter.protocol, Some(Protocol::Http));
151        assert_eq!(filter.method, Some("GET".to_string()));
152    }
153
154    #[test]
155    fn test_request_matches_filter() {
156        let request = RecordedRequest {
157            id: "req-1".to_string(),
158            protocol: Protocol::Http,
159            timestamp: chrono::Utc::now(),
160            method: "GET".to_string(),
161            path: "/api/users/123".to_string(),
162            query_params: None,
163            headers: "{}".to_string(),
164            body: None,
165            body_encoding: "utf8".to_string(),
166            client_ip: None,
167            trace_id: Some("trace-1".to_string()),
168            span_id: None,
169            duration_ms: Some(42),
170            status_code: Some(200),
171            tags: Some(r#"["users","read"]"#.to_string()),
172        };
173
174        let filter = QueryFilter {
175            protocol: Some(Protocol::Http),
176            method: Some("GET".to_string()),
177            path: Some("/api/users/*".to_string()),
178            status_code: Some(200),
179            trace_id: Some("trace-1".to_string()),
180            min_duration_ms: Some(40),
181            max_duration_ms: Some(100),
182            tags: Some(vec!["users".to_string()]),
183            limit: Some(10),
184            offset: Some(0),
185        };
186
187        assert!(request_matches_filter(&request, &filter));
188    }
189}