risc0_zkvm/host/server/exec/
executor.rs1use std::{
16 cell::{Cell, RefCell},
17 rc::Rc,
18 sync::Arc,
19 time::Instant,
20};
21
22use anyhow::{bail, Context as _, Result};
23use risc0_binfmt::{
24 AbiKind, ByteAddr, ExitCode, MemoryImage, Program, ProgramBinary, ProgramBinaryHeader,
25 SystemState,
26};
27use risc0_circuit_rv32im::{
28 execute::{
29 platform::WORD_SIZE, CycleLimit, Executor, Syscall as CircuitSyscall,
30 SyscallContext as CircuitSyscallContext, DEFAULT_SEGMENT_LIMIT_PO2,
31 },
32 MAX_INSN_CYCLES, MAX_INSN_CYCLES_LOWER_PO2,
33};
34use risc0_core::scope;
35use risc0_zkp::core::digest::Digest;
36use risc0_zkvm_platform::{align_up, fileno};
37use tempfile::tempdir;
38
39use crate::{
40 claim::receipt::exit_code_from_terminate_state,
41 host::{client::env::SegmentPath, server::session::Session},
42 Assumptions, ExecutorEnv, FileSegmentRef, Output, Segment, SegmentRef,
43};
44
45use super::{
46 profiler::{self, Profiler},
47 syscall::{SyscallContext, SyscallTable},
48 Journal,
49};
50
51pub struct ExecutorImpl<'a> {
55 pub(crate) env: ExecutorEnv<'a>,
56 pub(crate) image: MemoryImage,
57 pub(crate) syscall_table: SyscallTable<'a>,
58 pub(crate) elf: Option<Vec<u8>>,
59 profiler: Option<Rc<RefCell<Profiler>>>,
60 return_cache: Cell<(u32, u32)>,
61}
62
63fn check_program_version(header: &ProgramBinaryHeader) -> Result<()> {
65 let abi_kind = header.abi_kind;
66 let abi_version = &header.abi_version;
67
68 if abi_kind != AbiKind::V1Compat {
69 bail!("ProgramBinary abi_kind mismatch {abi_kind:?} != AbiKind::V1Compat");
70 }
71 if !semver::VersionReq::parse("^1.0.0")
72 .unwrap()
73 .matches(abi_version)
74 {
75 bail!("ProgramBinary abi_version mismatch {abi_version} doesn't match ^1.0.0");
76 }
77
78 Ok(())
79}
80
81impl<'a> ExecutorImpl<'a> {
82 pub fn new(env: ExecutorEnv<'a>, image: MemoryImage) -> Result<Self> {
90 Self::with_details(env, None, image, None)
91 }
92
93 pub fn from_elf(mut env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
97 let binary = ProgramBinary::decode(elf)?;
98 check_program_version(&binary.header)?;
99
100 let image = binary.to_image()?;
101
102 let profiler = if env.pprof_out.is_some() {
103 let profiler = Rc::new(RefCell::new(Profiler::new(
104 &binary,
105 None,
106 profiler::read_enable_inline_functions_env_var(),
107 )?));
108 env.trace.push(profiler.clone());
109 Some(profiler)
110 } else {
111 None
112 };
113
114 Self::with_details(env, Some(binary.user_elf), image, profiler)
115 }
116
117 #[allow(dead_code)]
119 pub(crate) fn from_kernel_elf(env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
120 let kernel = Program::load_elf(elf, u32::MAX)?;
121 let image = MemoryImage::new_kernel(kernel);
122 Self::with_details(env, Some(elf), image, None)
123 }
124
125 fn with_details(
126 env: ExecutorEnv<'a>,
127 elf: Option<&[u8]>,
128 image: MemoryImage,
129 profiler: Option<Rc<RefCell<Profiler>>>,
130 ) -> Result<Self> {
131 let syscall_table = SyscallTable::from_env(&env);
132 Ok(Self {
133 env,
134 elf: elf.map(|e| e.to_owned()),
135 image,
136 syscall_table,
137 profiler,
138 return_cache: Cell::new((0, 0)),
139 })
140 }
141
142 pub fn run(&mut self) -> Result<Session> {
145 if self.env.segment_path.is_none() {
146 self.env.segment_path = Some(SegmentPath::TempDir(Arc::new(tempdir()?)));
147 }
148
149 let path = self.env.segment_path.clone().unwrap();
150 self.run_with_callback(|segment| Ok(Box::new(FileSegmentRef::new(&segment, &path)?)))
151 }
152
153 pub fn run_with_debugger(&mut self) -> Result<()> {
155 let debugger = super::gdb::GdbExecutor::new(self)?;
156 eprintln!(
157 "connect gdb by running `riscv32im-gdb -ex \"target remote {}\" {}`",
158 debugger.local_addr()?,
159 debugger.elf_path().display()
160 );
161
162 debugger.run()
163 }
164
165 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<Session>
168 where
169 F: FnMut(Segment) -> Result<Box<dyn SegmentRef>> + Send,
170 {
171 scope!("execute");
172
173 let journal = Journal::default();
174 self.env
175 .posix_io
176 .borrow_mut()
177 .with_write_fd(fileno::JOURNAL, journal.clone());
178
179 let segment_limit_po2 = self
180 .env
181 .segment_limit_po2
182 .unwrap_or(DEFAULT_SEGMENT_LIMIT_PO2 as u32) as usize;
183
184 let session_limit = match self.env.session_limit {
185 Some(limit) => CycleLimit::Hard(limit),
186 None => CycleLimit::None,
187 };
188
189 let mut refs = Vec::new();
190 let mut exec = Executor::new(
191 self.image.clone(),
192 self,
193 self.env.input_digest,
194 self.env.trace.clone(),
195 self.env.povw_job_id,
196 );
197
198 let max_insn_cycles = if segment_limit_po2 >= 15 {
199 MAX_INSN_CYCLES
200 } else {
201 MAX_INSN_CYCLES_LOWER_PO2
202 };
203
204 let start_time = Instant::now();
205 let result = exec.run(segment_limit_po2, max_insn_cycles, session_limit, |inner| {
206 let output = inner
207 .claim
208 .terminate_state
209 .is_some()
210 .then(|| -> Option<Result<_>> {
211 inner
212 .claim
213 .output
214 .and_then(|digest| {
215 (digest != Digest::ZERO).then(|| journal.buf.lock().unwrap().clone())
216 })
217 .map(|journal| {
218 Ok(Output {
219 journal: journal.into(),
220 assumptions: Assumptions(
221 self.syscall_table
222 .assumptions_used
223 .lock()
224 .unwrap()
225 .iter()
226 .map(|(a, _)| a.clone().into())
227 .collect::<Vec<_>>(),
228 )
229 .into(),
230 })
231 })
232 })
233 .flatten()
234 .transpose()?;
235
236 let segment = Segment {
237 index: inner.index as u32,
238 inner,
239 output,
240 };
241 let segment_ref = callback(segment)?;
242 refs.push(segment_ref);
243 Ok(())
244 })?;
245 let elapsed = start_time.elapsed();
246
247 tracing::debug!("output_digest: {:?}", result.claim.output);
248
249 let exit_code = exit_code_from_terminate_state(&result.claim.terminate_state)?;
250
251 let session_journal = result.claim.output.and_then(|digest| {
253 (digest != Digest::ZERO).then(|| std::mem::take(&mut *journal.buf.lock().unwrap()))
254 });
255 if !exit_code.expects_output() && session_journal.is_some() {
256 tracing::debug!(
257 "dropping non-empty journal due to exit code {exit_code:?}: 0x{}",
258 hex::encode(journal.buf.lock().unwrap().as_slice())
259 );
260 };
261
262 let ecall_metrics = exec.take_ecall_metrics();
263
264 let assumptions = std::mem::take(&mut *self.syscall_table.assumptions_used.lock().unwrap());
267 let mmr_assumptions = self.syscall_table.mmr_assumptions.take();
268 let pending_keccaks = self.syscall_table.pending_keccaks.take();
269
270 if let Some(profiler) = self.profiler.take() {
271 let report = profiler.borrow_mut().finalize_to_vec();
272 std::fs::write(self.env.pprof_out.as_ref().unwrap(), report)?;
273 }
274
275 self.image = result.post_image.clone();
276 let syscall_metrics = self.syscall_table.metrics.borrow().clone();
277
278 let post_digest = match exit_code {
280 ExitCode::Halted(_) => Digest::ZERO,
281 _ => result.claim.post_state,
282 };
283
284 let session = Session {
285 segments: refs,
286 input: self.env.input_digest.unwrap_or_default(),
287 journal: session_journal.map(crate::Journal::new),
288 exit_code,
289 assumptions,
290 mmr_assumptions,
291 user_cycles: result.user_cycles,
292 paging_cycles: result.paging_cycles,
293 reserved_cycles: result.reserved_cycles,
294 total_cycles: result.total_cycles,
295 pre_state: SystemState {
296 pc: 0,
297 merkle_root: result.claim.pre_state,
298 },
299 post_state: SystemState {
300 pc: 0,
301 merkle_root: post_digest,
302 },
303 pending_keccaks,
304 syscall_metrics,
305 hooks: vec![],
306 ecall_metrics: ecall_metrics.into(),
307 povw_job_id: self.env.povw_job_id,
308 };
309
310 tracing::info!("execution time: {elapsed:?}");
311 session.log();
312
313 assert_eq!(
314 session.total_cycles,
315 session.user_cycles + session.paging_cycles + session.reserved_cycles
316 );
317
318 Ok(session)
319 }
320}
321
322struct ContextAdapter<'a, 'b> {
323 ctx: &'b mut dyn CircuitSyscallContext,
324 syscall_table: SyscallTable<'a>,
325}
326
327impl<'a> SyscallContext<'a> for ContextAdapter<'a, '_> {
328 fn get_pc(&self) -> u32 {
329 self.ctx.get_pc()
330 }
331
332 fn get_cycle(&self) -> u64 {
333 self.ctx.get_cycle()
334 }
335
336 fn load_register(&mut self, idx: usize) -> u32 {
337 self.ctx.peek_register(idx).unwrap()
338 }
339
340 fn load_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
341 self.ctx.peek_page(page_idx)
342 }
343
344 fn load_u8(&mut self, addr: ByteAddr) -> Result<u8> {
345 self.ctx.peek_u8(addr)
346 }
347
348 fn load_u32(&mut self, addr: ByteAddr) -> Result<u32> {
349 self.ctx.peek_u32(addr)
350 }
351
352 fn syscall_table(&self) -> &SyscallTable<'a> {
353 &self.syscall_table
354 }
355}
356
357impl CircuitSyscall for ExecutorImpl<'_> {
358 fn host_read(
359 &self,
360 ctx: &mut dyn CircuitSyscallContext,
361 fd: u32,
362 buf: &mut [u8],
363 ) -> Result<u32> {
364 if fd == 0 {
365 let (a0, a1) = self.return_cache.get();
366 tracing::trace!("host_read(buf: {}) -> ({a0:#010x}, {a1:#010x})", buf.len());
367 let buf: &mut [u32] = bytemuck::cast_slice_mut(buf);
368 (buf[0], buf[1]) = (a0, a1);
369 return Ok(2 * WORD_SIZE as u32);
370 }
371
372 let mut ctx = ContextAdapter {
373 ctx,
374 syscall_table: self.syscall_table.clone(),
375 };
376
377 let name_ptr = ByteAddr(fd);
378 let syscall = ctx.peek_string(name_ptr)?;
379 tracing::trace!("host_read({syscall}, into_guest: {})", buf.len());
380
381 let words = align_up(buf.len(), WORD_SIZE) / WORD_SIZE;
382 let mut to_guest = vec![0u32; words];
383
384 self.return_cache.set(
385 self.syscall_table
386 .get_syscall(&syscall)
387 .context(format!("Unknown syscall: {syscall:?}"))?
388 .borrow_mut()
389 .syscall(&syscall, &mut ctx, &mut to_guest)?,
390 );
391
392 let bytes = bytemuck::cast_slice(to_guest.as_slice());
393 let rlen = buf.len();
394 buf.copy_from_slice(&bytes[..rlen]);
395
396 Ok(rlen as u32)
397 }
398
399 fn host_write(&self, ctx: &mut dyn CircuitSyscallContext, _fd: u32, buf: &[u8]) -> Result<u32> {
400 let str = String::from_utf8(buf.to_vec())?;
401 tracing::debug!("R0VM[{}] {str}", ctx.get_cycle());
402 Ok(buf.len() as u32)
403 }
404}
405
406impl ContextAdapter<'_, '_> {
407 fn peek_string(&mut self, mut addr: ByteAddr) -> Result<String> {
408 tracing::trace!("peek_string: {addr:?}");
409 let mut buf = Vec::new();
410 loop {
411 let bytes = self.ctx.peek_u8(addr)?;
412 if bytes == 0 {
413 break;
414 }
415 buf.push(bytes);
416 addr += 1u32;
417 }
418 Ok(String::from_utf8(buf)?)
419 }
420}