Skip to main content

harn_vm/
http.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3use std::rc::Rc;
4
5use crate::value::{VmError, VmValue};
6use crate::vm::Vm;
7
8// =============================================================================
9// Mock HTTP framework (thread-local, mirrors the mock LLM pattern)
10// =============================================================================
11
12struct HttpMock {
13    method: String,
14    url_pattern: String,
15    status: i64,
16    body: String,
17    headers: BTreeMap<String, VmValue>,
18}
19
20#[derive(Clone)]
21struct HttpMockCall {
22    method: String,
23    url: String,
24    body: Option<String>,
25}
26
27thread_local! {
28    static HTTP_MOCKS: RefCell<Vec<HttpMock>> = const { RefCell::new(Vec::new()) };
29    static HTTP_MOCK_CALLS: RefCell<Vec<HttpMockCall>> = const { RefCell::new(Vec::new()) };
30}
31
32/// Reset thread-local HTTP mock state. Call between test runs.
33pub fn reset_http_state() {
34    HTTP_MOCKS.with(|m| m.borrow_mut().clear());
35    HTTP_MOCK_CALLS.with(|c| c.borrow_mut().clear());
36}
37
38/// Check if a URL matches a mock pattern (exact or glob with `*`).
39fn url_matches(pattern: &str, url: &str) -> bool {
40    if pattern == "*" {
41        return true;
42    }
43    if !pattern.contains('*') {
44        return pattern == url;
45    }
46    // Multi-glob: split on `*` and check that all segments appear in order
47    let parts: Vec<&str> = pattern.split('*').collect();
48    let mut remaining = url;
49    for (i, part) in parts.iter().enumerate() {
50        if part.is_empty() {
51            continue;
52        }
53        if i == 0 {
54            // First segment must be a prefix
55            if !remaining.starts_with(part) {
56                return false;
57            }
58            remaining = &remaining[part.len()..];
59        } else if i == parts.len() - 1 {
60            // Last segment must be a suffix
61            if !remaining.ends_with(part) {
62                return false;
63            }
64            remaining = "";
65        } else {
66            // Middle segments must appear somewhere in the remaining string
67            match remaining.find(part) {
68                Some(pos) => remaining = &remaining[pos + part.len()..],
69                None => return false,
70            }
71        }
72    }
73    true
74}
75
76/// Build a standard HTTP response dict with status, headers, body, and ok fields.
77fn build_http_response(status: i64, headers: BTreeMap<String, VmValue>, body: String) -> VmValue {
78    let mut result = BTreeMap::new();
79    result.insert("status".to_string(), VmValue::Int(status));
80    result.insert("headers".to_string(), VmValue::Dict(Rc::new(headers)));
81    result.insert("body".to_string(), VmValue::String(Rc::from(body)));
82    result.insert(
83        "ok".to_string(),
84        VmValue::Bool((200..300).contains(&(status as u16))),
85    );
86    VmValue::Dict(Rc::new(result))
87}
88
89/// Extract URL, validate it, and pull an options dict from `args`.
90/// For methods with a body (POST/PUT/PATCH), the body is at index 1 and
91/// options at index 2; for methods without (GET/DELETE), options are at index 1.
92async fn http_verb_handler(
93    method: &str,
94    has_body: bool,
95    args: Vec<VmValue>,
96) -> Result<VmValue, VmError> {
97    let url = args.first().map(|a| a.display()).unwrap_or_default();
98    if url.is_empty() {
99        return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
100            "http_{}: URL is required",
101            method.to_ascii_lowercase()
102        )))));
103    }
104    let mut options = if has_body {
105        match args.get(2) {
106            Some(VmValue::Dict(d)) => (**d).clone(),
107            _ => BTreeMap::new(),
108        }
109    } else {
110        match args.get(1) {
111            Some(VmValue::Dict(d)) => (**d).clone(),
112            _ => BTreeMap::new(),
113        }
114    };
115    if has_body {
116        let body = args.get(1).map(|a| a.display()).unwrap_or_default();
117        options.insert("body".to_string(), VmValue::String(Rc::from(body)));
118    }
119    vm_execute_http_request(method, &url, &options).await
120}
121
122/// Register HTTP builtins on a VM.
123pub fn register_http_builtins(vm: &mut Vm) {
124    vm.register_async_builtin("http_get", |args| async move {
125        http_verb_handler("GET", false, args).await
126    });
127    vm.register_async_builtin("http_post", |args| async move {
128        http_verb_handler("POST", true, args).await
129    });
130    vm.register_async_builtin("http_put", |args| async move {
131        http_verb_handler("PUT", true, args).await
132    });
133    vm.register_async_builtin("http_patch", |args| async move {
134        http_verb_handler("PATCH", true, args).await
135    });
136    vm.register_async_builtin("http_delete", |args| async move {
137        http_verb_handler("DELETE", false, args).await
138    });
139
140    // --- Mock HTTP builtins ---
141
142    // http_mock(method, url_pattern, response) -> nil
143    vm.register_builtin("http_mock", |args, _out| {
144        let method = args.first().map(|a| a.display()).unwrap_or_default();
145        let url_pattern = args.get(1).map(|a| a.display()).unwrap_or_default();
146        let response = args
147            .get(2)
148            .and_then(|a| a.as_dict())
149            .cloned()
150            .unwrap_or_default();
151
152        let status = response
153            .get("status")
154            .and_then(|v| v.as_int())
155            .unwrap_or(200);
156        let body = response
157            .get("body")
158            .map(|v| v.display())
159            .unwrap_or_default();
160        let headers = response
161            .get("headers")
162            .and_then(|v| v.as_dict())
163            .cloned()
164            .unwrap_or_default();
165
166        HTTP_MOCKS.with(|mocks| {
167            mocks.borrow_mut().push(HttpMock {
168                method,
169                url_pattern,
170                status,
171                body,
172                headers,
173            });
174        });
175        Ok(VmValue::Nil)
176    });
177
178    // http_mock_clear() -> nil
179    vm.register_builtin("http_mock_clear", |_args, _out| {
180        HTTP_MOCKS.with(|mocks| mocks.borrow_mut().clear());
181        HTTP_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
182        Ok(VmValue::Nil)
183    });
184
185    // http_mock_calls() -> list of {method, url, body}
186    vm.register_builtin("http_mock_calls", |_args, _out| {
187        let calls = HTTP_MOCK_CALLS.with(|calls| calls.borrow().clone());
188        let result: Vec<VmValue> = calls
189            .iter()
190            .map(|c| {
191                let mut dict = BTreeMap::new();
192                dict.insert(
193                    "method".to_string(),
194                    VmValue::String(Rc::from(c.method.as_str())),
195                );
196                dict.insert("url".to_string(), VmValue::String(Rc::from(c.url.as_str())));
197                dict.insert(
198                    "body".to_string(),
199                    match &c.body {
200                        Some(b) => VmValue::String(Rc::from(b.as_str())),
201                        None => VmValue::Nil,
202                    },
203                );
204                VmValue::Dict(Rc::new(dict))
205            })
206            .collect();
207        Ok(VmValue::List(Rc::new(result)))
208    });
209
210    vm.register_async_builtin("http_request", |args| async move {
211        let method = args
212            .first()
213            .map(|a| a.display())
214            .unwrap_or_default()
215            .to_uppercase();
216        if method.is_empty() {
217            return Err(VmError::Thrown(VmValue::String(Rc::from(
218                "http_request: method is required",
219            ))));
220        }
221        let url = args.get(1).map(|a| a.display()).unwrap_or_default();
222        if url.is_empty() {
223            return Err(VmError::Thrown(VmValue::String(Rc::from(
224                "http_request: URL is required",
225            ))));
226        }
227        let options = match args.get(2) {
228            Some(VmValue::Dict(d)) => (**d).clone(),
229            _ => BTreeMap::new(),
230        };
231        vm_execute_http_request(&method, &url, &options).await
232    });
233}
234
235// =============================================================================
236// HTTP request helpers
237// =============================================================================
238
239fn vm_get_int_option(options: &BTreeMap<String, VmValue>, key: &str, default: i64) -> i64 {
240    options.get(key).and_then(|v| v.as_int()).unwrap_or(default)
241}
242
243fn vm_get_bool_option(options: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
244    match options.get(key) {
245        Some(VmValue::Bool(b)) => *b,
246        _ => default,
247    }
248}
249
250async fn vm_execute_http_request(
251    method: &str,
252    url: &str,
253    options: &BTreeMap<String, VmValue>,
254) -> Result<VmValue, VmError> {
255    use std::time::Duration;
256
257    // Check mock responses first
258    let mock_match = HTTP_MOCKS.with(|mocks| {
259        let mocks = mocks.borrow();
260        for mock in mocks.iter() {
261            if (mock.method == "*" || mock.method.eq_ignore_ascii_case(method))
262                && url_matches(&mock.url_pattern, url)
263            {
264                return Some((mock.status, mock.body.clone(), mock.headers.clone()));
265            }
266        }
267        None
268    });
269
270    if let Some((status, body, headers)) = mock_match {
271        // Record the call
272        let body_str = options.get("body").map(|v| v.display());
273        HTTP_MOCK_CALLS.with(|calls| {
274            calls.borrow_mut().push(HttpMockCall {
275                method: method.to_string(),
276                url: url.to_string(),
277                body: body_str,
278            });
279        });
280        return Ok(build_http_response(status, headers, body));
281    }
282
283    if !url.starts_with("http://") && !url.starts_with("https://") {
284        return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
285            "http: URL must start with http:// or https://, got '{url}'"
286        )))));
287    }
288
289    let timeout_ms = vm_get_int_option(options, "timeout", 30_000).max(0) as u64;
290    let retries = vm_get_int_option(options, "retries", 0).max(0) as u32;
291    let backoff_ms = vm_get_int_option(options, "backoff", 1000).max(0) as u64;
292    let follow_redirects = vm_get_bool_option(options, "follow_redirects", true);
293    let max_redirects = vm_get_int_option(options, "max_redirects", 10).max(0) as usize;
294
295    let redirect_policy = if follow_redirects {
296        reqwest::redirect::Policy::limited(max_redirects)
297    } else {
298        reqwest::redirect::Policy::none()
299    };
300
301    let client = reqwest::Client::builder()
302        .timeout(Duration::from_millis(timeout_ms))
303        .redirect(redirect_policy)
304        .build()
305        .map_err(|e| {
306            VmError::Thrown(VmValue::String(Rc::from(format!(
307                "http: failed to build client: {e}"
308            ))))
309        })?;
310
311    let req_method = method.parse::<reqwest::Method>().map_err(|e| {
312        VmError::Thrown(VmValue::String(Rc::from(format!(
313            "http: invalid method '{method}': {e}"
314        ))))
315    })?;
316
317    let mut header_map = reqwest::header::HeaderMap::new();
318
319    // Apply auth
320    if let Some(auth_val) = options.get("auth") {
321        match auth_val {
322            VmValue::String(s) => {
323                let hv = reqwest::header::HeaderValue::from_str(s).map_err(|e| {
324                    VmError::Thrown(VmValue::String(Rc::from(format!(
325                        "http: invalid auth header value: {e}"
326                    ))))
327                })?;
328                header_map.insert(reqwest::header::AUTHORIZATION, hv);
329            }
330            VmValue::Dict(d) => {
331                if let Some(bearer) = d.get("bearer") {
332                    let token = bearer.display();
333                    let hv = reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
334                        .map_err(|e| {
335                            VmError::Thrown(VmValue::String(Rc::from(format!(
336                                "http: invalid bearer token: {e}"
337                            ))))
338                        })?;
339                    header_map.insert(reqwest::header::AUTHORIZATION, hv);
340                } else if let Some(VmValue::Dict(basic)) = d.get("basic") {
341                    let user = basic.get("user").map(|v| v.display()).unwrap_or_default();
342                    let password = basic
343                        .get("password")
344                        .map(|v| v.display())
345                        .unwrap_or_default();
346                    use base64::Engine;
347                    let encoded = base64::engine::general_purpose::STANDARD
348                        .encode(format!("{user}:{password}"));
349                    let hv = reqwest::header::HeaderValue::from_str(&format!("Basic {encoded}"))
350                        .map_err(|e| {
351                            VmError::Thrown(VmValue::String(Rc::from(format!(
352                                "http: invalid basic auth: {e}"
353                            ))))
354                        })?;
355                    header_map.insert(reqwest::header::AUTHORIZATION, hv);
356                }
357            }
358            _ => {}
359        }
360    }
361
362    // Apply explicit headers
363    if let Some(VmValue::Dict(hdrs)) = options.get("headers") {
364        for (k, v) in hdrs.iter() {
365            let name = reqwest::header::HeaderName::from_bytes(k.as_bytes()).map_err(|e| {
366                VmError::Thrown(VmValue::String(Rc::from(format!(
367                    "http: invalid header name '{k}': {e}"
368                ))))
369            })?;
370            let val = reqwest::header::HeaderValue::from_str(&v.display()).map_err(|e| {
371                VmError::Thrown(VmValue::String(Rc::from(format!(
372                    "http: invalid header value for '{k}': {e}"
373                ))))
374            })?;
375            header_map.insert(name, val);
376        }
377    }
378
379    let body_str = options.get("body").map(|v| v.display());
380
381    let mut last_err: Option<VmError> = None;
382    let total_attempts = 1 + retries;
383
384    for attempt in 0..total_attempts {
385        if attempt > 0 {
386            use rand::Rng;
387            let base_delay = backoff_ms.saturating_mul(1u64 << (attempt - 1).min(30));
388            let jitter: f64 = rand::thread_rng().gen_range(0.75..=1.25);
389            let delay_ms = ((base_delay as f64 * jitter) as u64).min(60_000);
390            tokio::time::sleep(Duration::from_millis(delay_ms)).await;
391        }
392
393        let mut req = client.request(req_method.clone(), url);
394        req = req.headers(header_map.clone());
395        if let Some(ref b) = body_str {
396            req = req.body(b.clone());
397        }
398
399        match req.send().await {
400            Ok(response) => {
401                let status = response.status().as_u16() as i64;
402
403                let mut resp_headers = BTreeMap::new();
404                for (name, value) in response.headers() {
405                    if let Ok(v) = value.to_str() {
406                        resp_headers
407                            .insert(name.as_str().to_string(), VmValue::String(Rc::from(v)));
408                    }
409                }
410
411                let body_text = response.text().await.map_err(|e| {
412                    VmError::Thrown(VmValue::String(Rc::from(format!(
413                        "http: failed to read response body: {e}"
414                    ))))
415                })?;
416
417                if status >= 500 && attempt + 1 < total_attempts {
418                    last_err = Some(VmError::Thrown(VmValue::String(Rc::from(format!(
419                        "http: server error {status}"
420                    )))));
421                    continue;
422                }
423
424                return Ok(build_http_response(status, resp_headers, body_text));
425            }
426            Err(e) => {
427                let retryable = e.is_timeout() || e.is_connect();
428                if retryable && attempt + 1 < total_attempts {
429                    last_err = Some(VmError::Thrown(VmValue::String(Rc::from(format!(
430                        "http: request failed: {e}"
431                    )))));
432                    continue;
433                }
434                return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
435                    "http: request failed: {e}"
436                )))));
437            }
438        }
439    }
440
441    Err(last_err
442        .unwrap_or_else(|| VmError::Thrown(VmValue::String(Rc::from("http: request failed")))))
443}