Skip to main content

harn_vm/
http.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3use std::rc::Rc;
4
5use crate::value::{VmError, VmValue};
6use crate::vm::Vm;
7
8// =============================================================================
9// Mock HTTP framework (thread-local, mirrors the mock LLM pattern)
10// =============================================================================
11
12struct HttpMock {
13    method: String,
14    url_pattern: String,
15    status: i64,
16    body: String,
17    headers: BTreeMap<String, VmValue>,
18}
19
20#[derive(Clone)]
21struct HttpMockCall {
22    method: String,
23    url: String,
24    body: Option<String>,
25}
26
27thread_local! {
28    static HTTP_MOCKS: RefCell<Vec<HttpMock>> = const { RefCell::new(Vec::new()) };
29    static HTTP_MOCK_CALLS: RefCell<Vec<HttpMockCall>> = const { RefCell::new(Vec::new()) };
30}
31
32/// Reset thread-local HTTP mock state. Call between test runs.
33pub fn reset_http_state() {
34    HTTP_MOCKS.with(|m| m.borrow_mut().clear());
35    HTTP_MOCK_CALLS.with(|c| c.borrow_mut().clear());
36}
37
38/// Check if a URL matches a mock pattern (exact or glob with `*`).
39fn url_matches(pattern: &str, url: &str) -> bool {
40    if pattern == "*" {
41        return true;
42    }
43    if !pattern.contains('*') {
44        return pattern == url;
45    }
46    // Multi-glob: split on `*` and check that all segments appear in order
47    let parts: Vec<&str> = pattern.split('*').collect();
48    let mut remaining = url;
49    for (i, part) in parts.iter().enumerate() {
50        if part.is_empty() {
51            continue;
52        }
53        if i == 0 {
54            // First segment must be a prefix
55            if !remaining.starts_with(part) {
56                return false;
57            }
58            remaining = &remaining[part.len()..];
59        } else if i == parts.len() - 1 {
60            // Last segment must be a suffix
61            if !remaining.ends_with(part) {
62                return false;
63            }
64            remaining = "";
65        } else {
66            // Middle segments must appear somewhere in the remaining string
67            match remaining.find(part) {
68                Some(pos) => remaining = &remaining[pos + part.len()..],
69                None => return false,
70            }
71        }
72    }
73    true
74}
75
76/// Register HTTP builtins on a VM.
77pub fn register_http_builtins(vm: &mut Vm) {
78    vm.register_async_builtin("http_get", |args| async move {
79        let url = args.first().map(|a| a.display()).unwrap_or_default();
80        if url.is_empty() {
81            return Err(VmError::Thrown(VmValue::String(Rc::from(
82                "http_get: URL is required",
83            ))));
84        }
85        let options = match args.get(1) {
86            Some(VmValue::Dict(d)) => (**d).clone(),
87            _ => BTreeMap::new(),
88        };
89        vm_execute_http_request("GET", &url, &options).await
90    });
91
92    vm.register_async_builtin("http_post", |args| async move {
93        let url = args.first().map(|a| a.display()).unwrap_or_default();
94        if url.is_empty() {
95            return Err(VmError::Thrown(VmValue::String(Rc::from(
96                "http_post: URL is required",
97            ))));
98        }
99        let body = args.get(1).map(|a| a.display()).unwrap_or_default();
100        let mut options = match args.get(2) {
101            Some(VmValue::Dict(d)) => (**d).clone(),
102            _ => BTreeMap::new(),
103        };
104        options.insert("body".to_string(), VmValue::String(Rc::from(body)));
105        vm_execute_http_request("POST", &url, &options).await
106    });
107
108    vm.register_async_builtin("http_put", |args| async move {
109        let url = args.first().map(|a| a.display()).unwrap_or_default();
110        if url.is_empty() {
111            return Err(VmError::Thrown(VmValue::String(Rc::from(
112                "http_put: URL is required",
113            ))));
114        }
115        let body = args.get(1).map(|a| a.display()).unwrap_or_default();
116        let mut options = match args.get(2) {
117            Some(VmValue::Dict(d)) => (**d).clone(),
118            _ => BTreeMap::new(),
119        };
120        options.insert("body".to_string(), VmValue::String(Rc::from(body)));
121        vm_execute_http_request("PUT", &url, &options).await
122    });
123
124    vm.register_async_builtin("http_patch", |args| async move {
125        let url = args.first().map(|a| a.display()).unwrap_or_default();
126        if url.is_empty() {
127            return Err(VmError::Thrown(VmValue::String(Rc::from(
128                "http_patch: URL is required",
129            ))));
130        }
131        let body = args.get(1).map(|a| a.display()).unwrap_or_default();
132        let mut options = match args.get(2) {
133            Some(VmValue::Dict(d)) => (**d).clone(),
134            _ => BTreeMap::new(),
135        };
136        options.insert("body".to_string(), VmValue::String(Rc::from(body)));
137        vm_execute_http_request("PATCH", &url, &options).await
138    });
139
140    vm.register_async_builtin("http_delete", |args| async move {
141        let url = args.first().map(|a| a.display()).unwrap_or_default();
142        if url.is_empty() {
143            return Err(VmError::Thrown(VmValue::String(Rc::from(
144                "http_delete: URL is required",
145            ))));
146        }
147        let options = match args.get(1) {
148            Some(VmValue::Dict(d)) => (**d).clone(),
149            _ => BTreeMap::new(),
150        };
151        vm_execute_http_request("DELETE", &url, &options).await
152    });
153
154    // --- Mock HTTP builtins ---
155
156    // http_mock(method, url_pattern, response) -> nil
157    vm.register_builtin("http_mock", |args, _out| {
158        let method = args.first().map(|a| a.display()).unwrap_or_default();
159        let url_pattern = args.get(1).map(|a| a.display()).unwrap_or_default();
160        let response = args
161            .get(2)
162            .and_then(|a| a.as_dict())
163            .cloned()
164            .unwrap_or_default();
165
166        let status = response
167            .get("status")
168            .and_then(|v| v.as_int())
169            .unwrap_or(200);
170        let body = response
171            .get("body")
172            .map(|v| v.display())
173            .unwrap_or_default();
174        let headers = response
175            .get("headers")
176            .and_then(|v| v.as_dict())
177            .cloned()
178            .unwrap_or_default();
179
180        HTTP_MOCKS.with(|mocks| {
181            mocks.borrow_mut().push(HttpMock {
182                method,
183                url_pattern,
184                status,
185                body,
186                headers,
187            });
188        });
189        Ok(VmValue::Nil)
190    });
191
192    // http_mock_clear() -> nil
193    vm.register_builtin("http_mock_clear", |_args, _out| {
194        HTTP_MOCKS.with(|mocks| mocks.borrow_mut().clear());
195        HTTP_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
196        Ok(VmValue::Nil)
197    });
198
199    // http_mock_calls() -> list of {method, url, body}
200    vm.register_builtin("http_mock_calls", |_args, _out| {
201        let calls = HTTP_MOCK_CALLS.with(|calls| calls.borrow().clone());
202        let result: Vec<VmValue> = calls
203            .iter()
204            .map(|c| {
205                let mut dict = BTreeMap::new();
206                dict.insert(
207                    "method".to_string(),
208                    VmValue::String(Rc::from(c.method.as_str())),
209                );
210                dict.insert("url".to_string(), VmValue::String(Rc::from(c.url.as_str())));
211                dict.insert(
212                    "body".to_string(),
213                    match &c.body {
214                        Some(b) => VmValue::String(Rc::from(b.as_str())),
215                        None => VmValue::Nil,
216                    },
217                );
218                VmValue::Dict(Rc::new(dict))
219            })
220            .collect();
221        Ok(VmValue::List(Rc::new(result)))
222    });
223
224    vm.register_async_builtin("http_request", |args| async move {
225        let method = args
226            .first()
227            .map(|a| a.display())
228            .unwrap_or_default()
229            .to_uppercase();
230        if method.is_empty() {
231            return Err(VmError::Thrown(VmValue::String(Rc::from(
232                "http_request: method is required",
233            ))));
234        }
235        let url = args.get(1).map(|a| a.display()).unwrap_or_default();
236        if url.is_empty() {
237            return Err(VmError::Thrown(VmValue::String(Rc::from(
238                "http_request: URL is required",
239            ))));
240        }
241        let options = match args.get(2) {
242            Some(VmValue::Dict(d)) => (**d).clone(),
243            _ => BTreeMap::new(),
244        };
245        vm_execute_http_request(&method, &url, &options).await
246    });
247}
248
249// =============================================================================
250// HTTP request helpers
251// =============================================================================
252
253fn vm_get_int_option(options: &BTreeMap<String, VmValue>, key: &str, default: i64) -> i64 {
254    options.get(key).and_then(|v| v.as_int()).unwrap_or(default)
255}
256
257fn vm_get_bool_option(options: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
258    match options.get(key) {
259        Some(VmValue::Bool(b)) => *b,
260        _ => default,
261    }
262}
263
264async fn vm_execute_http_request(
265    method: &str,
266    url: &str,
267    options: &BTreeMap<String, VmValue>,
268) -> Result<VmValue, VmError> {
269    use std::time::Duration;
270
271    // Check mock responses first
272    let mock_match = HTTP_MOCKS.with(|mocks| {
273        let mocks = mocks.borrow();
274        for mock in mocks.iter() {
275            if (mock.method == "*" || mock.method.eq_ignore_ascii_case(method))
276                && url_matches(&mock.url_pattern, url)
277            {
278                return Some((mock.status, mock.body.clone(), mock.headers.clone()));
279            }
280        }
281        None
282    });
283
284    if let Some((status, body, headers)) = mock_match {
285        // Record the call
286        let body_str = options.get("body").map(|v| v.display());
287        HTTP_MOCK_CALLS.with(|calls| {
288            calls.borrow_mut().push(HttpMockCall {
289                method: method.to_string(),
290                url: url.to_string(),
291                body: body_str,
292            });
293        });
294        // Return mock response
295        let mut result = BTreeMap::new();
296        result.insert("status".to_string(), VmValue::Int(status));
297        result.insert("headers".to_string(), VmValue::Dict(Rc::new(headers)));
298        result.insert("body".to_string(), VmValue::String(Rc::from(body)));
299        result.insert(
300            "ok".to_string(),
301            VmValue::Bool((200..300).contains(&(status as u16))),
302        );
303        return Ok(VmValue::Dict(Rc::new(result)));
304    }
305
306    if !url.starts_with("http://") && !url.starts_with("https://") {
307        return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
308            "http: URL must start with http:// or https://, got '{url}'"
309        )))));
310    }
311
312    let timeout_ms = vm_get_int_option(options, "timeout", 30_000).max(0) as u64;
313    let retries = vm_get_int_option(options, "retries", 0).max(0) as u32;
314    let backoff_ms = vm_get_int_option(options, "backoff", 1000).max(0) as u64;
315    let follow_redirects = vm_get_bool_option(options, "follow_redirects", true);
316    let max_redirects = vm_get_int_option(options, "max_redirects", 10).max(0) as usize;
317
318    let redirect_policy = if follow_redirects {
319        reqwest::redirect::Policy::limited(max_redirects)
320    } else {
321        reqwest::redirect::Policy::none()
322    };
323
324    let client = reqwest::Client::builder()
325        .timeout(Duration::from_millis(timeout_ms))
326        .redirect(redirect_policy)
327        .build()
328        .map_err(|e| {
329            VmError::Thrown(VmValue::String(Rc::from(format!(
330                "http: failed to build client: {e}"
331            ))))
332        })?;
333
334    let req_method = method.parse::<reqwest::Method>().map_err(|e| {
335        VmError::Thrown(VmValue::String(Rc::from(format!(
336            "http: invalid method '{method}': {e}"
337        ))))
338    })?;
339
340    let mut header_map = reqwest::header::HeaderMap::new();
341
342    // Apply auth
343    if let Some(auth_val) = options.get("auth") {
344        match auth_val {
345            VmValue::String(s) => {
346                let hv = reqwest::header::HeaderValue::from_str(s).map_err(|e| {
347                    VmError::Thrown(VmValue::String(Rc::from(format!(
348                        "http: invalid auth header value: {e}"
349                    ))))
350                })?;
351                header_map.insert(reqwest::header::AUTHORIZATION, hv);
352            }
353            VmValue::Dict(d) => {
354                if let Some(bearer) = d.get("bearer") {
355                    let token = bearer.display();
356                    let hv = reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
357                        .map_err(|e| {
358                            VmError::Thrown(VmValue::String(Rc::from(format!(
359                                "http: invalid bearer token: {e}"
360                            ))))
361                        })?;
362                    header_map.insert(reqwest::header::AUTHORIZATION, hv);
363                } else if let Some(VmValue::Dict(basic)) = d.get("basic") {
364                    let user = basic.get("user").map(|v| v.display()).unwrap_or_default();
365                    let password = basic
366                        .get("password")
367                        .map(|v| v.display())
368                        .unwrap_or_default();
369                    use base64::Engine;
370                    let encoded = base64::engine::general_purpose::STANDARD
371                        .encode(format!("{user}:{password}"));
372                    let hv = reqwest::header::HeaderValue::from_str(&format!("Basic {encoded}"))
373                        .map_err(|e| {
374                            VmError::Thrown(VmValue::String(Rc::from(format!(
375                                "http: invalid basic auth: {e}"
376                            ))))
377                        })?;
378                    header_map.insert(reqwest::header::AUTHORIZATION, hv);
379                }
380            }
381            _ => {}
382        }
383    }
384
385    // Apply explicit headers
386    if let Some(VmValue::Dict(hdrs)) = options.get("headers") {
387        for (k, v) in hdrs.iter() {
388            let name = reqwest::header::HeaderName::from_bytes(k.as_bytes()).map_err(|e| {
389                VmError::Thrown(VmValue::String(Rc::from(format!(
390                    "http: invalid header name '{k}': {e}"
391                ))))
392            })?;
393            let val = reqwest::header::HeaderValue::from_str(&v.display()).map_err(|e| {
394                VmError::Thrown(VmValue::String(Rc::from(format!(
395                    "http: invalid header value for '{k}': {e}"
396                ))))
397            })?;
398            header_map.insert(name, val);
399        }
400    }
401
402    let body_str = options.get("body").map(|v| v.display());
403
404    let mut last_err: Option<VmError> = None;
405    let total_attempts = 1 + retries;
406
407    for attempt in 0..total_attempts {
408        if attempt > 0 {
409            use rand::Rng;
410            let base_delay = backoff_ms.saturating_mul(1u64 << (attempt - 1).min(30));
411            let jitter: f64 = rand::thread_rng().gen_range(0.75..=1.25);
412            let delay_ms = ((base_delay as f64 * jitter) as u64).min(60_000);
413            tokio::time::sleep(Duration::from_millis(delay_ms)).await;
414        }
415
416        let mut req = client.request(req_method.clone(), url);
417        req = req.headers(header_map.clone());
418        if let Some(ref b) = body_str {
419            req = req.body(b.clone());
420        }
421
422        match req.send().await {
423            Ok(response) => {
424                let status_code = response.status().as_u16();
425                let ok = (200..300).contains(&status_code);
426                let status = status_code as i64;
427
428                let mut resp_headers = BTreeMap::new();
429                for (name, value) in response.headers() {
430                    if let Ok(v) = value.to_str() {
431                        resp_headers
432                            .insert(name.as_str().to_string(), VmValue::String(Rc::from(v)));
433                    }
434                }
435
436                let body_text = response.text().await.map_err(|e| {
437                    VmError::Thrown(VmValue::String(Rc::from(format!(
438                        "http: failed to read response body: {e}"
439                    ))))
440                })?;
441
442                if status >= 500 && attempt + 1 < total_attempts {
443                    last_err = Some(VmError::Thrown(VmValue::String(Rc::from(format!(
444                        "http: server error {status}"
445                    )))));
446                    continue;
447                }
448
449                let mut result = BTreeMap::new();
450                result.insert("status".to_string(), VmValue::Int(status));
451                result.insert("headers".to_string(), VmValue::Dict(Rc::new(resp_headers)));
452                result.insert("body".to_string(), VmValue::String(Rc::from(body_text)));
453                result.insert("ok".to_string(), VmValue::Bool(ok));
454                return Ok(VmValue::Dict(Rc::new(result)));
455            }
456            Err(e) => {
457                let retryable = e.is_timeout() || e.is_connect();
458                if retryable && attempt + 1 < total_attempts {
459                    last_err = Some(VmError::Thrown(VmValue::String(Rc::from(format!(
460                        "http: request failed: {e}"
461                    )))));
462                    continue;
463                }
464                return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
465                    "http: request failed: {e}"
466                )))));
467            }
468        }
469    }
470
471    Err(last_err
472        .unwrap_or_else(|| VmError::Thrown(VmValue::String(Rc::from("http: request failed")))))
473}