Skip to main content

agent_sandbox/runtime/
mod.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use agent_fetch::SafeClient;
5use wasmtime::{
6    Caller, Config, Engine, Linker, Module, Store, StoreLimits, StoreLimitsBuilder, Trap,
7};
8use wasmtime_wasi::WasiCtx;
9use wasmtime_wasi::p2::pipe::{MemoryInputPipe, MemoryOutputPipe};
10
11use crate::config::SandboxConfig;
12use crate::error::{Result, SandboxError};
13
14/// Result of executing a command in the sandbox.
15#[derive(Debug, Clone)]
16pub struct ExecResult {
17    pub exit_code: i32,
18    pub stdout: Vec<u8>,
19    pub stderr: Vec<u8>,
20}
21
22/// JSON request sent from WASM guest to host for fetch.
23#[derive(serde::Deserialize)]
24struct GuestFetchRequest {
25    url: String,
26    #[serde(default = "default_method")]
27    method: String,
28    #[serde(default)]
29    headers: HashMap<String, String>,
30    #[serde(default)]
31    body: Option<String>,
32}
33
34fn default_method() -> String {
35    "GET".to_string()
36}
37
38/// JSON response sent from host back to WASM guest.
39#[derive(serde::Serialize)]
40struct GuestFetchResponse {
41    status: u16,
42    headers: HashMap<String, String>,
43    body: String,
44    ok: bool,
45    error: Option<String>,
46}
47
48/// Store data combining WASI context with resource limits and fetch state.
49struct SandboxState {
50    wasi: wasmtime_wasi::p1::WasiP1Ctx,
51    limits: StoreLimits,
52    fetch_client: Option<Arc<SafeClient>>,
53    fetch_response: Option<Vec<u8>>,
54    tokio_handle: Option<tokio::runtime::Handle>,
55}
56
57/// Cached WASM engine and compiled module shared across all Sandbox instances.
58struct CachedModule {
59    engine: Engine,
60    module: Module,
61}
62
63/// Global cache for the compiled WASM module.
64/// Compiling the toolbox WASM binary is expensive, so we do it once.
65static MODULE_CACHE: OnceLock<std::result::Result<CachedModule, String>> = OnceLock::new();
66
67fn get_or_compile_module() -> Result<(&'static Engine, &'static Module)> {
68    let cached = MODULE_CACHE.get_or_init(|| {
69        let precompiled_bytes = include_bytes!(env!("TOOLBOX_CWASM_PATH"));
70
71        if precompiled_bytes.is_empty() {
72            return Err("WASM toolbox not available".to_string());
73        }
74
75        // Engine config MUST match build.rs exactly
76        let mut engine_config = Config::new();
77        engine_config.consume_fuel(true);
78
79        let engine =
80            Engine::new(&engine_config).map_err(|e| format!("engine creation failed: {e}"))?;
81
82        // SAFETY: The precompiled bytes come from our own build.rs via
83        // Engine::precompile_module() with the same engine config and wasmtime version.
84        let module = unsafe { Module::deserialize(&engine, precompiled_bytes) }
85            .map_err(|e| format!("module deserialization failed: {e}"))?;
86
87        Ok(CachedModule { engine, module })
88    });
89
90    match cached {
91        Ok(c) => Ok((&c.engine, &c.module)),
92        Err(e) => Err(SandboxError::Other(e.clone())),
93    }
94}
95
96/// The WASI runtime that manages Wasmtime engine and module compilation.
97pub struct WasiRuntime {
98    engine: &'static Engine,
99    module: &'static Module,
100    config: Arc<SandboxConfig>,
101    fetch_client: Option<Arc<SafeClient>>,
102}
103
104impl WasiRuntime {
105    /// Create a new WASI runtime with the given sandbox config.
106    /// The toolbox WASM binary is compiled once and cached globally.
107    pub fn new(config: SandboxConfig, fetch_client: Option<Arc<SafeClient>>) -> Result<Self> {
108        let (engine, module) = get_or_compile_module()?;
109
110        Ok(Self {
111            engine,
112            module,
113            config: Arc::new(config),
114            fetch_client,
115        })
116    }
117
118    /// Execute a command inside the WASM sandbox.
119    pub async fn exec(&self, command: &str, args: &[String]) -> Result<ExecResult> {
120        let config = self.config.clone();
121        let engine = self.engine;
122        let module = self.module;
123        let command = command.to_string();
124        let args = args.to_vec();
125        let timeout = config.timeout;
126        let fetch_client = self.fetch_client.clone();
127        let tokio_handle = tokio::runtime::Handle::current();
128
129        // Run in blocking thread since Wasmtime is synchronous, with a wall-clock timeout
130        let task = tokio::task::spawn_blocking(move || {
131            exec_sync(
132                engine,
133                module,
134                &config,
135                &command,
136                &args,
137                fetch_client,
138                tokio_handle,
139            )
140        });
141
142        match tokio::time::timeout(timeout, task).await {
143            Ok(Ok(result)) => result,
144            Ok(Err(e)) => Err(SandboxError::Other(format!("task join error: {}", e))),
145            Err(_) => Err(SandboxError::Timeout(timeout)),
146        }
147    }
148}
149
150/// Read a byte slice from WASM guest memory.
151fn read_guest_memory(caller: &mut Caller<'_, SandboxState>, ptr: i32, len: i32) -> Option<Vec<u8>> {
152    if ptr < 0 || len < 0 {
153        return None;
154    }
155    let memory = caller.get_export("memory")?.into_memory()?;
156    let data = memory.data(&*caller);
157    let start = ptr as usize;
158    let end = start.checked_add(len as usize)?;
159    if end > data.len() {
160        return None;
161    }
162    Some(data[start..end].to_vec())
163}
164
165/// Write a byte slice into WASM guest memory.
166fn write_guest_memory(caller: &mut Caller<'_, SandboxState>, ptr: i32, buf: &[u8]) -> bool {
167    if ptr < 0 {
168        return false;
169    }
170    let memory = match caller.get_export("memory") {
171        Some(ext) => match ext.into_memory() {
172            Some(m) => m,
173            None => return false,
174        },
175        None => return false,
176    };
177    let data = memory.data_mut(caller);
178    let start = ptr as usize;
179    let end = match start.checked_add(buf.len()) {
180        Some(e) => e,
181        None => return false,
182    };
183    if end > data.len() {
184        return false;
185    }
186    data[start..end].copy_from_slice(buf);
187    true
188}
189
190fn exec_sync(
191    engine: &Engine,
192    module: &Module,
193    config: &SandboxConfig,
194    command: &str,
195    args: &[String],
196    fetch_client: Option<Arc<SafeClient>>,
197    tokio_handle: tokio::runtime::Handle,
198) -> Result<ExecResult> {
199    // Build argv: [command, ...args]
200    let mut argv: Vec<String> = vec![command.to_string()];
201    argv.extend(args.iter().cloned());
202
203    let argv_refs: Vec<&str> = argv.iter().map(|s| s.as_str()).collect();
204
205    // Set up stdout/stderr capture via MemoryOutputPipe
206    let stdout_pipe = MemoryOutputPipe::new(1024 * 1024); // 1MB capacity
207    let stderr_pipe = MemoryOutputPipe::new(1024 * 1024);
208
209    // Build WASI context using WasiCtx::builder()
210    let mut builder = WasiCtx::builder();
211    builder.args(&argv_refs);
212    builder.stdin(MemoryInputPipe::new(b"" as &[u8])); // Empty stdin — prevents blocking on host stdin
213    builder.stdout(stdout_pipe.clone());
214    builder.stderr(stderr_pipe.clone());
215
216    // Set TOOLBOX_CMD env var for BusyBox-style dispatch
217    builder.env("TOOLBOX_CMD", command);
218
219    // Set user-configured env vars
220    for (key, value) in &config.env_vars {
221        builder.env(key, value);
222    }
223
224    // Mount work directory
225    let work_dir = config.work_dir.canonicalize().map_err(|e| {
226        SandboxError::Io(std::io::Error::new(
227            std::io::ErrorKind::NotFound,
228            format!("work_dir '{}': {}", config.work_dir.display(), e),
229        ))
230    })?;
231
232    let dir = wasmtime_wasi::DirPerms::all();
233    let file = wasmtime_wasi::FilePerms::all();
234    builder.preopened_dir(&work_dir, "/work", dir, file)?;
235
236    // Mount additional directories
237    for mount in &config.mounts {
238        let host = mount.host_path.canonicalize().map_err(|e| {
239            SandboxError::Io(std::io::Error::new(
240                std::io::ErrorKind::NotFound,
241                format!("mount '{}': {}", mount.host_path.display(), e),
242            ))
243        })?;
244
245        let (d, f) = if mount.writable {
246            (
247                wasmtime_wasi::DirPerms::all(),
248                wasmtime_wasi::FilePerms::all(),
249            )
250        } else {
251            (
252                wasmtime_wasi::DirPerms::READ,
253                wasmtime_wasi::FilePerms::READ,
254            )
255        };
256
257        builder.preopened_dir(&host, &mount.guest_path, d, f)?;
258    }
259
260    // Build the WASIp1 context
261    let wasi_p1 = builder.build_p1();
262
263    // Build memory limiter
264    let limits = StoreLimitsBuilder::new()
265        .memory_size(config.memory_limit_bytes as usize)
266        .build();
267
268    let mut store = Store::new(
269        engine,
270        SandboxState {
271            wasi: wasi_p1,
272            limits,
273            fetch_client,
274            fetch_response: None,
275            tokio_handle: Some(tokio_handle),
276        },
277    );
278    store.limiter(|state| &mut state.limits);
279
280    // Set fuel limit
281    store.set_fuel(config.fuel_limit)?;
282
283    // Link WASI p1 and instantiate
284    let mut linker = Linker::new(engine);
285    wasmtime_wasi::p1::add_to_linker_sync(&mut linker, |state: &mut SandboxState| &mut state.wasi)?;
286
287    // Link sandbox host functions for fetch bridge
288    linker.func_wrap(
289        "sandbox",
290        "__sandbox_fetch",
291        |mut caller: Caller<'_, SandboxState>, req_ptr: i32, req_len: i32| -> i32 {
292            // Read request JSON from guest memory
293            let req_bytes = match read_guest_memory(&mut caller, req_ptr, req_len) {
294                Some(b) => b,
295                None => return -1,
296            };
297
298            let guest_req: GuestFetchRequest = match serde_json::from_slice(&req_bytes) {
299                Ok(r) => r,
300                Err(_) => return -1,
301            };
302
303            let client = match caller.data().fetch_client.as_ref() {
304                Some(c) => c.clone(),
305                None => {
306                    // Networking disabled — store error response
307                    let resp = GuestFetchResponse {
308                        status: 0,
309                        headers: HashMap::new(),
310                        body: String::new(),
311                        ok: false,
312                        error: Some("networking disabled: configure fetch_policy to enable".into()),
313                    };
314                    caller.data_mut().fetch_response = Some(serde_json::to_vec(&resp).unwrap());
315                    return -2;
316                }
317            };
318
319            let handle = match caller.data().tokio_handle.as_ref() {
320                Some(h) => h.clone(),
321                None => return -1,
322            };
323
324            let fetch_req = agent_fetch::FetchRequest {
325                url: guest_req.url,
326                method: guest_req.method,
327                headers: guest_req.headers,
328                body: guest_req.body.map(|s| s.into_bytes()),
329            };
330
331            // Bridge async fetch to sync context via the tokio handle
332            let result = std::thread::scope(|_| handle.block_on(client.fetch(fetch_req)));
333
334            let resp = match result {
335                Ok(r) => GuestFetchResponse {
336                    status: r.status,
337                    headers: r.headers,
338                    body: String::from_utf8_lossy(&r.body).to_string(),
339                    ok: (200..300).contains(&(r.status as u32)),
340                    error: None,
341                },
342                Err(e) => GuestFetchResponse {
343                    status: 0,
344                    headers: HashMap::new(),
345                    body: String::new(),
346                    ok: false,
347                    error: Some(e.to_string()),
348                },
349            };
350
351            caller.data_mut().fetch_response = Some(serde_json::to_vec(&resp).unwrap());
352            0
353        },
354    )?;
355
356    linker.func_wrap(
357        "sandbox",
358        "__sandbox_fetch_response_len",
359        |caller: Caller<'_, SandboxState>| -> i32 {
360            caller
361                .data()
362                .fetch_response
363                .as_ref()
364                .map(|r| r.len() as i32)
365                .unwrap_or(0)
366        },
367    )?;
368
369    linker.func_wrap(
370        "sandbox",
371        "__sandbox_fetch_response_read",
372        |mut caller: Caller<'_, SandboxState>, buf_ptr: i32, buf_len: i32| -> i32 {
373            if buf_ptr < 0 || buf_len < 0 {
374                return -1;
375            }
376            let resp = match caller.data().fetch_response.as_ref() {
377                Some(r) => r.clone(),
378                None => return -1,
379            };
380            let copy_len = std::cmp::min(resp.len(), buf_len as usize);
381            if write_guest_memory(&mut caller, buf_ptr, &resp[..copy_len]) {
382                copy_len as i32
383            } else {
384                -1
385            }
386        },
387    )?;
388
389    linker.module(&mut store, "", module)?;
390
391    // Get the default function (_start) and call it
392    let func = linker
393        .get_default(&mut store, "")?
394        .typed::<(), ()>(&store)?;
395
396    let exit_code = match func.call(&mut store, ()) {
397        Ok(()) => 0,
398        Err(e) => {
399            // Check if it's a normal process exit
400            if let Some(exit) = e.downcast_ref::<wasmtime_wasi::I32Exit>() {
401                exit.0
402            } else if e.downcast_ref::<Trap>() == Some(&Trap::OutOfFuel) {
403                return Err(SandboxError::Timeout(config.timeout));
404            } else {
405                return Err(SandboxError::Runtime(e));
406            }
407        }
408    };
409
410    let stdout_bytes = stdout_pipe.contents().to_vec();
411    let stderr_bytes = stderr_pipe.contents().to_vec();
412
413    Ok(ExecResult {
414        exit_code,
415        stdout: stdout_bytes,
416        stderr: stderr_bytes,
417    })
418}