Skip to main content

harn_vm/
http.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3use std::rc::Rc;
4
5use crate::value::{VmError, VmValue};
6use crate::vm::Vm;
7
8// Mock HTTP framework (thread-local, mirrors the mock LLM pattern).
9
10struct HttpMock {
11    method: String,
12    url_pattern: String,
13    status: i64,
14    body: String,
15    headers: BTreeMap<String, VmValue>,
16}
17
18#[derive(Clone)]
19struct HttpMockCall {
20    method: String,
21    url: String,
22    body: Option<String>,
23}
24
25thread_local! {
26    static HTTP_MOCKS: RefCell<Vec<HttpMock>> = const { RefCell::new(Vec::new()) };
27    static HTTP_MOCK_CALLS: RefCell<Vec<HttpMockCall>> = const { RefCell::new(Vec::new()) };
28}
29
30/// Reset thread-local HTTP mock state. Call between test runs.
31pub fn reset_http_state() {
32    HTTP_MOCKS.with(|m| m.borrow_mut().clear());
33    HTTP_MOCK_CALLS.with(|c| c.borrow_mut().clear());
34}
35
36/// Check if a URL matches a mock pattern (exact or glob with `*`).
37fn url_matches(pattern: &str, url: &str) -> bool {
38    if pattern == "*" {
39        return true;
40    }
41    if !pattern.contains('*') {
42        return pattern == url;
43    }
44    // Multi-glob: split on `*` and match segments in order.
45    let parts: Vec<&str> = pattern.split('*').collect();
46    let mut remaining = url;
47    for (i, part) in parts.iter().enumerate() {
48        if part.is_empty() {
49            continue;
50        }
51        if i == 0 {
52            if !remaining.starts_with(part) {
53                return false;
54            }
55            remaining = &remaining[part.len()..];
56        } else if i == parts.len() - 1 {
57            if !remaining.ends_with(part) {
58                return false;
59            }
60            remaining = "";
61        } else {
62            match remaining.find(part) {
63                Some(pos) => remaining = &remaining[pos + part.len()..],
64                None => return false,
65            }
66        }
67    }
68    true
69}
70
71/// Build a standard HTTP response dict with status, headers, body, and ok fields.
72fn build_http_response(status: i64, headers: BTreeMap<String, VmValue>, body: String) -> VmValue {
73    let mut result = BTreeMap::new();
74    result.insert("status".to_string(), VmValue::Int(status));
75    result.insert("headers".to_string(), VmValue::Dict(Rc::new(headers)));
76    result.insert("body".to_string(), VmValue::String(Rc::from(body)));
77    result.insert(
78        "ok".to_string(),
79        VmValue::Bool((200..300).contains(&(status as u16))),
80    );
81    VmValue::Dict(Rc::new(result))
82}
83
84/// Extract URL, validate it, and pull an options dict from `args`.
85/// For methods with a body (POST/PUT/PATCH), the body is at index 1 and
86/// options at index 2; for methods without (GET/DELETE), options are at index 1.
87async fn http_verb_handler(
88    method: &str,
89    has_body: bool,
90    args: Vec<VmValue>,
91) -> Result<VmValue, VmError> {
92    let url = args.first().map(|a| a.display()).unwrap_or_default();
93    if url.is_empty() {
94        return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
95            "http_{}: URL is required",
96            method.to_ascii_lowercase()
97        )))));
98    }
99    let mut options = if has_body {
100        match args.get(2) {
101            Some(VmValue::Dict(d)) => (**d).clone(),
102            _ => BTreeMap::new(),
103        }
104    } else {
105        match args.get(1) {
106            Some(VmValue::Dict(d)) => (**d).clone(),
107            _ => BTreeMap::new(),
108        }
109    };
110    if has_body {
111        let body = args.get(1).map(|a| a.display()).unwrap_or_default();
112        options.insert("body".to_string(), VmValue::String(Rc::from(body)));
113    }
114    vm_execute_http_request(method, &url, &options).await
115}
116
117/// Register HTTP builtins on a VM.
118pub fn register_http_builtins(vm: &mut Vm) {
119    vm.register_async_builtin("http_get", |args| async move {
120        http_verb_handler("GET", false, args).await
121    });
122    vm.register_async_builtin("http_post", |args| async move {
123        http_verb_handler("POST", true, args).await
124    });
125    vm.register_async_builtin("http_put", |args| async move {
126        http_verb_handler("PUT", true, args).await
127    });
128    vm.register_async_builtin("http_patch", |args| async move {
129        http_verb_handler("PATCH", true, args).await
130    });
131    vm.register_async_builtin("http_delete", |args| async move {
132        http_verb_handler("DELETE", false, args).await
133    });
134
135    // --- Mock HTTP builtins ---
136
137    // http_mock(method, url_pattern, response) -> nil
138    vm.register_builtin("http_mock", |args, _out| {
139        let method = args.first().map(|a| a.display()).unwrap_or_default();
140        let url_pattern = args.get(1).map(|a| a.display()).unwrap_or_default();
141        let response = args
142            .get(2)
143            .and_then(|a| a.as_dict())
144            .cloned()
145            .unwrap_or_default();
146
147        let status = response
148            .get("status")
149            .and_then(|v| v.as_int())
150            .unwrap_or(200);
151        let body = response
152            .get("body")
153            .map(|v| v.display())
154            .unwrap_or_default();
155        let headers = response
156            .get("headers")
157            .and_then(|v| v.as_dict())
158            .cloned()
159            .unwrap_or_default();
160
161        HTTP_MOCKS.with(|mocks| {
162            mocks.borrow_mut().push(HttpMock {
163                method,
164                url_pattern,
165                status,
166                body,
167                headers,
168            });
169        });
170        Ok(VmValue::Nil)
171    });
172
173    // http_mock_clear() -> nil
174    vm.register_builtin("http_mock_clear", |_args, _out| {
175        HTTP_MOCKS.with(|mocks| mocks.borrow_mut().clear());
176        HTTP_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
177        Ok(VmValue::Nil)
178    });
179
180    // http_mock_calls() -> list of {method, url, body}
181    vm.register_builtin("http_mock_calls", |_args, _out| {
182        let calls = HTTP_MOCK_CALLS.with(|calls| calls.borrow().clone());
183        let result: Vec<VmValue> = calls
184            .iter()
185            .map(|c| {
186                let mut dict = BTreeMap::new();
187                dict.insert(
188                    "method".to_string(),
189                    VmValue::String(Rc::from(c.method.as_str())),
190                );
191                dict.insert("url".to_string(), VmValue::String(Rc::from(c.url.as_str())));
192                dict.insert(
193                    "body".to_string(),
194                    match &c.body {
195                        Some(b) => VmValue::String(Rc::from(b.as_str())),
196                        None => VmValue::Nil,
197                    },
198                );
199                VmValue::Dict(Rc::new(dict))
200            })
201            .collect();
202        Ok(VmValue::List(Rc::new(result)))
203    });
204
205    vm.register_async_builtin("http_request", |args| async move {
206        let method = args
207            .first()
208            .map(|a| a.display())
209            .unwrap_or_default()
210            .to_uppercase();
211        if method.is_empty() {
212            return Err(VmError::Thrown(VmValue::String(Rc::from(
213                "http_request: method is required",
214            ))));
215        }
216        let url = args.get(1).map(|a| a.display()).unwrap_or_default();
217        if url.is_empty() {
218            return Err(VmError::Thrown(VmValue::String(Rc::from(
219                "http_request: URL is required",
220            ))));
221        }
222        let options = match args.get(2) {
223            Some(VmValue::Dict(d)) => (**d).clone(),
224            _ => BTreeMap::new(),
225        };
226        vm_execute_http_request(&method, &url, &options).await
227    });
228}
229
230fn vm_get_int_option(options: &BTreeMap<String, VmValue>, key: &str, default: i64) -> i64 {
231    options.get(key).and_then(|v| v.as_int()).unwrap_or(default)
232}
233
234fn vm_get_bool_option(options: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
235    match options.get(key) {
236        Some(VmValue::Bool(b)) => *b,
237        _ => default,
238    }
239}
240
241async fn vm_execute_http_request(
242    method: &str,
243    url: &str,
244    options: &BTreeMap<String, VmValue>,
245) -> Result<VmValue, VmError> {
246    use std::time::Duration;
247
248    // Check mock responses first
249    let mock_match = HTTP_MOCKS.with(|mocks| {
250        let mocks = mocks.borrow();
251        for mock in mocks.iter() {
252            if (mock.method == "*" || mock.method.eq_ignore_ascii_case(method))
253                && url_matches(&mock.url_pattern, url)
254            {
255                return Some((mock.status, mock.body.clone(), mock.headers.clone()));
256            }
257        }
258        None
259    });
260
261    if let Some((status, body, headers)) = mock_match {
262        let body_str = options.get("body").map(|v| v.display());
263        HTTP_MOCK_CALLS.with(|calls| {
264            calls.borrow_mut().push(HttpMockCall {
265                method: method.to_string(),
266                url: url.to_string(),
267                body: body_str,
268            });
269        });
270        return Ok(build_http_response(status, headers, body));
271    }
272
273    if !url.starts_with("http://") && !url.starts_with("https://") {
274        return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
275            "http: URL must start with http:// or https://, got '{url}'"
276        )))));
277    }
278
279    let timeout_ms = vm_get_int_option(options, "timeout", 30_000).max(0) as u64;
280    let retries = vm_get_int_option(options, "retries", 0).max(0) as u32;
281    let backoff_ms = vm_get_int_option(options, "backoff", 1000).max(0) as u64;
282    let follow_redirects = vm_get_bool_option(options, "follow_redirects", true);
283    let max_redirects = vm_get_int_option(options, "max_redirects", 10).max(0) as usize;
284
285    let redirect_policy = if follow_redirects {
286        reqwest::redirect::Policy::limited(max_redirects)
287    } else {
288        reqwest::redirect::Policy::none()
289    };
290
291    let client = reqwest::Client::builder()
292        .timeout(Duration::from_millis(timeout_ms))
293        .redirect(redirect_policy)
294        .build()
295        .map_err(|e| {
296            VmError::Thrown(VmValue::String(Rc::from(format!(
297                "http: failed to build client: {e}"
298            ))))
299        })?;
300
301    let req_method = method.parse::<reqwest::Method>().map_err(|e| {
302        VmError::Thrown(VmValue::String(Rc::from(format!(
303            "http: invalid method '{method}': {e}"
304        ))))
305    })?;
306
307    let mut header_map = reqwest::header::HeaderMap::new();
308
309    if let Some(auth_val) = options.get("auth") {
310        match auth_val {
311            VmValue::String(s) => {
312                let hv = reqwest::header::HeaderValue::from_str(s).map_err(|e| {
313                    VmError::Thrown(VmValue::String(Rc::from(format!(
314                        "http: invalid auth header value: {e}"
315                    ))))
316                })?;
317                header_map.insert(reqwest::header::AUTHORIZATION, hv);
318            }
319            VmValue::Dict(d) => {
320                if let Some(bearer) = d.get("bearer") {
321                    let token = bearer.display();
322                    let hv = reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
323                        .map_err(|e| {
324                            VmError::Thrown(VmValue::String(Rc::from(format!(
325                                "http: invalid bearer token: {e}"
326                            ))))
327                        })?;
328                    header_map.insert(reqwest::header::AUTHORIZATION, hv);
329                } else if let Some(VmValue::Dict(basic)) = d.get("basic") {
330                    let user = basic.get("user").map(|v| v.display()).unwrap_or_default();
331                    let password = basic
332                        .get("password")
333                        .map(|v| v.display())
334                        .unwrap_or_default();
335                    use base64::Engine;
336                    let encoded = base64::engine::general_purpose::STANDARD
337                        .encode(format!("{user}:{password}"));
338                    let hv = reqwest::header::HeaderValue::from_str(&format!("Basic {encoded}"))
339                        .map_err(|e| {
340                            VmError::Thrown(VmValue::String(Rc::from(format!(
341                                "http: invalid basic auth: {e}"
342                            ))))
343                        })?;
344                    header_map.insert(reqwest::header::AUTHORIZATION, hv);
345                }
346            }
347            _ => {}
348        }
349    }
350
351    if let Some(VmValue::Dict(hdrs)) = options.get("headers") {
352        for (k, v) in hdrs.iter() {
353            let name = reqwest::header::HeaderName::from_bytes(k.as_bytes()).map_err(|e| {
354                VmError::Thrown(VmValue::String(Rc::from(format!(
355                    "http: invalid header name '{k}': {e}"
356                ))))
357            })?;
358            let val = reqwest::header::HeaderValue::from_str(&v.display()).map_err(|e| {
359                VmError::Thrown(VmValue::String(Rc::from(format!(
360                    "http: invalid header value for '{k}': {e}"
361                ))))
362            })?;
363            header_map.insert(name, val);
364        }
365    }
366
367    let body_str = options.get("body").map(|v| v.display());
368
369    let mut last_err: Option<VmError> = None;
370    let total_attempts = 1 + retries;
371
372    for attempt in 0..total_attempts {
373        if attempt > 0 {
374            use rand::RngExt;
375            let base_delay = backoff_ms.saturating_mul(1u64 << (attempt - 1).min(30));
376            let jitter: f64 = rand::rng().random_range(0.75..=1.25);
377            let delay_ms = ((base_delay as f64 * jitter) as u64).min(60_000);
378            tokio::time::sleep(Duration::from_millis(delay_ms)).await;
379        }
380
381        let mut req = client.request(req_method.clone(), url);
382        req = req.headers(header_map.clone());
383        if let Some(ref b) = body_str {
384            req = req.body(b.clone());
385        }
386
387        match req.send().await {
388            Ok(response) => {
389                let status = response.status().as_u16() as i64;
390
391                let mut resp_headers = BTreeMap::new();
392                for (name, value) in response.headers() {
393                    if let Ok(v) = value.to_str() {
394                        resp_headers
395                            .insert(name.as_str().to_string(), VmValue::String(Rc::from(v)));
396                    }
397                }
398
399                let body_text = response.text().await.map_err(|e| {
400                    VmError::Thrown(VmValue::String(Rc::from(format!(
401                        "http: failed to read response body: {e}"
402                    ))))
403                })?;
404
405                if status >= 500 && attempt + 1 < total_attempts {
406                    last_err = Some(VmError::Thrown(VmValue::String(Rc::from(format!(
407                        "http: server error {status}"
408                    )))));
409                    continue;
410                }
411
412                return Ok(build_http_response(status, resp_headers, body_text));
413            }
414            Err(e) => {
415                let retryable = e.is_timeout() || e.is_connect();
416                if retryable && attempt + 1 < total_attempts {
417                    last_err = Some(VmError::Thrown(VmValue::String(Rc::from(format!(
418                        "http: request failed: {e}"
419                    )))));
420                    continue;
421                }
422                return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
423                    "http: request failed: {e}"
424                )))));
425            }
426        }
427    }
428
429    Err(last_err
430        .unwrap_or_else(|| VmError::Thrown(VmValue::String(Rc::from("http: request failed")))))
431}