1use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10use serde::Serialize;
11
12#[derive(Clone)]
26pub struct MockHttp {
27 mocks: Arc<RwLock<Vec<MockHandler>>>,
28 requests: Arc<RwLock<Vec<RecordedRequest>>>,
29}
30
31pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
33
34struct MockHandler {
36 pattern: String,
37 handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
38}
39
40#[derive(Debug, Clone)]
42pub struct RecordedRequest {
43 pub method: String,
45 pub url: String,
47 pub headers: HashMap<String, String>,
49 pub body: serde_json::Value,
51}
52
53#[derive(Debug, Clone)]
55pub struct MockRequest {
56 pub method: String,
58 pub path: String,
60 pub url: String,
62 pub headers: HashMap<String, String>,
64 pub body: serde_json::Value,
66}
67
68#[derive(Debug, Clone)]
70pub struct MockResponse {
71 pub status: u16,
73 pub headers: HashMap<String, String>,
75 pub body: serde_json::Value,
77}
78
79impl MockResponse {
80 pub fn json<T: Serialize>(body: T) -> Self {
82 Self {
83 status: 200,
84 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
85 body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
86 }
87 }
88
89 pub fn error(status: u16, message: &str) -> Self {
91 Self {
92 status,
93 headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
94 body: serde_json::json!({ "error": message }),
95 }
96 }
97
98 pub fn internal_error(message: &str) -> Self {
100 Self::error(500, message)
101 }
102
103 pub fn not_found(message: &str) -> Self {
105 Self::error(404, message)
106 }
107
108 pub fn unauthorized(message: &str) -> Self {
110 Self::error(401, message)
111 }
112
113 pub fn ok() -> Self {
115 Self::json(serde_json::json!({}))
116 }
117}
118
119impl MockHttp {
120 pub fn new() -> Self {
122 Self {
123 mocks: Arc::new(RwLock::new(Vec::new())),
124 requests: Arc::new(RwLock::new(Vec::new())),
125 }
126 }
127
128 pub fn builder() -> MockHttpBuilder {
130 MockHttpBuilder::new()
131 }
132
133 pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
135 where
136 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
137 {
138 let mut mocks = self.mocks.write().unwrap();
139 mocks.push(MockHandler {
140 pattern: pattern.to_string(),
141 handler: Arc::new(handler),
142 });
143 }
144
145 pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
147 let mut mocks = self.mocks.write().unwrap();
148 mocks.push(MockHandler {
149 pattern: pattern.to_string(),
150 handler: Arc::from(handler),
151 });
152 }
153
154 pub async fn execute(&self, request: MockRequest) -> MockResponse {
156 {
158 let mut requests = self.requests.write().unwrap();
159 requests.push(RecordedRequest {
160 method: request.method.clone(),
161 url: request.url.clone(),
162 headers: request.headers.clone(),
163 body: request.body.clone(),
164 });
165 }
166
167 let mocks = self.mocks.read().unwrap();
169 for mock in mocks.iter() {
170 if self.matches_pattern(&request.url, &mock.pattern)
171 || self.matches_pattern(&request.path, &mock.pattern)
172 {
173 return (mock.handler)(&request);
174 }
175 }
176
177 MockResponse::error(500, &format!("No mock found for {}", request.url))
179 }
180
181 fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
183 let pattern_parts: Vec<&str> = pattern.split('*').collect();
185 if pattern_parts.len() == 1 {
186 return url == pattern;
188 }
189
190 let mut remaining = url;
191 for (i, part) in pattern_parts.iter().enumerate() {
192 if part.is_empty() {
193 continue;
194 }
195
196 if i == 0 {
197 if !remaining.starts_with(part) {
199 return false;
200 }
201 remaining = &remaining[part.len()..];
202 } else if i == pattern_parts.len() - 1 {
203 if !remaining.ends_with(part) {
205 return false;
206 }
207 } else {
208 if let Some(pos) = remaining.find(part) {
210 remaining = &remaining[pos + part.len()..];
211 } else {
212 return false;
213 }
214 }
215 }
216
217 true
218 }
219
220 pub fn requests(&self) -> Vec<RecordedRequest> {
222 self.requests.read().unwrap().clone()
223 }
224
225 pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
227 self.requests.read().unwrap().clone()
228 }
229
230 pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
232 self.requests
233 .read()
234 .unwrap()
235 .iter()
236 .filter(|r| self.matches_pattern(&r.url, pattern))
237 .cloned()
238 .collect()
239 }
240
241 pub fn clear_requests(&self) {
243 self.requests.write().unwrap().clear();
244 }
245
246 pub fn clear_mocks(&self) {
248 self.mocks.write().unwrap().clear();
249 }
250
251 pub fn assert_called(&self, pattern: &str) {
253 let requests = self.requests_blocking();
254 let matching = requests
255 .iter()
256 .filter(|r| self.matches_pattern(&r.url, pattern))
257 .count();
258 assert!(
259 matching > 0,
260 "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
261 pattern,
262 requests.iter().map(|r| &r.url).collect::<Vec<_>>()
263 );
264 }
265
266 pub fn assert_called_times(&self, pattern: &str, expected: usize) {
268 let requests = self.requests_blocking();
269 let matching = requests
270 .iter()
271 .filter(|r| self.matches_pattern(&r.url, pattern))
272 .count();
273 assert_eq!(
274 matching, expected,
275 "Expected {} HTTP calls matching '{}', but found {}",
276 expected, pattern, matching
277 );
278 }
279
280 pub fn assert_not_called(&self, pattern: &str) {
282 let requests = self.requests_blocking();
283 let matching = requests
284 .iter()
285 .filter(|r| self.matches_pattern(&r.url, pattern))
286 .count();
287 assert_eq!(
288 matching, 0,
289 "Expected no HTTP calls matching '{}', but found {}",
290 pattern, matching
291 );
292 }
293
294 pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
296 where
297 F: Fn(&serde_json::Value) -> bool,
298 {
299 let requests = self.requests_blocking();
300 let matching = requests
301 .iter()
302 .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
303 assert!(
304 matching.count() > 0,
305 "Expected HTTP call matching '{}' with matching body, but none found",
306 pattern
307 );
308 }
309}
310
311impl Default for MockHttp {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317pub struct MockHttpBuilder {
319 mocks: Vec<(String, BoxedHandler)>,
320}
321
322impl MockHttpBuilder {
323 pub fn new() -> Self {
325 Self { mocks: Vec::new() }
326 }
327
328 pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
330 where
331 F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
332 {
333 self.mocks.push((pattern.to_string(), Box::new(handler)));
334 self
335 }
336
337 pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
339 self,
340 pattern: &str,
341 response: T,
342 ) -> Self {
343 self.mock(pattern, move |_| MockResponse::json(response.clone()))
344 }
345
346 pub fn build(self) -> MockHttp {
348 let mut mock = MockHttp::new();
349 for (pattern, handler) in self.mocks {
350 mock.add_mock_boxed(&pattern, handler);
351 }
352 mock
353 }
354}
355
356impl Default for MockHttpBuilder {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_mock_response_json() {
368 let response = MockResponse::json(serde_json::json!({"id": 123}));
369 assert_eq!(response.status, 200);
370 assert_eq!(response.body["id"], 123);
371 }
372
373 #[test]
374 fn test_mock_response_error() {
375 let response = MockResponse::error(404, "Not found");
376 assert_eq!(response.status, 404);
377 assert_eq!(response.body["error"], "Not found");
378 }
379
380 #[test]
381 fn test_pattern_matching() {
382 let mock = MockHttp::new();
383
384 assert!(mock.matches_pattern(
386 "https://api.example.com/users",
387 "https://api.example.com/users"
388 ));
389
390 assert!(mock.matches_pattern(
392 "https://api.example.com/users/123",
393 "https://api.example.com/*"
394 ));
395
396 assert!(mock.matches_pattern(
398 "https://api.example.com/v2/users",
399 "https://api.example.com/*/users"
400 ));
401
402 assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
404 }
405
406 #[tokio::test]
407 async fn test_mock_execution() {
408 let mock = MockHttp::new();
409 mock.add_mock_sync("https://api.example.com/*", |_| {
410 MockResponse::json(serde_json::json!({"status": "ok"}))
411 });
412
413 let request = MockRequest {
414 method: "GET".to_string(),
415 path: "/users".to_string(),
416 url: "https://api.example.com/users".to_string(),
417 headers: HashMap::new(),
418 body: serde_json::Value::Null,
419 };
420
421 let response = mock.execute(request).await;
422 assert_eq!(response.status, 200);
423 assert_eq!(response.body["status"], "ok");
424 }
425
426 #[tokio::test]
427 async fn test_request_recording() {
428 let mock = MockHttp::new();
429 mock.add_mock_sync("*", |_| MockResponse::ok());
430
431 let request = MockRequest {
432 method: "POST".to_string(),
433 path: "/api/users".to_string(),
434 url: "https://api.example.com/users".to_string(),
435 headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
436 body: serde_json::json!({"name": "Test"}),
437 };
438
439 let _ = mock.execute(request).await;
440
441 let requests = mock.requests();
442 assert_eq!(requests.len(), 1);
443 assert_eq!(requests[0].method, "POST");
444 assert_eq!(requests[0].body["name"], "Test");
445 }
446
447 #[tokio::test]
448 async fn test_assert_called() {
449 let mock = MockHttp::new();
450 mock.add_mock_sync("*", |_| MockResponse::ok());
451
452 let request = MockRequest {
453 method: "GET".to_string(),
454 path: "/test".to_string(),
455 url: "https://api.example.com/test".to_string(),
456 headers: HashMap::new(),
457 body: serde_json::Value::Null,
458 };
459
460 let _ = mock.execute(request).await;
461
462 mock.assert_called("https://api.example.com/*");
463 mock.assert_called_times("https://api.example.com/*", 1);
464 mock.assert_not_called("https://other.com/*");
465 }
466
467 #[test]
468 fn test_builder() {
469 let mock = MockHttpBuilder::new()
470 .mock("https://api.example.com/*", |_| MockResponse::ok())
471 .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
472 .build();
473
474 assert_eq!(mock.mocks.read().unwrap().len(), 2);
475 }
476}