Skip to main content

orcs_lua/
http_command.rs

1//! Lua-exposed HTTP client for `orcs.http()`.
2//!
3//! Provides a blocking HTTP client (via `ureq`) exposed to Lua as
4//! `orcs.http(method, url, opts)`. Gated by `Capability::HTTP`.
5//!
6//! # Design
7//!
8//! Rust owns the transport layer (TLS, timeout, error classification).
9//! Lua owns the application logic (request construction, response parsing).
10//!
11//! ```text
12//! Lua: orcs.http("POST", url, { headers={...}, body="...", timeout=30 })
13//!   → Capability::HTTP gate (ctx_fns / child)
14//!   → http_request_impl (Rust/ureq)
15//!   → { ok, status, headers, body, error, error_kind }
16//! ```
17
18use mlua::{Lua, Table};
19
20/// Default timeout in seconds for HTTP requests.
21const DEFAULT_TIMEOUT_SECS: u64 = 30;
22
23/// Maximum response body size (10 MiB).
24const MAX_BODY_SIZE: u64 = 10 * 1024 * 1024;
25
26/// Registers `orcs.http` as a deny-by-default stub.
27///
28/// The real implementation is injected by `ctx_fns.rs` / `child.rs`
29/// when a `ChildContext` with `Capability::HTTP` is available.
30pub fn register_http_deny_stub(lua: &Lua, orcs_table: &Table) -> Result<(), mlua::Error> {
31    if orcs_table.get::<mlua::Function>("http").is_err() {
32        let http_fn = lua.create_function(|lua, _args: mlua::MultiValue| {
33            let result = lua.create_table()?;
34            result.set("ok", false)?;
35            result.set(
36                "error",
37                "http denied: no execution context (ChildContext with Capability::HTTP required)",
38            )?;
39            result.set("error_kind", "permission_denied")?;
40            Ok(result)
41        })?;
42        orcs_table.set("http", http_fn)?;
43    }
44    Ok(())
45}
46
47/// Executes an HTTP request using ureq. Called from capability-gated context.
48///
49/// # Arguments (from Lua)
50///
51/// * `method` - HTTP method: "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"
52/// * `url` - Request URL (must be http:// or https://)
53/// * `opts` - Optional table:
54///   - `headers` - Table of {name = value} pairs
55///   - `body` - Request body string
56///   - `timeout` - Timeout in seconds (default: 30)
57///
58/// # Returns (Lua table)
59///
60/// * `ok` - boolean, true if HTTP response received (even 4xx/5xx)
61/// * `status` - HTTP status code (number)
62/// * `headers` - Response headers as {name = value} table
63/// * `body` - Response body as string
64/// * `error` - Error message (only when ok=false)
65/// * `error_kind` - Error classification: "timeout", "dns", "connection_refused",
66///   "tls", "too_large", "invalid_url", "network", "unknown"
67pub fn http_request_impl(lua: &Lua, args: (String, String, Option<Table>)) -> mlua::Result<Table> {
68    let (method, url, opts) = args;
69
70    // Validate URL scheme
71    if !url.starts_with("http://") && !url.starts_with("https://") {
72        let result = lua.create_table()?;
73        result.set("ok", false)?;
74        result.set(
75            "error",
76            format!(
77                "invalid URL scheme: URL must start with http:// or https://, got: {}",
78                truncate_for_error(&url, 100)
79            ),
80        )?;
81        result.set("error_kind", "invalid_url")?;
82        return Ok(result);
83    }
84
85    // Parse options
86    let timeout_secs = opts
87        .as_ref()
88        .and_then(|o| o.get::<u64>("timeout").ok())
89        .unwrap_or(DEFAULT_TIMEOUT_SECS);
90
91    let body: Option<String> = opts.as_ref().and_then(|o| o.get::<String>("body").ok());
92
93    // Collect headers from opts
94    let mut extra_headers: Vec<(String, String)> = Vec::new();
95    if let Some(ref o) = opts {
96        if let Ok(headers) = o.get::<Table>("headers") {
97            for (name, value) in headers.pairs::<String, String>().flatten() {
98                extra_headers.push((name, value));
99            }
100        }
101    }
102
103    // Check if Content-Type is explicitly set
104    let has_content_type = extra_headers
105        .iter()
106        .any(|(k, _)| k.to_lowercase() == "content-type");
107
108    // Build ureq agent with timeout
109    let config = ureq::Agent::config_builder()
110        .timeout_global(Some(std::time::Duration::from_secs(timeout_secs)))
111        .build();
112    let agent = ureq::Agent::new_with_config(config);
113
114    // Execute request based on method.
115    // ureq v3 has separate types for WithBody (POST/PUT/PATCH) and WithoutBody (GET/DELETE/HEAD),
116    // so we handle them in separate branches.
117    let method_upper = method.to_uppercase();
118
119    let response: Result<ureq::http::Response<ureq::Body>, ureq::Error> =
120        match method_upper.as_str() {
121            "GET" | "DELETE" | "HEAD" => {
122                let mut req = match method_upper.as_str() {
123                    "GET" => agent.get(&url),
124                    "DELETE" => agent.delete(&url),
125                    "HEAD" => agent.head(&url),
126                    other => {
127                        return Err(mlua::Error::runtime(format!(
128                            "internal error: unexpected method {other} in no-body branch"
129                        )));
130                    }
131                };
132                for (name, value) in &extra_headers {
133                    req = req.header(name.as_str(), value.as_str());
134                }
135                req.call()
136            }
137            "POST" | "PUT" | "PATCH" => {
138                let mut req = match method_upper.as_str() {
139                    "POST" => agent.post(&url),
140                    "PUT" => agent.put(&url),
141                    "PATCH" => agent.patch(&url),
142                    other => {
143                        return Err(mlua::Error::runtime(format!(
144                            "internal error: unexpected method {other} in with-body branch"
145                        )));
146                    }
147                };
148                for (name, value) in &extra_headers {
149                    req = req.header(name.as_str(), value.as_str());
150                }
151                // Default Content-Type to application/json when body is present
152                if !has_content_type && body.is_some() {
153                    req = req.header("Content-Type", "application/json");
154                }
155                match body {
156                    Some(ref body_str) => req.send(body_str.as_bytes()),
157                    None => req.send_empty(),
158                }
159            }
160            _ => {
161                let result = lua.create_table()?;
162                result.set("ok", false)?;
163                result.set("error", format!("unsupported HTTP method: {method_upper}"))?;
164                result.set("error_kind", "invalid_method")?;
165                return Ok(result);
166            }
167        };
168
169    match response {
170        Ok(resp) => build_success_response(lua, resp),
171        Err(e) => build_error_response(lua, e),
172    }
173}
174
175/// Builds a Lua table from a successful ureq response.
176fn build_success_response(
177    lua: &Lua,
178    mut resp: ureq::http::Response<ureq::Body>,
179) -> mlua::Result<Table> {
180    let status = resp.status().as_u16();
181
182    // Collect response headers
183    let headers_table = lua.create_table()?;
184    for (name, value) in resp.headers() {
185        if let Ok(v) = value.to_str() {
186            headers_table.set(name.as_str(), v)?;
187        }
188    }
189
190    // Read body with size limit
191    let body = {
192        use std::io::Read;
193        let mut buf = Vec::new();
194        let reader = resp.body_mut().as_reader();
195        match reader.take(MAX_BODY_SIZE).read_to_end(&mut buf) {
196            Ok(n) if n as u64 >= MAX_BODY_SIZE => Err("response body exceeds size limit"),
197            Ok(_) => String::from_utf8(buf).map_err(|_| "response body is not valid UTF-8"),
198            Err(_) => Err("failed to read response body"),
199        }
200    };
201
202    let result = lua.create_table()?;
203    result.set("ok", true)?;
204    result.set("status", status)?;
205    result.set("headers", headers_table)?;
206
207    match body {
208        Ok(body_str) => {
209            result.set("body", body_str)?;
210        }
211        Err(reason) => {
212            let is_too_large = reason.contains("size limit");
213            result.set("body", "")?;
214            result.set("error", reason)?;
215            result.set(
216                "error_kind",
217                if is_too_large { "too_large" } else { "network" },
218            )?;
219        }
220    }
221
222    Ok(result)
223}
224
225/// Builds a Lua error table from a ureq error.
226fn build_error_response(lua: &Lua, error: ureq::Error) -> mlua::Result<Table> {
227    let (error_kind, error_msg) = classify_error(&error);
228
229    let result = lua.create_table()?;
230    result.set("ok", false)?;
231    result.set("error", error_msg)?;
232    result.set("error_kind", error_kind)?;
233    Ok(result)
234}
235
236/// Classifies a ureq error into a kind string and message.
237fn classify_error(error: &ureq::Error) -> (&'static str, String) {
238    let msg = error.to_string();
239
240    // Check the error chain for specific IO errors
241    let io_err = find_io_error(error);
242
243    if let Some(io) = io_err {
244        let kind = io.kind();
245        match kind {
246            std::io::ErrorKind::TimedOut => return ("timeout", msg),
247            std::io::ErrorKind::ConnectionRefused => return ("connection_refused", msg),
248            std::io::ErrorKind::ConnectionReset => return ("connection_reset", msg),
249            std::io::ErrorKind::ConnectionAborted => return ("connection_aborted", msg),
250            _ => {}
251        }
252    }
253
254    // String-based heuristics for cases not covered by io::ErrorKind
255    let lower = msg.to_lowercase();
256    if lower.contains("timeout") || lower.contains("timed out") {
257        ("timeout", msg)
258    } else if lower.contains("dns")
259        || lower.contains("resolve")
260        || lower.contains("name resolution")
261    {
262        ("dns", msg)
263    } else if lower.contains("connection refused") {
264        ("connection_refused", msg)
265    } else if lower.contains("tls") || lower.contains("ssl") || lower.contains("certificate") {
266        ("tls", msg)
267    } else {
268        ("network", msg)
269    }
270}
271
272/// Walks the error source chain to find an `io::Error`.
273fn find_io_error(error: &ureq::Error) -> Option<&std::io::Error> {
274    let mut source: Option<&dyn std::error::Error> = Some(error);
275    while let Some(err) = source {
276        if let Some(io) = err.downcast_ref::<std::io::Error>() {
277            return Some(io);
278        }
279        source = err.source();
280    }
281    None
282}
283
284/// Truncates a string for safe inclusion in error messages.
285fn truncate_for_error(s: &str, max: usize) -> &str {
286    if s.len() <= max {
287        s
288    } else {
289        // Find a safe UTF-8 boundary
290        let mut end = max;
291        while end > 0 && !s.is_char_boundary(end) {
292            end -= 1;
293        }
294        &s[..end]
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::orcs_helpers::ensure_orcs_table;
302
303    #[test]
304    fn deny_stub_returns_permission_denied() {
305        let lua = Lua::new();
306        let orcs = ensure_orcs_table(&lua).expect("create orcs table");
307        register_http_deny_stub(&lua, &orcs).expect("register stub");
308
309        let result: Table = lua
310            .load(r#"return orcs.http("GET", "http://example.com")"#)
311            .eval()
312            .expect("should return deny table");
313
314        assert!(!result.get::<bool>("ok").expect("get ok"));
315        let error: String = result.get("error").expect("get error");
316        assert!(
317            error.contains("http denied"),
318            "expected permission denied, got: {error}"
319        );
320        assert_eq!(
321            result.get::<String>("error_kind").expect("get error_kind"),
322            "permission_denied"
323        );
324    }
325
326    #[test]
327    fn invalid_url_scheme_returns_error() {
328        let lua = Lua::new();
329        let result = http_request_impl(&lua, ("GET".into(), "ftp://example.com".into(), None))
330            .expect("should not panic");
331
332        assert!(!result.get::<bool>("ok").expect("get ok"));
333        assert_eq!(
334            result.get::<String>("error_kind").expect("get error_kind"),
335            "invalid_url"
336        );
337    }
338
339    #[test]
340    fn unsupported_method_returns_error() {
341        let lua = Lua::new();
342        let result = http_request_impl(&lua, ("CONNECT".into(), "http://localhost".into(), None))
343            .expect("should not panic");
344
345        assert!(!result.get::<bool>("ok").expect("get ok"));
346        assert_eq!(
347            result.get::<String>("error_kind").expect("get error_kind"),
348            "invalid_method"
349        );
350    }
351
352    #[test]
353    fn connection_refused_returns_error_kind() {
354        let lua = Lua::new();
355        // Port 1 is very unlikely to be open
356        let opts = lua.create_table().expect("create opts");
357        opts.set("timeout", 2).expect("set timeout");
358
359        let result = http_request_impl(
360            &lua,
361            ("GET".into(), "http://127.0.0.1:1/test".into(), Some(opts)),
362        )
363        .expect("should not panic");
364
365        assert!(!result.get::<bool>("ok").expect("get ok"));
366        let error_kind: String = result.get("error_kind").expect("get error_kind");
367        assert!(
368            error_kind == "connection_refused"
369                || error_kind == "network"
370                || error_kind == "timeout",
371            "expected connection error kind, got: {error_kind}"
372        );
373    }
374
375    #[test]
376    fn dns_failure_returns_error_kind() {
377        let lua = Lua::new();
378        let opts = lua.create_table().expect("create opts");
379        opts.set("timeout", 3).expect("set timeout");
380
381        let result = http_request_impl(
382            &lua,
383            (
384                "GET".into(),
385                "http://this-domain-does-not-exist-12345.invalid/test".into(),
386                Some(opts),
387            ),
388        )
389        .expect("should not panic");
390
391        assert!(!result.get::<bool>("ok").expect("get ok"));
392        let error_kind: String = result.get("error_kind").expect("get error_kind");
393        // DNS resolution may fail differently on different systems
394        assert!(
395            error_kind == "dns" || error_kind == "network" || error_kind == "timeout",
396            "expected dns/network error kind, got: {error_kind}"
397        );
398    }
399
400    #[test]
401    fn truncate_for_error_handles_ascii() {
402        assert_eq!(truncate_for_error("hello", 10), "hello");
403        assert_eq!(truncate_for_error("hello world", 5), "hello");
404    }
405
406    #[test]
407    fn truncate_for_error_handles_utf8() {
408        // "あいう" is 9 bytes (3 chars × 3 bytes)
409        let s = "あいう";
410        let t = truncate_for_error(s, 4);
411        assert_eq!(t, "あ"); // 3 bytes, not 4 (boundary)
412    }
413
414    #[test]
415    fn opts_timeout_is_respected() {
416        let lua = Lua::new();
417        let opts = lua.create_table().expect("create opts");
418        opts.set("timeout", 1).expect("set timeout");
419
420        // This will attempt to connect to a non-routable IP, should timeout quickly
421        let start = std::time::Instant::now();
422        let result = http_request_impl(
423            &lua,
424            (
425                "GET".into(),
426                "http://192.0.2.1/test".into(), // TEST-NET, non-routable
427                Some(opts),
428            ),
429        )
430        .expect("should not panic");
431
432        let elapsed = start.elapsed();
433        assert!(!result.get::<bool>("ok").expect("get ok"));
434        // Should timeout within ~3 seconds (1s timeout + overhead)
435        assert!(
436            elapsed.as_secs() < 5,
437            "should timeout quickly, took: {:?}",
438            elapsed
439        );
440    }
441
442    #[test]
443    fn headers_are_passed_through() {
444        // This test verifies the code path that sets headers.
445        // We can't test actual HTTP without a server, but we can verify
446        // the opts parsing doesn't crash.
447        let lua = Lua::new();
448        let opts = lua.create_table().expect("create opts");
449        let headers = lua.create_table().expect("create headers");
450        headers
451            .set("Authorization", "Bearer test-token")
452            .expect("set auth");
453        headers.set("X-Custom", "custom-value").expect("set custom");
454        opts.set("headers", headers).expect("set headers");
455        opts.set("timeout", 1).expect("set timeout");
456
457        // Will fail to connect but shouldn't panic on header processing
458        let result = http_request_impl(
459            &lua,
460            ("POST".into(), "http://127.0.0.1:1/test".into(), Some(opts)),
461        )
462        .expect("should not panic on header processing");
463
464        assert!(!result.get::<bool>("ok").expect("get ok"));
465    }
466}