1#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
4
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7
8use serde::Serialize;
9
10#[derive(Clone)]
19pub struct MockHttp {
20 mocks: Arc<RwLock<Vec<MockHandler>>>,
21 requests: Arc<RwLock<Vec<RecordedRequest>>>,
22}
23
24pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
25
26struct MockHandler {
27 pattern: String,
28 handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
29}
30
31#[derive(Debug, Clone)]
33pub struct RecordedRequest {
34 pub method: String,
35 pub url: String,
36 pub headers: HashMap<String, String>,
37 pub body: serde_json::Value,
38}
39
40#[derive(Debug, Clone)]
42pub struct MockRequest {
43 pub method: String,
44 pub path: String,
45 pub url: String,
46 pub headers: HashMap<String, String>,
47 pub body: serde_json::Value,
48}
49
50#[derive(Debug, Clone)]
52pub struct MockResponse {
53 pub status: u16,
54 pub headers: HashMap<String, String>,
55 pub body: serde_json::Value,
56}
57
58impl MockResponse {
59 pub fn json<T: Serialize>(body: T) -> Self {
60 Self {
61 status: 200,
62 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
63 body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
64 }
65 }
66
67 pub fn error(status: u16, message: &str) -> Self {
68 Self {
69 status,
70 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
71 body: serde_json::json!({ "error": message }),
72 }
73 }
74
75 pub fn internal_error(message: &str) -> Self {
76 Self::error(500, message)
77 }
78
79 pub fn not_found(message: &str) -> Self {
80 Self::error(404, message)
81 }
82
83 pub fn unauthorized(message: &str) -> Self {
84 Self::error(401, message)
85 }
86
87 pub fn ok() -> Self {
88 Self::json(serde_json::json!({}))
89 }
90}
91
92impl MockHttp {
93 pub fn new() -> Self {
94 Self {
95 mocks: Arc::new(RwLock::new(Vec::new())),
96 requests: Arc::new(RwLock::new(Vec::new())),
97 }
98 }
99
100 pub fn builder() -> MockHttpBuilder {
101 MockHttpBuilder::new()
102 }
103
104 pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
109 where
110 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
111 {
112 let mut mocks = self.mocks.write().unwrap();
113 mocks.push(MockHandler {
114 pattern: pattern.to_string(),
115 handler: Arc::new(handler),
116 });
117 }
118
119 pub fn mock_exact<F>(&self, url: &str, handler: F)
121 where
122 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
123 {
124 self.add_mock_sync(url, handler);
125 }
126
127 pub fn mock_glob<F>(&self, pattern: &str, handler: F)
131 where
132 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
133 {
134 self.add_mock_sync(pattern, handler);
135 }
136
137 pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
138 let mut mocks = self.mocks.write().unwrap();
139 mocks.push(MockHandler {
140 pattern: pattern.to_string(),
141 handler: Arc::from(handler),
142 });
143 }
144
145 pub async fn execute(&self, request: MockRequest) -> MockResponse {
146 {
147 let mut requests = self.requests.write().unwrap();
148 requests.push(RecordedRequest {
149 method: request.method.clone(),
150 url: request.url.clone(),
151 headers: request.headers.clone(),
152 body: request.body.clone(),
153 });
154 }
155
156 let mocks = self.mocks.read().unwrap();
157 for mock in mocks.iter() {
158 if self.matches_pattern(&request.url, &mock.pattern)
159 || self.matches_pattern(&request.path, &mock.pattern)
160 {
161 return (mock.handler)(&request);
162 }
163 }
164
165 MockResponse::error(500, &format!("No mock found for {}", request.url))
166 }
167
168 fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
169 let pattern_parts: Vec<&str> = pattern.split('*').collect();
170 if pattern_parts.len() == 1 {
171 return url == pattern;
172 }
173
174 let mut remaining = url;
175 for (i, part) in pattern_parts.iter().enumerate() {
176 if part.is_empty() {
177 continue;
178 }
179
180 if i == 0 {
181 if !remaining.starts_with(part) {
182 return false;
183 }
184 remaining = &remaining[part.len()..];
185 } else if i == pattern_parts.len() - 1 {
186 if !remaining.ends_with(part) {
187 return false;
188 }
189 } else if let Some(pos) = remaining.find(part) {
190 remaining = &remaining[pos + part.len()..];
191 } else {
192 return false;
193 }
194 }
195
196 true
197 }
198
199 pub fn requests(&self) -> Vec<RecordedRequest> {
200 self.requests.read().unwrap().clone()
201 }
202
203 pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
204 self.requests.read().unwrap().clone()
205 }
206
207 pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
208 self.requests
209 .read()
210 .unwrap()
211 .iter()
212 .filter(|r| self.matches_pattern(&r.url, pattern))
213 .cloned()
214 .collect()
215 }
216
217 pub fn clear_requests(&self) {
218 self.requests.write().unwrap().clear();
219 }
220
221 pub fn clear_mocks(&self) {
222 self.mocks.write().unwrap().clear();
223 }
224
225 pub fn assert_called(&self, pattern: &str) {
227 let requests = self.requests_blocking();
228 let matching = requests
229 .iter()
230 .filter(|r| self.matches_pattern(&r.url, pattern))
231 .count();
232 assert!(
233 matching > 0,
234 "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
235 pattern,
236 requests.iter().map(|r| &r.url).collect::<Vec<_>>()
237 );
238 }
239
240 pub fn assert_called_times(&self, pattern: &str, expected: usize) {
242 let requests = self.requests_blocking();
243 let matching = requests
244 .iter()
245 .filter(|r| self.matches_pattern(&r.url, pattern))
246 .count();
247 assert_eq!(
248 matching, expected,
249 "Expected {} HTTP calls matching '{}', but found {}",
250 expected, pattern, matching
251 );
252 }
253
254 pub fn assert_not_called(&self, pattern: &str) {
256 let requests = self.requests_blocking();
257 let matching = requests
258 .iter()
259 .filter(|r| self.matches_pattern(&r.url, pattern))
260 .count();
261 assert_eq!(
262 matching, 0,
263 "Expected no HTTP calls matching '{}', but found {}",
264 pattern, matching
265 );
266 }
267
268 pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
270 where
271 F: Fn(&serde_json::Value) -> bool,
272 {
273 let requests = self.requests_blocking();
274 let matching = requests
275 .iter()
276 .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
277 assert!(
278 matching.count() > 0,
279 "Expected HTTP call matching '{}' with matching body, but none found",
280 pattern
281 );
282 }
283}
284
285impl Default for MockHttp {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291pub struct MockHttpBuilder {
292 mocks: Vec<(String, BoxedHandler)>,
293}
294
295impl MockHttpBuilder {
296 pub fn new() -> Self {
297 Self { mocks: Vec::new() }
298 }
299
300 pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
301 where
302 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
303 {
304 self.mocks.push((pattern.to_string(), Box::new(handler)));
305 self
306 }
307
308 pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
309 self,
310 pattern: &str,
311 response: T,
312 ) -> Self {
313 self.mock(pattern, move |_| MockResponse::json(response.clone()))
314 }
315
316 pub fn build(self) -> MockHttp {
317 let mut mock = MockHttp::new();
318 for (pattern, handler) in self.mocks {
319 mock.add_mock_boxed(&pattern, handler);
320 }
321 mock
322 }
323}
324
325impl Default for MockHttpBuilder {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_mock_response_json() {
337 let response = MockResponse::json(serde_json::json!({"id": 123}));
338 assert_eq!(response.status, 200);
339 assert_eq!(response.body["id"], 123);
340 }
341
342 #[test]
343 fn test_mock_response_error() {
344 let response = MockResponse::error(404, "Not found");
345 assert_eq!(response.status, 404);
346 assert_eq!(response.body["error"], "Not found");
347 }
348
349 #[test]
350 fn test_pattern_matching() {
351 let mock = MockHttp::new();
352
353 assert!(mock.matches_pattern(
354 "https://api.example.com/users",
355 "https://api.example.com/users"
356 ));
357
358 assert!(mock.matches_pattern(
359 "https://api.example.com/users/123",
360 "https://api.example.com/*"
361 ));
362
363 assert!(mock.matches_pattern(
364 "https://api.example.com/v2/users",
365 "https://api.example.com/*/users"
366 ));
367
368 assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
369 }
370
371 #[tokio::test]
372 async fn test_mock_execution() {
373 let mock = MockHttp::new();
374 mock.add_mock_sync("https://api.example.com/*", |_| {
375 MockResponse::json(serde_json::json!({"status": "ok"}))
376 });
377
378 let request = MockRequest {
379 method: "GET".to_string(),
380 path: "/users".to_string(),
381 url: "https://api.example.com/users".to_string(),
382 headers: HashMap::new(),
383 body: serde_json::Value::Null,
384 };
385
386 let response = mock.execute(request).await;
387 assert_eq!(response.status, 200);
388 assert_eq!(response.body["status"], "ok");
389 }
390
391 #[tokio::test]
392 async fn test_request_recording() {
393 let mock = MockHttp::new();
394 mock.add_mock_sync("*", |_| MockResponse::ok());
395
396 let request = MockRequest {
397 method: "POST".to_string(),
398 path: "/api/users".to_string(),
399 url: "https://api.example.com/users".to_string(),
400 headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
401 body: serde_json::json!({"name": "Test"}),
402 };
403
404 let _ = mock.execute(request).await;
405
406 let requests = mock.requests();
407 assert_eq!(requests.len(), 1);
408 assert_eq!(requests[0].method, "POST");
409 assert_eq!(requests[0].body["name"], "Test");
410 }
411
412 #[tokio::test]
413 async fn test_assert_called() {
414 let mock = MockHttp::new();
415 mock.add_mock_sync("*", |_| MockResponse::ok());
416
417 let request = MockRequest {
418 method: "GET".to_string(),
419 path: "/test".to_string(),
420 url: "https://api.example.com/test".to_string(),
421 headers: HashMap::new(),
422 body: serde_json::Value::Null,
423 };
424
425 let _ = mock.execute(request).await;
426
427 mock.assert_called("https://api.example.com/*");
428 mock.assert_called_times("https://api.example.com/*", 1);
429 mock.assert_not_called("https://other.com/*");
430 }
431
432 #[test]
433 fn test_builder() {
434 let mock = MockHttpBuilder::new()
435 .mock("https://api.example.com/*", |_| MockResponse::ok())
436 .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
437 .build();
438
439 assert_eq!(mock.mocks.read().unwrap().len(), 2);
440 }
441
442 fn req(method: &str, url: &str, path: &str) -> MockRequest {
443 MockRequest {
444 method: method.to_string(),
445 path: path.to_string(),
446 url: url.to_string(),
447 headers: HashMap::new(),
448 body: serde_json::Value::Null,
449 }
450 }
451
452 #[test]
453 fn response_status_helpers_use_documented_codes() {
454 assert_eq!(MockResponse::internal_error("boom").status, 500);
455 assert_eq!(MockResponse::not_found("nope").status, 404);
456 assert_eq!(MockResponse::unauthorized("nope").status, 401);
457 assert_eq!(MockResponse::ok().status, 200);
458
459 assert_eq!(MockResponse::ok().body, serde_json::json!({}));
462 }
463
464 #[test]
465 fn response_json_sets_content_type_header() {
466 let r = MockResponse::json(serde_json::json!({"ok": true}));
467 assert_eq!(
468 r.headers.get("content-type"),
469 Some(&"application/json".to_string())
470 );
471 }
472
473 #[test]
474 fn pattern_matcher_handles_leading_and_double_wildcards() {
475 let m = MockHttp::new();
476 assert!(m.matches_pattern("https://api.example.com/v1/users", "*/users"));
478 assert!(!m.matches_pattern("https://api.example.com/v1/posts", "*/users"));
479
480 assert!(m.matches_pattern("anything", "*"));
482 assert!(m.matches_pattern("", "*"));
483 }
484
485 #[test]
486 fn pattern_matcher_rejects_exact_pattern_with_extra_suffix() {
487 let m = MockHttp::new();
488 assert!(!m.matches_pattern(
489 "https://api.example.com/users/extra",
490 "https://api.example.com/users"
491 ));
492 }
493
494 #[tokio::test]
495 async fn execute_falls_back_to_500_when_no_mock_matches() {
496 let mock = MockHttp::new();
497 let r = mock.execute(req("GET", "https://nowhere/", "/")).await;
498 assert_eq!(r.status, 500);
499 assert!(
500 r.body["error"]
501 .as_str()
502 .unwrap_or_default()
503 .contains("No mock found"),
504 "fallback should explain the failure, got {:?}",
505 r.body
506 );
507 }
508
509 #[tokio::test]
510 async fn execute_records_request_even_when_no_mock_matches() {
511 let mock = MockHttp::new();
514 let _ = mock.execute(req("DELETE", "https://nowhere/x", "/x")).await;
515 let recorded = mock.requests();
516 assert_eq!(recorded.len(), 1);
517 assert_eq!(recorded[0].method, "DELETE");
518 assert_eq!(recorded[0].url, "https://nowhere/x");
519 }
520
521 #[tokio::test]
522 async fn execute_matches_against_path_when_url_misses() {
523 let mock = MockHttp::new();
525 mock.add_mock_sync("/health", |_| MockResponse::ok());
526 let r = mock
527 .execute(req("GET", "https://internal.svc:8080/health", "/health"))
528 .await;
529 assert_eq!(r.status, 200);
530 }
531
532 #[tokio::test]
533 async fn execute_uses_first_registered_mock_on_overlapping_patterns() {
534 let mock = MockHttp::new();
535 mock.add_mock_sync("https://api.example.com/*", |_| {
536 MockResponse::json(serde_json::json!({"hit": "first"}))
537 });
538 mock.add_mock_sync("https://api.example.com/users", |_| {
539 MockResponse::json(serde_json::json!({"hit": "second"}))
540 });
541
542 let r = mock
543 .execute(req("GET", "https://api.example.com/users", "/users"))
544 .await;
545 assert_eq!(r.body["hit"], "first");
546 }
547
548 #[tokio::test]
549 async fn requests_to_filters_by_pattern() {
550 let mock = MockHttp::new();
551 mock.add_mock_sync("*", |_| MockResponse::ok());
552
553 let _ = mock
554 .execute(req("GET", "https://api.example.com/a", "/a"))
555 .await;
556 let _ = mock.execute(req("GET", "https://other.com/b", "/b")).await;
557 let _ = mock
558 .execute(req("GET", "https://api.example.com/c", "/c"))
559 .await;
560
561 let api_calls = mock.requests_to("https://api.example.com/*");
562 assert_eq!(api_calls.len(), 2);
563 assert!(api_calls.iter().all(|r| r.url.contains("api.example.com")));
564 }
565
566 #[tokio::test]
567 async fn clear_requests_and_clear_mocks_independently_reset_state() {
568 let mock = MockHttp::new();
569 mock.add_mock_sync("*", |_| MockResponse::ok());
570 let _ = mock.execute(req("GET", "https://x/", "/")).await;
571 assert_eq!(mock.requests().len(), 1);
572
573 mock.clear_requests();
574 assert!(mock.requests().is_empty());
575 let r = mock.execute(req("GET", "https://x/", "/")).await;
577 assert_eq!(r.status, 200);
578
579 mock.clear_mocks();
580 let r = mock.execute(req("GET", "https://x/", "/")).await;
581 assert_eq!(r.status, 500, "after clear_mocks, fallback should hit");
582 }
583
584 #[tokio::test]
585 async fn assert_called_with_body_runs_predicate_against_recorded_body() {
586 let mock = MockHttp::new();
587 mock.add_mock_sync("*", |_| MockResponse::ok());
588
589 let mut request = req("POST", "https://api/upload", "/upload");
590 request.body = serde_json::json!({"size": 42});
591 let _ = mock.execute(request).await;
592
593 mock.assert_called_with_body("https://api/*", |body| body["size"] == 42);
595 }
596
597 #[test]
598 fn defaults_match_new() {
599 let m1 = MockHttp::default();
602 assert!(m1.requests().is_empty());
603 let b1 = MockHttpBuilder::default();
604 let m2 = b1.build();
605 assert!(m2.requests().is_empty());
606 }
607}