risc0_circuit_rv32im/execute/
executor.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{cell::RefCell, collections::BTreeSet, rc::Rc};
16
17use anyhow::{bail, Result};
18use enum_map::EnumMap;
19use risc0_binfmt::{ByteAddr, MemoryImage, WordAddr};
20use risc0_zkp::core::{
21    digest::{Digest, DIGEST_BYTES},
22    log2_ceil,
23};
24
25use crate::{
26    trace::{TraceCallback, TraceEvent},
27    Rv32imV2Claim, TerminateState,
28};
29
30use super::{
31    bigint,
32    pager::{compute_partial_image, PageTraceEvent, PagedMemory},
33    platform::*,
34    poseidon2::Poseidon2State,
35    r0vm::{EcallKind, LoadOp, Risc0Context, Risc0Machine},
36    rv32im::{disasm, DecodedInstruction, Emulator, Instruction},
37    segment::Segment,
38    sha2::Sha2State,
39    syscall::Syscall,
40    SyscallContext,
41};
42
43#[derive(Clone, Debug, Default)]
44#[non_exhaustive]
45pub struct EcallMetric {
46    pub count: u64,
47    pub cycles: u64,
48}
49
50#[derive(Default)]
51pub struct EcallMetrics(EnumMap<EcallKind, EcallMetric>);
52
53pub struct Executor<'a, 'b, S: Syscall> {
54    pc: ByteAddr,
55    user_pc: ByteAddr,
56    machine_mode: u32,
57    user_cycles: u32,
58    pager: PagedMemory,
59    terminate_state: Option<TerminateState>,
60    read_record: Vec<Vec<u8>>,
61    write_record: Vec<u32>,
62    syscall_handler: &'a S,
63    input_digest: Digest,
64    output_digest: Option<Digest>,
65    trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
66    cycles: SessionCycles,
67    ecall_metrics: EcallMetrics,
68}
69
70#[non_exhaustive]
71pub struct ExecutorResult {
72    pub segments: u64,
73    pub post_image: MemoryImage,
74    pub user_cycles: u64,
75    pub total_cycles: u64,
76    pub paging_cycles: u64,
77    pub reserved_cycles: u64,
78    pub claim: Rv32imV2Claim,
79}
80
81#[derive(Default)]
82struct SessionCycles {
83    total: u64,
84    user: u64,
85    paging: u64,
86    reserved: u64,
87}
88
89#[non_exhaustive]
90pub struct SimpleSession {
91    pub segments: Vec<Segment>,
92    pub result: ExecutorResult,
93}
94
95struct ComputePartialImageRequest {
96    image: MemoryImage,
97    page_indexes: BTreeSet<u32>,
98
99    input_digest: Digest,
100    output_digest: Option<Digest>,
101    read_record: Vec<Vec<u8>>,
102    write_record: Vec<u32>,
103    user_cycles: u32,
104    pager_cycles: u32,
105    terminate_state: Option<TerminateState>,
106    segment_threshold: u32,
107    pre_digest: Digest,
108    post_digest: Digest,
109    po2: u32,
110    index: u64,
111}
112
113/// Maximum number of segments we can queue up before we block execution
114const MAX_OUTSTANDING_SEGMENTS: usize = 5;
115
116fn compute_partial_images(
117    recv: std::sync::mpsc::Receiver<ComputePartialImageRequest>,
118    mut callback: impl FnMut(Segment) -> Result<()>,
119) -> Result<()> {
120    while let Ok(req) = recv.recv() {
121        let partial_image = compute_partial_image(req.image, req.page_indexes);
122        callback(Segment {
123            partial_image,
124            claim: Rv32imV2Claim {
125                pre_state: req.pre_digest,
126                post_state: req.post_digest,
127                input: req.input_digest,
128                output: req.output_digest,
129                terminate_state: req.terminate_state,
130                shutdown_cycle: None,
131            },
132            read_record: req.read_record,
133            write_record: req.write_record,
134            suspend_cycle: req.user_cycles,
135            paging_cycles: req.pager_cycles,
136            po2: req.po2,
137            index: req.index,
138            segment_threshold: req.segment_threshold,
139        })?;
140    }
141    Ok(())
142}
143
144impl<'a, 'b, S: Syscall> Executor<'a, 'b, S> {
145    pub fn new(
146        image: MemoryImage,
147        syscall_handler: &'a S,
148        input_digest: Option<Digest>,
149        trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
150    ) -> Self {
151        Self {
152            pc: ByteAddr(0),
153            user_pc: ByteAddr(0),
154            machine_mode: 0,
155            user_cycles: 0,
156            pager: PagedMemory::new(image, !trace.is_empty() /* tracing_enabled */),
157            terminate_state: None,
158            read_record: Vec::new(),
159            write_record: Vec::new(),
160            syscall_handler,
161            input_digest: input_digest.unwrap_or_default(),
162            output_digest: None,
163            trace,
164            cycles: SessionCycles::default(),
165            ecall_metrics: EcallMetrics::default(),
166        }
167    }
168
169    pub fn run(
170        &mut self,
171        segment_po2: usize,
172        max_insn_cycles: usize,
173        max_cycles: Option<u64>,
174        callback: impl FnMut(Segment) -> Result<()> + Send,
175    ) -> Result<ExecutorResult> {
176        let segment_limit: u32 = 1 << segment_po2;
177        assert!(max_insn_cycles < segment_limit as usize);
178        let segment_threshold = segment_limit - max_insn_cycles as u32;
179        let mut segment_counter = 0;
180
181        self.reset();
182
183        let mut emu = Emulator::new();
184        Risc0Machine::resume(self)?;
185        let initial_digest = self.pager.image.image_id();
186        tracing::debug!("initial_digest: {initial_digest}");
187
188        let (commit_sender, commit_recv) =
189            std::sync::mpsc::sync_channel(MAX_OUTSTANDING_SEGMENTS - 1);
190
191        let post_digest = std::thread::scope(|scope| {
192            let partial_images_thread =
193                scope.spawn(move || compute_partial_images(commit_recv, callback));
194
195            while self.terminate_state.is_none() {
196                if let Some(max_cycles) = max_cycles {
197                    if self.cycles.user >= max_cycles {
198                        bail!(
199                            "Session limit exceeded: {} >= {max_cycles}",
200                            self.cycles.user
201                        );
202                    }
203                }
204
205                if self.segment_cycles() > segment_threshold {
206                    tracing::debug!(
207                        "split(phys: {} + pager: {} + reserved: {RESERVED_CYCLES}) = {} >= {segment_threshold}",
208                        self.user_cycles,
209                        self.pager.cycles,
210                        self.segment_cycles()
211                    );
212
213                    assert!(
214                        self.segment_cycles() < segment_limit,
215                        "segment limit ({segment_limit}) too small for instruction at pc: {:?}",
216                        self.pc
217                    );
218                    Risc0Machine::suspend(self)?;
219
220                    let (pre_image, pre_digest, post_digest) = self.pager.commit();
221
222                    let req = ComputePartialImageRequest {
223                        image: pre_image,
224                        page_indexes: self.pager.page_indexes(),
225                        input_digest: self.input_digest,
226                        output_digest: self.output_digest,
227                        read_record: std::mem::take(&mut self.read_record),
228                        write_record: std::mem::take(&mut self.write_record),
229                        user_cycles: self.user_cycles,
230                        pager_cycles: self.pager.cycles,
231                        terminate_state: self.terminate_state,
232                        segment_threshold,
233                        pre_digest,
234                        post_digest,
235                        po2: segment_po2 as u32,
236                        index: segment_counter,
237                    };
238                    if commit_sender.send(req).is_err() {
239                        return Err(partial_images_thread.join().unwrap().unwrap_err());
240                    }
241
242                    segment_counter += 1;
243                    let total_cycles = 1 << segment_po2;
244                    let pager_cycles = self.pager.cycles as u64;
245                    let user_cycles = self.user_cycles as u64;
246                    self.cycles.total += total_cycles;
247                    self.cycles.paging += pager_cycles;
248                    self.cycles.reserved += total_cycles - pager_cycles - user_cycles;
249                    self.user_cycles = 0;
250                    self.pager.reset();
251
252                    Risc0Machine::resume(self)?;
253                }
254
255                Risc0Machine::step(&mut emu, self).map_err(|err| {
256                    let result = self.dump_segment(segment_po2, segment_threshold, segment_counter);
257                    if let Err(inner) = result {
258                        err.context(inner)
259                    } else {
260                        err
261                    }
262                })?;
263            }
264
265            Risc0Machine::suspend(self)?;
266
267            let final_cycles = self.segment_cycles().next_power_of_two();
268            let final_po2 = log2_ceil(final_cycles as usize);
269            let segment_threshold = (1 << final_po2) - max_insn_cycles as u32;
270            let (pre_image, pre_digest, post_digest) = self.pager.commit();
271            let req = ComputePartialImageRequest {
272                image: pre_image,
273                page_indexes: self.pager.page_indexes(),
274                input_digest: self.input_digest,
275                output_digest: self.output_digest,
276                read_record: std::mem::take(&mut self.read_record),
277                write_record: std::mem::take(&mut self.write_record),
278                user_cycles: self.user_cycles,
279                pager_cycles: self.pager.cycles,
280                terminate_state: self.terminate_state,
281                segment_threshold,
282                pre_digest,
283                post_digest,
284                po2: final_po2 as u32,
285                index: segment_counter,
286            };
287            if commit_sender.send(req).is_err() {
288                return Err(partial_images_thread.join().unwrap().unwrap_err());
289            }
290
291            let final_cycles = final_cycles as u64;
292            let user_cycles = self.user_cycles as u64;
293            let pager_cycles = self.pager.cycles as u64;
294            self.cycles.total += final_cycles;
295            self.cycles.paging += pager_cycles;
296            self.cycles.reserved += final_cycles - pager_cycles - user_cycles;
297
298            drop(commit_sender);
299
300            partial_images_thread.join().unwrap()?;
301
302            Ok(post_digest)
303        })?;
304
305        let session_claim = Rv32imV2Claim {
306            pre_state: initial_digest,
307            post_state: post_digest,
308            input: self.input_digest,
309            output: self.output_digest,
310            terminate_state: self.terminate_state,
311            shutdown_cycle: None,
312        };
313
314        Ok(ExecutorResult {
315            segments: segment_counter + 1,
316            post_image: self.pager.image.clone(),
317            user_cycles: self.cycles.user,
318            total_cycles: self.cycles.total,
319            paging_cycles: self.cycles.paging,
320            reserved_cycles: self.cycles.reserved,
321            claim: session_claim,
322        })
323    }
324
325    fn dump_segment(
326        &mut self,
327        po2: usize,
328        segment_threshold: u32,
329        index: u64,
330    ) -> anyhow::Result<()> {
331        if let Some(dump_path) = std::env::var_os("RISC0_DUMP_PATH") {
332            tracing::error!(
333                "Execution failure, saving segment to {}:",
334                dump_path.to_string_lossy()
335            );
336
337            let (image, pre_digest, post_digest) = self.pager.commit();
338            let page_indexes = self.pager.page_indexes();
339            let partial_image = compute_partial_image(image, page_indexes);
340
341            let segment = Segment {
342                partial_image,
343                claim: Rv32imV2Claim {
344                    pre_state: pre_digest,
345                    post_state: post_digest,
346                    input: self.input_digest,
347                    output: self.output_digest,
348                    terminate_state: self.terminate_state,
349                    shutdown_cycle: None,
350                },
351                read_record: std::mem::take(&mut self.read_record),
352                write_record: std::mem::take(&mut self.write_record),
353                suspend_cycle: self.user_cycles,
354                paging_cycles: self.pager.cycles,
355                po2: po2 as u32,
356                index,
357                segment_threshold,
358            };
359            tracing::error!("{segment:?}");
360
361            let bytes = segment.encode()?;
362            tracing::error!("serialized {} bytes", bytes.len());
363
364            std::fs::write(dump_path, bytes)?;
365        }
366        Ok(())
367    }
368
369    pub fn take_ecall_metrics(&mut self) -> EcallMetrics {
370        std::mem::take(&mut self.ecall_metrics)
371    }
372
373    fn reset(&mut self) {
374        self.pager.reset();
375        self.terminate_state = None;
376        self.read_record.clear();
377        self.write_record.clear();
378        self.output_digest = None;
379        self.machine_mode = 0;
380        self.user_cycles = 0;
381        self.cycles = SessionCycles::default();
382        self.pc = ByteAddr(0);
383        self.ecall_metrics = EcallMetrics::default();
384    }
385
386    fn segment_cycles(&self) -> u32 {
387        self.user_cycles + self.pager.cycles + RESERVED_CYCLES as u32
388    }
389
390    fn inc_user_cycles(&mut self, count: usize, ecall: Option<EcallKind>) {
391        self.cycles.user += count as u64;
392        self.user_cycles += count as u32;
393        if let Some(kind) = ecall {
394            self.ecall_metrics.0[kind].cycles += count as u64;
395        }
396    }
397
398    #[cold]
399    fn trace(&mut self, event: TraceEvent) -> Result<()> {
400        for trace in self.trace.iter() {
401            trace.borrow_mut().trace_callback(event.clone())?;
402        }
403        Ok(())
404    }
405
406    #[cold]
407    fn trace_pager(&mut self) -> Result<()> {
408        for &event in self.pager.trace_events() {
409            let event = TraceEvent::from(event);
410            for trace in self.trace.iter() {
411                trace.borrow_mut().trace_callback(event.clone())?;
412            }
413        }
414        self.pager.clear_trace_events();
415        Ok(())
416    }
417
418    #[cold]
419    fn trace_instruction(&self, cycle: u64, insn: &Instruction, decoded: &DecodedInstruction) {
420        tracing::trace!(
421            "[{}:{}:{cycle}] {:?}> {:#010x}  {}",
422            self.user_cycles + 1,
423            self.segment_cycles() + 1,
424            self.pc,
425            decoded.insn,
426            disasm(insn, decoded)
427        );
428    }
429}
430
431impl<S: Syscall> Risc0Context for Executor<'_, '_, S> {
432    fn get_pc(&self) -> ByteAddr {
433        self.pc
434    }
435
436    fn set_pc(&mut self, addr: ByteAddr) {
437        self.pc = addr;
438    }
439
440    fn set_user_pc(&mut self, addr: ByteAddr) {
441        self.user_pc = addr;
442    }
443
444    fn get_machine_mode(&self) -> u32 {
445        self.machine_mode
446    }
447
448    fn set_machine_mode(&mut self, mode: u32) {
449        self.machine_mode = mode;
450    }
451
452    fn resume(&mut self) -> Result<()> {
453        let input_words = self.input_digest.as_words().to_vec();
454        for (i, word) in input_words.iter().enumerate() {
455            self.store_u32(GLOBAL_INPUT_ADDR.waddr() + i, *word)?;
456        }
457        Ok(())
458    }
459
460    fn on_insn_start(&mut self, insn: &Instruction, decoded: &DecodedInstruction) -> Result<()> {
461        let cycle = self.cycles.user;
462        if tracing::enabled!(tracing::Level::TRACE) {
463            self.trace_instruction(cycle, insn, decoded);
464        }
465        if !self.trace.is_empty() {
466            self.trace(TraceEvent::InstructionStart {
467                cycle,
468                pc: self.pc.0,
469                insn: decoded.insn,
470            })
471        } else {
472            Ok(())
473        }
474    }
475
476    fn on_insn_end(&mut self, _insn: &Instruction, _decoded: &DecodedInstruction) -> Result<()> {
477        self.inc_user_cycles(1, None);
478        if !self.trace.is_empty() {
479            self.trace_pager()?;
480        }
481        Ok(())
482    }
483
484    fn on_ecall_cycle(
485        &mut self,
486        cur: CycleState,
487        _next: CycleState,
488        _s0: u32,
489        _s1: u32,
490        _s2: u32,
491        kind: EcallKind,
492    ) -> Result<()> {
493        if cur == CycleState::MachineEcall {
494            self.ecall_metrics.0[kind].count += 1;
495        }
496        self.inc_user_cycles(1, Some(kind));
497        if !self.trace.is_empty() {
498            self.trace_pager()?;
499        }
500        Ok(())
501    }
502
503    fn load_u32(&mut self, op: LoadOp, addr: WordAddr) -> Result<u32> {
504        let word = match op {
505            LoadOp::Peek => self.pager.peek(addr)?,
506            LoadOp::Load | LoadOp::Record => self.pager.load(addr)?,
507        };
508        // tracing::trace!("load_mem({:?}) -> {word:#010x}", addr.baddr());
509        Ok(word)
510    }
511
512    fn load_register(&mut self, _op: LoadOp, base: WordAddr, idx: usize) -> Result<u32> {
513        let word = self.pager.load_register(base, idx);
514        // tracing::trace!("load_register({:?}) -> {word:#010x}", addr.baddr());
515        Ok(word)
516    }
517
518    fn store_u32(&mut self, addr: WordAddr, word: u32) -> Result<()> {
519        // tracing::trace!(
520        //     "store_u32({:?}, {word:#010x}), pc: {:?}",
521        //     addr.baddr(),
522        //     self.pc
523        // );
524        if !self.trace.is_empty() {
525            self.trace(TraceEvent::MemorySet {
526                addr: addr.baddr().0,
527                region: word.to_be_bytes().to_vec(),
528            })?;
529        }
530        self.pager.store(addr, word)
531    }
532
533    fn store_register(&mut self, base: WordAddr, idx: usize, word: u32) -> Result<()> {
534        // tracing::trace!("store_register({:?}, {word:#010x})", addr.baddr());
535        if !self.trace.is_empty() {
536            self.trace(TraceEvent::MemorySet {
537                addr: (base + idx).baddr().0,
538                region: word.to_be_bytes().to_vec(),
539            })?;
540        }
541        self.pager.store_register(base, idx, word);
542        Ok(())
543    }
544
545    fn on_terminate(&mut self, a0: u32, a1: u32) -> Result<()> {
546        self.terminate_state = Some(TerminateState {
547            a0: a0.into(),
548            a1: a1.into(),
549        });
550        tracing::debug!("{:?}", self.terminate_state);
551
552        let output: Digest = self
553            .load_region(LoadOp::Peek, GLOBAL_OUTPUT_ADDR, DIGEST_BYTES)?
554            .as_slice()
555            .try_into()?;
556        self.output_digest = Some(output);
557
558        Ok(())
559    }
560
561    fn host_read(&mut self, fd: u32, buf: &mut [u8]) -> Result<u32> {
562        let rlen = self.syscall_handler.host_read(self, fd, buf)?;
563        let slice = &buf[..rlen as usize];
564        self.read_record.push(slice.to_vec());
565        Ok(rlen)
566    }
567
568    fn host_write(&mut self, fd: u32, buf: &[u8]) -> Result<u32> {
569        let rlen = self.syscall_handler.host_write(self, fd, buf)?;
570        self.write_record.push(rlen);
571        Ok(rlen)
572    }
573
574    fn on_sha2_cycle(&mut self, _cur_state: CycleState, _sha2: &Sha2State) {
575        self.inc_user_cycles(1, Some(EcallKind::Sha2));
576    }
577
578    fn on_poseidon2_cycle(&mut self, _cur_state: CycleState, _p2: &Poseidon2State) {
579        self.inc_user_cycles(1, Some(EcallKind::Poseidon2));
580    }
581
582    fn ecall_bigint(&mut self) -> Result<()> {
583        let cycles = bigint::ecall_execute(self)?;
584        self.inc_user_cycles(cycles, Some(EcallKind::BigInt));
585        Ok(())
586    }
587}
588
589impl<S: Syscall> SyscallContext for Executor<'_, '_, S> {
590    fn peek_register(&mut self, idx: usize) -> Result<u32> {
591        if idx >= REG_MAX {
592            bail!("invalid register: x{idx}");
593        }
594        self.load_register(LoadOp::Peek, USER_REGS_ADDR.waddr(), idx)
595    }
596
597    fn peek_u32(&mut self, addr: ByteAddr) -> Result<u32> {
598        // let addr = Self::check_guest_addr(addr)?;
599        self.load_u32(LoadOp::Peek, addr.waddr())
600    }
601
602    fn peek_u8(&mut self, addr: ByteAddr) -> Result<u8> {
603        // let addr = Self::check_guest_addr(addr)?;
604        self.load_u8(LoadOp::Peek, addr)
605    }
606
607    fn peek_region(&mut self, addr: ByteAddr, size: usize) -> Result<Vec<u8>> {
608        // let addr = Self::check_guest_addr(addr)?;
609        self.load_region(LoadOp::Peek, addr, size)
610    }
611
612    fn peek_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
613        self.pager.peek_page(page_idx)
614    }
615
616    fn get_cycle(&self) -> u64 {
617        self.cycles.user
618    }
619
620    fn get_pc(&self) -> u32 {
621        self.user_pc.0
622    }
623}
624
625impl From<EcallMetrics> for Vec<(String, EcallMetric)> {
626    fn from(metrics: EcallMetrics) -> Self {
627        metrics
628            .0
629            .into_iter()
630            .map(|(kind, metric)| (format!("{kind:?}"), metric))
631            .collect()
632    }
633}
634
635impl From<PageTraceEvent> for TraceEvent {
636    fn from(event: PageTraceEvent) -> Self {
637        match event {
638            PageTraceEvent::PageIn { cycles } => TraceEvent::PageIn {
639                cycles: cycles as u64,
640            },
641            PageTraceEvent::PageOut { cycles } => TraceEvent::PageOut {
642                cycles: cycles as u64,
643            },
644        }
645    }
646}