1use std::{cell::RefCell, collections::BTreeSet, rc::Rc};
16
17use anyhow::{bail, Result};
18use enum_map::EnumMap;
19use risc0_binfmt::{ByteAddr, MemoryImage, WordAddr};
20use risc0_zkp::core::{
21 digest::{Digest, DIGEST_BYTES},
22 log2_ceil,
23};
24
25use crate::{
26 trace::{TraceCallback, TraceEvent},
27 Rv32imV2Claim, TerminateState,
28};
29
30use super::{
31 bigint,
32 pager::{compute_partial_image, PageTraceEvent, PagedMemory},
33 platform::*,
34 poseidon2::Poseidon2State,
35 r0vm::{EcallKind, LoadOp, Risc0Context, Risc0Machine},
36 rv32im::{disasm, DecodedInstruction, Emulator, Instruction},
37 segment::Segment,
38 sha2::Sha2State,
39 syscall::Syscall,
40 SyscallContext,
41};
42
43#[derive(Clone, Debug, Default)]
44#[non_exhaustive]
45pub struct EcallMetric {
46 pub count: u64,
47 pub cycles: u64,
48}
49
50#[derive(Default)]
51pub struct EcallMetrics(EnumMap<EcallKind, EcallMetric>);
52
53pub struct Executor<'a, 'b, S: Syscall> {
54 pc: ByteAddr,
55 user_pc: ByteAddr,
56 machine_mode: u32,
57 user_cycles: u32,
58 pager: PagedMemory,
59 terminate_state: Option<TerminateState>,
60 read_record: Vec<Vec<u8>>,
61 write_record: Vec<u32>,
62 syscall_handler: &'a S,
63 input_digest: Digest,
64 output_digest: Option<Digest>,
65 trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
66 cycles: SessionCycles,
67 ecall_metrics: EcallMetrics,
68}
69
70#[non_exhaustive]
71pub struct ExecutorResult {
72 pub segments: u64,
73 pub post_image: MemoryImage,
74 pub user_cycles: u64,
75 pub total_cycles: u64,
76 pub paging_cycles: u64,
77 pub reserved_cycles: u64,
78 pub claim: Rv32imV2Claim,
79}
80
81#[derive(Default)]
82struct SessionCycles {
83 total: u64,
84 user: u64,
85 paging: u64,
86 reserved: u64,
87}
88
89#[non_exhaustive]
90pub struct SimpleSession {
91 pub segments: Vec<Segment>,
92 pub result: ExecutorResult,
93}
94
95struct ComputePartialImageRequest {
96 image: MemoryImage,
97 page_indexes: BTreeSet<u32>,
98
99 input_digest: Digest,
100 output_digest: Option<Digest>,
101 read_record: Vec<Vec<u8>>,
102 write_record: Vec<u32>,
103 user_cycles: u32,
104 pager_cycles: u32,
105 terminate_state: Option<TerminateState>,
106 segment_threshold: u32,
107 pre_digest: Digest,
108 post_digest: Digest,
109 po2: u32,
110 index: u64,
111}
112
113const MAX_OUTSTANDING_SEGMENTS: usize = 5;
115
116fn compute_partial_images(
117 recv: std::sync::mpsc::Receiver<ComputePartialImageRequest>,
118 mut callback: impl FnMut(Segment) -> Result<()>,
119) -> Result<()> {
120 while let Ok(req) = recv.recv() {
121 let partial_image = compute_partial_image(req.image, req.page_indexes);
122 callback(Segment {
123 partial_image,
124 claim: Rv32imV2Claim {
125 pre_state: req.pre_digest,
126 post_state: req.post_digest,
127 input: req.input_digest,
128 output: req.output_digest,
129 terminate_state: req.terminate_state,
130 shutdown_cycle: None,
131 },
132 read_record: req.read_record,
133 write_record: req.write_record,
134 suspend_cycle: req.user_cycles,
135 paging_cycles: req.pager_cycles,
136 po2: req.po2,
137 index: req.index,
138 segment_threshold: req.segment_threshold,
139 })?;
140 }
141 Ok(())
142}
143
144impl<'a, 'b, S: Syscall> Executor<'a, 'b, S> {
145 pub fn new(
146 image: MemoryImage,
147 syscall_handler: &'a S,
148 input_digest: Option<Digest>,
149 trace: Vec<Rc<RefCell<dyn TraceCallback + 'b>>>,
150 ) -> Self {
151 Self {
152 pc: ByteAddr(0),
153 user_pc: ByteAddr(0),
154 machine_mode: 0,
155 user_cycles: 0,
156 pager: PagedMemory::new(image, !trace.is_empty() ),
157 terminate_state: None,
158 read_record: Vec::new(),
159 write_record: Vec::new(),
160 syscall_handler,
161 input_digest: input_digest.unwrap_or_default(),
162 output_digest: None,
163 trace,
164 cycles: SessionCycles::default(),
165 ecall_metrics: EcallMetrics::default(),
166 }
167 }
168
169 pub fn run(
170 &mut self,
171 segment_po2: usize,
172 max_insn_cycles: usize,
173 max_cycles: Option<u64>,
174 callback: impl FnMut(Segment) -> Result<()> + Send,
175 ) -> Result<ExecutorResult> {
176 let segment_limit: u32 = 1 << segment_po2;
177 assert!(max_insn_cycles < segment_limit as usize);
178 let segment_threshold = segment_limit - max_insn_cycles as u32;
179 let mut segment_counter = 0;
180
181 self.reset();
182
183 let mut emu = Emulator::new();
184 Risc0Machine::resume(self)?;
185 let initial_digest = self.pager.image.image_id();
186 tracing::debug!("initial_digest: {initial_digest}");
187
188 let (commit_sender, commit_recv) =
189 std::sync::mpsc::sync_channel(MAX_OUTSTANDING_SEGMENTS - 1);
190
191 let post_digest = std::thread::scope(|scope| {
192 let partial_images_thread =
193 scope.spawn(move || compute_partial_images(commit_recv, callback));
194
195 while self.terminate_state.is_none() {
196 if let Some(max_cycles) = max_cycles {
197 if self.cycles.user >= max_cycles {
198 bail!(
199 "Session limit exceeded: {} >= {max_cycles}",
200 self.cycles.user
201 );
202 }
203 }
204
205 if self.segment_cycles() > segment_threshold {
206 tracing::debug!(
207 "split(phys: {} + pager: {} + reserved: {RESERVED_CYCLES}) = {} >= {segment_threshold}",
208 self.user_cycles,
209 self.pager.cycles,
210 self.segment_cycles()
211 );
212
213 assert!(
214 self.segment_cycles() < segment_limit,
215 "segment limit ({segment_limit}) too small for instruction at pc: {:?}",
216 self.pc
217 );
218 Risc0Machine::suspend(self)?;
219
220 let (pre_image, pre_digest, post_digest) = self.pager.commit();
221
222 let req = ComputePartialImageRequest {
223 image: pre_image,
224 page_indexes: self.pager.page_indexes(),
225 input_digest: self.input_digest,
226 output_digest: self.output_digest,
227 read_record: std::mem::take(&mut self.read_record),
228 write_record: std::mem::take(&mut self.write_record),
229 user_cycles: self.user_cycles,
230 pager_cycles: self.pager.cycles,
231 terminate_state: self.terminate_state,
232 segment_threshold,
233 pre_digest,
234 post_digest,
235 po2: segment_po2 as u32,
236 index: segment_counter,
237 };
238 if commit_sender.send(req).is_err() {
239 return Err(partial_images_thread.join().unwrap().unwrap_err());
240 }
241
242 segment_counter += 1;
243 let total_cycles = 1 << segment_po2;
244 let pager_cycles = self.pager.cycles as u64;
245 let user_cycles = self.user_cycles as u64;
246 self.cycles.total += total_cycles;
247 self.cycles.paging += pager_cycles;
248 self.cycles.reserved += total_cycles - pager_cycles - user_cycles;
249 self.user_cycles = 0;
250 self.pager.reset();
251
252 Risc0Machine::resume(self)?;
253 }
254
255 Risc0Machine::step(&mut emu, self).map_err(|err| {
256 let result = self.dump_segment(segment_po2, segment_threshold, segment_counter);
257 if let Err(inner) = result {
258 err.context(inner)
259 } else {
260 err
261 }
262 })?;
263 }
264
265 Risc0Machine::suspend(self)?;
266
267 let final_cycles = self.segment_cycles().next_power_of_two();
268 let final_po2 = log2_ceil(final_cycles as usize);
269 let (pre_image, pre_digest, post_digest) = self.pager.commit();
270 let req = ComputePartialImageRequest {
271 image: pre_image,
272 page_indexes: self.pager.page_indexes(),
273 input_digest: self.input_digest,
274 output_digest: self.output_digest,
275 read_record: std::mem::take(&mut self.read_record),
276 write_record: std::mem::take(&mut self.write_record),
277 user_cycles: self.user_cycles,
278 pager_cycles: self.pager.cycles,
279 terminate_state: self.terminate_state,
280 segment_threshold: 0, pre_digest,
282 post_digest,
283 po2: final_po2 as u32,
284 index: segment_counter,
285 };
286 if commit_sender.send(req).is_err() {
287 return Err(partial_images_thread.join().unwrap().unwrap_err());
288 }
289
290 let final_cycles = final_cycles as u64;
291 let user_cycles = self.user_cycles as u64;
292 let pager_cycles = self.pager.cycles as u64;
293 self.cycles.total += final_cycles;
294 self.cycles.paging += pager_cycles;
295 self.cycles.reserved += final_cycles - pager_cycles - user_cycles;
296
297 drop(commit_sender);
298
299 partial_images_thread.join().unwrap()?;
300
301 Ok(post_digest)
302 })?;
303
304 let session_claim = Rv32imV2Claim {
305 pre_state: initial_digest,
306 post_state: post_digest,
307 input: self.input_digest,
308 output: self.output_digest,
309 terminate_state: self.terminate_state,
310 shutdown_cycle: None,
311 };
312
313 Ok(ExecutorResult {
314 segments: segment_counter + 1,
315 post_image: self.pager.image.clone(),
316 user_cycles: self.cycles.user,
317 total_cycles: self.cycles.total,
318 paging_cycles: self.cycles.paging,
319 reserved_cycles: self.cycles.reserved,
320 claim: session_claim,
321 })
322 }
323
324 fn dump_segment(
325 &mut self,
326 po2: usize,
327 segment_threshold: u32,
328 index: u64,
329 ) -> anyhow::Result<()> {
330 if let Some(dump_path) = std::env::var_os("RISC0_DUMP_PATH") {
331 tracing::error!(
332 "Execution failure, saving segment to {}:",
333 dump_path.to_string_lossy()
334 );
335
336 let (image, pre_digest, post_digest) = self.pager.commit();
337 let page_indexes = self.pager.page_indexes();
338 let partial_image = compute_partial_image(image, page_indexes);
339
340 let segment = Segment {
341 partial_image,
342 claim: Rv32imV2Claim {
343 pre_state: pre_digest,
344 post_state: post_digest,
345 input: self.input_digest,
346 output: self.output_digest,
347 terminate_state: self.terminate_state,
348 shutdown_cycle: None,
349 },
350 read_record: std::mem::take(&mut self.read_record),
351 write_record: std::mem::take(&mut self.write_record),
352 suspend_cycle: self.user_cycles,
353 paging_cycles: self.pager.cycles,
354 po2: po2 as u32,
355 index,
356 segment_threshold,
357 };
358 tracing::error!("{segment:?}");
359
360 let bytes = segment.encode()?;
361 tracing::error!("serialized {} bytes", bytes.len());
362
363 std::fs::write(dump_path, bytes)?;
364 }
365 Ok(())
366 }
367
368 pub fn take_ecall_metrics(&mut self) -> EcallMetrics {
369 std::mem::take(&mut self.ecall_metrics)
370 }
371
372 fn reset(&mut self) {
373 self.pager.reset();
374 self.terminate_state = None;
375 self.read_record.clear();
376 self.write_record.clear();
377 self.output_digest = None;
378 self.machine_mode = 0;
379 self.user_cycles = 0;
380 self.cycles = SessionCycles::default();
381 self.pc = ByteAddr(0);
382 self.ecall_metrics = EcallMetrics::default();
383 }
384
385 fn segment_cycles(&self) -> u32 {
386 self.user_cycles + self.pager.cycles + RESERVED_CYCLES as u32
387 }
388
389 fn inc_user_cycles(&mut self, count: usize, ecall: Option<EcallKind>) {
390 self.cycles.user += count as u64;
391 self.user_cycles += count as u32;
392 if let Some(kind) = ecall {
393 self.ecall_metrics.0[kind].cycles += count as u64;
394 }
395 }
396
397 #[cold]
398 fn trace(&mut self, event: TraceEvent) -> Result<()> {
399 for trace in self.trace.iter() {
400 trace.borrow_mut().trace_callback(event.clone())?;
401 }
402 Ok(())
403 }
404
405 #[cold]
406 fn trace_pager(&mut self) -> Result<()> {
407 for &event in self.pager.trace_events() {
408 let event = TraceEvent::from(event);
409 for trace in self.trace.iter() {
410 trace.borrow_mut().trace_callback(event.clone())?;
411 }
412 }
413 self.pager.clear_trace_events();
414 Ok(())
415 }
416
417 #[cold]
418 fn trace_instruction(&self, cycle: u64, insn: &Instruction, decoded: &DecodedInstruction) {
419 tracing::trace!(
420 "[{}:{}:{cycle}] {:?}> {:#010x} {}",
421 self.user_cycles + 1,
422 self.segment_cycles() + 1,
423 self.pc,
424 decoded.insn,
425 disasm(insn, decoded)
426 );
427 }
428}
429
430impl<S: Syscall> Risc0Context for Executor<'_, '_, S> {
431 fn get_pc(&self) -> ByteAddr {
432 self.pc
433 }
434
435 fn set_pc(&mut self, addr: ByteAddr) {
436 self.pc = addr;
437 }
438
439 fn set_user_pc(&mut self, addr: ByteAddr) {
440 self.user_pc = addr;
441 }
442
443 fn get_machine_mode(&self) -> u32 {
444 self.machine_mode
445 }
446
447 fn set_machine_mode(&mut self, mode: u32) {
448 self.machine_mode = mode;
449 }
450
451 fn resume(&mut self) -> Result<()> {
452 let input_words = self.input_digest.as_words().to_vec();
453 for (i, word) in input_words.iter().enumerate() {
454 self.store_u32(GLOBAL_INPUT_ADDR.waddr() + i, *word)?;
455 }
456 Ok(())
457 }
458
459 fn on_insn_start(&mut self, insn: &Instruction, decoded: &DecodedInstruction) -> Result<()> {
460 let cycle = self.cycles.user;
461 if tracing::enabled!(tracing::Level::TRACE) {
462 self.trace_instruction(cycle, insn, decoded);
463 }
464 if !self.trace.is_empty() {
465 self.trace(TraceEvent::InstructionStart {
466 cycle,
467 pc: self.pc.0,
468 insn: decoded.insn,
469 })
470 } else {
471 Ok(())
472 }
473 }
474
475 fn on_insn_end(&mut self, _insn: &Instruction, _decoded: &DecodedInstruction) -> Result<()> {
476 self.inc_user_cycles(1, None);
477 if !self.trace.is_empty() {
478 self.trace_pager()?;
479 }
480 Ok(())
481 }
482
483 fn on_ecall_cycle(
484 &mut self,
485 cur: CycleState,
486 _next: CycleState,
487 _s0: u32,
488 _s1: u32,
489 _s2: u32,
490 kind: EcallKind,
491 ) -> Result<()> {
492 if cur == CycleState::MachineEcall {
493 self.ecall_metrics.0[kind].count += 1;
494 }
495 self.inc_user_cycles(1, Some(kind));
496 if !self.trace.is_empty() {
497 self.trace_pager()?;
498 }
499 Ok(())
500 }
501
502 fn load_u32(&mut self, op: LoadOp, addr: WordAddr) -> Result<u32> {
503 let word = match op {
504 LoadOp::Peek => self.pager.peek(addr)?,
505 LoadOp::Load | LoadOp::Record => self.pager.load(addr)?,
506 };
507 Ok(word)
509 }
510
511 fn load_register(&mut self, _op: LoadOp, base: WordAddr, idx: usize) -> Result<u32> {
512 let word = self.pager.load_register(base, idx);
513 Ok(word)
515 }
516
517 fn store_u32(&mut self, addr: WordAddr, word: u32) -> Result<()> {
518 if !self.trace.is_empty() {
524 self.trace(TraceEvent::MemorySet {
525 addr: addr.baddr().0,
526 region: word.to_be_bytes().to_vec(),
527 })?;
528 }
529 self.pager.store(addr, word)
530 }
531
532 fn store_register(&mut self, base: WordAddr, idx: usize, word: u32) -> Result<()> {
533 if !self.trace.is_empty() {
535 self.trace(TraceEvent::MemorySet {
536 addr: (base + idx).baddr().0,
537 region: word.to_be_bytes().to_vec(),
538 })?;
539 }
540 self.pager.store_register(base, idx, word);
541 Ok(())
542 }
543
544 fn on_terminate(&mut self, a0: u32, a1: u32) -> Result<()> {
545 self.terminate_state = Some(TerminateState {
546 a0: a0.into(),
547 a1: a1.into(),
548 });
549 tracing::debug!("{:?}", self.terminate_state);
550
551 let output: Digest = self
552 .load_region(LoadOp::Peek, GLOBAL_OUTPUT_ADDR, DIGEST_BYTES)?
553 .as_slice()
554 .try_into()?;
555 self.output_digest = Some(output);
556
557 Ok(())
558 }
559
560 fn host_read(&mut self, fd: u32, buf: &mut [u8]) -> Result<u32> {
561 let rlen = self.syscall_handler.host_read(self, fd, buf)?;
562 let slice = &buf[..rlen as usize];
563 self.read_record.push(slice.to_vec());
564 Ok(rlen)
565 }
566
567 fn host_write(&mut self, fd: u32, buf: &[u8]) -> Result<u32> {
568 let rlen = self.syscall_handler.host_write(self, fd, buf)?;
569 self.write_record.push(rlen);
570 Ok(rlen)
571 }
572
573 fn on_sha2_cycle(&mut self, _cur_state: CycleState, _sha2: &Sha2State) {
574 self.inc_user_cycles(1, Some(EcallKind::Sha2));
575 }
576
577 fn on_poseidon2_cycle(&mut self, _cur_state: CycleState, _p2: &Poseidon2State) {
578 self.inc_user_cycles(1, Some(EcallKind::Poseidon2));
579 }
580
581 fn ecall_bigint(&mut self) -> Result<()> {
582 let cycles = bigint::ecall_execute(self)?;
583 self.inc_user_cycles(cycles, Some(EcallKind::BigInt));
584 Ok(())
585 }
586}
587
588impl<S: Syscall> SyscallContext for Executor<'_, '_, S> {
589 fn peek_register(&mut self, idx: usize) -> Result<u32> {
590 if idx >= REG_MAX {
591 bail!("invalid register: x{idx}");
592 }
593 self.load_register(LoadOp::Peek, USER_REGS_ADDR.waddr(), idx)
594 }
595
596 fn peek_u32(&mut self, addr: ByteAddr) -> Result<u32> {
597 self.load_u32(LoadOp::Peek, addr.waddr())
599 }
600
601 fn peek_u8(&mut self, addr: ByteAddr) -> Result<u8> {
602 self.load_u8(LoadOp::Peek, addr)
604 }
605
606 fn peek_region(&mut self, addr: ByteAddr, size: usize) -> Result<Vec<u8>> {
607 self.load_region(LoadOp::Peek, addr, size)
609 }
610
611 fn peek_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
612 self.pager.peek_page(page_idx)
613 }
614
615 fn get_cycle(&self) -> u64 {
616 self.cycles.user
617 }
618
619 fn get_pc(&self) -> u32 {
620 self.user_pc.0
621 }
622}
623
624impl From<EcallMetrics> for Vec<(String, EcallMetric)> {
625 fn from(metrics: EcallMetrics) -> Self {
626 metrics
627 .0
628 .into_iter()
629 .map(|(kind, metric)| (format!("{kind:?}"), metric))
630 .collect()
631 }
632}
633
634impl From<PageTraceEvent> for TraceEvent {
635 fn from(event: PageTraceEvent) -> Self {
636 match event {
637 PageTraceEvent::PageIn { cycles } => TraceEvent::PageIn {
638 cycles: cycles as u64,
639 },
640 PageTraceEvent::PageOut { cycles } => TraceEvent::PageOut {
641 cycles: cycles as u64,
642 },
643 }
644 }
645}