1use std::{cell::RefCell, collections::BTreeSet, rc::Rc};
16
17use anyhow::{bail, Result};
18use enum_map::EnumMap;
19use risc0_binfmt::{ByteAddr, MemoryImage, WordAddr};
20use risc0_zkp::core::{
21 digest::{Digest, DIGEST_BYTES},
22 log2_ceil,
23};
24
25use crate::{
26 trace::{TraceCallback, TraceEvent},
27 Rv32imV2Claim, TerminateState,
28};
29
30use super::{
31 bigint,
32 pager::{compute_partial_image, PageTraceEvent, PagedMemory},
33 platform::*,
34 poseidon2::Poseidon2State,
35 r0vm::{EcallKind, LoadOp, Risc0Context, Risc0Machine},
36 rv32im::{disasm, DecodedInstruction, Emulator, Instruction},
37 segment::Segment,
38 sha2::Sha2State,
39 syscall::Syscall,
40 SyscallContext,
41};
42
43#[derive(Clone, Debug, Default)]
44#[non_exhaustive]
45pub struct EcallMetric {
46 pub count: u64,
47 pub cycles: u64,
48}
49
50#[derive(Default)]
51pub struct EcallMetrics(EnumMap<EcallKind, EcallMetric>);
52
53pub struct Executor<'a, 'b, S: Syscall> {
54 pc: ByteAddr,
55 user_pc: ByteAddr,
56 machine_mode: u32,
57 user_cycles: u32,
58 pager: PagedMemory,
59 terminate_state: Option<TerminateState>,
60 read_record: Vec<Vec<u8>>,
61 write_record: Vec<u32>,
62 syscall_handler: &'a S,
63 input_digest: Digest,
64 output_digest: Option<Digest>,
65 trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
66 cycles: SessionCycles,
67 ecall_metrics: EcallMetrics,
68}
69
70#[non_exhaustive]
71pub struct ExecutorResult {
72 pub segments: u64,
73 pub post_image: MemoryImage,
74 pub user_cycles: u64,
75 pub total_cycles: u64,
76 pub paging_cycles: u64,
77 pub reserved_cycles: u64,
78 pub claim: Rv32imV2Claim,
79}
80
81#[derive(Default)]
82struct SessionCycles {
83 total: u64,
84 user: u64,
85 paging: u64,
86 reserved: u64,
87}
88
89#[non_exhaustive]
90pub struct SimpleSession {
91 pub segments: Vec<Segment>,
92 pub result: ExecutorResult,
93}
94
95struct ComputePartialImageRequest {
96 image: MemoryImage,
97 page_indexes: BTreeSet<u32>,
98
99 input_digest: Digest,
100 output_digest: Option<Digest>,
101 read_record: Vec<Vec<u8>>,
102 write_record: Vec<u32>,
103 user_cycles: u32,
104 pager_cycles: u32,
105 terminate_state: Option<TerminateState>,
106 segment_threshold: u32,
107 pre_digest: Digest,
108 post_digest: Digest,
109 po2: u32,
110 index: u64,
111}
112
113const MAX_OUTSTANDING_SEGMENTS: usize = 5;
115
116fn compute_partial_images(
117 recv: std::sync::mpsc::Receiver<ComputePartialImageRequest>,
118 mut callback: impl FnMut(Segment) -> Result<()>,
119) -> Result<()> {
120 while let Ok(req) = recv.recv() {
121 let partial_image = compute_partial_image(req.image, req.page_indexes);
122 callback(Segment {
123 partial_image,
124 claim: Rv32imV2Claim {
125 pre_state: req.pre_digest,
126 post_state: req.post_digest,
127 input: req.input_digest,
128 output: req.output_digest,
129 terminate_state: req.terminate_state,
130 shutdown_cycle: None,
131 },
132 read_record: req.read_record,
133 write_record: req.write_record,
134 suspend_cycle: req.user_cycles,
135 paging_cycles: req.pager_cycles,
136 po2: req.po2,
137 index: req.index,
138 segment_threshold: req.segment_threshold,
139 })?;
140 }
141 Ok(())
142}
143
144impl<'a, 'b, S: Syscall> Executor<'a, 'b, S> {
145 pub fn new(
146 image: MemoryImage,
147 syscall_handler: &'a S,
148 input_digest: Option<Digest>,
149 trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
150 ) -> Self {
151 Self {
152 pc: ByteAddr(0),
153 user_pc: ByteAddr(0),
154 machine_mode: 0,
155 user_cycles: 0,
156 pager: PagedMemory::new(image, !trace.is_empty() ),
157 terminate_state: None,
158 read_record: Vec::new(),
159 write_record: Vec::new(),
160 syscall_handler,
161 input_digest: input_digest.unwrap_or_default(),
162 output_digest: None,
163 trace,
164 cycles: SessionCycles::default(),
165 ecall_metrics: EcallMetrics::default(),
166 }
167 }
168
169 pub fn run(
170 &mut self,
171 segment_po2: usize,
172 max_insn_cycles: usize,
173 max_cycles: Option<u64>,
174 callback: impl FnMut(Segment) -> Result<()> + Send,
175 ) -> Result<ExecutorResult> {
176 let segment_limit: u32 = 1 << segment_po2;
177 assert!(max_insn_cycles < segment_limit as usize);
178 let segment_threshold = segment_limit - max_insn_cycles as u32;
179 let mut segment_counter = 0;
180
181 self.reset();
182
183 let mut emu = Emulator::new();
184 Risc0Machine::resume(self)?;
185 let initial_digest = self.pager.image.image_id();
186 tracing::debug!("initial_digest: {initial_digest}");
187
188 let (commit_sender, commit_recv) =
189 std::sync::mpsc::sync_channel(MAX_OUTSTANDING_SEGMENTS - 1);
190
191 let post_digest = std::thread::scope(|scope| {
192 let partial_images_thread =
193 scope.spawn(move || compute_partial_images(commit_recv, callback));
194
195 while self.terminate_state.is_none() {
196 if let Some(max_cycles) = max_cycles {
197 if self.cycles.user >= max_cycles {
198 bail!(
199 "Session limit exceeded: {} >= {max_cycles}",
200 self.cycles.user
201 );
202 }
203 }
204
205 if self.segment_cycles() >= segment_threshold {
206 tracing::debug!(
207 "split(phys: {} + pager: {} + reserved: {LOOKUP_TABLE_CYCLES}) = {} >= {segment_threshold}",
208 self.user_cycles,
209 self.pager.cycles,
210 self.segment_cycles()
211 );
212
213 assert!(
214 self.segment_cycles() < segment_limit,
215 "segment limit ({segment_limit}) too small for instruction at pc: {:?}",
216 self.pc
217 );
218 Risc0Machine::suspend(self)?;
219
220 let (pre_image, pre_digest, post_digest) = self.pager.commit();
221
222 let req = ComputePartialImageRequest {
223 image: pre_image,
224 page_indexes: self.pager.page_indexes(),
225 input_digest: self.input_digest,
226 output_digest: self.output_digest,
227 read_record: std::mem::take(&mut self.read_record),
228 write_record: std::mem::take(&mut self.write_record),
229 user_cycles: self.user_cycles,
230 pager_cycles: self.pager.cycles,
231 terminate_state: self.terminate_state,
232 segment_threshold,
233 pre_digest,
234 post_digest,
235 po2: segment_po2 as u32,
236 index: segment_counter,
237 };
238 if commit_sender.send(req).is_err() {
239 return Err(partial_images_thread.join().unwrap().unwrap_err());
240 }
241
242 segment_counter += 1;
243 let total_cycles = 1 << segment_po2;
244 let pager_cycles = self.pager.cycles as u64;
245 let user_cycles = self.user_cycles as u64;
246 self.cycles.total += total_cycles;
247 self.cycles.paging += pager_cycles;
248 self.cycles.reserved += total_cycles - pager_cycles - user_cycles;
249 self.user_cycles = 0;
250 self.pager.reset();
251
252 Risc0Machine::resume(self)?;
253 }
254
255 Risc0Machine::step(&mut emu, self).map_err(|err| {
256 let result = self.dump_segment(segment_po2, segment_threshold, segment_counter);
257 if let Err(inner) = result {
258 err.context(inner)
259 } else {
260 err
261 }
262 })?;
263 }
264
265 Risc0Machine::suspend(self)?;
266
267 let final_cycles = self.segment_cycles().next_power_of_two();
268 let final_po2 = log2_ceil(final_cycles as usize);
269 let segment_threshold = (1 << final_po2) - max_insn_cycles as u32;
270 let (pre_image, pre_digest, post_digest) = self.pager.commit();
271 let req = ComputePartialImageRequest {
272 image: pre_image,
273 page_indexes: self.pager.page_indexes(),
274 input_digest: self.input_digest,
275 output_digest: self.output_digest,
276 read_record: std::mem::take(&mut self.read_record),
277 write_record: std::mem::take(&mut self.write_record),
278 user_cycles: self.user_cycles,
279 pager_cycles: self.pager.cycles,
280 terminate_state: self.terminate_state,
281 segment_threshold,
282 pre_digest,
283 post_digest,
284 po2: final_po2 as u32,
285 index: segment_counter,
286 };
287 if commit_sender.send(req).is_err() {
288 return Err(partial_images_thread.join().unwrap().unwrap_err());
289 }
290
291 let final_cycles = final_cycles as u64;
292 let user_cycles = self.user_cycles as u64;
293 let pager_cycles = self.pager.cycles as u64;
294 self.cycles.total += final_cycles;
295 self.cycles.paging += pager_cycles;
296 self.cycles.reserved += final_cycles - pager_cycles - user_cycles;
297
298 drop(commit_sender);
299
300 partial_images_thread.join().unwrap()?;
301
302 Ok(post_digest)
303 })?;
304
305 let session_claim = Rv32imV2Claim {
306 pre_state: initial_digest,
307 post_state: post_digest,
308 input: self.input_digest,
309 output: self.output_digest,
310 terminate_state: self.terminate_state,
311 shutdown_cycle: None,
312 };
313
314 Ok(ExecutorResult {
315 segments: segment_counter + 1,
316 post_image: self.pager.image.clone(),
317 user_cycles: self.cycles.user,
318 total_cycles: self.cycles.total,
319 paging_cycles: self.cycles.paging,
320 reserved_cycles: self.cycles.reserved,
321 claim: session_claim,
322 })
323 }
324
325 fn dump_segment(
326 &mut self,
327 po2: usize,
328 segment_threshold: u32,
329 index: u64,
330 ) -> anyhow::Result<()> {
331 if let Some(dump_path) = std::env::var_os("RISC0_DUMP_PATH") {
332 tracing::error!(
333 "Execution failure, saving segment to {}:",
334 dump_path.to_string_lossy()
335 );
336
337 let (image, pre_digest, post_digest) = self.pager.commit();
338 let page_indexes = self.pager.page_indexes();
339 let partial_image = compute_partial_image(image, page_indexes);
340
341 let segment = Segment {
342 partial_image,
343 claim: Rv32imV2Claim {
344 pre_state: pre_digest,
345 post_state: post_digest,
346 input: self.input_digest,
347 output: self.output_digest,
348 terminate_state: self.terminate_state,
349 shutdown_cycle: None,
350 },
351 read_record: std::mem::take(&mut self.read_record),
352 write_record: std::mem::take(&mut self.write_record),
353 suspend_cycle: self.user_cycles,
354 paging_cycles: self.pager.cycles,
355 po2: po2 as u32,
356 index,
357 segment_threshold,
358 };
359 tracing::error!("{segment:?}");
360
361 let bytes = segment.encode()?;
362 tracing::error!("serialized {} bytes", bytes.len());
363
364 std::fs::write(dump_path, bytes)?;
365 }
366 Ok(())
367 }
368
369 pub fn take_ecall_metrics(&mut self) -> EcallMetrics {
370 std::mem::take(&mut self.ecall_metrics)
371 }
372
373 fn reset(&mut self) {
374 self.pager.reset();
375 self.terminate_state = None;
376 self.read_record.clear();
377 self.write_record.clear();
378 self.output_digest = None;
379 self.machine_mode = 0;
380 self.user_cycles = 0;
381 self.cycles = SessionCycles::default();
382 self.pc = ByteAddr(0);
383 self.ecall_metrics = EcallMetrics::default();
384 }
385
386 fn segment_cycles(&self) -> u32 {
387 self.user_cycles + self.pager.cycles + LOOKUP_TABLE_CYCLES as u32
388 }
389
390 fn inc_user_cycles(&mut self, count: usize, ecall: Option<EcallKind>) {
391 self.cycles.user += count as u64;
392 self.user_cycles += count as u32;
393 if let Some(kind) = ecall {
394 self.ecall_metrics.0[kind].cycles += count as u64;
395 }
396 }
397
398 #[cold]
399 fn trace(&mut self, event: TraceEvent) -> Result<()> {
400 for trace in self.trace.iter() {
401 trace.borrow_mut().trace_callback(event.clone())?;
402 }
403 Ok(())
404 }
405
406 #[cold]
407 fn trace_pager(&mut self) -> Result<()> {
408 for &event in self.pager.trace_events() {
409 let event = TraceEvent::from(event);
410 for trace in self.trace.iter() {
411 trace.borrow_mut().trace_callback(event.clone())?;
412 }
413 }
414 self.pager.clear_trace_events();
415 Ok(())
416 }
417
418 #[cold]
419 fn trace_instruction(&self, cycle: u64, insn: &Instruction, decoded: &DecodedInstruction) {
420 tracing::trace!(
421 "[{}:{}:{cycle}] {:?}> {:#010x} {}",
422 self.user_cycles + 1,
423 self.segment_cycles() + 1,
424 self.pc,
425 decoded.insn,
426 disasm(insn, decoded)
427 );
428 }
429}
430
431impl<S: Syscall> Risc0Context for Executor<'_, '_, S> {
432 fn get_pc(&self) -> ByteAddr {
433 self.pc
434 }
435
436 fn set_pc(&mut self, addr: ByteAddr) {
437 self.pc = addr;
438 }
439
440 fn set_user_pc(&mut self, addr: ByteAddr) {
441 self.user_pc = addr;
442 }
443
444 fn get_machine_mode(&self) -> u32 {
445 self.machine_mode
446 }
447
448 fn set_machine_mode(&mut self, mode: u32) {
449 self.machine_mode = mode;
450 }
451
452 fn resume(&mut self) -> Result<()> {
453 let input_words = self.input_digest.as_words().to_vec();
454 for (i, word) in input_words.iter().enumerate() {
455 self.store_u32(GLOBAL_INPUT_ADDR.waddr() + i, *word)?;
456 }
457 Ok(())
458 }
459
460 fn on_insn_start(&mut self, insn: &Instruction, decoded: &DecodedInstruction) -> Result<()> {
461 let cycle = self.cycles.user;
462 if tracing::enabled!(tracing::Level::TRACE) {
463 self.trace_instruction(cycle, insn, decoded);
464 }
465 if !self.trace.is_empty() {
466 self.trace(TraceEvent::InstructionStart {
467 cycle,
468 pc: self.pc.0,
469 insn: decoded.insn,
470 })
471 } else {
472 Ok(())
473 }
474 }
475
476 fn on_insn_end(&mut self, _insn: &Instruction, _decoded: &DecodedInstruction) -> Result<()> {
477 self.inc_user_cycles(1, None);
478 if !self.trace.is_empty() {
479 self.trace_pager()?;
480 }
481 Ok(())
482 }
483
484 fn on_ecall_cycle(
485 &mut self,
486 cur: CycleState,
487 _next: CycleState,
488 _s0: u32,
489 _s1: u32,
490 _s2: u32,
491 kind: EcallKind,
492 ) -> Result<()> {
493 if cur == CycleState::MachineEcall {
494 self.ecall_metrics.0[kind].count += 1;
495 }
496 self.inc_user_cycles(1, Some(kind));
497 if !self.trace.is_empty() {
498 self.trace_pager()?;
499 }
500 Ok(())
501 }
502
503 fn load_u32(&mut self, op: LoadOp, addr: WordAddr) -> Result<u32> {
504 let word = match op {
505 LoadOp::Peek => self.pager.peek(addr)?,
506 LoadOp::Load | LoadOp::Record => self.pager.load(addr)?,
507 };
508 Ok(word)
510 }
511
512 fn load_register(&mut self, _op: LoadOp, base: WordAddr, idx: usize) -> Result<u32> {
513 let word = self.pager.load_register(base, idx);
514 Ok(word)
516 }
517
518 fn store_u32(&mut self, addr: WordAddr, word: u32) -> Result<()> {
519 if !self.trace.is_empty() {
525 self.trace(TraceEvent::MemorySet {
526 addr: addr.baddr().0,
527 region: word.to_be_bytes().to_vec(),
528 })?;
529 }
530 self.pager.store(addr, word)
531 }
532
533 fn store_register(&mut self, base: WordAddr, idx: usize, word: u32) -> Result<()> {
534 if !self.trace.is_empty() {
536 self.trace(TraceEvent::MemorySet {
537 addr: (base + idx).baddr().0,
538 region: word.to_be_bytes().to_vec(),
539 })?;
540 }
541 self.pager.store_register(base, idx, word);
542 Ok(())
543 }
544
545 fn on_terminate(&mut self, a0: u32, a1: u32) -> Result<()> {
546 self.terminate_state = Some(TerminateState {
547 a0: a0.into(),
548 a1: a1.into(),
549 });
550 tracing::debug!("{:?}", self.terminate_state);
551
552 let output: Digest = self
553 .load_region(LoadOp::Peek, GLOBAL_OUTPUT_ADDR, DIGEST_BYTES)?
554 .as_slice()
555 .try_into()?;
556 self.output_digest = Some(output);
557
558 Ok(())
559 }
560
561 fn host_read(&mut self, fd: u32, buf: &mut [u8]) -> Result<u32> {
562 let rlen = self.syscall_handler.host_read(self, fd, buf)?;
563 let slice = &buf[..rlen as usize];
564 self.read_record.push(slice.to_vec());
565 Ok(rlen)
566 }
567
568 fn host_write(&mut self, fd: u32, buf: &[u8]) -> Result<u32> {
569 let rlen = self.syscall_handler.host_write(self, fd, buf)?;
570 self.write_record.push(rlen);
571 Ok(rlen)
572 }
573
574 fn on_sha2_cycle(&mut self, _cur_state: CycleState, _sha2: &Sha2State) {
575 self.inc_user_cycles(1, Some(EcallKind::Sha2));
576 }
577
578 fn on_poseidon2_cycle(&mut self, _cur_state: CycleState, _p2: &Poseidon2State) {
579 self.inc_user_cycles(1, Some(EcallKind::Poseidon2));
580 }
581
582 fn ecall_bigint(&mut self) -> Result<()> {
583 let cycles = bigint::ecall_execute(self)?;
584 self.inc_user_cycles(cycles, Some(EcallKind::BigInt));
585 Ok(())
586 }
587}
588
589impl<S: Syscall> SyscallContext for Executor<'_, '_, S> {
590 fn peek_register(&mut self, idx: usize) -> Result<u32> {
591 if idx >= REG_MAX {
592 bail!("invalid register: x{idx}");
593 }
594 self.load_register(LoadOp::Peek, USER_REGS_ADDR.waddr(), idx)
595 }
596
597 fn peek_u32(&mut self, addr: ByteAddr) -> Result<u32> {
598 self.load_u32(LoadOp::Peek, addr.waddr())
600 }
601
602 fn peek_u8(&mut self, addr: ByteAddr) -> Result<u8> {
603 self.load_u8(LoadOp::Peek, addr)
605 }
606
607 fn peek_region(&mut self, addr: ByteAddr, size: usize) -> Result<Vec<u8>> {
608 self.load_region(LoadOp::Peek, addr, size)
610 }
611
612 fn peek_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
613 self.pager.peek_page(page_idx)
614 }
615
616 fn get_cycle(&self) -> u64 {
617 self.cycles.user
618 }
619
620 fn get_pc(&self) -> u32 {
621 self.user_pc.0
622 }
623}
624
625impl From<EcallMetrics> for Vec<(String, EcallMetric)> {
626 fn from(metrics: EcallMetrics) -> Self {
627 metrics
628 .0
629 .into_iter()
630 .map(|(kind, metric)| (format!("{kind:?}"), metric))
631 .collect()
632 }
633}
634
635impl From<PageTraceEvent> for TraceEvent {
636 fn from(event: PageTraceEvent) -> Self {
637 match event {
638 PageTraceEvent::PageIn { cycles } => TraceEvent::PageIn {
639 cycles: cycles as u64,
640 },
641 PageTraceEvent::PageOut { cycles } => TraceEvent::PageOut {
642 cycles: cycles as u64,
643 },
644 }
645 }
646}