risc0_circuit_rv32im/execute/
testutil.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16
17use anyhow::{bail, Result};
18use risc0_binfmt::{MemoryImage, Program};
19use risc0_core::scope;
20use risc0_zkp::{
21    core::{digest::Digest, log2_ceil},
22    MAX_CYCLES_PO2,
23};
24
25use super::{
26    pager::RESERVED_PAGING_CYCLES, platform::*, syscall::Syscall, Executor, SimpleSession,
27    SyscallContext,
28};
29
30pub const DEFAULT_SESSION_LIMIT: Option<u64> = Some(1 << 24);
31pub const MIN_CYCLES_PO2: usize = log2_ceil(RESERVED_CYCLES + RESERVED_PAGING_CYCLES as usize);
32
33#[derive(Default)]
34pub struct NullSyscall;
35
36impl Syscall for NullSyscall {
37    fn host_read(&self, _ctx: &mut dyn SyscallContext, _fd: u32, buf: &mut [u8]) -> Result<u32> {
38        for (i, byte) in buf.iter_mut().enumerate() {
39            *byte = i as u8;
40        }
41        Ok(buf.len() as u32)
42    }
43
44    fn host_write(&self, _ctx: &mut dyn SyscallContext, _fd: u32, _buf: &[u8]) -> Result<u32> {
45        unimplemented!()
46    }
47}
48
49pub fn execute<S: Syscall>(
50    image: MemoryImage,
51    segment_limit_po2: usize,
52    max_insn_cycles: usize,
53    max_cycles: Option<u64>,
54    syscall_handler: &S,
55    input_digest: Option<Digest>,
56) -> Result<SimpleSession> {
57    scope!("execute");
58
59    if !(MIN_CYCLES_PO2..=MAX_CYCLES_PO2).contains(&segment_limit_po2) {
60        bail!("Invalid segment_limit_po2: {segment_limit_po2}");
61    }
62
63    let mut segments = Vec::new();
64    let trace = Vec::new();
65    let result = Executor::new(image, syscall_handler, input_digest, trace).run(
66        segment_limit_po2,
67        max_insn_cycles,
68        max_cycles,
69        |segment| {
70            tracing::trace!("{segment:#?}");
71            segments.push(segment);
72            Ok(())
73        },
74    )?;
75
76    Ok(SimpleSession { segments, result })
77}
78
79pub mod user {
80    use super::*;
81
82    pub fn basic() -> Program {
83        let mut asm = Assembler::new();
84        asm.li(REG_A1, 0x4000_0000);
85        asm.ecall();
86        asm.program()
87    }
88
89    pub fn simple_loop(count: u32) -> Program {
90        let mut asm = Assembler::new();
91        asm.addi(REG_A4, REG_ZERO, 0);
92        asm.li(REG_A5, count);
93        // loop:
94        asm.addi(REG_A4, REG_A4, 1);
95        asm.blt(REG_A4, REG_A5, -4 /*loop: */);
96        asm.lui(REG_A1, 0x1000);
97        asm.ecall();
98        asm.program()
99    }
100}
101
102pub mod kernel {
103    use super::*;
104
105    pub fn basic() -> Program {
106        let mut asm = Assembler::new();
107        asm.host_terminate(0, 0);
108        asm.program()
109    }
110
111    pub fn simple_loop(count: u32) -> Program {
112        let mut asm = Assembler::new();
113        asm.addi(REG_A4, REG_ZERO, 0);
114        asm.li(REG_A5, count);
115        // loop:
116        asm.addi(REG_A4, REG_A4, 1);
117        asm.blt(REG_A4, REG_A5, -4 /*loop: */);
118        asm.host_terminate(0, 0);
119        asm.program()
120    }
121
122    pub fn multi_read() -> Program {
123        const LENGTHS: &[u32] = &[0, 1, 2, 3, 4, 5, 7, 13, 19, 40, 101];
124
125        let ptr = 0x0050_0000;
126
127        let mut asm = Assembler::new();
128        asm.li(REG_T0, ptr);
129        // Try all 4 alignments
130        for i in 0..4 {
131            // Try a variety of size
132            for &len in LENGTHS {
133                asm.host_ecall_read(0, ptr + i, len);
134                for k in 0..len {
135                    asm.lb(REG_T1, REG_T0, i + k);
136                    asm.li(REG_T2, k);
137                    asm.beq(REG_T1, REG_T2, 8);
138                    asm.die();
139                }
140            }
141        }
142
143        asm.host_terminate(0, 0);
144
145        asm.program()
146    }
147}
148
149#[allow(unused)]
150mod consts {
151    pub(crate) const OP_BASE: u32 = 0b0110011;
152    pub(crate) const OP_IMM: u32 = 0b0010011;
153    pub(crate) const OP_LOAD: u32 = 0b0000011;
154    pub(crate) const OP_STORE: u32 = 0b0100011;
155    pub(crate) const OP_BRANCH: u32 = 0b1100011;
156    pub(crate) const OP_JAL: u32 = 0b1101111;
157    pub(crate) const OP_JALR: u32 = 0b1100111;
158    pub(crate) const OP_LUI: u32 = 0b0110111;
159    pub(crate) const OP_AUIPC: u32 = 0b0010111;
160    pub(crate) const OP_ENV: u32 = 0b1110011;
161
162    pub(crate) const FUNCT3_EQ: u32 = 0x0;
163    pub(crate) const FUNCT3_NE: u32 = 0x1;
164    pub(crate) const FUNCT3_LT: u32 = 0x4;
165    pub(crate) const FUNCT3_GE: u32 = 0x5;
166    pub(crate) const FUNCT3_LTU: u32 = 0x6;
167    pub(crate) const FUNCT3_GEU: u32 = 0x7;
168
169    pub(crate) const FUNCT3_BYTE: u32 = 0x0;
170    pub(crate) const FUNCT3_HALF: u32 = 0x1;
171    pub(crate) const FUNCT3_WORD: u32 = 0x2;
172    pub(crate) const FUNCT3_BYTEU: u32 = 0x4;
173    pub(crate) const FUNCT3_HALFU: u32 = 0x5;
174}
175
176use consts::*;
177
178struct Assembler {
179    text: Vec<u32>,
180    data: BTreeMap<u32, u32>,
181}
182
183#[allow(dead_code)]
184impl Assembler {
185    pub fn new() -> Self {
186        Self {
187            text: vec![],
188            data: BTreeMap::new(),
189        }
190    }
191
192    pub fn program(&self) -> Program {
193        let entry = USER_START_ADDR + WORD_SIZE;
194        let entry = entry.0;
195        let mut pc = entry;
196
197        let mut image: BTreeMap<_, _> = self
198            .text
199            .iter()
200            .map(|instr| {
201                let result = (pc, *instr);
202                pc += WORD_SIZE as u32;
203                result
204            })
205            .collect();
206
207        image.extend(self.data.iter());
208
209        Program::new_from_entry_and_image(entry, image)
210    }
211
212    pub fn word(&mut self, addr: u32, word: u32) {
213        self.data.insert(addr, word);
214    }
215
216    pub fn add(&mut self, rd: usize, rs1: usize, rs2: usize) {
217        self.text.push(insn_r(
218            0x00, rs2 as u32, rs1 as u32, 0x0, rd as u32, OP_BASE,
219        ));
220    }
221
222    pub fn addi(&mut self, rd: usize, rs1: usize, imm: u32) {
223        self.text
224            .push(insn_i(imm, rs1 as u32, 0x0, rd as u32, OP_IMM));
225    }
226
227    pub fn blt(&mut self, rs1: usize, rs2: usize, offset: i32) {
228        self.text.push(insn_b(
229            offset as u32,
230            rs2 as u32,
231            rs1 as u32,
232            FUNCT3_LT,
233            OP_BRANCH,
234        ));
235    }
236
237    pub fn beq(&mut self, rs1: usize, rs2: usize, offset: i32) {
238        self.text.push(insn_b(
239            offset as u32,
240            rs2 as u32,
241            rs1 as u32,
242            FUNCT3_EQ,
243            OP_BRANCH,
244        ));
245    }
246
247    pub fn bne(&mut self, rs1: usize, rs2: usize, offset: i32) {
248        self.text.push(insn_b(
249            offset as u32,
250            rs2 as u32,
251            rs1 as u32,
252            FUNCT3_NE,
253            OP_BRANCH,
254        ));
255    }
256
257    pub fn ecall(&mut self) {
258        self.text.push(insn_i(0x0, 0x0, 0x0, 0x0, OP_ENV));
259    }
260
261    pub fn lui(&mut self, rd: usize, imm: u32) {
262        self.text.push(insn_u(imm, rd as u32, OP_LUI));
263    }
264
265    pub fn li(&mut self, rd: usize, imm: u32) {
266        if imm < (1 << 11) {
267            self.addi(rd, REG_ZERO, imm);
268        } else {
269            // sign extend low 12 bits
270            let low = ((imm as i32) << 20) >> 20;
271            // upper 20 bits
272            let high = (imm as i32 - low) >> 12;
273
274            self.lui(rd, high as u32);
275            self.addi(rd, rd, low as u32);
276        }
277    }
278
279    pub fn lb(&mut self, rd: usize, rs1: usize, imm: u32) {
280        self.text
281            .push(insn_i(imm, rs1 as u32, FUNCT3_BYTE, rd as u32, OP_LOAD));
282    }
283
284    pub fn lw(&mut self, rd: usize, rs1: usize, imm: u32) {
285        self.text
286            .push(insn_i(imm, rs1 as u32, FUNCT3_WORD, rd as u32, OP_LOAD));
287    }
288
289    pub fn host_ecall_read(&mut self, fd: u32, ptr: u32, len: u32) {
290        self.li(REG_A7, HOST_ECALL_READ);
291        self.li(REG_A0, fd);
292        self.li(REG_A1, ptr);
293        self.li(REG_A2, len);
294        self.ecall();
295    }
296
297    pub fn host_terminate(&mut self, a0: u32, a1: u32) {
298        self.li(REG_A7, HOST_ECALL_TERMINATE);
299        self.li(REG_A0, a0);
300        self.li(REG_A1, a1);
301        self.ecall();
302    }
303
304    pub fn die(&mut self) {
305        self.text.push(fence());
306    }
307}
308
309// 31        25 | 24  20 | 19  15 | 14  12 | 11        7 | 6    0 |
310//    funct7    |   rs2  |   rs1  | funct3 |      rd     | opcode |
311fn insn_r(funct7: u32, rs2: u32, rs1: u32, funct3: u32, rd: u32, opcode: u32) -> u32 {
312    (funct7 << 25) | (rs2 << 20) | (rs1 << 15) | (funct3 << 12) | (rd << 7) | opcode
313}
314
315// 31                 20 | 19  15 | 14  12 | 11        7 | 6    0 |
316//     imm[11:0]         |   rs1  | funct3 |      rd     | opcode |
317fn insn_i(imm: u32, rs1: u32, funct3: u32, rd: u32, opcode: u32) -> u32 {
318    (imm << 20) | (rs1 << 15) | (funct3 << 12) | (rd << 7) | opcode
319}
320
321// 31        25 | 24  20 | 19  15 | 14  12 | 11        7 | 6    0 |
322// imm[12|10:5] |   rs2  |   rs1  | funct3 | imm[4:1|11] | opcode |
323fn insn_b(imm: u32, rs2: u32, rs1: u32, funct3: u32, opcode: u32) -> u32 {
324    let imm_12 = (imm >> 12) & 0b1;
325    let imm_10_5 = (imm >> 5) & 0b111111;
326    let imm_11 = (imm >> 11) & 0b1;
327    let imm_4_1 = (imm >> 1) & 0b1111;
328    (((imm_12 << 6) | imm_10_5) << 25)
329        | (rs2 << 20)
330        | (rs1 << 15)
331        | (funct3 << 12)
332        | (((imm_4_1 << 1) | imm_11) << 7)
333        | opcode
334}
335
336// 31                                   12 | 11        7 | 6    0 |
337//    imm[31:12]                           |      rd     | opcode |
338fn insn_u(imm: u32, rd: u32, opcode: u32) -> u32 {
339    (imm << 12) | (rd << 7) | opcode
340}
341
342fn fence() -> u32 {
343    insn_i(0, 0, 0, 0, 0b0001111)
344}