1use mlua::{Lua, Table};
21use std::time::Duration;
22
23use crate::llm_command::retry::{
24 classify_reqwest_error, read_body_limited, truncate_for_error, ReadBodyError,
25};
26
27const DEFAULT_TIMEOUT_SECS: u64 = 30;
29
30const MAX_BODY_SIZE: u64 = 10 * 1024 * 1024;
32
33pub fn register_http_deny_stub(lua: &Lua, orcs_table: &Table) -> Result<(), mlua::Error> {
38 if orcs_table.get::<mlua::Function>("http").is_err() {
39 let http_fn = lua.create_function(|lua, _args: mlua::MultiValue| {
40 let result = lua.create_table()?;
41 result.set("ok", false)?;
42 result.set(
43 "error",
44 "http denied: no execution context (ChildContext with Capability::HTTP required)",
45 )?;
46 result.set("error_kind", "permission_denied")?;
47 Ok(result)
48 })?;
49 orcs_table.set("http", http_fn)?;
50 }
51 Ok(())
52}
53
54pub fn http_request_impl(lua: &Lua, args: (String, String, Option<Table>)) -> mlua::Result<Table> {
75 let (method, url, opts) = args;
76
77 if !url.starts_with("http://") && !url.starts_with("https://") {
79 let result = lua.create_table()?;
80 result.set("ok", false)?;
81 result.set(
82 "error",
83 format!(
84 "invalid URL scheme: URL must start with http:// or https://, got: {}",
85 truncate_for_error(&url, 100)
86 ),
87 )?;
88 result.set("error_kind", "invalid_url")?;
89 return Ok(result);
90 }
91
92 let timeout_secs = opts
94 .as_ref()
95 .and_then(|o| o.get::<u64>("timeout").ok())
96 .unwrap_or(DEFAULT_TIMEOUT_SECS);
97
98 let body: Option<String> = opts.as_ref().and_then(|o| o.get::<String>("body").ok());
99
100 let mut extra_headers: Vec<(String, String)> = Vec::new();
102 if let Some(ref o) = opts {
103 if let Ok(headers) = o.get::<Table>("headers") {
104 for (name, value) in headers.pairs::<String, String>().flatten() {
105 extra_headers.push((name, value));
106 }
107 }
108 }
109
110 let has_content_type = extra_headers
112 .iter()
113 .any(|(k, _)| k.to_lowercase() == "content-type");
114
115 let method_upper = method.to_uppercase();
117 let reqwest_method = match method_upper.as_str() {
118 "GET" => reqwest::Method::GET,
119 "POST" => reqwest::Method::POST,
120 "PUT" => reqwest::Method::PUT,
121 "DELETE" => reqwest::Method::DELETE,
122 "PATCH" => reqwest::Method::PATCH,
123 "HEAD" => reqwest::Method::HEAD,
124 _ => {
125 let result = lua.create_table()?;
126 result.set("ok", false)?;
127 result.set("error", format!("unsupported HTTP method: {method_upper}"))?;
128 result.set("error_kind", "invalid_method")?;
129 return Ok(result);
130 }
131 };
132
133 let client = crate::llm_command::get_or_init_http_client(lua)?;
135
136 let handle = tokio::runtime::Handle::try_current().map_err(|_| {
138 mlua::Error::RuntimeError("no tokio runtime available for async HTTP".into())
139 })?;
140
141 let mut req = client
143 .request(reqwest_method, &url)
144 .timeout(Duration::from_secs(timeout_secs));
145 for (name, value) in &extra_headers {
146 req = req.header(name.as_str(), value.as_str());
147 }
148
149 if !has_content_type && body.is_some() {
151 req = req.header("Content-Type", "application/json");
152 }
153
154 if let Some(ref body_str) = body {
155 req = req.body(body_str.clone());
156 }
157
158 match tokio::task::block_in_place(|| handle.block_on(req.send())) {
160 Ok(resp) => build_success_response(lua, resp),
161 Err(e) => build_error_response(lua, e),
162 }
163}
164
165fn build_success_response(lua: &Lua, resp: reqwest::Response) -> mlua::Result<Table> {
170 let status = resp.status().as_u16();
171
172 let headers_table = lua.create_table()?;
174 for (name, value) in resp.headers() {
175 if let Ok(v) = value.to_str() {
176 headers_table.set(name.as_str(), v)?;
177 }
178 }
179
180 let result = lua.create_table()?;
182 result.set("ok", true)?;
183 result.set("status", status)?;
184 result.set("headers", headers_table)?;
185
186 match read_body_limited(resp, MAX_BODY_SIZE) {
187 Ok(body_str) => {
188 result.set("body", body_str)?;
189 }
190 Err(ReadBodyError::TooLarge) => {
191 result.set("body", "")?;
192 result.set("error", "response body exceeds size limit")?;
193 result.set("error_kind", "too_large")?;
194 }
195 Err(ReadBodyError::InvalidUtf8) => {
196 result.set("body", "")?;
197 result.set("error", "response body is not valid UTF-8")?;
198 result.set("error_kind", "network")?;
199 }
200 Err(ReadBodyError::NoRuntime) => {
201 result.set("body", "")?;
202 result.set("error", "no tokio runtime available for reading body")?;
203 result.set("error_kind", "network")?;
204 }
205 Err(ReadBodyError::Network(msg)) => {
206 result.set("body", "")?;
207 result.set("error", format!("failed to read response body: {msg}"))?;
208 result.set("error_kind", "network")?;
209 }
210 }
211
212 Ok(result)
213}
214
215fn build_error_response(lua: &Lua, error: reqwest::Error) -> mlua::Result<Table> {
217 let (error_kind, error_msg) = classify_reqwest_error(&error);
218
219 let result = lua.create_table()?;
220 result.set("ok", false)?;
221 result.set("error", error_msg)?;
222 result.set("error_kind", error_kind)?;
223 Ok(result)
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::orcs_helpers::ensure_orcs_table;
230
231 #[test]
232 fn deny_stub_returns_permission_denied() {
233 let lua = Lua::new();
234 let orcs = ensure_orcs_table(&lua).expect("create orcs table");
235 register_http_deny_stub(&lua, &orcs).expect("register stub");
236
237 let result: Table = lua
238 .load(r#"return orcs.http("GET", "http://example.com")"#)
239 .eval()
240 .expect("should return deny table");
241
242 assert!(!result.get::<bool>("ok").expect("get ok"));
243 let error: String = result.get("error").expect("get error");
244 assert!(
245 error.contains("http denied"),
246 "expected permission denied, got: {error}"
247 );
248 assert_eq!(
249 result.get::<String>("error_kind").expect("get error_kind"),
250 "permission_denied"
251 );
252 }
253
254 #[test]
255 fn invalid_url_scheme_returns_error() {
256 let lua = Lua::new();
257 let result = http_request_impl(&lua, ("GET".into(), "ftp://example.com".into(), None))
258 .expect("should not panic");
259
260 assert!(!result.get::<bool>("ok").expect("get ok"));
261 assert_eq!(
262 result.get::<String>("error_kind").expect("get error_kind"),
263 "invalid_url"
264 );
265 }
266
267 #[test]
268 fn unsupported_method_returns_error() {
269 let lua = Lua::new();
270 let result = http_request_impl(&lua, ("CONNECT".into(), "http://localhost".into(), None))
271 .expect("should not panic");
272
273 assert!(!result.get::<bool>("ok").expect("get ok"));
274 assert_eq!(
275 result.get::<String>("error_kind").expect("get error_kind"),
276 "invalid_method"
277 );
278 }
279
280 #[tokio::test(flavor = "multi_thread")]
281 async fn connection_refused_returns_error_kind() {
282 let lua = Lua::new();
283 let opts = lua.create_table().expect("create opts");
285 opts.set("timeout", 2).expect("set timeout");
286
287 let result = http_request_impl(
288 &lua,
289 ("GET".into(), "http://127.0.0.1:1/test".into(), Some(opts)),
290 )
291 .expect("should not panic");
292
293 assert!(!result.get::<bool>("ok").expect("get ok"));
294 let error_kind: String = result.get("error_kind").expect("get error_kind");
295 assert!(
296 error_kind == "connection_refused"
297 || error_kind == "network"
298 || error_kind == "timeout",
299 "expected connection error kind, got: {error_kind}"
300 );
301 }
302
303 #[tokio::test(flavor = "multi_thread")]
304 async fn dns_failure_returns_error_kind() {
305 let lua = Lua::new();
306 let opts = lua.create_table().expect("create opts");
307 opts.set("timeout", 3).expect("set timeout");
308
309 let result = http_request_impl(
310 &lua,
311 (
312 "GET".into(),
313 "http://this-domain-does-not-exist-12345.invalid/test".into(),
314 Some(opts),
315 ),
316 )
317 .expect("should not panic");
318
319 assert!(!result.get::<bool>("ok").expect("get ok"));
320 let error_kind: String = result.get("error_kind").expect("get error_kind");
321 assert!(
323 error_kind == "dns" || error_kind == "network" || error_kind == "timeout",
324 "expected dns/network error kind, got: {error_kind}"
325 );
326 }
327
328 #[test]
329 fn truncate_for_error_handles_ascii() {
330 assert_eq!(truncate_for_error("hello", 10), "hello");
331 assert_eq!(truncate_for_error("hello world", 5), "hello");
332 }
333
334 #[test]
335 fn truncate_for_error_handles_utf8() {
336 let s = "あいう";
338 let t = truncate_for_error(s, 4);
339 assert_eq!(t, "あ"); }
341
342 #[tokio::test(flavor = "multi_thread")]
343 async fn opts_timeout_is_respected() {
344 let lua = Lua::new();
345 let opts = lua.create_table().expect("create opts");
346 opts.set("timeout", 1).expect("set timeout");
347
348 let start = std::time::Instant::now();
350 let result = http_request_impl(
351 &lua,
352 (
353 "GET".into(),
354 "http://192.0.2.1/test".into(), Some(opts),
356 ),
357 )
358 .expect("should not panic");
359
360 let elapsed = start.elapsed();
361 assert!(!result.get::<bool>("ok").expect("get ok"));
362 assert!(
364 elapsed.as_secs() < 5,
365 "should timeout quickly, took: {:?}",
366 elapsed
367 );
368 }
369
370 #[tokio::test(flavor = "multi_thread")]
371 async fn headers_are_passed_through() {
372 let lua = Lua::new();
376 let opts = lua.create_table().expect("create opts");
377 let headers = lua.create_table().expect("create headers");
378 headers
379 .set("Authorization", "Bearer test-token")
380 .expect("set auth");
381 headers.set("X-Custom", "custom-value").expect("set custom");
382 opts.set("headers", headers).expect("set headers");
383 opts.set("timeout", 1).expect("set timeout");
384
385 let result = http_request_impl(
387 &lua,
388 ("POST".into(), "http://127.0.0.1:1/test".into(), Some(opts)),
389 )
390 .expect("should not panic on header processing");
391
392 assert!(!result.get::<bool>("ok").expect("get ok"));
393 }
394}