1use std::collections::HashMap;
4use std::sync::Arc;
5
6use serde::Serialize;
7use tokio::sync::RwLock;
8
9#[allow(clippy::indexing_slicing)]
11fn glob_matches(pattern: &str, text: &str) -> bool {
12 let pat = pattern.as_bytes();
13 let txt = text.as_bytes();
14 let (mut pi, mut ti) = (0, 0);
15 let (mut star_pi, mut star_ti) = (usize::MAX, 0);
16
17 while ti < txt.len() {
18 if pi < pat.len() && (pat[pi] == b'?' || pat[pi] == txt[ti]) {
19 pi += 1;
20 ti += 1;
21 } else if pi < pat.len() && pat[pi] == b'*' {
22 star_pi = pi;
23 star_ti = ti;
24 pi += 1;
25 } else if star_pi != usize::MAX {
26 pi = star_pi + 1;
27 star_ti += 1;
28 ti = star_ti;
29 } else {
30 return false;
31 }
32 }
33
34 while pi < pat.len() && pat[pi] == b'*' {
35 pi += 1;
36 }
37 pi == pat.len()
38}
39
40#[derive(Clone)]
42pub struct MockHttp {
43 mocks: Arc<RwLock<Vec<MockHandler>>>,
44 requests: Arc<RwLock<Vec<RecordedRequest>>>,
45}
46
47struct MockHandler {
49 pattern: String,
50 handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
51}
52
53#[derive(Debug, Clone)]
55pub struct RecordedRequest {
56 pub method: String,
58 pub url: String,
60 pub headers: HashMap<String, String>,
62 pub body: serde_json::Value,
64}
65
66#[derive(Debug, Clone)]
68pub struct MockRequest {
69 pub method: String,
71 pub path: String,
73 pub url: String,
75 pub headers: HashMap<String, String>,
77 pub body: serde_json::Value,
79}
80
81#[derive(Debug, Clone)]
83pub struct MockResponse {
84 pub status: u16,
86 pub headers: HashMap<String, String>,
88 pub body: serde_json::Value,
90}
91
92impl MockResponse {
93 pub fn json<T: Serialize>(body: T) -> Self {
95 Self {
96 status: 200,
97 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
98 body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
99 }
100 }
101
102 pub fn error(status: u16, message: &str) -> Self {
104 Self {
105 status,
106 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
107 body: serde_json::json!({ "error": message }),
108 }
109 }
110
111 pub fn internal_error(message: &str) -> Self {
113 Self::error(500, message)
114 }
115
116 pub fn not_found(message: &str) -> Self {
118 Self::error(404, message)
119 }
120
121 pub fn unauthorized(message: &str) -> Self {
123 Self::error(401, message)
124 }
125
126 pub fn ok() -> Self {
128 Self::json(serde_json::json!({}))
129 }
130}
131
132impl MockHttp {
133 pub fn new() -> Self {
135 Self {
136 mocks: Arc::new(RwLock::new(Vec::new())),
137 requests: Arc::new(RwLock::new(Vec::new())),
138 }
139 }
140
141 pub fn add_mock<F>(&mut self, pattern: &str, handler: F)
143 where
144 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
145 {
146 let mocks = self.mocks.clone();
147 let pattern = pattern.to_string();
148 let handler = Arc::new(handler);
149 tokio::task::block_in_place(|| {
150 let rt = tokio::runtime::Handle::try_current();
151 if let Ok(rt) = rt {
152 rt.block_on(async {
153 let mut mocks = mocks.write().await;
154 mocks.push(MockHandler { pattern, handler });
155 });
156 }
157 });
158 }
159
160 pub async fn execute(&self, request: MockRequest) -> MockResponse {
162 {
164 let mut requests = self.requests.write().await;
165 requests.push(RecordedRequest {
166 method: request.method.clone(),
167 url: request.url.clone(),
168 headers: request.headers.clone(),
169 body: request.body.clone(),
170 });
171 }
172
173 let mocks = self.mocks.read().await;
175 for mock in mocks.iter() {
176 if glob_matches(&mock.pattern, &request.url)
177 || glob_matches(&mock.pattern, &request.path)
178 {
179 return (mock.handler)(&request);
180 }
181 }
182
183 MockResponse::error(500, &format!("No mock found for {}", request.url))
185 }
186
187 pub async fn requests(&self) -> Vec<RecordedRequest> {
189 self.requests.read().await.clone()
190 }
191
192 pub async fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
194 self.requests
195 .read()
196 .await
197 .iter()
198 .filter(|r| glob_matches(pattern, &r.url))
199 .cloned()
200 .collect()
201 }
202
203 pub async fn clear_requests(&self) {
205 self.requests.write().await.clear();
206 }
207
208 pub async fn clear_mocks(&self) {
210 self.mocks.write().await.clear();
211 }
212}
213
214impl Default for MockHttp {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220type MockHandlerFn = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
222
223pub struct MockHttpBuilder {
225 mocks: Vec<(String, MockHandlerFn)>,
226}
227
228impl MockHttpBuilder {
229 pub fn new() -> Self {
231 Self { mocks: Vec::new() }
232 }
233
234 pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
236 where
237 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
238 {
239 self.mocks.push((pattern.to_string(), Box::new(handler)));
240 self
241 }
242
243 pub fn build(self) -> MockHttp {
245 MockHttp::new()
247 }
248}
249
250impl Default for MockHttpBuilder {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[cfg(test)]
257#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_mock_response_json() {
263 let response = MockResponse::json(serde_json::json!({"id": 123}));
264 assert_eq!(response.status, 200);
265 assert_eq!(response.body["id"], 123);
266 }
267
268 #[test]
269 fn test_mock_response_error() {
270 let response = MockResponse::error(404, "Not found");
271 assert_eq!(response.status, 404);
272 assert_eq!(response.body["error"], "Not found");
273 }
274
275 #[test]
276 fn test_mock_response_internal_error() {
277 let response = MockResponse::internal_error("Server error");
278 assert_eq!(response.status, 500);
279 }
280
281 #[test]
282 fn test_mock_response_not_found() {
283 let response = MockResponse::not_found("Resource not found");
284 assert_eq!(response.status, 404);
285 }
286
287 #[test]
288 fn test_mock_response_unauthorized() {
289 let response = MockResponse::unauthorized("Invalid token");
290 assert_eq!(response.status, 401);
291 }
292
293 #[tokio::test]
294 async fn test_mock_http_no_handler() {
295 let mock = MockHttp::new();
296 let request = MockRequest {
297 method: "GET".to_string(),
298 path: "/test".to_string(),
299 url: "https://example.com/test".to_string(),
300 headers: HashMap::new(),
301 body: serde_json::Value::Null,
302 };
303
304 let response = mock.execute(request).await;
305 assert_eq!(response.status, 500);
306 }
307
308 #[tokio::test]
309 async fn test_mock_http_records_requests() {
310 let mock = MockHttp::new();
311 let request = MockRequest {
312 method: "POST".to_string(),
313 path: "/api/users".to_string(),
314 url: "https://api.example.com/users".to_string(),
315 headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
316 body: serde_json::json!({"name": "Test"}),
317 };
318
319 let _ = mock.execute(request).await;
320
321 let requests = mock.requests().await;
322 assert_eq!(requests.len(), 1);
323 assert_eq!(requests[0].method, "POST");
324 assert_eq!(requests[0].body["name"], "Test");
325 }
326
327 #[tokio::test]
328 async fn test_mock_http_clear_requests() {
329 let mock = MockHttp::new();
330 let request = MockRequest {
331 method: "GET".to_string(),
332 path: "/test".to_string(),
333 url: "https://example.com/test".to_string(),
334 headers: HashMap::new(),
335 body: serde_json::Value::Null,
336 };
337
338 let _ = mock.execute(request).await;
339 assert_eq!(mock.requests().await.len(), 1);
340
341 mock.clear_requests().await;
342 assert_eq!(mock.requests().await.len(), 0);
343 }
344}