Skip to main content

forge_core/testing/
mock_http.rs

1//! HTTP mocking utilities for testing.
2//!
3//! Provides a mock HTTP client that intercepts requests and returns
4//! predefined responses. Supports pattern matching and request recording
5//! for verification.
6
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10use serde::Serialize;
11
12/// Mock HTTP client for testing.
13///
14/// # Example
15///
16/// ```ignore
17/// let mut mock = MockHttp::new();
18/// mock.add_mock_sync("https://api.example.com/*", |req| {
19///     MockResponse::json(json!({"status": "ok"}))
20/// });
21///
22/// let response = mock.execute(request).await;
23/// mock.assert_called("https://api.example.com/*");
24/// ```
25#[derive(Clone)]
26pub struct MockHttp {
27    mocks: Arc<RwLock<Vec<MockHandler>>>,
28    requests: Arc<RwLock<Vec<RecordedRequest>>>,
29}
30
31/// Type alias for mock handler closure.
32pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
33
34/// A mock handler.
35struct MockHandler {
36    pattern: String,
37    handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
38}
39
40/// A recorded request for verification.
41#[derive(Debug, Clone)]
42pub struct RecordedRequest {
43    /// Request method.
44    pub method: String,
45    /// Request URL.
46    pub url: String,
47    /// Request headers.
48    pub headers: HashMap<String, String>,
49    /// Request body.
50    pub body: serde_json::Value,
51}
52
53/// Mock HTTP request.
54#[derive(Debug, Clone)]
55pub struct MockRequest {
56    /// Request method.
57    pub method: String,
58    /// Request path.
59    pub path: String,
60    /// Request URL.
61    pub url: String,
62    /// Request headers.
63    pub headers: HashMap<String, String>,
64    /// Request body.
65    pub body: serde_json::Value,
66}
67
68/// Mock HTTP response.
69#[derive(Debug, Clone)]
70pub struct MockResponse {
71    /// Status code.
72    pub status: u16,
73    /// Response headers.
74    pub headers: HashMap<String, String>,
75    /// Response body.
76    pub body: serde_json::Value,
77}
78
79impl MockResponse {
80    /// Create a successful JSON response.
81    pub fn json<T: Serialize>(body: T) -> Self {
82        Self {
83            status: 200,
84            headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
85            body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
86        }
87    }
88
89    /// Create an error response.
90    pub fn error(status: u16, message: &str) -> Self {
91        Self {
92            status,
93            headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
94            body: serde_json::json!({ "error": message }),
95        }
96    }
97
98    /// Create a 500 internal error.
99    pub fn internal_error(message: &str) -> Self {
100        Self::error(500, message)
101    }
102
103    /// Create a 404 not found.
104    pub fn not_found(message: &str) -> Self {
105        Self::error(404, message)
106    }
107
108    /// Create a 401 unauthorized.
109    pub fn unauthorized(message: &str) -> Self {
110        Self::error(401, message)
111    }
112
113    /// Create an empty 200 OK response.
114    pub fn ok() -> Self {
115        Self::json(serde_json::json!({}))
116    }
117}
118
119impl MockHttp {
120    /// Create a new mock HTTP client.
121    pub fn new() -> Self {
122        Self {
123            mocks: Arc::new(RwLock::new(Vec::new())),
124            requests: Arc::new(RwLock::new(Vec::new())),
125        }
126    }
127
128    /// Create a builder.
129    pub fn builder() -> MockHttpBuilder {
130        MockHttpBuilder::new()
131    }
132
133    /// Add a mock handler (sync version for use in builders).
134    pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
135    where
136        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
137    {
138        let mut mocks = self.mocks.write().unwrap();
139        mocks.push(MockHandler {
140            pattern: pattern.to_string(),
141            handler: Arc::new(handler),
142        });
143    }
144
145    /// Add a mock handler from a boxed closure.
146    pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
147        let mut mocks = self.mocks.write().unwrap();
148        mocks.push(MockHandler {
149            pattern: pattern.to_string(),
150            handler: Arc::from(handler),
151        });
152    }
153
154    /// Execute a mock request.
155    pub async fn execute(&self, request: MockRequest) -> MockResponse {
156        // Record the request
157        {
158            let mut requests = self.requests.write().unwrap();
159            requests.push(RecordedRequest {
160                method: request.method.clone(),
161                url: request.url.clone(),
162                headers: request.headers.clone(),
163                body: request.body.clone(),
164            });
165        }
166
167        // Find matching mock
168        let mocks = self.mocks.read().unwrap();
169        for mock in mocks.iter() {
170            if self.matches_pattern(&request.url, &mock.pattern)
171                || self.matches_pattern(&request.path, &mock.pattern)
172            {
173                return (mock.handler)(&request);
174            }
175        }
176
177        // No mock found
178        MockResponse::error(500, &format!("No mock found for {}", request.url))
179    }
180
181    /// Check if a URL matches a pattern.
182    fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
183        // Convert glob pattern to simple matching
184        let pattern_parts: Vec<&str> = pattern.split('*').collect();
185        if pattern_parts.len() == 1 {
186            // No wildcards - exact match
187            return url == pattern;
188        }
189
190        let mut remaining = url;
191        for (i, part) in pattern_parts.iter().enumerate() {
192            if part.is_empty() {
193                continue;
194            }
195
196            if i == 0 {
197                // First part must match at start
198                if !remaining.starts_with(part) {
199                    return false;
200                }
201                remaining = &remaining[part.len()..];
202            } else if i == pattern_parts.len() - 1 {
203                // Last part must match at end
204                if !remaining.ends_with(part) {
205                    return false;
206                }
207            } else {
208                // Middle parts can match anywhere
209                if let Some(pos) = remaining.find(part) {
210                    remaining = &remaining[pos + part.len()..];
211                } else {
212                    return false;
213                }
214            }
215        }
216
217        true
218    }
219
220    /// Get recorded requests.
221    pub fn requests(&self) -> Vec<RecordedRequest> {
222        self.requests.read().unwrap().clone()
223    }
224
225    /// Get recorded requests (blocking version for use in sync contexts).
226    pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
227        self.requests.read().unwrap().clone()
228    }
229
230    /// Get requests matching a pattern.
231    pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
232        self.requests
233            .read()
234            .unwrap()
235            .iter()
236            .filter(|r| self.matches_pattern(&r.url, pattern))
237            .cloned()
238            .collect()
239    }
240
241    /// Clear recorded requests.
242    pub fn clear_requests(&self) {
243        self.requests.write().unwrap().clear();
244    }
245
246    /// Clear all mocks.
247    pub fn clear_mocks(&self) {
248        self.mocks.write().unwrap().clear();
249    }
250
251    /// Assert that a URL pattern was called.
252    pub fn assert_called(&self, pattern: &str) {
253        let requests = self.requests_blocking();
254        let matching = requests
255            .iter()
256            .filter(|r| self.matches_pattern(&r.url, pattern))
257            .count();
258        assert!(
259            matching > 0,
260            "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
261            pattern,
262            requests.iter().map(|r| &r.url).collect::<Vec<_>>()
263        );
264    }
265
266    /// Assert that a URL pattern was called a specific number of times.
267    pub fn assert_called_times(&self, pattern: &str, expected: usize) {
268        let requests = self.requests_blocking();
269        let matching = requests
270            .iter()
271            .filter(|r| self.matches_pattern(&r.url, pattern))
272            .count();
273        assert_eq!(
274            matching, expected,
275            "Expected {} HTTP calls matching '{}', but found {}",
276            expected, pattern, matching
277        );
278    }
279
280    /// Assert that a URL pattern was not called.
281    pub fn assert_not_called(&self, pattern: &str) {
282        let requests = self.requests_blocking();
283        let matching = requests
284            .iter()
285            .filter(|r| self.matches_pattern(&r.url, pattern))
286            .count();
287        assert_eq!(
288            matching, 0,
289            "Expected no HTTP calls matching '{}', but found {}",
290            pattern, matching
291        );
292    }
293
294    /// Assert that a request was made with specific body content.
295    pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
296    where
297        F: Fn(&serde_json::Value) -> bool,
298    {
299        let requests = self.requests_blocking();
300        let matching = requests
301            .iter()
302            .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
303        assert!(
304            matching.count() > 0,
305            "Expected HTTP call matching '{}' with matching body, but none found",
306            pattern
307        );
308    }
309}
310
311impl Default for MockHttp {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317/// Builder for MockHttp.
318pub struct MockHttpBuilder {
319    mocks: Vec<(String, BoxedHandler)>,
320}
321
322impl MockHttpBuilder {
323    /// Create a new builder.
324    pub fn new() -> Self {
325        Self { mocks: Vec::new() }
326    }
327
328    /// Add a mock with a custom handler.
329    pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
330    where
331        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
332    {
333        self.mocks.push((pattern.to_string(), Box::new(handler)));
334        self
335    }
336
337    /// Add a mock that returns a JSON response.
338    pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
339        self,
340        pattern: &str,
341        response: T,
342    ) -> Self {
343        self.mock(pattern, move |_| MockResponse::json(response.clone()))
344    }
345
346    /// Build the MockHttp.
347    pub fn build(self) -> MockHttp {
348        let mut mock = MockHttp::new();
349        for (pattern, handler) in self.mocks {
350            mock.add_mock_boxed(&pattern, handler);
351        }
352        mock
353    }
354}
355
356impl Default for MockHttpBuilder {
357    fn default() -> Self {
358        Self::new()
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_mock_response_json() {
368        let response = MockResponse::json(serde_json::json!({"id": 123}));
369        assert_eq!(response.status, 200);
370        assert_eq!(response.body["id"], 123);
371    }
372
373    #[test]
374    fn test_mock_response_error() {
375        let response = MockResponse::error(404, "Not found");
376        assert_eq!(response.status, 404);
377        assert_eq!(response.body["error"], "Not found");
378    }
379
380    #[test]
381    fn test_pattern_matching() {
382        let mock = MockHttp::new();
383
384        // Exact match
385        assert!(mock.matches_pattern(
386            "https://api.example.com/users",
387            "https://api.example.com/users"
388        ));
389
390        // Wildcard at end
391        assert!(mock.matches_pattern(
392            "https://api.example.com/users/123",
393            "https://api.example.com/*"
394        ));
395
396        // Wildcard in middle
397        assert!(mock.matches_pattern(
398            "https://api.example.com/v2/users",
399            "https://api.example.com/*/users"
400        ));
401
402        // No match
403        assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
404    }
405
406    #[tokio::test]
407    async fn test_mock_execution() {
408        let mock = MockHttp::new();
409        mock.add_mock_sync("https://api.example.com/*", |_| {
410            MockResponse::json(serde_json::json!({"status": "ok"}))
411        });
412
413        let request = MockRequest {
414            method: "GET".to_string(),
415            path: "/users".to_string(),
416            url: "https://api.example.com/users".to_string(),
417            headers: HashMap::new(),
418            body: serde_json::Value::Null,
419        };
420
421        let response = mock.execute(request).await;
422        assert_eq!(response.status, 200);
423        assert_eq!(response.body["status"], "ok");
424    }
425
426    #[tokio::test]
427    async fn test_request_recording() {
428        let mock = MockHttp::new();
429        mock.add_mock_sync("*", |_| MockResponse::ok());
430
431        let request = MockRequest {
432            method: "POST".to_string(),
433            path: "/api/users".to_string(),
434            url: "https://api.example.com/users".to_string(),
435            headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
436            body: serde_json::json!({"name": "Test"}),
437        };
438
439        let _ = mock.execute(request).await;
440
441        let requests = mock.requests();
442        assert_eq!(requests.len(), 1);
443        assert_eq!(requests[0].method, "POST");
444        assert_eq!(requests[0].body["name"], "Test");
445    }
446
447    #[tokio::test]
448    async fn test_assert_called() {
449        let mock = MockHttp::new();
450        mock.add_mock_sync("*", |_| MockResponse::ok());
451
452        let request = MockRequest {
453            method: "GET".to_string(),
454            path: "/test".to_string(),
455            url: "https://api.example.com/test".to_string(),
456            headers: HashMap::new(),
457            body: serde_json::Value::Null,
458        };
459
460        let _ = mock.execute(request).await;
461
462        mock.assert_called("https://api.example.com/*");
463        mock.assert_called_times("https://api.example.com/*", 1);
464        mock.assert_not_called("https://other.com/*");
465    }
466
467    #[test]
468    fn test_builder() {
469        let mock = MockHttpBuilder::new()
470            .mock("https://api.example.com/*", |_| MockResponse::ok())
471            .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
472            .build();
473
474        assert_eq!(mock.mocks.read().unwrap().len(), 2);
475    }
476}