Skip to main content

forge_core/testing/
mock_http.rs

1//! HTTP mocking utilities for testing.
2//!
3//! Provides a mock HTTP client that intercepts requests and returns
4//! predefined responses. Supports pattern matching and request recording
5//! for verification.
6
7#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
8
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12use serde::Serialize;
13
14/// Mock HTTP client for testing.
15///
16/// # Example
17///
18/// ```ignore
19/// let mut mock = MockHttp::new();
20/// mock.add_mock_sync("https://api.example.com/*", |req| {
21///     MockResponse::json(json!({"status": "ok"}))
22/// });
23///
24/// let response = mock.execute(request).await;
25/// mock.assert_called("https://api.example.com/*");
26/// ```
27#[derive(Clone)]
28pub struct MockHttp {
29    mocks: Arc<RwLock<Vec<MockHandler>>>,
30    requests: Arc<RwLock<Vec<RecordedRequest>>>,
31}
32
33/// Type alias for mock handler closure.
34pub type BoxedHandler = Box<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>;
35
36/// A mock handler.
37struct MockHandler {
38    pattern: String,
39    handler: Arc<dyn Fn(&MockRequest) -> MockResponse + Send + Sync>,
40}
41
42/// A recorded request for verification.
43#[derive(Debug, Clone)]
44pub struct RecordedRequest {
45    /// Request method.
46    pub method: String,
47    /// Request URL.
48    pub url: String,
49    /// Request headers.
50    pub headers: HashMap<String, String>,
51    /// Request body.
52    pub body: serde_json::Value,
53}
54
55/// Mock HTTP request.
56#[derive(Debug, Clone)]
57pub struct MockRequest {
58    /// Request method.
59    pub method: String,
60    /// Request path.
61    pub path: String,
62    /// Request URL.
63    pub url: String,
64    /// Request headers.
65    pub headers: HashMap<String, String>,
66    /// Request body.
67    pub body: serde_json::Value,
68}
69
70/// Mock HTTP response.
71#[derive(Debug, Clone)]
72pub struct MockResponse {
73    /// Status code.
74    pub status: u16,
75    /// Response headers.
76    pub headers: HashMap<String, String>,
77    /// Response body.
78    pub body: serde_json::Value,
79}
80
81impl MockResponse {
82    /// Create a successful JSON response.
83    pub fn json<T: Serialize>(body: T) -> Self {
84        Self {
85            status: 200,
86            headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
87            body: serde_json::to_value(body).unwrap_or(serde_json::Value::Null),
88        }
89    }
90
91    /// Create an error response.
92    pub fn error(status: u16, message: &str) -> Self {
93        Self {
94            status,
95            headers: HashMap::from([("content-type".to_string(), "application/json".to_string())]),
96            body: serde_json::json!({ "error": message }),
97        }
98    }
99
100    /// Create a 500 internal error.
101    pub fn internal_error(message: &str) -> Self {
102        Self::error(500, message)
103    }
104
105    /// Create a 404 not found.
106    pub fn not_found(message: &str) -> Self {
107        Self::error(404, message)
108    }
109
110    /// Create a 401 unauthorized.
111    pub fn unauthorized(message: &str) -> Self {
112        Self::error(401, message)
113    }
114
115    /// Create an empty 200 OK response.
116    pub fn ok() -> Self {
117        Self::json(serde_json::json!({}))
118    }
119}
120
121impl MockHttp {
122    /// Create a new mock HTTP client.
123    pub fn new() -> Self {
124        Self {
125            mocks: Arc::new(RwLock::new(Vec::new())),
126            requests: Arc::new(RwLock::new(Vec::new())),
127        }
128    }
129
130    /// Create a builder.
131    pub fn builder() -> MockHttpBuilder {
132        MockHttpBuilder::new()
133    }
134
135    /// Add a mock handler (sync version for use in builders).
136    pub fn add_mock_sync<F>(&self, pattern: &str, handler: F)
137    where
138        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
139    {
140        let mut mocks = self.mocks.write().unwrap();
141        mocks.push(MockHandler {
142            pattern: pattern.to_string(),
143            handler: Arc::new(handler),
144        });
145    }
146
147    /// Add a mock handler from a boxed closure.
148    pub fn add_mock_boxed(&mut self, pattern: &str, handler: BoxedHandler) {
149        let mut mocks = self.mocks.write().unwrap();
150        mocks.push(MockHandler {
151            pattern: pattern.to_string(),
152            handler: Arc::from(handler),
153        });
154    }
155
156    /// Execute a mock request.
157    pub async fn execute(&self, request: MockRequest) -> MockResponse {
158        // Record the request
159        {
160            let mut requests = self.requests.write().unwrap();
161            requests.push(RecordedRequest {
162                method: request.method.clone(),
163                url: request.url.clone(),
164                headers: request.headers.clone(),
165                body: request.body.clone(),
166            });
167        }
168
169        // Find matching mock
170        let mocks = self.mocks.read().unwrap();
171        for mock in mocks.iter() {
172            if self.matches_pattern(&request.url, &mock.pattern)
173                || self.matches_pattern(&request.path, &mock.pattern)
174            {
175                return (mock.handler)(&request);
176            }
177        }
178
179        // No mock found
180        MockResponse::error(500, &format!("No mock found for {}", request.url))
181    }
182
183    /// Check if a URL matches a pattern.
184    fn matches_pattern(&self, url: &str, pattern: &str) -> bool {
185        // Convert glob pattern to simple matching
186        let pattern_parts: Vec<&str> = pattern.split('*').collect();
187        if pattern_parts.len() == 1 {
188            // No wildcards - exact match
189            return url == pattern;
190        }
191
192        let mut remaining = url;
193        for (i, part) in pattern_parts.iter().enumerate() {
194            if part.is_empty() {
195                continue;
196            }
197
198            if i == 0 {
199                // First part must match at start
200                if !remaining.starts_with(part) {
201                    return false;
202                }
203                remaining = &remaining[part.len()..];
204            } else if i == pattern_parts.len() - 1 {
205                // Last part must match at end
206                if !remaining.ends_with(part) {
207                    return false;
208                }
209            } else {
210                // Middle parts can match anywhere
211                if let Some(pos) = remaining.find(part) {
212                    remaining = &remaining[pos + part.len()..];
213                } else {
214                    return false;
215                }
216            }
217        }
218
219        true
220    }
221
222    /// Get recorded requests.
223    pub fn requests(&self) -> Vec<RecordedRequest> {
224        self.requests.read().unwrap().clone()
225    }
226
227    /// Get recorded requests (blocking version for use in sync contexts).
228    pub fn requests_blocking(&self) -> Vec<RecordedRequest> {
229        self.requests.read().unwrap().clone()
230    }
231
232    /// Get requests matching a pattern.
233    pub fn requests_to(&self, pattern: &str) -> Vec<RecordedRequest> {
234        self.requests
235            .read()
236            .unwrap()
237            .iter()
238            .filter(|r| self.matches_pattern(&r.url, pattern))
239            .cloned()
240            .collect()
241    }
242
243    /// Clear recorded requests.
244    pub fn clear_requests(&self) {
245        self.requests.write().unwrap().clear();
246    }
247
248    /// Clear all mocks.
249    pub fn clear_mocks(&self) {
250        self.mocks.write().unwrap().clear();
251    }
252
253    /// Assert that a URL pattern was called.
254    pub fn assert_called(&self, pattern: &str) {
255        let requests = self.requests_blocking();
256        let matching = requests
257            .iter()
258            .filter(|r| self.matches_pattern(&r.url, pattern))
259            .count();
260        assert!(
261            matching > 0,
262            "Expected HTTP call matching '{}', but none found. Recorded requests: {:?}",
263            pattern,
264            requests.iter().map(|r| &r.url).collect::<Vec<_>>()
265        );
266    }
267
268    /// Assert that a URL pattern was called a specific number of times.
269    pub fn assert_called_times(&self, pattern: &str, expected: usize) {
270        let requests = self.requests_blocking();
271        let matching = requests
272            .iter()
273            .filter(|r| self.matches_pattern(&r.url, pattern))
274            .count();
275        assert_eq!(
276            matching, expected,
277            "Expected {} HTTP calls matching '{}', but found {}",
278            expected, pattern, matching
279        );
280    }
281
282    /// Assert that a URL pattern was not called.
283    pub fn assert_not_called(&self, pattern: &str) {
284        let requests = self.requests_blocking();
285        let matching = requests
286            .iter()
287            .filter(|r| self.matches_pattern(&r.url, pattern))
288            .count();
289        assert_eq!(
290            matching, 0,
291            "Expected no HTTP calls matching '{}', but found {}",
292            pattern, matching
293        );
294    }
295
296    /// Assert that a request was made with specific body content.
297    pub fn assert_called_with_body<F>(&self, pattern: &str, predicate: F)
298    where
299        F: Fn(&serde_json::Value) -> bool,
300    {
301        let requests = self.requests_blocking();
302        let matching = requests
303            .iter()
304            .filter(|r| self.matches_pattern(&r.url, pattern) && predicate(&r.body));
305        assert!(
306            matching.count() > 0,
307            "Expected HTTP call matching '{}' with matching body, but none found",
308            pattern
309        );
310    }
311}
312
313impl Default for MockHttp {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319/// Builder for MockHttp.
320pub struct MockHttpBuilder {
321    mocks: Vec<(String, BoxedHandler)>,
322}
323
324impl MockHttpBuilder {
325    /// Create a new builder.
326    pub fn new() -> Self {
327        Self { mocks: Vec::new() }
328    }
329
330    /// Add a mock with a custom handler.
331    pub fn mock<F>(mut self, pattern: &str, handler: F) -> Self
332    where
333        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
334    {
335        self.mocks.push((pattern.to_string(), Box::new(handler)));
336        self
337    }
338
339    /// Add a mock that returns a JSON response.
340    pub fn mock_json<T: Serialize + Clone + Send + Sync + 'static>(
341        self,
342        pattern: &str,
343        response: T,
344    ) -> Self {
345        self.mock(pattern, move |_| MockResponse::json(response.clone()))
346    }
347
348    /// Build the MockHttp.
349    pub fn build(self) -> MockHttp {
350        let mut mock = MockHttp::new();
351        for (pattern, handler) in self.mocks {
352            mock.add_mock_boxed(&pattern, handler);
353        }
354        mock
355    }
356}
357
358impl Default for MockHttpBuilder {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_mock_response_json() {
370        let response = MockResponse::json(serde_json::json!({"id": 123}));
371        assert_eq!(response.status, 200);
372        assert_eq!(response.body["id"], 123);
373    }
374
375    #[test]
376    fn test_mock_response_error() {
377        let response = MockResponse::error(404, "Not found");
378        assert_eq!(response.status, 404);
379        assert_eq!(response.body["error"], "Not found");
380    }
381
382    #[test]
383    fn test_pattern_matching() {
384        let mock = MockHttp::new();
385
386        // Exact match
387        assert!(mock.matches_pattern(
388            "https://api.example.com/users",
389            "https://api.example.com/users"
390        ));
391
392        // Wildcard at end
393        assert!(mock.matches_pattern(
394            "https://api.example.com/users/123",
395            "https://api.example.com/*"
396        ));
397
398        // Wildcard in middle
399        assert!(mock.matches_pattern(
400            "https://api.example.com/v2/users",
401            "https://api.example.com/*/users"
402        ));
403
404        // No match
405        assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*"));
406    }
407
408    #[tokio::test]
409    async fn test_mock_execution() {
410        let mock = MockHttp::new();
411        mock.add_mock_sync("https://api.example.com/*", |_| {
412            MockResponse::json(serde_json::json!({"status": "ok"}))
413        });
414
415        let request = MockRequest {
416            method: "GET".to_string(),
417            path: "/users".to_string(),
418            url: "https://api.example.com/users".to_string(),
419            headers: HashMap::new(),
420            body: serde_json::Value::Null,
421        };
422
423        let response = mock.execute(request).await;
424        assert_eq!(response.status, 200);
425        assert_eq!(response.body["status"], "ok");
426    }
427
428    #[tokio::test]
429    async fn test_request_recording() {
430        let mock = MockHttp::new();
431        mock.add_mock_sync("*", |_| MockResponse::ok());
432
433        let request = MockRequest {
434            method: "POST".to_string(),
435            path: "/api/users".to_string(),
436            url: "https://api.example.com/users".to_string(),
437            headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]),
438            body: serde_json::json!({"name": "Test"}),
439        };
440
441        let _ = mock.execute(request).await;
442
443        let requests = mock.requests();
444        assert_eq!(requests.len(), 1);
445        assert_eq!(requests[0].method, "POST");
446        assert_eq!(requests[0].body["name"], "Test");
447    }
448
449    #[tokio::test]
450    async fn test_assert_called() {
451        let mock = MockHttp::new();
452        mock.add_mock_sync("*", |_| MockResponse::ok());
453
454        let request = MockRequest {
455            method: "GET".to_string(),
456            path: "/test".to_string(),
457            url: "https://api.example.com/test".to_string(),
458            headers: HashMap::new(),
459            body: serde_json::Value::Null,
460        };
461
462        let _ = mock.execute(request).await;
463
464        mock.assert_called("https://api.example.com/*");
465        mock.assert_called_times("https://api.example.com/*", 1);
466        mock.assert_not_called("https://other.com/*");
467    }
468
469    #[test]
470    fn test_builder() {
471        let mock = MockHttpBuilder::new()
472            .mock("https://api.example.com/*", |_| MockResponse::ok())
473            .mock_json("https://other.com/*", serde_json::json!({"id": 1}))
474            .build();
475
476        assert_eq!(mock.mocks.read().unwrap().len(), 2);
477    }
478}