use std::cell::RefCell;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
const DEFAULT_FETCH_TIMEOUT: Duration = Duration::from_secs(10);
const FETCH_BODY_MAX_BYTES: usize = 1024 * 1024;
pub(crate) const SANDBOX_FILE_MAX_BYTES: usize = 1024 * 1024;
thread_local! {
static SCRIPT_OUTPUT: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
}
#[derive(Clone, Copy, Default)]
pub(crate) struct OutputBuffer;
impl OutputBuffer {
pub(crate) fn new() -> Self {
SCRIPT_OUTPUT.with(|cell| cell.borrow_mut().clear());
Self
}
pub(crate) fn write_str(&self, s: &str) {
SCRIPT_OUTPUT.with(|cell| {
if let Ok(mut buf) = cell.try_borrow_mut() {
buf.extend_from_slice(s.as_bytes());
}
});
}
pub(crate) fn drain_to_string(&self) -> String {
SCRIPT_OUTPUT.with(|cell| {
let bytes = match cell.try_borrow_mut() {
Ok(mut buf) => std::mem::take(&mut *buf),
Err(_) => return String::new(),
};
match String::from_utf8(bytes) {
Ok(s) => s,
Err(e) => String::from_utf8_lossy(&e.into_bytes()).into_owned(),
}
})
}
}
#[derive(Debug, Default, Deserialize)]
#[serde(default)]
pub(crate) struct FetchRequest {
pub method: Option<String>,
pub headers: Option<std::collections::HashMap<String, String>>,
pub body: Option<String>,
pub timeout_ms: Option<u64>,
}
#[derive(Debug, Serialize)]
pub(crate) struct FetchResponse {
pub status: u16,
pub ok: bool,
pub url: String,
pub headers: std::collections::HashMap<String, String>,
pub body: String,
pub truncated: bool,
}
pub(crate) fn agent_fetch_blocking(
client: &reqwest::Client,
runtime: &tokio::runtime::Handle,
interrupt: &Arc<AtomicBool>,
usage: &super::ScriptUsage,
url: &str,
req: FetchRequest,
) -> Result<FetchResponse, String> {
usage.fetch_calls.fetch_add(1, Ordering::Relaxed);
if interrupt.load(Ordering::Relaxed) {
usage.fetch_errors.fetch_add(1, Ordering::Relaxed);
return Err("interrupted".into());
}
let method = req.method.as_deref().unwrap_or("GET").to_ascii_uppercase();
let method = reqwest::Method::from_bytes(method.as_bytes()).map_err(|e| {
usage.fetch_errors.fetch_add(1, Ordering::Relaxed);
format!("invalid method: {e}")
})?;
let timeout = req
.timeout_ms
.map(Duration::from_millis)
.unwrap_or(DEFAULT_FETCH_TIMEOUT);
let client = client.clone();
let url_owned = url.to_string();
runtime.block_on(async move {
let mut builder = client.request(method, &url_owned).timeout(timeout);
if let Some(headers) = req.headers {
for (k, v) in headers {
builder = builder.header(k, v);
}
}
if let Some(body) = req.body {
builder = builder.body(body);
}
let response = match builder.send().await {
Ok(r) => r,
Err(e) => {
usage.fetch_errors.fetch_add(1, Ordering::Relaxed);
return Err(format!("fetch error: {e}"));
}
};
let status = response.status().as_u16();
let ok = response.status().is_success();
let final_url = response.url().to_string();
let mut headers = std::collections::HashMap::new();
for (k, v) in response.headers().iter() {
if let Ok(value) = v.to_str() {
headers.insert(k.as_str().to_string(), value.to_string());
}
}
let bytes = match response.bytes().await {
Ok(b) => b,
Err(e) => {
usage.fetch_errors.fetch_add(1, Ordering::Relaxed);
return Err(format!("read body: {e}"));
}
};
let (slice, truncated) = if bytes.len() > FETCH_BODY_MAX_BYTES {
(&bytes[..FETCH_BODY_MAX_BYTES], true)
} else {
(&bytes[..], false)
};
usage
.fetch_bytes_in
.fetch_add(slice.len() as u64, Ordering::Relaxed);
let body = String::from_utf8_lossy(slice).into_owned();
Ok(FetchResponse {
status,
ok,
url: final_url,
headers,
body,
truncated,
})
})
}
pub(crate) struct SandboxedDir {
dir: cap_std::fs::Dir,
root: PathBuf,
_tempdir: tempfile::TempDir,
}
impl SandboxedDir {
pub(crate) fn new() -> std::io::Result<Self> {
let tempdir = tempfile::tempdir()?;
let root = tempdir.path().to_path_buf();
let authority = cap_std::ambient_authority();
let dir = cap_std::fs::Dir::open_ambient_dir(&root, authority)?;
Ok(Self {
dir,
root,
_tempdir: tempdir,
})
}
pub(crate) fn root_path(&self) -> &std::path::Path {
&self.root
}
pub(crate) fn read_file(&self, relative_path: &str) -> Result<String, String> {
use std::io::Read;
let mut f = self
.dir
.open(relative_path)
.map_err(|e| format!("open {relative_path}: {e}"))?;
let meta = f
.metadata()
.map_err(|e| format!("metadata {relative_path}: {e}"))?;
if meta.len() as usize > SANDBOX_FILE_MAX_BYTES {
return Err(format!(
"file too large: {} bytes (max {})",
meta.len(),
SANDBOX_FILE_MAX_BYTES
));
}
let mut buf = String::new();
f.read_to_string(&mut buf)
.map_err(|e| format!("read {relative_path}: {e}"))?;
Ok(buf)
}
pub(crate) fn write_file(&self, relative_path: &str, content: &str) -> Result<(), String> {
use std::io::Write;
if content.len() > SANDBOX_FILE_MAX_BYTES {
return Err(format!(
"content too large: {} bytes (max {})",
content.len(),
SANDBOX_FILE_MAX_BYTES
));
}
let mut f = self
.dir
.create(relative_path)
.map_err(|e| format!("create {relative_path}: {e}"))?;
f.write_all(content.as_bytes())
.map_err(|e| format!("write {relative_path}: {e}"))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sandbox_round_trip() {
let sb = SandboxedDir::new().unwrap();
sb.write_file("hello.txt", "world").unwrap();
assert_eq!(sb.read_file("hello.txt").unwrap(), "world");
}
#[test]
fn sandbox_rejects_escape() {
let sb = SandboxedDir::new().unwrap();
let bad = sb.read_file("../../../etc/passwd");
assert!(bad.is_err(), "path traversal must be rejected");
}
#[test]
fn sandbox_rejects_absolute() {
let sb = SandboxedDir::new().unwrap();
let bad = sb.read_file("/etc/passwd");
assert!(bad.is_err(), "absolute paths must be rejected");
}
#[test]
fn output_buffer_drains() {
let buf = OutputBuffer::new();
buf.write_str("hello ");
buf.write_str("world");
assert_eq!(buf.drain_to_string(), "hello world");
assert_eq!(buf.drain_to_string(), "");
}
#[test]
fn fetch_request_default_method_is_get() {
let req: FetchRequest = serde_json::from_str("{}").unwrap();
assert!(req.method.is_none());
}
}