1#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
8
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12use serde::Serialize;
13
14#[derive(Clone)]
28pub struct MockHttp {
29 mocks: Arc<RwLock<Vec<MockHandler>>>,
30 requests: Arc<RwLock<Vec<RecordedRequest>>>,
31}
32
33pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
35
36struct MockHandler {
38 pattern: String,
39 handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
40}
41
42#[derive(Debug, Clone)]
44pub struct RecordedRequest {
45 pub method: String,
47 pub url: String,
49 pub headers: HashMap<String, String>,
51 pub body: serde_json::Value,
53}
54
55#[derive(Debug, Clone)]
57pub struct MockRequest {
58 pub method: String,
60 pub path: String,
62 pub url: String,
64 pub headers: HashMap<String, String>,
66 pub body: serde_json::Value,
68}
69
70#[derive(Debug, Clone)]
72pub struct MockResponse {
73 pub status: u16,
75 pub headers: HashMap<String, String>,
77 pub body: serde_json::Value,
79}
80
81impl MockResponse {
82 pub fn json<T: Serialize>(body: T) -> Self {
84 Self {
85 status: 200,
86 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
87 body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
88 }
89 }
90
91 pub fn error(status: u16, message: &str) -> Self {
93 Self {
94 status,
95 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
96 body: serde_json::json!({ "error": message }),
97 }
98 }
99
100 pub fn internal_error(message: &str) -> Self {
102 Self::error(500, message)
103 }
104
105 pub fn not_found(message: &str) -> Self {
107 Self::error(404, message)
108 }
109
110 pub fn unauthorized(message: &str) -> Self {
112 Self::error(401, message)
113 }
114
115 pub fn ok() -> Self {
117 Self::json(serde_json::json!({}))
118 }
119}
120
121impl MockHttp {
122 pub fn new() -> Self {
124 Self {
125 mocks: Arc::new(RwLock::new(Vec::new())),
126 requests: Arc::new(RwLock::new(Vec::new())),
127 }
128 }
129
130 pub fn builder() -> MockHttpBuilder {
132 MockHttpBuilder::new()
133 }
134
135 pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
137 where
138 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
139 {
140 let mut mocks = self.mocks.write().unwrap();
141 mocks.push(MockHandler {
142 pattern: pattern.to_string(),
143 handler: Arc::new(handler),
144 });
145 }
146
147 pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
149 let mut mocks = self.mocks.write().unwrap();
150 mocks.push(MockHandler {
151 pattern: pattern.to_string(),
152 handler: Arc::from(handler),
153 });
154 }
155
156 pub async fn execute(&self, request: MockRequest) -> MockResponse {
158 {
160 let mut requests = self.requests.write().unwrap();
161 requests.push(RecordedRequest {
162 method: request.method.clone(),
163 url: request.url.clone(),
164 headers: request.headers.clone(),
165 body: request.body.clone(),
166 });
167 }
168
169 let mocks = self.mocks.read().unwrap();
171 for mock in mocks.iter() {
172 if self.matches_pattern(&request.url, &mock.pattern)
173 || self.matches_pattern(&request.path, &mock.pattern)
174 {
175 return (mock.handler)(&request);
176 }
177 }
178
179 MockResponse::error(500, &format!("No mock found for {}", request.url))
181 }
182
183 fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
185 let pattern_parts: Vec<&str> = pattern.split('*').collect();
187 if pattern_parts.len() == 1 {
188 return url == pattern;
190 }
191
192 let mut remaining = url;
193 for (i, part) in pattern_parts.iter().enumerate() {
194 if part.is_empty() {
195 continue;
196 }
197
198 if i == 0 {
199 if !remaining.starts_with(part) {
201 return false;
202 }
203 remaining = &remaining[part.len()..];
204 } else if i == pattern_parts.len() - 1 {
205 if !remaining.ends_with(part) {
207 return false;
208 }
209 } else {
210 if let Some(pos) = remaining.find(part) {
212 remaining = &remaining[pos + part.len()..];
213 } else {
214 return false;
215 }
216 }
217 }
218
219 true
220 }
221
222 pub fn requests(&self) -> Vec<RecordedRequest> {
224 self.requests.read().unwrap().clone()
225 }
226
227 pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
229 self.requests.read().unwrap().clone()
230 }
231
232 pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
234 self.requests
235 .read()
236 .unwrap()
237 .iter()
238 .filter(|r| self.matches_pattern(&r.url, pattern))
239 .cloned()
240 .collect()
241 }
242
243 pub fn clear_requests(&self) {
245 self.requests.write().unwrap().clear();
246 }
247
248 pub fn clear_mocks(&self) {
250 self.mocks.write().unwrap().clear();
251 }
252
253 pub fn assert_called(&self, pattern: &str) {
255 let requests = self.requests_blocking();
256 let matching = requests
257 .iter()
258 .filter(|r| self.matches_pattern(&r.url, pattern))
259 .count();
260 assert!(
261 matching > 0,
262 "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
263 pattern,
264 requests.iter().map(|r| &r.url).collect::<Vec<_>>()
265 );
266 }
267
268 pub fn assert_called_times(&self, pattern: &str, expected: usize) {
270 let requests = self.requests_blocking();
271 let matching = requests
272 .iter()
273 .filter(|r| self.matches_pattern(&r.url, pattern))
274 .count();
275 assert_eq!(
276 matching, expected,
277 "Expected {} HTTP calls matching '{}', but found {}",
278 expected, pattern, matching
279 );
280 }
281
282 pub fn assert_not_called(&self, pattern: &str) {
284 let requests = self.requests_blocking();
285 let matching = requests
286 .iter()
287 .filter(|r| self.matches_pattern(&r.url, pattern))
288 .count();
289 assert_eq!(
290 matching, 0,
291 "Expected no HTTP calls matching '{}', but found {}",
292 pattern, matching
293 );
294 }
295
296 pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
298 where
299 F: Fn(&serde_json::Value) -> bool,
300 {
301 let requests = self.requests_blocking();
302 let matching = requests
303 .iter()
304 .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
305 assert!(
306 matching.count() > 0,
307 "Expected HTTP call matching '{}' with matching body, but none found",
308 pattern
309 );
310 }
311}
312
313impl Default for MockHttp {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319pub struct MockHttpBuilder {
321 mocks: Vec<(String, BoxedHandler)>,
322}
323
324impl MockHttpBuilder {
325 pub fn new() -> Self {
327 Self { mocks: Vec::new() }
328 }
329
330 pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
332 where
333 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
334 {
335 self.mocks.push((pattern.to_string(), Box::new(handler)));
336 self
337 }
338
339 pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
341 self,
342 pattern: &str,
343 response: T,
344 ) -> Self {
345 self.mock(pattern, move |_| MockResponse::json(response.clone()))
346 }
347
348 pub fn build(self) -> MockHttp {
350 let mut mock = MockHttp::new();
351 for (pattern, handler) in self.mocks {
352 mock.add_mock_boxed(&pattern, handler);
353 }
354 mock
355 }
356}
357
358impl Default for MockHttpBuilder {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_mock_response_json() {
370 let response = MockResponse::json(serde_json::json!({"id": 123}));
371 assert_eq!(response.status, 200);
372 assert_eq!(response.body["id"], 123);
373 }
374
375 #[test]
376 fn test_mock_response_error() {
377 let response = MockResponse::error(404, "Not found");
378 assert_eq!(response.status, 404);
379 assert_eq!(response.body["error"], "Not found");
380 }
381
382 #[test]
383 fn test_pattern_matching() {
384 let mock = MockHttp::new();
385
386 assert!(mock.matches_pattern(
388 "https://api.example.com/users",
389 "https://api.example.com/users"
390 ));
391
392 assert!(mock.matches_pattern(
394 "https://api.example.com/users/123",
395 "https://api.example.com/*"
396 ));
397
398 assert!(mock.matches_pattern(
400 "https://api.example.com/v2/users",
401 "https://api.example.com/*/users"
402 ));
403
404 assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
406 }
407
408 #[tokio::test]
409 async fn test_mock_execution() {
410 let mock = MockHttp::new();
411 mock.add_mock_sync("https://api.example.com/*", |_| {
412 MockResponse::json(serde_json::json!({"status": "ok"}))
413 });
414
415 let request = MockRequest {
416 method: "GET".to_string(),
417 path: "/users".to_string(),
418 url: "https://api.example.com/users".to_string(),
419 headers: HashMap::new(),
420 body: serde_json::Value::Null,
421 };
422
423 let response = mock.execute(request).await;
424 assert_eq!(response.status, 200);
425 assert_eq!(response.body["status"], "ok");
426 }
427
428 #[tokio::test]
429 async fn test_request_recording() {
430 let mock = MockHttp::new();
431 mock.add_mock_sync("*", |_| MockResponse::ok());
432
433 let request = MockRequest {
434 method: "POST".to_string(),
435 path: "/api/users".to_string(),
436 url: "https://api.example.com/users".to_string(),
437 headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
438 body: serde_json::json!({"name": "Test"}),
439 };
440
441 let _ = mock.execute(request).await;
442
443 let requests = mock.requests();
444 assert_eq!(requests.len(), 1);
445 assert_eq!(requests[0].method, "POST");
446 assert_eq!(requests[0].body["name"], "Test");
447 }
448
449 #[tokio::test]
450 async fn test_assert_called() {
451 let mock = MockHttp::new();
452 mock.add_mock_sync("*", |_| MockResponse::ok());
453
454 let request = MockRequest {
455 method: "GET".to_string(),
456 path: "/test".to_string(),
457 url: "https://api.example.com/test".to_string(),
458 headers: HashMap::new(),
459 body: serde_json::Value::Null,
460 };
461
462 let _ = mock.execute(request).await;
463
464 mock.assert_called("https://api.example.com/*");
465 mock.assert_called_times("https://api.example.com/*", 1);
466 mock.assert_not_called("https://other.com/*");
467 }
468
469 #[test]
470 fn test_builder() {
471 let mock = MockHttpBuilder::new()
472 .mock("https://api.example.com/*", |_| MockResponse::ok())
473 .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
474 .build();
475
476 assert_eq!(mock.mocks.read().unwrap().len(), 2);
477 }
478}