agent_block_core/bridge/
http.rs1use mlua::prelude::*;
20use std::collections::HashSet;
21use std::time::Duration;
22
23use crate::host::HostContext;
24use agent_block_types::obs;
25
26const DEFAULT_TIMEOUT_SECS: u64 = 120;
28
29const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
31
32pub fn register(lua: &Lua, ctx: &HostContext) -> LuaResult<()> {
33 let http_tbl = lua.create_table()?;
34
35 let script_name: String = lua
36 .globals()
37 .get::<Option<String>>("_SCRIPT_NAME")?
38 .unwrap_or_else(|| "unknown".to_string());
39 let client = ctx.http_client.clone();
40 let fallback_agent_id = ctx.mesh_agent.as_ref().map(|a| a.agent_id().to_string());
41 http_tbl.set(
42 "request",
43 lua.create_async_function(move |lua, (url, opts): (String, Option<LuaTable>)| {
44 let client = client.clone();
45 let fallback_agent_id = fallback_agent_id.clone();
46 let script_name = script_name.clone();
47 async move {
48 let method = opts
50 .as_ref()
51 .and_then(|t| t.get::<Option<String>>("method").ok().flatten())
52 .unwrap_or_else(|| "GET".to_string());
53
54 let timeout_secs: u64 = opts
55 .as_ref()
56 .and_then(|t| t.get::<Option<u64>>("timeout").ok().flatten())
57 .unwrap_or(DEFAULT_TIMEOUT_SECS);
58
59 let body: Option<String> = opts
60 .as_ref()
61 .and_then(|t| t.get::<Option<String>>("body").ok().flatten());
62
63 let stream_mode: bool = opts
64 .as_ref()
65 .and_then(|t| t.get::<Option<bool>>("stream").ok().flatten())
66 .unwrap_or(false);
67
68 let on_data: Option<LuaFunction> = if stream_mode {
69 opts.as_ref()
70 .and_then(|t| t.get::<Option<LuaFunction>>("on_data").ok().flatten())
71 } else {
72 None
73 };
74
75 let reqwest_method = method.parse::<reqwest::Method>().map_err(|e| {
77 LuaError::external(format!("invalid HTTP method '{method}': {e}"))
78 })?;
79
80 let mut req = client
81 .request(reqwest_method, &url)
82 .timeout(Duration::from_secs(timeout_secs));
83
84 let mut explicit_headers = HashSet::<String>::new();
85 if let Some(ref opts_tbl) = opts {
86 if let Some(headers_tbl) = opts_tbl.get::<Option<LuaTable>>("headers")? {
87 for pair in headers_tbl.pairs::<String, String>() {
88 let (k, v) = pair?;
89 explicit_headers.insert(k.to_ascii_lowercase());
90 req = req.header(&k, &v);
91 }
92 }
93 }
94
95 let trace_headers = [
98 ("x-trace-id", std::env::var("AGENT_BLOCK_TRACE_ID").ok()),
99 ("x-run-id", std::env::var("AGENT_BLOCK_RUN_ID").ok()),
100 (
101 "x-agent-id",
102 std::env::var("AGENT_BLOCK_AGENT_ID")
103 .ok()
104 .or_else(|| fallback_agent_id.clone()),
105 ),
106 ("x-agent-name", std::env::var("AGENT_BLOCK_AGENT_NAME").ok()),
107 ];
108 for (name, val_opt) in trace_headers {
109 if explicit_headers.contains(name) {
110 continue;
111 }
112 if let Some(v) = val_opt {
113 if !v.is_empty() {
114 req = req.header(name, v);
115 }
116 }
117 }
118
119 if let Some(b) = body {
120 req = req.body(b);
121 }
122
123 tracing::info!(
125 target: "lua",
126 script = %script_name,
127 "{}",
128 obs::obs_line(
129 "http",
130 "http_request",
131 &obs::obs_context(fallback_agent_id.as_deref()),
132 &[("method", method.as_str()), ("url", url.as_str())],
133 )
134 );
135 let resp = req.send().await.map_err(|e| {
136 if e.is_timeout() {
137 LuaError::external(format!("http timeout after {timeout_secs}s: {e}"))
138 } else if e.is_connect() {
139 LuaError::external(format!("http connection error: {e}"))
140 } else {
141 LuaError::external(format!("http request error: {e}"))
142 }
143 })?;
144
145 let status = resp.status().as_u16();
146 let status_s = status.to_string();
147 tracing::info!(
148 target: "lua",
149 script = %script_name,
150 "{}",
151 obs::obs_line(
152 "http",
153 "http_response",
154 &obs::obs_context(fallback_agent_id.as_deref()),
155 &[("method", method.as_str()), ("url", url.as_str()), ("status", status_s.as_str())],
156 )
157 );
158
159 let resp_headers = lua.create_table()?;
160 for (k, v) in resp.headers() {
161 if let Ok(vs) = v.to_str() {
162 resp_headers.set(k.as_str(), vs.to_string())?;
163 }
164 }
165
166 if stream_mode {
167 read_sse(resp, &on_data).await?;
169
170 let result = lua.create_table()?;
171 result.set("status", status)?;
172 result.set("headers", resp_headers)?;
173 Ok(result)
174 } else {
175 let body_bytes = resp
177 .bytes()
178 .await
179 .map_err(|e| LuaError::external(format!("http read body error: {e}")))?;
180
181 if body_bytes.len() > MAX_BODY_SIZE {
182 return Err(LuaError::external(format!(
183 "response body too large: {} bytes (max {MAX_BODY_SIZE})",
184 body_bytes.len()
185 )));
186 }
187
188 let body_str = String::from_utf8_lossy(&body_bytes).to_string();
189
190 let result = lua.create_table()?;
191 result.set("status", status)?;
192 result.set("headers", resp_headers)?;
193 result.set("body", body_str)?;
194 Ok(result)
195 }
196 }
197 })?,
198 )?;
199
200 lua.globals().set("http", http_tbl)?;
201 Ok(())
202}
203
204async fn read_sse(mut resp: reqwest::Response, on_data: &Option<LuaFunction>) -> LuaResult<()> {
219 let mut buffer = String::new();
220
221 loop {
223 let chunk = resp
224 .chunk()
225 .await
226 .map_err(|e| LuaError::external(format!("http stream read error: {e}")))?;
227
228 let chunk = match chunk {
229 Some(c) => c,
230 None => break, };
232
233 buffer.push_str(&String::from_utf8_lossy(&chunk));
234
235 while let Some(pos) = buffer.find("\n\n") {
237 let event_block = buffer[..pos].to_string();
238 buffer = buffer[pos + 2..].to_string();
239
240 for line in event_block.lines() {
241 if let Some(data) = line
242 .strip_prefix("data: ")
243 .or_else(|| line.strip_prefix("data:"))
244 {
245 let data = data.trim();
246 if data == "[DONE]" {
247 return Ok(());
248 }
249 if let Some(ref cb) = on_data {
250 cb.call::<()>(data.to_string())?;
251 }
252 }
253 }
255 }
256 }
257
258 Ok(())
259}