risc0_circuit_rv32im/execute/
executor.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{cell::RefCell, collections::BTreeSet, rc::Rc};
16
17use anyhow::{bail, Result};
18use enum_map::EnumMap;
19use risc0_binfmt::{ByteAddr, MemoryImage, WordAddr};
20use risc0_zkp::core::{
21    digest::{Digest, DIGEST_BYTES},
22    log2_ceil,
23};
24
25use crate::{
26    trace::{TraceCallback, TraceEvent},
27    Rv32imV2Claim, TerminateState,
28};
29
30use super::{
31    bigint,
32    pager::{compute_partial_image, PageTraceEvent, PagedMemory},
33    platform::*,
34    poseidon2::Poseidon2State,
35    r0vm::{EcallKind, LoadOp, Risc0Context, Risc0Machine},
36    rv32im::{disasm, DecodedInstruction, Emulator, Instruction},
37    segment::Segment,
38    sha2::Sha2State,
39    syscall::Syscall,
40    SyscallContext,
41};
42
43#[derive(Clone, Debug, Default)]
44#[non_exhaustive]
45pub struct EcallMetric {
46    pub count: u64,
47    pub cycles: u64,
48}
49
50#[derive(Default)]
51pub struct EcallMetrics(EnumMap<EcallKind, EcallMetric>);
52
53pub struct Executor<'a, 'b, S: Syscall> {
54    pc: ByteAddr,
55    user_pc: ByteAddr,
56    machine_mode: u32,
57    user_cycles: u32,
58    pager: PagedMemory,
59    terminate_state: Option<TerminateState>,
60    read_record: Vec<Vec<u8>>,
61    write_record: Vec<u32>,
62    syscall_handler: &'a S,
63    input_digest: Digest,
64    output_digest: Option<Digest>,
65    trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
66    cycles: SessionCycles,
67    ecall_metrics: EcallMetrics,
68}
69
70#[non_exhaustive]
71pub struct ExecutorResult {
72    pub segments: u64,
73    pub post_image: MemoryImage,
74    pub user_cycles: u64,
75    pub total_cycles: u64,
76    pub paging_cycles: u64,
77    pub reserved_cycles: u64,
78    pub claim: Rv32imV2Claim,
79}
80
81#[derive(Default)]
82struct SessionCycles {
83    total: u64,
84    user: u64,
85    paging: u64,
86    reserved: u64,
87}
88
89#[non_exhaustive]
90pub struct SimpleSession {
91    pub segments: Vec<Segment>,
92    pub result: ExecutorResult,
93}
94
95struct ComputePartialImageRequest {
96    image: MemoryImage,
97    page_indexes: BTreeSet<u32>,
98
99    input_digest: Digest,
100    output_digest: Option<Digest>,
101    read_record: Vec<Vec<u8>>,
102    write_record: Vec<u32>,
103    user_cycles: u32,
104    pager_cycles: u32,
105    terminate_state: Option<TerminateState>,
106    segment_threshold: u32,
107    pre_digest: Digest,
108    post_digest: Digest,
109    po2: u32,
110    index: u64,
111}
112
113/// Maximum number of segments we can queue up before we block execution
114const MAX_OUTSTANDING_SEGMENTS: usize = 5;
115
116fn compute_partial_images(
117    recv: std::sync::mpsc::Receiver<ComputePartialImageRequest>,
118    mut callback: impl FnMut(Segment) -> Result<()>,
119) -> Result<()> {
120    while let Ok(req) = recv.recv() {
121        let partial_image = compute_partial_image(req.image, req.page_indexes);
122        callback(Segment {
123            partial_image,
124            claim: Rv32imV2Claim {
125                pre_state: req.pre_digest,
126                post_state: req.post_digest,
127                input: req.input_digest,
128                output: req.output_digest,
129                terminate_state: req.terminate_state,
130                shutdown_cycle: None,
131            },
132            read_record: req.read_record,
133            write_record: req.write_record,
134            suspend_cycle: req.user_cycles,
135            paging_cycles: req.pager_cycles,
136            po2: req.po2,
137            index: req.index,
138            segment_threshold: req.segment_threshold,
139        })?;
140    }
141    Ok(())
142}
143
144impl<'a, 'b, S: Syscall> Executor<'a, 'b, S> {
145    pub fn new(
146        image: MemoryImage,
147        syscall_handler: &'a S,
148        input_digest: Option<Digest>,
149        trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
150    ) -> Self {
151        Self {
152            pc: ByteAddr(0),
153            user_pc: ByteAddr(0),
154            machine_mode: 0,
155            user_cycles: 0,
156            pager: PagedMemory::new(image, !trace.is_empty() /* tracing_enabled */),
157            terminate_state: None,
158            read_record: Vec::new(),
159            write_record: Vec::new(),
160            syscall_handler,
161            input_digest: input_digest.unwrap_or_default(),
162            output_digest: None,
163            trace,
164            cycles: SessionCycles::default(),
165            ecall_metrics: EcallMetrics::default(),
166        }
167    }
168
169    pub fn run(
170        &mut self,
171        segment_po2: usize,
172        max_insn_cycles: usize,
173        max_cycles: Option<u64>,
174        callback: impl FnMut(Segment) -> Result<()> + Send,
175    ) -> Result<ExecutorResult> {
176        let segment_limit: u32 = 1 << segment_po2;
177        assert!(max_insn_cycles < segment_limit as usize);
178        let segment_threshold = segment_limit - max_insn_cycles as u32;
179        let mut segment_counter = 0;
180
181        self.reset();
182
183        let mut emu = Emulator::new();
184        Risc0Machine::resume(self)?;
185        let initial_digest = self.pager.image.image_id();
186        tracing::debug!("initial_digest: {initial_digest}");
187
188        let (commit_sender, commit_recv) =
189            std::sync::mpsc::sync_channel(MAX_OUTSTANDING_SEGMENTS - 1);
190
191        let post_digest = std::thread::scope(|scope| {
192            let partial_images_thread =
193                scope.spawn(move || compute_partial_images(commit_recv, callback));
194
195            while self.terminate_state.is_none() {
196                if let Some(max_cycles) = max_cycles {
197                    if self.cycles.user >= max_cycles {
198                        bail!(
199                            "Session limit exceeded: {} >= {max_cycles}",
200                            self.cycles.user
201                        );
202                    }
203                }
204
205                if self.segment_cycles() > segment_threshold {
206                    tracing::debug!(
207                        "split(phys: {} + pager: {} + reserved: {RESERVED_CYCLES}) = {} >= {segment_threshold}",
208                        self.user_cycles,
209                        self.pager.cycles,
210                        self.segment_cycles()
211                    );
212
213                    assert!(
214                        self.segment_cycles() < segment_limit,
215                        "segment limit ({segment_limit}) too small for instruction at pc: {:?}",
216                        self.pc
217                    );
218                    Risc0Machine::suspend(self)?;
219
220                    let (pre_image, pre_digest, post_digest) = self.pager.commit();
221
222                    let req = ComputePartialImageRequest {
223                        image: pre_image,
224                        page_indexes: self.pager.page_indexes(),
225                        input_digest: self.input_digest,
226                        output_digest: self.output_digest,
227                        read_record: std::mem::take(&mut self.read_record),
228                        write_record: std::mem::take(&mut self.write_record),
229                        user_cycles: self.user_cycles,
230                        pager_cycles: self.pager.cycles,
231                        terminate_state: self.terminate_state,
232                        segment_threshold,
233                        pre_digest,
234                        post_digest,
235                        po2: segment_po2 as u32,
236                        index: segment_counter,
237                    };
238                    if commit_sender.send(req).is_err() {
239                        return Err(partial_images_thread.join().unwrap().unwrap_err());
240                    }
241
242                    segment_counter += 1;
243                    let total_cycles = 1 << segment_po2;
244                    let pager_cycles = self.pager.cycles as u64;
245                    let user_cycles = self.user_cycles as u64;
246                    self.cycles.total += total_cycles;
247                    self.cycles.paging += pager_cycles;
248                    self.cycles.reserved += total_cycles - pager_cycles - user_cycles;
249                    self.user_cycles = 0;
250                    self.pager.reset();
251
252                    Risc0Machine::resume(self)?;
253                }
254
255                Risc0Machine::step(&mut emu, self).map_err(|err| {
256                    let result = self.dump_segment(segment_po2, segment_threshold, segment_counter);
257                    if let Err(inner) = result {
258                        err.context(inner)
259                    } else {
260                        err
261                    }
262                })?;
263            }
264
265            Risc0Machine::suspend(self)?;
266
267            let final_cycles = self.segment_cycles().next_power_of_two();
268            let final_po2 = log2_ceil(final_cycles as usize);
269            let (pre_image, pre_digest, post_digest) = self.pager.commit();
270            let req = ComputePartialImageRequest {
271                image: pre_image,
272                page_indexes: self.pager.page_indexes(),
273                input_digest: self.input_digest,
274                output_digest: self.output_digest,
275                read_record: std::mem::take(&mut self.read_record),
276                write_record: std::mem::take(&mut self.write_record),
277                user_cycles: self.user_cycles,
278                pager_cycles: self.pager.cycles,
279                terminate_state: self.terminate_state,
280                segment_threshold: 0, // meaningless for final segment
281                pre_digest,
282                post_digest,
283                po2: final_po2 as u32,
284                index: segment_counter,
285            };
286            if commit_sender.send(req).is_err() {
287                return Err(partial_images_thread.join().unwrap().unwrap_err());
288            }
289
290            let final_cycles = final_cycles as u64;
291            let user_cycles = self.user_cycles as u64;
292            let pager_cycles = self.pager.cycles as u64;
293            self.cycles.total += final_cycles;
294            self.cycles.paging += pager_cycles;
295            self.cycles.reserved += final_cycles - pager_cycles - user_cycles;
296
297            drop(commit_sender);
298
299            partial_images_thread.join().unwrap()?;
300
301            Ok(post_digest)
302        })?;
303
304        let session_claim = Rv32imV2Claim {
305            pre_state: initial_digest,
306            post_state: post_digest,
307            input: self.input_digest,
308            output: self.output_digest,
309            terminate_state: self.terminate_state,
310            shutdown_cycle: None,
311        };
312
313        Ok(ExecutorResult {
314            segments: segment_counter + 1,
315            post_image: self.pager.image.clone(),
316            user_cycles: self.cycles.user,
317            total_cycles: self.cycles.total,
318            paging_cycles: self.cycles.paging,
319            reserved_cycles: self.cycles.reserved,
320            claim: session_claim,
321        })
322    }
323
324    fn dump_segment(
325        &mut self,
326        po2: usize,
327        segment_threshold: u32,
328        index: u64,
329    ) -> anyhow::Result<()> {
330        if let Some(dump_path) = std::env::var_os("RISC0_DUMP_PATH") {
331            tracing::error!(
332                "Execution failure, saving segment to {}:",
333                dump_path.to_string_lossy()
334            );
335
336            let (image, pre_digest, post_digest) = self.pager.commit();
337            let page_indexes = self.pager.page_indexes();
338            let partial_image = compute_partial_image(image, page_indexes);
339
340            let segment = Segment {
341                partial_image,
342                claim: Rv32imV2Claim {
343                    pre_state: pre_digest,
344                    post_state: post_digest,
345                    input: self.input_digest,
346                    output: self.output_digest,
347                    terminate_state: self.terminate_state,
348                    shutdown_cycle: None,
349                },
350                read_record: std::mem::take(&mut self.read_record),
351                write_record: std::mem::take(&mut self.write_record),
352                suspend_cycle: self.user_cycles,
353                paging_cycles: self.pager.cycles,
354                po2: po2 as u32,
355                index,
356                segment_threshold,
357            };
358            tracing::error!("{segment:?}");
359
360            let bytes = segment.encode()?;
361            tracing::error!("serialized {} bytes", bytes.len());
362
363            std::fs::write(dump_path, bytes)?;
364        }
365        Ok(())
366    }
367
368    pub fn take_ecall_metrics(&mut self) -> EcallMetrics {
369        std::mem::take(&mut self.ecall_metrics)
370    }
371
372    fn reset(&mut self) {
373        self.pager.reset();
374        self.terminate_state = None;
375        self.read_record.clear();
376        self.write_record.clear();
377        self.output_digest = None;
378        self.machine_mode = 0;
379        self.user_cycles = 0;
380        self.cycles = SessionCycles::default();
381        self.pc = ByteAddr(0);
382        self.ecall_metrics = EcallMetrics::default();
383    }
384
385    fn segment_cycles(&self) -> u32 {
386        self.user_cycles + self.pager.cycles + RESERVED_CYCLES as u32
387    }
388
389    fn inc_user_cycles(&mut self, count: usize, ecall: Option<EcallKind>) {
390        self.cycles.user += count as u64;
391        self.user_cycles += count as u32;
392        if let Some(kind) = ecall {
393            self.ecall_metrics.0[kind].cycles += count as u64;
394        }
395    }
396
397    #[cold]
398    fn trace(&mut self, event: TraceEvent) -> Result<()> {
399        for trace in self.trace.iter() {
400            trace.borrow_mut().trace_callback(event.clone())?;
401        }
402        Ok(())
403    }
404
405    #[cold]
406    fn trace_pager(&mut self) -> Result<()> {
407        for &event in self.pager.trace_events() {
408            let event = TraceEvent::from(event);
409            for trace in self.trace.iter() {
410                trace.borrow_mut().trace_callback(event.clone())?;
411            }
412        }
413        self.pager.clear_trace_events();
414        Ok(())
415    }
416
417    #[cold]
418    fn trace_instruction(&self, cycle: u64, insn: &Instruction, decoded: &DecodedInstruction) {
419        tracing::trace!(
420            "[{}:{}:{cycle}] {:?}> {:#010x}  {}",
421            self.user_cycles + 1,
422            self.segment_cycles() + 1,
423            self.pc,
424            decoded.insn,
425            disasm(insn, decoded)
426        );
427    }
428}
429
430impl<S: Syscall> Risc0Context for Executor<'_, '_, S> {
431    fn get_pc(&self) -> ByteAddr {
432        self.pc
433    }
434
435    fn set_pc(&mut self, addr: ByteAddr) {
436        self.pc = addr;
437    }
438
439    fn set_user_pc(&mut self, addr: ByteAddr) {
440        self.user_pc = addr;
441    }
442
443    fn get_machine_mode(&self) -> u32 {
444        self.machine_mode
445    }
446
447    fn set_machine_mode(&mut self, mode: u32) {
448        self.machine_mode = mode;
449    }
450
451    fn resume(&mut self) -> Result<()> {
452        let input_words = self.input_digest.as_words().to_vec();
453        for (i, word) in input_words.iter().enumerate() {
454            self.store_u32(GLOBAL_INPUT_ADDR.waddr() + i, *word)?;
455        }
456        Ok(())
457    }
458
459    fn on_insn_start(&mut self, insn: &Instruction, decoded: &DecodedInstruction) -> Result<()> {
460        let cycle = self.cycles.user;
461        if tracing::enabled!(tracing::Level::TRACE) {
462            self.trace_instruction(cycle, insn, decoded);
463        }
464        if !self.trace.is_empty() {
465            self.trace(TraceEvent::InstructionStart {
466                cycle,
467                pc: self.pc.0,
468                insn: decoded.insn,
469            })
470        } else {
471            Ok(())
472        }
473    }
474
475    fn on_insn_end(&mut self, _insn: &Instruction, _decoded: &DecodedInstruction) -> Result<()> {
476        self.inc_user_cycles(1, None);
477        if !self.trace.is_empty() {
478            self.trace_pager()?;
479        }
480        Ok(())
481    }
482
483    fn on_ecall_cycle(
484        &mut self,
485        cur: CycleState,
486        _next: CycleState,
487        _s0: u32,
488        _s1: u32,
489        _s2: u32,
490        kind: EcallKind,
491    ) -> Result<()> {
492        if cur == CycleState::MachineEcall {
493            self.ecall_metrics.0[kind].count += 1;
494        }
495        self.inc_user_cycles(1, Some(kind));
496        if !self.trace.is_empty() {
497            self.trace_pager()?;
498        }
499        Ok(())
500    }
501
502    fn load_u32(&mut self, op: LoadOp, addr: WordAddr) -> Result<u32> {
503        let word = match op {
504            LoadOp::Peek => self.pager.peek(addr)?,
505            LoadOp::Load | LoadOp::Record => self.pager.load(addr)?,
506        };
507        // tracing::trace!("load_mem({:?}) -> {word:#010x}", addr.baddr());
508        Ok(word)
509    }
510
511    fn load_register(&mut self, _op: LoadOp, base: WordAddr, idx: usize) -> Result<u32> {
512        let word = self.pager.load_register(base, idx);
513        // tracing::trace!("load_register({:?}) -> {word:#010x}", addr.baddr());
514        Ok(word)
515    }
516
517    fn store_u32(&mut self, addr: WordAddr, word: u32) -> Result<()> {
518        // tracing::trace!(
519        //     "store_u32({:?}, {word:#010x}), pc: {:?}",
520        //     addr.baddr(),
521        //     self.pc
522        // );
523        if !self.trace.is_empty() {
524            self.trace(TraceEvent::MemorySet {
525                addr: addr.baddr().0,
526                region: word.to_be_bytes().to_vec(),
527            })?;
528        }
529        self.pager.store(addr, word)
530    }
531
532    fn store_register(&mut self, base: WordAddr, idx: usize, word: u32) -> Result<()> {
533        // tracing::trace!("store_register({:?}, {word:#010x})", addr.baddr());
534        if !self.trace.is_empty() {
535            self.trace(TraceEvent::MemorySet {
536                addr: (base + idx).baddr().0,
537                region: word.to_be_bytes().to_vec(),
538            })?;
539        }
540        self.pager.store_register(base, idx, word);
541        Ok(())
542    }
543
544    fn on_terminate(&mut self, a0: u32, a1: u32) -> Result<()> {
545        self.terminate_state = Some(TerminateState {
546            a0: a0.into(),
547            a1: a1.into(),
548        });
549        tracing::debug!("{:?}", self.terminate_state);
550
551        let output: Digest = self
552            .load_region(LoadOp::Peek, GLOBAL_OUTPUT_ADDR, DIGEST_BYTES)?
553            .as_slice()
554            .try_into()?;
555        self.output_digest = Some(output);
556
557        Ok(())
558    }
559
560    fn host_read(&mut self, fd: u32, buf: &mut [u8]) -> Result<u32> {
561        let rlen = self.syscall_handler.host_read(self, fd, buf)?;
562        let slice = &buf[..rlen as usize];
563        self.read_record.push(slice.to_vec());
564        Ok(rlen)
565    }
566
567    fn host_write(&mut self, fd: u32, buf: &[u8]) -> Result<u32> {
568        let rlen = self.syscall_handler.host_write(self, fd, buf)?;
569        self.write_record.push(rlen);
570        Ok(rlen)
571    }
572
573    fn on_sha2_cycle(&mut self, _cur_state: CycleState, _sha2: &Sha2State) {
574        self.inc_user_cycles(1, Some(EcallKind::Sha2));
575    }
576
577    fn on_poseidon2_cycle(&mut self, _cur_state: CycleState, _p2: &Poseidon2State) {
578        self.inc_user_cycles(1, Some(EcallKind::Poseidon2));
579    }
580
581    fn ecall_bigint(&mut self) -> Result<()> {
582        let cycles = bigint::ecall_execute(self)?;
583        self.inc_user_cycles(cycles, Some(EcallKind::BigInt));
584        Ok(())
585    }
586}
587
588impl<S: Syscall> SyscallContext for Executor<'_, '_, S> {
589    fn peek_register(&mut self, idx: usize) -> Result<u32> {
590        if idx >= REG_MAX {
591            bail!("invalid register: x{idx}");
592        }
593        self.load_register(LoadOp::Peek, USER_REGS_ADDR.waddr(), idx)
594    }
595
596    fn peek_u32(&mut self, addr: ByteAddr) -> Result<u32> {
597        // let addr = Self::check_guest_addr(addr)?;
598        self.load_u32(LoadOp::Peek, addr.waddr())
599    }
600
601    fn peek_u8(&mut self, addr: ByteAddr) -> Result<u8> {
602        // let addr = Self::check_guest_addr(addr)?;
603        self.load_u8(LoadOp::Peek, addr)
604    }
605
606    fn peek_region(&mut self, addr: ByteAddr, size: usize) -> Result<Vec<u8>> {
607        // let addr = Self::check_guest_addr(addr)?;
608        self.load_region(LoadOp::Peek, addr, size)
609    }
610
611    fn peek_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
612        self.pager.peek_page(page_idx)
613    }
614
615    fn get_cycle(&self) -> u64 {
616        self.cycles.user
617    }
618
619    fn get_pc(&self) -> u32 {
620        self.user_pc.0
621    }
622}
623
624impl From<EcallMetrics> for Vec<(String, EcallMetric)> {
625    fn from(metrics: EcallMetrics) -> Self {
626        metrics
627            .0
628            .into_iter()
629            .map(|(kind, metric)| (format!("{kind:?}"), metric))
630            .collect()
631    }
632}
633
634impl From<PageTraceEvent> for TraceEvent {
635    fn from(event: PageTraceEvent) -> Self {
636        match event {
637            PageTraceEvent::PageIn { cycles } => TraceEvent::PageIn {
638                cycles: cycles as u64,
639            },
640            PageTraceEvent::PageOut { cycles } => TraceEvent::PageOut {
641                cycles: cycles as u64,
642            },
643        }
644    }
645}