advanced_syscalls/
advanced_syscalls.rs

1use libriscv::{
2    error_handler, stdout_handler, syscall, syscall_registry, ErrorContext, Machine, Options,
3    Registers, Result, StdoutContext, SyscallContext, SyscallResult,
4};
5use std::ffi::CStr;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8type GuestAddr = u64;
9
10static HOST_FN_ADDR: AtomicU64 = AtomicU64::new(0);
11
12#[repr(C)]
13#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
14struct Strings {
15    count: GuestAddr,
16    strings: [GuestAddr; 32],
17}
18
19#[repr(C)]
20#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
21struct Buffers {
22    count: GuestAddr,
23    buffer: [u8; 256],
24    another_count: GuestAddr,
25    another_buffer_address: GuestAddr,
26}
27
28#[error_handler]
29fn error_callback(ctx: &mut ErrorContext) {
30    let data = ctx.data();
31    let text = ctx
32        .message()
33        .map(CStr::to_string_lossy)
34        .unwrap_or_else(|| "<null>".into());
35    eprintln!("Error: {} (data: 0x{:X})", text, data);
36}
37
38#[stdout_handler]
39fn stdout_callback(ctx: &mut StdoutContext) {
40    let slice = ctx.data();
41    if slice.is_empty() {
42        return;
43    }
44    let text = String::from_utf8_lossy(slice);
45    print!("[libriscv] stdout: {}", text);
46}
47
48fn write_c_string(dst: &mut [u8], text: &[u8]) -> usize {
49    if dst.is_empty() {
50        return 0;
51    }
52    let max = dst.len() - 1;
53    let len = text.len().min(max);
54    dst[..len].copy_from_slice(&text[..len]);
55    dst[len] = 0;
56    len
57}
58
59#[syscall_registry]
60mod host_syscalls {
61    use super::*;
62
63    #[syscall(id = 500)]
64    fn host_function_500(ctx: &mut SyscallContext) -> SyscallResult<()> {
65        println!("Hello from host function 0!");
66        let addr = {
67            let regs = ctx.registers()?;
68            regs.x(10)?
69        };
70        let strings: Strings = ctx.read_pod(addr)?;
71        let count = (strings.count as usize).min(strings.strings.len());
72        for i in 0..count {
73            let slice = match ctx.memstring(strings.strings[i], 256) {
74                Ok(slice) => slice,
75                Err(_) => continue,
76            };
77            println!("  {}", String::from_utf8_lossy(slice));
78        }
79        Ok(())
80    }
81
82    #[syscall(id = 501)]
83    fn host_function_501(ctx: &mut SyscallContext) -> SyscallResult<()> {
84        println!("Hello from host function 501!");
85        let addr = ctx.registers()?.x(10)?;
86        let mut buf: Buffers = ctx.read_pod(addr)?;
87        let len = write_c_string(&mut buf.buffer, b"Hello from host function 501!");
88        buf.count = len as GuestAddr;
89
90        let another_len = buf.another_count as usize;
91        if another_len == 0 {
92            ctx.write_pod(addr, &buf)?;
93            return Ok(());
94        }
95        if another_len > u32::MAX as usize {
96            eprintln!("host_function_501: another buffer too large");
97            ctx.write_pod(addr, &buf)?;
98            return Ok(());
99        }
100        let another_slice = ctx.writable_memview(buf.another_buffer_address, another_len)?;
101        let second = b"Another buffer from host function 501!";
102        if second.len() >= another_len {
103            eprintln!("host_function_501: another buffer too small");
104            ctx.write_pod(addr, &buf)?;
105            return Ok(());
106        }
107        let len = write_c_string(another_slice, second);
108        buf.another_count = len as GuestAddr;
109        ctx.write_pod(addr, &buf)?;
110        Ok(())
111    }
112
113    #[syscall(id = 502)]
114    fn host_function_502(ctx: &mut SyscallContext) -> SyscallResult<()> {
115        let addr = ctx.registers()?.x(10)?;
116        HOST_FN_ADDR.store(addr, Ordering::Relaxed);
117        Ok(())
118    }
119
120    #[syscall(id = 503)]
121    fn host_function_503(ctx: &mut SyscallContext) -> SyscallResult<()> {
122        let mut regs = ctx.registers()?;
123        let mut x = regs.f32(10)?;
124        let mut y = regs.f32(11)?;
125        let mut z = regs.f32(12)?;
126
127        let len = (x * x + y * y + z * z).sqrt();
128        if len > 0.0 {
129            let inv = 1.0 / len;
130            x *= inv;
131            y *= inv;
132            z *= inv;
133        }
134
135        regs.set_f32(10, x)?;
136        regs.set_f32(11, y)?;
137        regs.set_f32(12, z)?;
138        Ok(())
139    }
140}
141
142fn reserve_stack(regs: &mut Registers<'_>, size: usize) -> Result<u64> {
143    let sp = regs.x(2)?;
144    let new_sp = sp.wrapping_sub(size as u64) & !0xFu64;
145    regs.set_x(2, new_sp)?;
146    Ok(new_sp)
147}
148
149fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
150    let args: Vec<String> = std::env::args().collect();
151    if args.len() < 2 {
152        eprintln!("Usage: {} [program file]", args[0]);
153        std::process::exit(1);
154    }
155
156    let elf = std::fs::read(&args[1])?;
157
158    let registry = host_syscalls::registry()?;
159
160    let options = Options::builder()
161        .stdout_handler(stdout_callback_handler())
162        .error_handler(error_callback_handler())
163        .args(["program"])
164        .build()?;
165
166    let mut machine = Machine::new(elf, options, &registry)?;
167    machine.run(u64::MAX)?;
168
169    let addr = HOST_FN_ADDR.load(Ordering::Relaxed);
170    if addr != 0 {
171        machine.setup_vmcall(addr)?;
172        let msg = b"Hello from a callback function!\0";
173        let str_addr = {
174            let mut regs = machine.registers()?;
175            reserve_stack(&mut regs, msg.len())?
176        };
177        machine.copy_to_guest(str_addr, msg)?;
178        {
179            let mut regs = machine.registers()?;
180            regs.set_x(10, str_addr)?;
181        }
182        machine.run(u64::MAX)?;
183    } else {
184        eprintln!("Host function 502 was not called");
185    }
186
187    println!("Done");
188    Ok(())
189}