1use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10use serde::Serialize;
11
12#[derive(Clone)]
26pub struct MockHttp {
27 mocks: Arc<RwLock<Vec<MockHandler>>>,
28 requests: Arc<RwLock<Vec<RecordedRequest>>>,
29}
30
31pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
33
34struct MockHandler {
36 pattern: String,
37 handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
38}
39
40#[derive(Debug, Clone)]
42pub struct RecordedRequest {
43 pub method: String,
45 pub url: String,
47 pub headers: HashMap<String, String>,
49 pub body: serde_json::Value,
51}
52
53#[derive(Debug, Clone)]
55pub struct MockRequest {
56 pub method: String,
58 pub path: String,
60 pub url: String,
62 pub headers: HashMap<String, String>,
64 pub body: serde_json::Value,
66}
67
68#[derive(Debug, Clone)]
70pub struct MockResponse {
71 pub status: u16,
73 pub headers: HashMap<String, String>,
75 pub body: serde_json::Value,
77}
78
79impl MockResponse {
80 pub fn json<T: Serialize>(body: T) -> Self {
82 Self {
83 status: 200,
84 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
85 body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
86 }
87 }
88
89 pub fn error(status: u16, message: &str) -> Self {
91 Self {
92 status,
93 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
94 body: serde_json::json!({ "error": message }),
95 }
96 }
97
98 pub fn internal_error(message: &str) -> Self {
100 Self::error(500, message)
101 }
102
103 pub fn not_found(message: &str) -> Self {
105 Self::error(404, message)
106 }
107
108 pub fn unauthorized(message: &str) -> Self {
110 Self::error(401, message)
111 }
112
113 pub fn ok() -> Self {
115 Self::json(serde_json::json!({}))
116 }
117}
118
119impl MockHttp {
120 pub fn new() -> Self {
122 Self {
123 mocks: Arc::new(RwLock::new(Vec::new())),
124 requests: Arc::new(RwLock::new(Vec::new())),
125 }
126 }
127
128 pub fn builder() -> MockHttpBuilder {
130 MockHttpBuilder::new()
131 }
132
133 pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
135 where
136 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
137 {
138 let mut mocks = self.mocks.write().unwrap();
139 mocks.push(MockHandler {
140 pattern: pattern.to_string(),
141 handler: Arc::new(handler),
142 });
143 }
144
145 pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
147 let mut mocks = self.mocks.write().unwrap();
148 mocks.push(MockHandler {
149 pattern: pattern.to_string(),
150 handler: Arc::from(handler),
151 });
152 }
153
154 pub async fn execute(&self, request: MockRequest) -> MockResponse {
156 {
158 let mut requests = self.requests.write().unwrap();
159 requests.push(RecordedRequest {
160 method: request.method.clone(),
161 url: request.url.clone(),
162 headers: request.headers.clone(),
163 body: request.body.clone(),
164 });
165 }
166
167 let mocks = self.mocks.read().unwrap();
169 for mock in mocks.iter() {
170 if self.matches_pattern(&request.url, &mock.pattern)
171 || self.matches_pattern(&request.path, &mock.pattern)
172 {
173 return (mock.handler)(&request);
174 }
175 }
176
177 MockResponse::error(500, &format!("No mock found for {}", request.url))
179 }
180
181 fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
183 let pattern_parts: Vec<&str> = pattern.split('*').collect();
185 if pattern_parts.len() == 1 {
186 return url == pattern;
188 }
189
190 let mut remaining = url;
191 for (i, part) in pattern_parts.iter().enumerate() {
192 if part.is_empty() {
193 continue;
194 }
195
196 if i == 0 {
197 if !remaining.starts_with(part) {
199 return false;
200 }
201 remaining = &remaining[part.len()..];
202 } else if i == pattern_parts.len() - 1 {
203 if !remaining.ends_with(part) {
205 return false;
206 }
207 } else {
208 if let Some(pos) = remaining.find(part) {
210 remaining = &remaining[pos + part.len()..];
211 } else {
212 return false;
213 }
214 }
215 }
216
217 true
218 }
219
220 pub fn requests(&self) -> Vec<RecordedRequest> {
222 self.requests.read().unwrap().clone()
223 }
224
225 pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
227 self.requests.read().unwrap().clone()
228 }
229
230 pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
232 self.requests
233 .read()
234 .unwrap()
235 .iter()
236 .filter(|r| self.matches_pattern(&r.url, pattern))
237 .cloned()
238 .collect()
239 }
240
241 pub fn clear_requests(&self) {
243 self.requests.write().unwrap().clear();
244 }
245
246 pub fn clear_mocks(&self) {
248 self.mocks.write().unwrap().clear();
249 }
250
251 pub fn assert_called(&self, pattern: &str) {
257 let requests = self.requests_blocking();
258 let matching = requests
259 .iter()
260 .filter(|r| self.matches_pattern(&r.url, pattern))
261 .count();
262 assert!(
263 matching > 0,
264 "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
265 pattern,
266 requests.iter().map(|r| &r.url).collect::<Vec<_>>()
267 );
268 }
269
270 pub fn assert_called_times(&self, pattern: &str, expected: usize) {
272 let requests = self.requests_blocking();
273 let matching = requests
274 .iter()
275 .filter(|r| self.matches_pattern(&r.url, pattern))
276 .count();
277 assert_eq!(
278 matching, expected,
279 "Expected {} HTTP calls matching '{}', but found {}",
280 expected, pattern, matching
281 );
282 }
283
284 pub fn assert_not_called(&self, pattern: &str) {
286 let requests = self.requests_blocking();
287 let matching = requests
288 .iter()
289 .filter(|r| self.matches_pattern(&r.url, pattern))
290 .count();
291 assert_eq!(
292 matching, 0,
293 "Expected no HTTP calls matching '{}', but found {}",
294 pattern, matching
295 );
296 }
297
298 pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
300 where
301 F: Fn(&serde_json::Value) -> bool,
302 {
303 let requests = self.requests_blocking();
304 let matching = requests
305 .iter()
306 .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
307 assert!(
308 matching.count() > 0,
309 "Expected HTTP call matching '{}' with matching body, but none found",
310 pattern
311 );
312 }
313}
314
315impl Default for MockHttp {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321pub struct MockHttpBuilder {
323 mocks: Vec<(String, BoxedHandler)>,
324}
325
326impl MockHttpBuilder {
327 pub fn new() -> Self {
329 Self { mocks: Vec::new() }
330 }
331
332 pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
334 where
335 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
336 {
337 self.mocks.push((pattern.to_string(), Box::new(handler)));
338 self
339 }
340
341 pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
343 self,
344 pattern: &str,
345 response: T,
346 ) -> Self {
347 self.mock(pattern, move |_| MockResponse::json(response.clone()))
348 }
349
350 pub fn build(self) -> MockHttp {
352 let mut mock = MockHttp::new();
353 for (pattern, handler) in self.mocks {
354 mock.add_mock_boxed(&pattern, handler);
355 }
356 mock
357 }
358}
359
360impl Default for MockHttpBuilder {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_mock_response_json() {
372 let response = MockResponse::json(serde_json::json!({"id": 123}));
373 assert_eq!(response.status, 200);
374 assert_eq!(response.body["id"], 123);
375 }
376
377 #[test]
378 fn test_mock_response_error() {
379 let response = MockResponse::error(404, "Not found");
380 assert_eq!(response.status, 404);
381 assert_eq!(response.body["error"], "Not found");
382 }
383
384 #[test]
385 fn test_pattern_matching() {
386 let mock = MockHttp::new();
387
388 assert!(mock.matches_pattern(
390 "https://api.example.com/users",
391 "https://api.example.com/users"
392 ));
393
394 assert!(mock.matches_pattern(
396 "https://api.example.com/users/123",
397 "https://api.example.com/*"
398 ));
399
400 assert!(mock.matches_pattern(
402 "https://api.example.com/v2/users",
403 "https://api.example.com/*/users"
404 ));
405
406 assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
408 }
409
410 #[tokio::test]
411 async fn test_mock_execution() {
412 let mock = MockHttp::new();
413 mock.add_mock_sync("https://api.example.com/*", |_| {
414 MockResponse::json(serde_json::json!({"status": "ok"}))
415 });
416
417 let request = MockRequest {
418 method: "GET".to_string(),
419 path: "/users".to_string(),
420 url: "https://api.example.com/users".to_string(),
421 headers: HashMap::new(),
422 body: serde_json::Value::Null,
423 };
424
425 let response = mock.execute(request).await;
426 assert_eq!(response.status, 200);
427 assert_eq!(response.body["status"], "ok");
428 }
429
430 #[tokio::test]
431 async fn test_request_recording() {
432 let mock = MockHttp::new();
433 mock.add_mock_sync("*", |_| MockResponse::ok());
434
435 let request = MockRequest {
436 method: "POST".to_string(),
437 path: "/api/users".to_string(),
438 url: "https://api.example.com/users".to_string(),
439 headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
440 body: serde_json::json!({"name": "Test"}),
441 };
442
443 let _ = mock.execute(request).await;
444
445 let requests = mock.requests();
446 assert_eq!(requests.len(), 1);
447 assert_eq!(requests[0].method, "POST");
448 assert_eq!(requests[0].body["name"], "Test");
449 }
450
451 #[tokio::test]
452 async fn test_assert_called() {
453 let mock = MockHttp::new();
454 mock.add_mock_sync("*", |_| MockResponse::ok());
455
456 let request = MockRequest {
457 method: "GET".to_string(),
458 path: "/test".to_string(),
459 url: "https://api.example.com/test".to_string(),
460 headers: HashMap::new(),
461 body: serde_json::Value::Null,
462 };
463
464 let _ = mock.execute(request).await;
465
466 mock.assert_called("https://api.example.com/*");
467 mock.assert_called_times("https://api.example.com/*", 1);
468 mock.assert_not_called("https://other.com/*");
469 }
470
471 #[test]
472 fn test_builder() {
473 let mock = MockHttpBuilder::new()
474 .mock("https://api.example.com/*", |_| MockResponse::ok())
475 .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
476 .build();
477
478 assert_eq!(mock.mocks.read().unwrap().len(), 2);
479 }
480}