Skip to main content

forge_core/testing/
mock_http.rs

1//! HTTP mocking utilities for testing.
2//!
3//! Provides a mock HTTP client that intercepts requests and returns
4//! predefined responses. Supports pattern matching and request recording
5//! for verification.
6
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10use serde::Serialize;
11
12/// Mock HTTP client for testing.
13///
14/// # Example
15///
16/// ```ignore
17/// let mut mock = MockHttp::new();
18/// mock.add_mock_sync("https://api.example.com/*", |req| {
19///     MockResponse::json(json!({"status": "ok"}))
20/// });
21///
22/// let response = mock.execute(request).await;
23/// mock.assert_called("https://api.example.com/*");
24/// ```
25#[derive(Clone)]
26pub struct MockHttp {
27    mocks: Arc<RwLock<Vec<MockHandler>>>,
28    requests: Arc<RwLock<Vec<RecordedRequest>>>,
29}
30
31/// Type alias for mock handler closure.
32pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
33
34/// A mock handler.
35struct MockHandler {
36    pattern: String,
37    handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
38}
39
40/// A recorded request for verification.
41#[derive(Debug, Clone)]
42pub struct RecordedRequest {
43    /// Request method.
44    pub method: String,
45    /// Request URL.
46    pub url: String,
47    /// Request headers.
48    pub headers: HashMap<String, String>,
49    /// Request body.
50    pub body: serde_json::Value,
51}
52
53/// Mock HTTP request.
54#[derive(Debug, Clone)]
55pub struct MockRequest {
56    /// Request method.
57    pub method: String,
58    /// Request path.
59    pub path: String,
60    /// Request URL.
61    pub url: String,
62    /// Request headers.
63    pub headers: HashMap<String, String>,
64    /// Request body.
65    pub body: serde_json::Value,
66}
67
68/// Mock HTTP response.
69#[derive(Debug, Clone)]
70pub struct MockResponse {
71    /// Status code.
72    pub status: u16,
73    /// Response headers.
74    pub headers: HashMap<String, String>,
75    /// Response body.
76    pub body: serde_json::Value,
77}
78
79impl MockResponse {
80    /// Create a successful JSON response.
81    pub fn json<T: Serialize>(body: T) -> Self {
82        Self {
83            status: 200,
84            headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
85            body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
86        }
87    }
88
89    /// Create an error response.
90    pub fn error(status: u16, message: &str) -> Self {
91        Self {
92            status,
93            headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
94            body: serde_json::json!({ "error": message }),
95        }
96    }
97
98    /// Create a 500 internal error.
99    pub fn internal_error(message: &str) -> Self {
100        Self::error(500, message)
101    }
102
103    /// Create a 404 not found.
104    pub fn not_found(message: &str) -> Self {
105        Self::error(404, message)
106    }
107
108    /// Create a 401 unauthorized.
109    pub fn unauthorized(message: &str) -> Self {
110        Self::error(401, message)
111    }
112
113    /// Create an empty 200 OK response.
114    pub fn ok() -> Self {
115        Self::json(serde_json::json!({}))
116    }
117}
118
119impl MockHttp {
120    /// Create a new mock HTTP client.
121    pub fn new() -> Self {
122        Self {
123            mocks: Arc::new(RwLock::new(Vec::new())),
124            requests: Arc::new(RwLock::new(Vec::new())),
125        }
126    }
127
128    /// Create a builder.
129    pub fn builder() -> MockHttpBuilder {
130        MockHttpBuilder::new()
131    }
132
133    /// Add a mock handler (sync version for use in builders).
134    pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
135    where
136        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
137    {
138        let mut mocks = self.mocks.write().unwrap();
139        mocks.push(MockHandler {
140            pattern: pattern.to_string(),
141            handler: Arc::new(handler),
142        });
143    }
144
145    /// Add a mock handler from a boxed closure.
146    pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
147        let mut mocks = self.mocks.write().unwrap();
148        mocks.push(MockHandler {
149            pattern: pattern.to_string(),
150            handler: Arc::from(handler),
151        });
152    }
153
154    /// Execute a mock request.
155    pub async fn execute(&self, request: MockRequest) -> MockResponse {
156        // Record the request
157        {
158            let mut requests = self.requests.write().unwrap();
159            requests.push(RecordedRequest {
160                method: request.method.clone(),
161                url: request.url.clone(),
162                headers: request.headers.clone(),
163                body: request.body.clone(),
164            });
165        }
166
167        // Find matching mock
168        let mocks = self.mocks.read().unwrap();
169        for mock in mocks.iter() {
170            if self.matches_pattern(&request.url, &mock.pattern)
171                || self.matches_pattern(&request.path, &mock.pattern)
172            {
173                return (mock.handler)(&request);
174            }
175        }
176
177        // No mock found
178        MockResponse::error(500, &format!("No mock found for {}", request.url))
179    }
180
181    /// Check if a URL matches a pattern.
182    fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
183        // Convert glob pattern to simple matching
184        let pattern_parts: Vec<&str> = pattern.split('*').collect();
185        if pattern_parts.len() == 1 {
186            // No wildcards - exact match
187            return url == pattern;
188        }
189
190        let mut remaining = url;
191        for (i, part) in pattern_parts.iter().enumerate() {
192            if part.is_empty() {
193                continue;
194            }
195
196            if i == 0 {
197                // First part must match at start
198                if !remaining.starts_with(part) {
199                    return false;
200                }
201                remaining = &remaining[part.len()..];
202            } else if i == pattern_parts.len() - 1 {
203                // Last part must match at end
204                if !remaining.ends_with(part) {
205                    return false;
206                }
207            } else {
208                // Middle parts can match anywhere
209                if let Some(pos) = remaining.find(part) {
210                    remaining = &remaining[pos + part.len()..];
211                } else {
212                    return false;
213                }
214            }
215        }
216
217        true
218    }
219
220    /// Get recorded requests.
221    pub fn requests(&self) -> Vec<RecordedRequest> {
222        self.requests.read().unwrap().clone()
223    }
224
225    /// Get recorded requests (blocking version for use in sync contexts).
226    pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
227        self.requests.read().unwrap().clone()
228    }
229
230    /// Get requests matching a pattern.
231    pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
232        self.requests
233            .read()
234            .unwrap()
235            .iter()
236            .filter(|r| self.matches_pattern(&r.url, pattern))
237            .cloned()
238            .collect()
239    }
240
241    /// Clear recorded requests.
242    pub fn clear_requests(&self) {
243        self.requests.write().unwrap().clear();
244    }
245
246    /// Clear all mocks.
247    pub fn clear_mocks(&self) {
248        self.mocks.write().unwrap().clear();
249    }
250
251    // =========================================================================
252    // VERIFICATION METHODS
253    // =========================================================================
254
255    /// Assert that a URL pattern was called.
256    pub fn assert_called(&self, pattern: &str) {
257        let requests = self.requests_blocking();
258        let matching = requests
259            .iter()
260            .filter(|r| self.matches_pattern(&r.url, pattern))
261            .count();
262        assert!(
263            matching > 0,
264            "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
265            pattern,
266            requests.iter().map(|r| &r.url).collect::<Vec<_>>()
267        );
268    }
269
270    /// Assert that a URL pattern was called a specific number of times.
271    pub fn assert_called_times(&self, pattern: &str, expected: usize) {
272        let requests = self.requests_blocking();
273        let matching = requests
274            .iter()
275            .filter(|r| self.matches_pattern(&r.url, pattern))
276            .count();
277        assert_eq!(
278            matching, expected,
279            "Expected {} HTTP calls matching '{}', but found {}",
280            expected, pattern, matching
281        );
282    }
283
284    /// Assert that a URL pattern was not called.
285    pub fn assert_not_called(&self, pattern: &str) {
286        let requests = self.requests_blocking();
287        let matching = requests
288            .iter()
289            .filter(|r| self.matches_pattern(&r.url, pattern))
290            .count();
291        assert_eq!(
292            matching, 0,
293            "Expected no HTTP calls matching '{}', but found {}",
294            pattern, matching
295        );
296    }
297
298    /// Assert that a request was made with specific body content.
299    pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
300    where
301        F: Fn(&serde_json::Value) -> bool,
302    {
303        let requests = self.requests_blocking();
304        let matching = requests
305            .iter()
306            .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
307        assert!(
308            matching.count() > 0,
309            "Expected HTTP call matching '{}' with matching body, but none found",
310            pattern
311        );
312    }
313}
314
315impl Default for MockHttp {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321/// Builder for MockHttp.
322pub struct MockHttpBuilder {
323    mocks: Vec<(String, BoxedHandler)>,
324}
325
326impl MockHttpBuilder {
327    /// Create a new builder.
328    pub fn new() -> Self {
329        Self { mocks: Vec::new() }
330    }
331
332    /// Add a mock with a custom handler.
333    pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
334    where
335        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
336    {
337        self.mocks.push((pattern.to_string(), Box::new(handler)));
338        self
339    }
340
341    /// Add a mock that returns a JSON response.
342    pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
343        self,
344        pattern: &str,
345        response: T,
346    ) -> Self {
347        self.mock(pattern, move |_| MockResponse::json(response.clone()))
348    }
349
350    /// Build the MockHttp.
351    pub fn build(self) -> MockHttp {
352        let mut mock = MockHttp::new();
353        for (pattern, handler) in self.mocks {
354            mock.add_mock_boxed(&pattern, handler);
355        }
356        mock
357    }
358}
359
360impl Default for MockHttpBuilder {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_mock_response_json() {
372        let response = MockResponse::json(serde_json::json!({"id": 123}));
373        assert_eq!(response.status, 200);
374        assert_eq!(response.body["id"], 123);
375    }
376
377    #[test]
378    fn test_mock_response_error() {
379        let response = MockResponse::error(404, "Not found");
380        assert_eq!(response.status, 404);
381        assert_eq!(response.body["error"], "Not found");
382    }
383
384    #[test]
385    fn test_pattern_matching() {
386        let mock = MockHttp::new();
387
388        // Exact match
389        assert!(mock.matches_pattern(
390            "https://api.example.com/users",
391            "https://api.example.com/users"
392        ));
393
394        // Wildcard at end
395        assert!(mock.matches_pattern(
396            "https://api.example.com/users/123",
397            "https://api.example.com/*"
398        ));
399
400        // Wildcard in middle
401        assert!(mock.matches_pattern(
402            "https://api.example.com/v2/users",
403            "https://api.example.com/*/users"
404        ));
405
406        // No match
407        assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
408    }
409
410    #[tokio::test]
411    async fn test_mock_execution() {
412        let mock = MockHttp::new();
413        mock.add_mock_sync("https://api.example.com/*", |_| {
414            MockResponse::json(serde_json::json!({"status": "ok"}))
415        });
416
417        let request = MockRequest {
418            method: "GET".to_string(),
419            path: "/users".to_string(),
420            url: "https://api.example.com/users".to_string(),
421            headers: HashMap::new(),
422            body: serde_json::Value::Null,
423        };
424
425        let response = mock.execute(request).await;
426        assert_eq!(response.status, 200);
427        assert_eq!(response.body["status"], "ok");
428    }
429
430    #[tokio::test]
431    async fn test_request_recording() {
432        let mock = MockHttp::new();
433        mock.add_mock_sync("*", |_| MockResponse::ok());
434
435        let request = MockRequest {
436            method: "POST".to_string(),
437            path: "/api/users".to_string(),
438            url: "https://api.example.com/users".to_string(),
439            headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
440            body: serde_json::json!({"name": "Test"}),
441        };
442
443        let _ = mock.execute(request).await;
444
445        let requests = mock.requests();
446        assert_eq!(requests.len(), 1);
447        assert_eq!(requests[0].method, "POST");
448        assert_eq!(requests[0].body["name"], "Test");
449    }
450
451    #[tokio::test]
452    async fn test_assert_called() {
453        let mock = MockHttp::new();
454        mock.add_mock_sync("*", |_| MockResponse::ok());
455
456        let request = MockRequest {
457            method: "GET".to_string(),
458            path: "/test".to_string(),
459            url: "https://api.example.com/test".to_string(),
460            headers: HashMap::new(),
461            body: serde_json::Value::Null,
462        };
463
464        let _ = mock.execute(request).await;
465
466        mock.assert_called("https://api.example.com/*");
467        mock.assert_called_times("https://api.example.com/*", 1);
468        mock.assert_not_called("https://other.com/*");
469    }
470
471    #[test]
472    fn test_builder() {
473        let mock = MockHttpBuilder::new()
474            .mock("https://api.example.com/*", |_| MockResponse::ok())
475            .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
476            .build();
477
478        assert_eq!(mock.mocks.read().unwrap().len(), 2);
479    }
480}