Skip to main content

agent_block_core/bridge/
http.rs

1//! http.* — Async HTTP client bridge.
2//!
3//! Provides `http.request(url, opts)` as an async Rust function.
4//! When called from Lua via `coroutine_eval`, the coroutine yields
5//! during the HTTP request and other coroutines can make progress.
6//!
7//! # Streaming (SSE)
8//!
9//! When `stream = true`, the response body is read as Server-Sent
10//! Events.  Each `data:` line is passed to the `on_data(data_string)`
11//! Lua callback.  The `[DONE]` sentinel terminates the stream.
12//!
13//! # Security
14//!
15//! No URL restrictions during development.  The trust boundary is
16//! the Lua script author.  A security model will be designed
17//! separately before production use.
18
19use mlua::prelude::*;
20use std::collections::HashSet;
21use std::time::Duration;
22
23use crate::host::HostContext;
24use agent_block_types::obs;
25
26/// Default request timeout in seconds.
27const DEFAULT_TIMEOUT_SECS: u64 = 120;
28
29/// Maximum response body size (10 MiB).  Non-streaming only.
30const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
31
32pub fn register(lua: &Lua, ctx: &HostContext) -> LuaResult<()> {
33    let http_tbl = lua.create_table()?;
34
35    let script_name: String = lua
36        .globals()
37        .get::<Option<String>>("_SCRIPT_NAME")?
38        .unwrap_or_else(|| "unknown".to_string());
39    let client = ctx.http_client.clone();
40    let fallback_agent_id = ctx.mesh_agent.as_ref().map(|a| a.agent_id().to_string());
41    http_tbl.set(
42        "request",
43        lua.create_async_function(move |lua, (url, opts): (String, Option<LuaTable>)| {
44            let client = client.clone();
45            let fallback_agent_id = fallback_agent_id.clone();
46            let script_name = script_name.clone();
47            async move {
48                // ── Parse options ─────────────────────────────────
49                let method = opts
50                    .as_ref()
51                    .and_then(|t| t.get::<Option<String>>("method").ok().flatten())
52                    .unwrap_or_else(|| "GET".to_string());
53
54                let timeout_secs: u64 = opts
55                    .as_ref()
56                    .and_then(|t| t.get::<Option<u64>>("timeout").ok().flatten())
57                    .unwrap_or(DEFAULT_TIMEOUT_SECS);
58
59                let body: Option<String> = opts
60                    .as_ref()
61                    .and_then(|t| t.get::<Option<String>>("body").ok().flatten());
62
63                let stream_mode: bool = opts
64                    .as_ref()
65                    .and_then(|t| t.get::<Option<bool>>("stream").ok().flatten())
66                    .unwrap_or(false);
67
68                let on_data: Option<LuaFunction> = if stream_mode {
69                    opts.as_ref()
70                        .and_then(|t| t.get::<Option<LuaFunction>>("on_data").ok().flatten())
71                } else {
72                    None
73                };
74
75                // ── Build request ─────────────────────────────────
76                let reqwest_method = method.parse::<reqwest::Method>().map_err(|e| {
77                    LuaError::external(format!("invalid HTTP method '{method}': {e}"))
78                })?;
79
80                let mut req = client
81                    .request(reqwest_method, &url)
82                    .timeout(Duration::from_secs(timeout_secs));
83
84                let mut explicit_headers = HashSet::<String>::new();
85                if let Some(ref opts_tbl) = opts {
86                    if let Some(headers_tbl) = opts_tbl.get::<Option<LuaTable>>("headers")? {
87                        for pair in headers_tbl.pairs::<String, String>() {
88                            let (k, v) = pair?;
89                            explicit_headers.insert(k.to_ascii_lowercase());
90                            req = req.header(&k, &v);
91                        }
92                    }
93                }
94
95                // Auto-propagate trace context to outbound HTTP requests.
96                // User-provided headers always win (no override).
97                let trace_headers = [
98                    ("x-trace-id", std::env::var("AGENT_BLOCK_TRACE_ID").ok()),
99                    ("x-run-id", std::env::var("AGENT_BLOCK_RUN_ID").ok()),
100                    (
101                        "x-agent-id",
102                        std::env::var("AGENT_BLOCK_AGENT_ID")
103                            .ok()
104                            .or_else(|| fallback_agent_id.clone()),
105                    ),
106                    ("x-agent-name", std::env::var("AGENT_BLOCK_AGENT_NAME").ok()),
107                ];
108                for (name, val_opt) in trace_headers {
109                    if explicit_headers.contains(name) {
110                        continue;
111                    }
112                    if let Some(v) = val_opt {
113                        if !v.is_empty() {
114                            req = req.header(name, v);
115                        }
116                    }
117                }
118
119                if let Some(b) = body {
120                    req = req.body(b);
121                }
122
123                // ── Send (yields here) ────────────────────────────
124                tracing::info!(
125                    target: "lua",
126                    script = %script_name,
127                    "{}",
128                    obs::obs_line(
129                        "http",
130                        "http_request",
131                        &obs::obs_context(fallback_agent_id.as_deref()),
132                        &[("method", method.as_str()), ("url", url.as_str())],
133                    )
134                );
135                let resp = req.send().await.map_err(|e| {
136                    if e.is_timeout() {
137                        LuaError::external(format!("http timeout after {timeout_secs}s: {e}"))
138                    } else if e.is_connect() {
139                        LuaError::external(format!("http connection error: {e}"))
140                    } else {
141                        LuaError::external(format!("http request error: {e}"))
142                    }
143                })?;
144
145                let status = resp.status().as_u16();
146                let status_s = status.to_string();
147                tracing::info!(
148                    target: "lua",
149                    script = %script_name,
150                    "{}",
151                    obs::obs_line(
152                        "http",
153                        "http_response",
154                        &obs::obs_context(fallback_agent_id.as_deref()),
155                        &[("method", method.as_str()), ("url", url.as_str()), ("status", status_s.as_str())],
156                    )
157                );
158
159                let resp_headers = lua.create_table()?;
160                for (k, v) in resp.headers() {
161                    if let Ok(vs) = v.to_str() {
162                        resp_headers.set(k.as_str(), vs.to_string())?;
163                    }
164                }
165
166                if stream_mode {
167                    // ── SSE streaming mode ────────────────────────
168                    read_sse(resp, &on_data).await?;
169
170                    let result = lua.create_table()?;
171                    result.set("status", status)?;
172                    result.set("headers", resp_headers)?;
173                    Ok(result)
174                } else {
175                    // ── Buffered mode ─────────────────────────────
176                    let body_bytes = resp
177                        .bytes()
178                        .await
179                        .map_err(|e| LuaError::external(format!("http read body error: {e}")))?;
180
181                    if body_bytes.len() > MAX_BODY_SIZE {
182                        return Err(LuaError::external(format!(
183                            "response body too large: {} bytes (max {MAX_BODY_SIZE})",
184                            body_bytes.len()
185                        )));
186                    }
187
188                    let body_str = String::from_utf8_lossy(&body_bytes).to_string();
189
190                    let result = lua.create_table()?;
191                    result.set("status", status)?;
192                    result.set("headers", resp_headers)?;
193                    result.set("body", body_str)?;
194                    Ok(result)
195                }
196            }
197        })?,
198    )?;
199
200    lua.globals().set("http", http_tbl)?;
201    Ok(())
202}
203
204/// Read SSE stream and dispatch `data:` lines to the Lua callback.
205///
206/// SSE format:
207/// ```text
208/// event: message_start
209/// data: {"type":"message_start",...}
210///
211/// data: {"type":"content_block_delta",...}
212///
213/// data: [DONE]
214/// ```
215///
216/// Each `data:` value is passed as a string to `on_data`.
217/// The `[DONE]` sentinel terminates the stream.
218async fn read_sse(mut resp: reqwest::Response, on_data: &Option<LuaFunction>) -> LuaResult<()> {
219    let mut buffer = String::new();
220
221    // Read chunks as they arrive (yields between chunks).
222    loop {
223        let chunk = resp
224            .chunk()
225            .await
226            .map_err(|e| LuaError::external(format!("http stream read error: {e}")))?;
227
228        let chunk = match chunk {
229            Some(c) => c,
230            None => break, // EOF
231        };
232
233        buffer.push_str(&String::from_utf8_lossy(&chunk));
234
235        // Process complete SSE events (delimited by blank lines).
236        while let Some(pos) = buffer.find("\n\n") {
237            let event_block = buffer[..pos].to_string();
238            buffer = buffer[pos + 2..].to_string();
239
240            for line in event_block.lines() {
241                if let Some(data) = line
242                    .strip_prefix("data: ")
243                    .or_else(|| line.strip_prefix("data:"))
244                {
245                    let data = data.trim();
246                    if data == "[DONE]" {
247                        return Ok(());
248                    }
249                    if let Some(ref cb) = on_data {
250                        cb.call::<()>(data.to_string())?;
251                    }
252                }
253                // `event:`, `id:`, `retry:` lines are silently skipped.
254            }
255        }
256    }
257
258    Ok(())
259}