Skip to main content

harn_vm/http/
mock.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3use std::rc::Rc;
4
5use crate::value::VmValue;
6
7#[derive(Clone)]
8pub(super) struct MockResponse {
9    pub(super) status: i64,
10    pub(super) body: String,
11    pub(super) headers: BTreeMap<String, VmValue>,
12}
13
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct HttpMockResponse {
16    pub status: i64,
17    pub body: String,
18    pub headers: BTreeMap<String, String>,
19}
20
21impl HttpMockResponse {
22    pub fn new(status: i64, body: impl Into<String>) -> Self {
23        Self {
24            status,
25            body: body.into(),
26            headers: BTreeMap::new(),
27        }
28    }
29
30    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
31        self.headers.insert(name.into(), value.into());
32        self
33    }
34}
35
36impl From<HttpMockResponse> for MockResponse {
37    fn from(value: HttpMockResponse) -> Self {
38        Self {
39            status: value.status,
40            body: value.body,
41            headers: value
42                .headers
43                .into_iter()
44                .map(|(key, value)| (key, VmValue::String(Rc::from(value))))
45                .collect(),
46        }
47    }
48}
49
50struct HttpMock {
51    method: String,
52    url_pattern: String,
53    responses: Vec<MockResponse>,
54    next_response: usize,
55}
56
57#[derive(Clone)]
58struct HttpMockCall {
59    method: String,
60    url: String,
61    headers: BTreeMap<String, VmValue>,
62    body: Option<String>,
63}
64
65#[derive(Clone, Debug, PartialEq, Eq)]
66pub struct HttpMockCallSnapshot {
67    pub method: String,
68    pub url: String,
69    pub headers: BTreeMap<String, String>,
70    pub body: Option<String>,
71}
72
73thread_local! {
74    static HTTP_MOCKS: RefCell<Vec<HttpMock>> = const { RefCell::new(Vec::new()) };
75    static HTTP_MOCK_CALLS: RefCell<Vec<HttpMockCall>> = const { RefCell::new(Vec::new()) };
76}
77
78pub(super) fn reset_http_mocks() {
79    HTTP_MOCKS.with(|mocks| mocks.borrow_mut().clear());
80    HTTP_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
81}
82
83pub(super) fn clear_http_mocks() {
84    reset_http_mocks();
85}
86
87pub fn push_http_mock(
88    method: impl Into<String>,
89    url_pattern: impl Into<String>,
90    responses: Vec<HttpMockResponse>,
91) {
92    let responses = if responses.is_empty() {
93        vec![MockResponse::from(HttpMockResponse::new(200, ""))]
94    } else {
95        responses.into_iter().map(MockResponse::from).collect()
96    };
97    register_http_mock(method.into(), url_pattern.into(), responses);
98}
99
100pub(super) fn register_http_mock(
101    method: impl Into<String>,
102    url_pattern: impl Into<String>,
103    responses: Vec<MockResponse>,
104) {
105    let method = method.into();
106    let url_pattern = url_pattern.into();
107    HTTP_MOCKS.with(|mocks| {
108        let mut mocks = mocks.borrow_mut();
109        // Re-registering the same (method, url_pattern) replaces the prior
110        // mock so tests can override per-case responses without first calling
111        // http_mock_clear(). Without this, the original mock keeps matching
112        // forever and the new one is dead.
113        mocks.retain(|mock| !(mock.method == method && mock.url_pattern == url_pattern));
114        mocks.push(HttpMock {
115            method,
116            url_pattern,
117            responses,
118            next_response: 0,
119        });
120    });
121}
122
123pub fn http_mock_calls_snapshot() -> Vec<HttpMockCallSnapshot> {
124    HTTP_MOCK_CALLS.with(|calls| {
125        calls
126            .borrow()
127            .iter()
128            .map(|call| HttpMockCallSnapshot {
129                method: call.method.clone(),
130                url: call.url.clone(),
131                headers: call
132                    .headers
133                    .iter()
134                    .map(|(key, value)| (key.clone(), value.display()))
135                    .collect(),
136                body: call.body.clone(),
137            })
138            .collect()
139    })
140}
141
142pub(super) fn http_mock_calls_value(redact_sensitive: bool) -> Vec<VmValue> {
143    HTTP_MOCK_CALLS.with(|calls| {
144        calls
145            .borrow()
146            .iter()
147            .map(|call| {
148                let mut dict = BTreeMap::new();
149                dict.insert(
150                    "method".to_string(),
151                    VmValue::String(Rc::from(call.method.as_str())),
152                );
153                dict.insert(
154                    "url".to_string(),
155                    VmValue::String(Rc::from(redact_mock_call_url(&call.url, redact_sensitive))),
156                );
157                dict.insert(
158                    "headers".to_string(),
159                    VmValue::Dict(Rc::new(mock_call_headers_value(
160                        &call.headers,
161                        redact_sensitive,
162                    ))),
163                );
164                dict.insert(
165                    "body".to_string(),
166                    match &call.body {
167                        Some(body) => VmValue::String(Rc::from(body.as_str())),
168                        None => VmValue::Nil,
169                    },
170                );
171                VmValue::Dict(Rc::new(dict))
172            })
173            .collect()
174    })
175}
176
177pub(super) fn parse_mock_responses(response: &BTreeMap<String, VmValue>) -> Vec<MockResponse> {
178    let scripted = response
179        .get("responses")
180        .and_then(|value| match value {
181            VmValue::List(items) => Some(
182                items
183                    .iter()
184                    .filter_map(|item| item.as_dict().map(parse_mock_response_dict))
185                    .collect::<Vec<_>>(),
186            ),
187            _ => None,
188        })
189        .unwrap_or_default();
190
191    if scripted.is_empty() {
192        vec![parse_mock_response_dict(response)]
193    } else {
194        scripted
195    }
196}
197
198fn parse_mock_response_dict(response: &BTreeMap<String, VmValue>) -> MockResponse {
199    let status = response
200        .get("status")
201        .and_then(|v| v.as_int())
202        .unwrap_or(200);
203    let body = response
204        .get("body")
205        .map(|v| v.display())
206        .unwrap_or_default();
207    let headers = response
208        .get("headers")
209        .and_then(|v| v.as_dict())
210        .cloned()
211        .unwrap_or_default();
212    MockResponse {
213        status,
214        body,
215        headers,
216    }
217}
218
219pub(super) fn consume_http_mock(
220    method: &str,
221    url: &str,
222    headers: BTreeMap<String, VmValue>,
223    body: Option<String>,
224) -> Option<MockResponse> {
225    let response = HTTP_MOCKS.with(|mocks| {
226        let mut mocks = mocks.borrow_mut();
227        for mock in mocks.iter_mut() {
228            if (mock.method == "*" || mock.method.eq_ignore_ascii_case(method))
229                && url_matches(&mock.url_pattern, url)
230            {
231                let Some(last_index) = mock.responses.len().checked_sub(1) else {
232                    continue;
233                };
234                let index = mock.next_response.min(last_index);
235                let response = mock.responses[index].clone();
236                if mock.next_response < last_index {
237                    mock.next_response += 1;
238                }
239                return Some(response);
240            }
241        }
242        None
243    })?;
244
245    HTTP_MOCK_CALLS.with(|calls| {
246        calls.borrow_mut().push(HttpMockCall {
247            method: method.to_string(),
248            url: url.to_string(),
249            headers,
250            body,
251        });
252    });
253
254    Some(response)
255}
256
257/// Check if a URL matches a mock pattern (exact or glob with `*`).
258pub(super) fn url_matches(pattern: &str, url: &str) -> bool {
259    if pattern == "*" {
260        return true;
261    }
262    if !pattern.contains('*') {
263        return pattern == url;
264    }
265    // Multi-glob: split on `*` and match segments in order.
266    let parts: Vec<&str> = pattern.split('*').collect();
267    let mut remaining = url;
268    for (i, part) in parts.iter().enumerate() {
269        if part.is_empty() {
270            continue;
271        }
272        if i == 0 {
273            if !remaining.starts_with(part) {
274                return false;
275            }
276            remaining = &remaining[part.len()..];
277        } else if i == parts.len() - 1 {
278            if !remaining.ends_with(part) {
279                return false;
280            }
281            remaining = "";
282        } else {
283            match remaining.find(part) {
284                Some(pos) => remaining = &remaining[pos + part.len()..],
285                None => return false,
286            }
287        }
288    }
289    true
290}
291
292fn is_sensitive_http_header(name: &str) -> bool {
293    matches!(
294        name.to_ascii_lowercase().as_str(),
295        "authorization"
296            | "proxy-authorization"
297            | "cookie"
298            | "set-cookie"
299            | "x-api-key"
300            | "api-key"
301            | "x-auth-token"
302            | "x-csrf-token"
303            | "x-xsrf-token"
304    )
305}
306
307fn is_sensitive_url_param(name: &str) -> bool {
308    let normalized = name.to_ascii_lowercase();
309    normalized == "api_key"
310        || normalized == "apikey"
311        || normalized == "access_token"
312        || normalized == "refresh_token"
313        || normalized == "id_token"
314        || normalized == "client_secret"
315        || normalized == "password"
316        || normalized == "secret"
317        || normalized == "token"
318        || normalized.ends_with("_token")
319        || normalized.ends_with("_secret")
320}
321
322pub(super) fn redact_mock_call_url(url: &str, redact: bool) -> String {
323    if !redact {
324        return url.to_string();
325    }
326    let Ok(mut parsed) = url::Url::parse(url) else {
327        return url.to_string();
328    };
329    let mut redacted_any = false;
330    let pairs: Vec<(String, String)> = parsed
331        .query_pairs()
332        .map(|(key, value)| {
333            let value = if is_sensitive_url_param(&key) {
334                redacted_any = true;
335                "[redacted]".to_string()
336            } else {
337                value.into_owned()
338            };
339            (key.into_owned(), value)
340        })
341        .collect();
342    if !redacted_any {
343        return url.to_string();
344    }
345    parsed.set_query(None);
346    {
347        let mut query = parsed.query_pairs_mut();
348        for (key, value) in pairs {
349            query.append_pair(&key, &value);
350        }
351    }
352    parsed.to_string()
353}
354
355pub(super) fn mock_call_headers_value(
356    headers: &BTreeMap<String, VmValue>,
357    redact_headers: bool,
358) -> BTreeMap<String, VmValue> {
359    headers
360        .iter()
361        .map(|(key, value)| {
362            let value = if redact_headers && is_sensitive_http_header(key) {
363                VmValue::String(Rc::from("[redacted]"))
364            } else {
365                value.clone()
366            };
367            (key.clone(), value)
368        })
369        .collect()
370}