Skip to main content

rustant_tools/sandbox/
host.rs

1//! Host state and function registration for WASM sandbox guests.
2
3use super::config::Capability;
4use tracing::debug;
5
6// ---------------------------------------------------------------------------
7// HostState
8// ---------------------------------------------------------------------------
9
10/// Mutable state shared between the host and WASM guest via host functions.
11///
12/// Each sandbox execution gets its own `HostState` that accumulates output
13/// written by the guest and provides the input buffer for reading.
14pub struct HostState {
15    /// Bytes written by the guest to stdout.
16    pub stdout: Vec<u8>,
17    /// Bytes written by the guest to stderr.
18    pub stderr: Vec<u8>,
19    /// Bytes written by the guest to the structured output channel.
20    pub output: Vec<u8>,
21    /// Input buffer supplied by the host before execution.
22    pub input: Vec<u8>,
23    /// Current read position within `input`.
24    pub input_pos: usize,
25    /// Capabilities granted to this sandbox execution.
26    pub capabilities: Vec<Capability>,
27    /// Peak linear memory usage observed (in bytes).
28    pub memory_peak: usize,
29}
30
31impl HostState {
32    /// Create a new `HostState` with the given input buffer and capabilities.
33    pub fn new(input: Vec<u8>, capabilities: Vec<Capability>) -> Self {
34        Self {
35            stdout: Vec::new(),
36            stderr: Vec::new(),
37            output: Vec::new(),
38            input,
39            input_pos: 0,
40            capabilities,
41            memory_peak: 0,
42        }
43    }
44
45    /// Check whether this host state includes the given capability.
46    ///
47    /// For simple capabilities like [`Capability::Stdout`] and
48    /// [`Capability::Stderr`], an exact variant match is performed. For
49    /// path-based capabilities like [`Capability::FileRead`], the check
50    /// succeeds if any allowed path is a prefix of (or equal to) the
51    /// requested path.
52    pub fn has_capability(&self, cap: &Capability) -> bool {
53        self.capabilities.iter().any(|c| match (c, cap) {
54            (Capability::Stdout, Capability::Stdout) => true,
55            (Capability::Stderr, Capability::Stderr) => true,
56            (Capability::FileRead(allowed), Capability::FileRead(requested)) => requested
57                .iter()
58                .all(|req| allowed.iter().any(|a| req.starts_with(a))),
59            (Capability::FileWrite(allowed), Capability::FileWrite(requested)) => requested
60                .iter()
61                .all(|req| allowed.iter().any(|a| req.starts_with(a))),
62            (Capability::NetworkAccess(allowed), Capability::NetworkAccess(requested)) => {
63                requested.iter().all(|req| allowed.contains(req))
64            }
65            (Capability::EnvironmentRead(allowed), Capability::EnvironmentRead(requested)) => {
66                requested.iter().all(|req| allowed.contains(req))
67            }
68            _ => false,
69        })
70    }
71
72    /// Track peak memory usage. If `current_bytes` exceeds the previous peak
73    /// the peak is updated.
74    pub fn track_memory(&mut self, current_bytes: usize) {
75        if current_bytes > self.memory_peak {
76            self.memory_peak = current_bytes;
77        }
78    }
79}
80
81// ---------------------------------------------------------------------------
82// Helper functions
83// ---------------------------------------------------------------------------
84
85/// Read a UTF-8 string from WASM linear memory at `[ptr .. ptr+len)`.
86///
87/// Returns `None` if the bounds are invalid or the bytes are not valid UTF-8.
88fn read_wasm_string(
89    memory: &wasmi::Memory,
90    store: &impl wasmi::AsContext,
91    ptr: i32,
92    len: i32,
93) -> Option<String> {
94    let bytes = read_wasm_bytes(memory, store, ptr, len)?;
95    String::from_utf8(bytes).ok()
96}
97
98/// Read raw bytes from WASM linear memory at `[ptr .. ptr+len)`.
99///
100/// Returns `None` if the requested range falls outside the memory bounds.
101fn read_wasm_bytes(
102    memory: &wasmi::Memory,
103    store: &impl wasmi::AsContext,
104    ptr: i32,
105    len: i32,
106) -> Option<Vec<u8>> {
107    if ptr < 0 || len < 0 {
108        return None;
109    }
110    let start = ptr as usize;
111    let size = len as usize;
112    let data = memory.data(store);
113    if start.checked_add(size)? > data.len() {
114        return None;
115    }
116    Some(data[start..start + size].to_vec())
117}
118
119// ---------------------------------------------------------------------------
120// Host function registration
121// ---------------------------------------------------------------------------
122
123/// Register the standard set of host functions in the `"env"` namespace.
124///
125/// The registered functions allow a WASM guest to log messages, write to
126/// stdout/stderr/output, and read from the input buffer provided by the host.
127///
128/// # Host functions
129///
130/// | Name                | Signature                            | Description                          |
131/// |---------------------|--------------------------------------|--------------------------------------|
132/// | `host_log`          | `(ptr: i32, len: i32)`               | Log a UTF-8 message via `tracing`    |
133/// | `host_write_stdout` | `(ptr: i32, len: i32)`               | Append bytes to `HostState::stdout`  |
134/// | `host_write_stderr` | `(ptr: i32, len: i32)`               | Append bytes to `HostState::stderr`  |
135/// | `host_write_output` | `(ptr: i32, len: i32)`               | Append bytes to `HostState::output`  |
136/// | `host_read_input`   | `(buf_ptr: i32, buf_len: i32) -> i32`| Copy input bytes into guest memory   |
137/// | `host_get_input_len`| `() -> i32`                          | Return total input buffer length     |
138pub fn register_host_functions(linker: &mut wasmi::Linker<HostState>) -> Result<(), wasmi::Error> {
139    // -- host_log(ptr, len) ---------------------------------------------------
140    linker.func_wrap(
141        "env",
142        "host_log",
143        |caller: wasmi::Caller<'_, HostState>, ptr: i32, len: i32| {
144            let Some(memory) = caller.get_export("memory").and_then(|e| e.into_memory()) else {
145                return;
146            };
147            let mem_size = memory.data(&caller).len();
148            caller.data().track_memory_peek(mem_size);
149
150            if let Some(msg) = read_wasm_string(&memory, &caller, ptr, len) {
151                debug!(target: "sandbox::guest", "{}", msg);
152            }
153        },
154    )?;
155
156    // -- host_write_stdout(ptr, len) ------------------------------------------
157    linker.func_wrap(
158        "env",
159        "host_write_stdout",
160        |mut caller: wasmi::Caller<'_, HostState>, ptr: i32, len: i32| {
161            let Some(memory) = caller.get_export("memory").and_then(|e| e.into_memory()) else {
162                return;
163            };
164            let mem_size = memory.data(&caller).len();
165
166            // Capability check — read bytes first, then mutate state.
167            let has_cap = caller.data().has_capability(&Capability::Stdout);
168            if !has_cap {
169                return;
170            }
171
172            let bytes = match read_wasm_bytes(&memory, &caller, ptr, len) {
173                Some(b) => b,
174                None => return,
175            };
176
177            let host = caller.data_mut();
178            host.track_memory(mem_size);
179            host.stdout.extend_from_slice(&bytes);
180        },
181    )?;
182
183    // -- host_write_stderr(ptr, len) ------------------------------------------
184    linker.func_wrap(
185        "env",
186        "host_write_stderr",
187        |mut caller: wasmi::Caller<'_, HostState>, ptr: i32, len: i32| {
188            let Some(memory) = caller.get_export("memory").and_then(|e| e.into_memory()) else {
189                return;
190            };
191            let mem_size = memory.data(&caller).len();
192
193            let has_cap = caller.data().has_capability(&Capability::Stderr);
194            if !has_cap {
195                return;
196            }
197
198            let bytes = match read_wasm_bytes(&memory, &caller, ptr, len) {
199                Some(b) => b,
200                None => return,
201            };
202
203            let host = caller.data_mut();
204            host.track_memory(mem_size);
205            host.stderr.extend_from_slice(&bytes);
206        },
207    )?;
208
209    // -- host_write_output(ptr, len) ------------------------------------------
210    linker.func_wrap(
211        "env",
212        "host_write_output",
213        |mut caller: wasmi::Caller<'_, HostState>, ptr: i32, len: i32| {
214            let Some(memory) = caller.get_export("memory").and_then(|e| e.into_memory()) else {
215                return;
216            };
217            let mem_size = memory.data(&caller).len();
218
219            let bytes = match read_wasm_bytes(&memory, &caller, ptr, len) {
220                Some(b) => b,
221                None => return,
222            };
223
224            let host = caller.data_mut();
225            host.track_memory(mem_size);
226            host.output.extend_from_slice(&bytes);
227        },
228    )?;
229
230    // -- host_read_input(buf_ptr, buf_len) -> i32 -----------------------------
231    linker.func_wrap(
232        "env",
233        "host_read_input",
234        |mut caller: wasmi::Caller<'_, HostState>, buf_ptr: i32, buf_len: i32| -> i32 {
235            let Some(memory) = caller.get_export("memory").and_then(|e| e.into_memory()) else {
236                return 0;
237            };
238
239            if buf_ptr < 0 || buf_len < 0 {
240                return 0;
241            }
242
243            let dst_start = buf_ptr as usize;
244            let dst_cap = buf_len as usize;
245
246            // Determine how many bytes remain in the input.
247            let input_pos = caller.data().input_pos;
248            let remaining = caller.data().input.len().saturating_sub(input_pos);
249            let to_copy = remaining.min(dst_cap);
250
251            if to_copy == 0 {
252                return 0;
253            }
254
255            // Validate destination bounds.
256            let mem_size = memory.data(&caller).len();
257            if dst_start.saturating_add(to_copy) > mem_size {
258                return 0;
259            }
260
261            // Copy the input slice into a temporary buffer so we can release
262            // the shared reference before borrowing mutably.
263            let src_bytes: Vec<u8> = caller.data().input[input_pos..input_pos + to_copy].to_vec();
264
265            // Write into WASM memory.
266            let data = memory.data_mut(&mut caller);
267            data[dst_start..dst_start + to_copy].copy_from_slice(&src_bytes);
268
269            // Advance input position and track memory.
270            let host = caller.data_mut();
271            host.input_pos += to_copy;
272            host.track_memory(mem_size);
273
274            to_copy as i32
275        },
276    )?;
277
278    // -- host_get_input_len() -> i32 ------------------------------------------
279    linker.func_wrap(
280        "env",
281        "host_get_input_len",
282        |caller: wasmi::Caller<'_, HostState>| -> i32 { caller.data().input.len() as i32 },
283    )?;
284
285    Ok(())
286}
287
288// ---------------------------------------------------------------------------
289// Private helpers used inside closures that only have shared access
290// ---------------------------------------------------------------------------
291
292impl HostState {
293    /// Non-mutating peek at memory size for use inside closures that hold a
294    /// shared `Caller` reference. The actual peak update happens later via
295    /// [`track_memory`].
296    fn track_memory_peek(&self, _current_bytes: usize) {
297        // Intentional no-op: the host_log function only has a shared reference
298        // so it cannot update peak. The next mutable host call will record it.
299    }
300}
301
302// ---------------------------------------------------------------------------
303// Tests
304// ---------------------------------------------------------------------------
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use std::path::PathBuf;
310
311    // -- HostState creation ---------------------------------------------------
312
313    #[test]
314    fn test_host_state_new() {
315        let input = b"hello world".to_vec();
316        let caps = vec![Capability::Stdout, Capability::Stderr];
317        let state = HostState::new(input.clone(), caps.clone());
318
319        assert!(state.stdout.is_empty());
320        assert!(state.stderr.is_empty());
321        assert!(state.output.is_empty());
322        assert_eq!(state.input, input);
323        assert_eq!(state.input_pos, 0);
324        assert_eq!(state.capabilities, caps);
325        assert_eq!(state.memory_peak, 0);
326    }
327
328    // -- Capability checks ----------------------------------------------------
329
330    #[test]
331    fn test_host_state_has_stdout_capability() {
332        let state = HostState::new(Vec::new(), vec![Capability::Stdout]);
333
334        assert!(state.has_capability(&Capability::Stdout));
335        assert!(!state.has_capability(&Capability::Stderr));
336    }
337
338    #[test]
339    fn test_host_state_has_stderr_capability() {
340        let state = HostState::new(Vec::new(), vec![Capability::Stderr]);
341
342        assert!(state.has_capability(&Capability::Stderr));
343        assert!(!state.has_capability(&Capability::Stdout));
344    }
345
346    #[test]
347    fn test_host_state_no_capability() {
348        let state = HostState::new(Vec::new(), Vec::new());
349
350        assert!(!state.has_capability(&Capability::Stdout));
351        assert!(!state.has_capability(&Capability::Stderr));
352        assert!(!state.has_capability(&Capability::FileRead(vec![PathBuf::from("/tmp")])));
353    }
354
355    // -- Memory tracking ------------------------------------------------------
356
357    #[test]
358    fn test_host_state_track_memory() {
359        let mut state = HostState::new(Vec::new(), Vec::new());
360
361        assert_eq!(state.memory_peak, 0);
362        state.track_memory(1024);
363        assert_eq!(state.memory_peak, 1024);
364        state.track_memory(4096);
365        assert_eq!(state.memory_peak, 4096);
366    }
367
368    #[test]
369    fn test_host_state_track_memory_no_decrease() {
370        let mut state = HostState::new(Vec::new(), Vec::new());
371
372        state.track_memory(8192);
373        assert_eq!(state.memory_peak, 8192);
374
375        // A smaller value should not decrease the peak.
376        state.track_memory(4096);
377        assert_eq!(state.memory_peak, 8192);
378
379        // An equal value should not change it either.
380        state.track_memory(8192);
381        assert_eq!(state.memory_peak, 8192);
382    }
383
384    // -- FileRead capability path matching ------------------------------------
385
386    #[test]
387    fn test_host_state_has_file_read_capability() {
388        let state = HostState::new(
389            Vec::new(),
390            vec![Capability::FileRead(vec![
391                PathBuf::from("/tmp"),
392                PathBuf::from("/home/user/data"),
393            ])],
394        );
395
396        // Exact match on an allowed prefix.
397        assert!(state.has_capability(&Capability::FileRead(vec![PathBuf::from("/tmp")])));
398
399        // Sub-path within an allowed prefix.
400        assert!(
401            state.has_capability(&Capability::FileRead(vec![PathBuf::from(
402                "/tmp/foo/bar.txt"
403            )]))
404        );
405
406        // Another allowed prefix.
407        assert!(
408            state.has_capability(&Capability::FileRead(vec![PathBuf::from(
409                "/home/user/data/report.csv"
410            )]))
411        );
412
413        // Path not covered by any prefix.
414        assert!(!state.has_capability(&Capability::FileRead(vec![PathBuf::from("/etc/passwd")])));
415
416        // A different capability variant should not match.
417        assert!(!state.has_capability(&Capability::FileWrite(vec![PathBuf::from("/tmp")])));
418    }
419
420    // -- Registration smoke test ----------------------------------------------
421
422    #[test]
423    fn test_register_host_functions() {
424        let engine = wasmi::Engine::default();
425        let mut linker = wasmi::Linker::<HostState>::new(&engine);
426
427        // Should not panic.
428        register_host_functions(&mut linker).expect("registration should succeed");
429    }
430}