1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use agent_fetch::SafeClient;
5use wasmtime::{
6 Caller, Config, Engine, Linker, Module, Store, StoreLimits, StoreLimitsBuilder, Trap,
7};
8use wasmtime_wasi::WasiCtx;
9use wasmtime_wasi::p2::pipe::{MemoryInputPipe, MemoryOutputPipe};
10
11use crate::config::SandboxConfig;
12use crate::error::{Result, SandboxError};
13
14#[derive(Debug, Clone)]
16pub struct ExecResult {
17 pub exit_code: i32,
18 pub stdout: Vec<u8>,
19 pub stderr: Vec<u8>,
20}
21
22#[derive(serde::Deserialize)]
24struct GuestFetchRequest {
25 url: String,
26 #[serde(default = "default_method")]
27 method: String,
28 #[serde(default)]
29 headers: HashMap<String, String>,
30 #[serde(default)]
31 body: Option<String>,
32}
33
34fn default_method() -> String {
35 "GET".to_string()
36}
37
38#[derive(serde::Serialize)]
40struct GuestFetchResponse {
41 status: u16,
42 headers: HashMap<String, String>,
43 body: String,
44 ok: bool,
45 error: Option<String>,
46}
47
48struct SandboxState {
50 wasi: wasmtime_wasi::p1::WasiP1Ctx,
51 limits: StoreLimits,
52 fetch_client: Option<Arc<SafeClient>>,
53 fetch_response: Option<Vec<u8>>,
54 tokio_handle: Option<tokio::runtime::Handle>,
55}
56
57struct CachedModule {
59 engine: Engine,
60 module: Module,
61}
62
63static MODULE_CACHE: OnceLock<std::result::Result<CachedModule, String>> = OnceLock::new();
66
67fn get_or_compile_module() -> Result<(&'static Engine, &'static Module)> {
68 let cached = MODULE_CACHE.get_or_init(|| {
69 let precompiled_bytes = include_bytes!(env!("TOOLBOX_CWASM_PATH"));
70
71 if precompiled_bytes.is_empty() {
72 return Err("WASM toolbox not available".to_string());
73 }
74
75 let mut engine_config = Config::new();
77 engine_config.consume_fuel(true);
78
79 let engine =
80 Engine::new(&engine_config).map_err(|e| format!("engine creation failed: {e}"))?;
81
82 let module = unsafe { Module::deserialize(&engine, precompiled_bytes) }
85 .map_err(|e| format!("module deserialization failed: {e}"))?;
86
87 Ok(CachedModule { engine, module })
88 });
89
90 match cached {
91 Ok(c) => Ok((&c.engine, &c.module)),
92 Err(e) => Err(SandboxError::Other(e.clone())),
93 }
94}
95
96pub struct WasiRuntime {
98 engine: &'static Engine,
99 module: &'static Module,
100 config: Arc<SandboxConfig>,
101 fetch_client: Option<Arc<SafeClient>>,
102}
103
104impl WasiRuntime {
105 pub fn new(config: SandboxConfig, fetch_client: Option<Arc<SafeClient>>) -> Result<Self> {
108 let (engine, module) = get_or_compile_module()?;
109
110 Ok(Self {
111 engine,
112 module,
113 config: Arc::new(config),
114 fetch_client,
115 })
116 }
117
118 pub async fn exec(&self, command: &str, args: &[String]) -> Result<ExecResult> {
120 let config = self.config.clone();
121 let engine = self.engine;
122 let module = self.module;
123 let command = command.to_string();
124 let args = args.to_vec();
125 let timeout = config.timeout;
126 let fetch_client = self.fetch_client.clone();
127 let tokio_handle = tokio::runtime::Handle::current();
128
129 let task = tokio::task::spawn_blocking(move || {
131 exec_sync(
132 engine,
133 module,
134 &config,
135 &command,
136 &args,
137 fetch_client,
138 tokio_handle,
139 )
140 });
141
142 match tokio::time::timeout(timeout, task).await {
143 Ok(Ok(result)) => result,
144 Ok(Err(e)) => Err(SandboxError::Other(format!("task join error: {}", e))),
145 Err(_) => Err(SandboxError::Timeout(timeout)),
146 }
147 }
148}
149
150fn read_guest_memory(caller: &mut Caller<'_, SandboxState>, ptr: i32, len: i32) -> Option<Vec<u8>> {
152 if ptr < 0 || len < 0 {
153 return None;
154 }
155 let memory = caller.get_export("memory")?.into_memory()?;
156 let data = memory.data(&*caller);
157 let start = ptr as usize;
158 let end = start.checked_add(len as usize)?;
159 if end > data.len() {
160 return None;
161 }
162 Some(data[start..end].to_vec())
163}
164
165fn write_guest_memory(caller: &mut Caller<'_, SandboxState>, ptr: i32, buf: &[u8]) -> bool {
167 if ptr < 0 {
168 return false;
169 }
170 let memory = match caller.get_export("memory") {
171 Some(ext) => match ext.into_memory() {
172 Some(m) => m,
173 None => return false,
174 },
175 None => return false,
176 };
177 let data = memory.data_mut(caller);
178 let start = ptr as usize;
179 let end = match start.checked_add(buf.len()) {
180 Some(e) => e,
181 None => return false,
182 };
183 if end > data.len() {
184 return false;
185 }
186 data[start..end].copy_from_slice(buf);
187 true
188}
189
190fn exec_sync(
191 engine: &Engine,
192 module: &Module,
193 config: &SandboxConfig,
194 command: &str,
195 args: &[String],
196 fetch_client: Option<Arc<SafeClient>>,
197 tokio_handle: tokio::runtime::Handle,
198) -> Result<ExecResult> {
199 let mut argv: Vec<String> = vec![command.to_string()];
201 argv.extend(args.iter().cloned());
202
203 let argv_refs: Vec<&str> = argv.iter().map(|s| s.as_str()).collect();
204
205 let stdout_pipe = MemoryOutputPipe::new(1024 * 1024); let stderr_pipe = MemoryOutputPipe::new(1024 * 1024);
208
209 let mut builder = WasiCtx::builder();
211 builder.args(&argv_refs);
212 builder.stdin(MemoryInputPipe::new(b"" as &[u8])); builder.stdout(stdout_pipe.clone());
214 builder.stderr(stderr_pipe.clone());
215
216 builder.env("TOOLBOX_CMD", command);
218
219 for (key, value) in &config.env_vars {
221 builder.env(key, value);
222 }
223
224 let work_dir = config.work_dir.canonicalize().map_err(|e| {
226 SandboxError::Io(std::io::Error::new(
227 std::io::ErrorKind::NotFound,
228 format!("work_dir '{}': {}", config.work_dir.display(), e),
229 ))
230 })?;
231
232 let dir = wasmtime_wasi::DirPerms::all();
233 let file = wasmtime_wasi::FilePerms::all();
234 builder.preopened_dir(&work_dir, "/work", dir, file)?;
235
236 for mount in &config.mounts {
238 let host = mount.host_path.canonicalize().map_err(|e| {
239 SandboxError::Io(std::io::Error::new(
240 std::io::ErrorKind::NotFound,
241 format!("mount '{}': {}", mount.host_path.display(), e),
242 ))
243 })?;
244
245 let (d, f) = if mount.writable {
246 (
247 wasmtime_wasi::DirPerms::all(),
248 wasmtime_wasi::FilePerms::all(),
249 )
250 } else {
251 (
252 wasmtime_wasi::DirPerms::READ,
253 wasmtime_wasi::FilePerms::READ,
254 )
255 };
256
257 builder.preopened_dir(&host, &mount.guest_path, d, f)?;
258 }
259
260 let wasi_p1 = builder.build_p1();
262
263 let limits = StoreLimitsBuilder::new()
265 .memory_size(config.memory_limit_bytes as usize)
266 .build();
267
268 let mut store = Store::new(
269 engine,
270 SandboxState {
271 wasi: wasi_p1,
272 limits,
273 fetch_client,
274 fetch_response: None,
275 tokio_handle: Some(tokio_handle),
276 },
277 );
278 store.limiter(|state| &mut state.limits);
279
280 store.set_fuel(config.fuel_limit)?;
282
283 let mut linker = Linker::new(engine);
285 wasmtime_wasi::p1::add_to_linker_sync(&mut linker, |state: &mut SandboxState| &mut state.wasi)?;
286
287 linker.func_wrap(
289 "sandbox",
290 "__sandbox_fetch",
291 |mut caller: Caller<'_, SandboxState>, req_ptr: i32, req_len: i32| -> i32 {
292 let req_bytes = match read_guest_memory(&mut caller, req_ptr, req_len) {
294 Some(b) => b,
295 None => return -1,
296 };
297
298 let guest_req: GuestFetchRequest = match serde_json::from_slice(&req_bytes) {
299 Ok(r) => r,
300 Err(_) => return -1,
301 };
302
303 let client = match caller.data().fetch_client.as_ref() {
304 Some(c) => c.clone(),
305 None => {
306 let resp = GuestFetchResponse {
308 status: 0,
309 headers: HashMap::new(),
310 body: String::new(),
311 ok: false,
312 error: Some("networking disabled: configure fetch_policy to enable".into()),
313 };
314 caller.data_mut().fetch_response = Some(serde_json::to_vec(&resp).unwrap());
315 return -2;
316 }
317 };
318
319 let handle = match caller.data().tokio_handle.as_ref() {
320 Some(h) => h.clone(),
321 None => return -1,
322 };
323
324 let fetch_req = agent_fetch::FetchRequest {
325 url: guest_req.url,
326 method: guest_req.method,
327 headers: guest_req.headers,
328 body: guest_req.body.map(|s| s.into_bytes()),
329 };
330
331 let result = std::thread::scope(|_| handle.block_on(client.fetch(fetch_req)));
333
334 let resp = match result {
335 Ok(r) => GuestFetchResponse {
336 status: r.status,
337 headers: r.headers,
338 body: String::from_utf8_lossy(&r.body).to_string(),
339 ok: (200..300).contains(&(r.status as u32)),
340 error: None,
341 },
342 Err(e) => GuestFetchResponse {
343 status: 0,
344 headers: HashMap::new(),
345 body: String::new(),
346 ok: false,
347 error: Some(e.to_string()),
348 },
349 };
350
351 caller.data_mut().fetch_response = Some(serde_json::to_vec(&resp).unwrap());
352 0
353 },
354 )?;
355
356 linker.func_wrap(
357 "sandbox",
358 "__sandbox_fetch_response_len",
359 |caller: Caller<'_, SandboxState>| -> i32 {
360 caller
361 .data()
362 .fetch_response
363 .as_ref()
364 .map(|r| r.len() as i32)
365 .unwrap_or(0)
366 },
367 )?;
368
369 linker.func_wrap(
370 "sandbox",
371 "__sandbox_fetch_response_read",
372 |mut caller: Caller<'_, SandboxState>, buf_ptr: i32, buf_len: i32| -> i32 {
373 if buf_ptr < 0 || buf_len < 0 {
374 return -1;
375 }
376 let resp = match caller.data().fetch_response.as_ref() {
377 Some(r) => r.clone(),
378 None => return -1,
379 };
380 let copy_len = std::cmp::min(resp.len(), buf_len as usize);
381 if write_guest_memory(&mut caller, buf_ptr, &resp[..copy_len]) {
382 copy_len as i32
383 } else {
384 -1
385 }
386 },
387 )?;
388
389 linker.module(&mut store, "", module)?;
390
391 let func = linker
393 .get_default(&mut store, "")?
394 .typed::<(), ()>(&store)?;
395
396 let exit_code = match func.call(&mut store, ()) {
397 Ok(()) => 0,
398 Err(e) => {
399 if let Some(exit) = e.downcast_ref::<wasmtime_wasi::I32Exit>() {
401 exit.0
402 } else if e.downcast_ref::<Trap>() == Some(&Trap::OutOfFuel) {
403 return Err(SandboxError::Timeout(config.timeout));
404 } else {
405 return Err(SandboxError::Runtime(e));
406 }
407 }
408 };
409
410 let stdout_bytes = stdout_pipe.contents().to_vec();
411 let stderr_bytes = stderr_pipe.contents().to_vec();
412
413 Ok(ExecResult {
414 exit_code,
415 stdout: stdout_bytes,
416 stderr: stderr_bytes,
417 })
418}