risc0_zkvm/host/server/exec/
executor.rs1use std::{
16 cell::{Cell, RefCell},
17 rc::Rc,
18 sync::Arc,
19 time::Instant,
20};
21
22use anyhow::{bail, Context as _, Result};
23use risc0_binfmt::{
24 AbiKind, ByteAddr, ExitCode, MemoryImage, Program, ProgramBinary, ProgramBinaryHeader,
25 SystemState,
26};
27use risc0_circuit_rv32im::{
28 execute::{
29 platform::WORD_SIZE, Executor, Syscall as CircuitSyscall,
30 SyscallContext as CircuitSyscallContext, DEFAULT_SEGMENT_LIMIT_PO2,
31 },
32 MAX_INSN_CYCLES, MAX_INSN_CYCLES_LOWER_PO2,
33};
34use risc0_core::scope;
35use risc0_zkp::core::digest::Digest;
36use risc0_zkvm_platform::{align_up, fileno};
37use tempfile::tempdir;
38
39use crate::{
40 host::{client::env::SegmentPath, server::session::Session},
41 receipt_claim::exit_code_from_rv32im_v2_claim,
42 Assumptions, ExecutorEnv, FileSegmentRef, Output, Segment, SegmentRef,
43};
44
45use super::{
46 profiler::{self, Profiler},
47 syscall::{SyscallContext, SyscallTable},
48 Journal,
49};
50
51pub struct ExecutorImpl<'a> {
55 env: ExecutorEnv<'a>,
56 image: MemoryImage,
57 pub(crate) syscall_table: SyscallTable<'a>,
58 profiler: Option<Rc<RefCell<Profiler>>>,
59 return_cache: Cell<(u32, u32)>,
60}
61
62fn check_program_version(header: &ProgramBinaryHeader) -> Result<()> {
64 let abi_kind = header.abi_kind;
65 let abi_version = &header.abi_version;
66
67 if abi_kind != AbiKind::V1Compat {
68 bail!("ProgramBinary abi_kind mismatch {abi_kind:?} != AbiKind::V1Compat");
69 }
70 if !semver::VersionReq::parse("^1.0.0")
71 .unwrap()
72 .matches(abi_version)
73 {
74 bail!("ProgramBinary abi_version mismatch {abi_version} doesn't match ^1.0.0");
75 }
76
77 Ok(())
78}
79
80impl<'a> ExecutorImpl<'a> {
81 #[allow(dead_code)]
89 pub fn new(env: ExecutorEnv<'a>, image: MemoryImage) -> Result<Self> {
90 Self::with_details(env, image, None)
91 }
92
93 pub fn from_elf(mut env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
97 let binary = ProgramBinary::decode(elf)?;
98 check_program_version(&binary.header)?;
99
100 let image = binary.to_image()?;
101
102 let profiler = if env.pprof_out.is_some() {
103 let profiler = Rc::new(RefCell::new(Profiler::new(
104 &binary,
105 None,
106 profiler::read_enable_inline_functions_env_var(),
107 )?));
108 env.trace.push(profiler.clone());
109 Some(profiler)
110 } else {
111 None
112 };
113
114 Self::with_details(env, image, profiler)
115 }
116
117 #[allow(dead_code)]
119 pub(crate) fn from_kernel_elf(env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
120 let kernel = Program::load_elf(elf, u32::MAX)?;
121 let image = MemoryImage::new_kernel(kernel);
122 Self::with_details(env, image, None)
123 }
124
125 fn with_details(
126 env: ExecutorEnv<'a>,
127 image: MemoryImage,
128 profiler: Option<Rc<RefCell<Profiler>>>,
129 ) -> Result<Self> {
130 let syscall_table = SyscallTable::from_env(&env);
131 Ok(Self {
132 env,
133 image,
134 syscall_table,
135 profiler,
136 return_cache: Cell::new((0, 0)),
137 })
138 }
139
140 pub fn run(&mut self) -> Result<Session> {
143 if self.env.segment_path.is_none() {
144 self.env.segment_path = Some(SegmentPath::TempDir(Arc::new(tempdir()?)));
145 }
146
147 let path = self.env.segment_path.clone().unwrap();
148 self.run_with_callback(|segment| Ok(Box::new(FileSegmentRef::new(&segment, &path)?)))
149 }
150
151 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<Session>
154 where
155 F: FnMut(Segment) -> Result<Box<dyn SegmentRef>> + Send,
156 {
157 scope!("execute");
158
159 let journal = Journal::default();
160 self.env
161 .posix_io
162 .borrow_mut()
163 .with_write_fd(fileno::JOURNAL, journal.clone());
164
165 let segment_limit_po2 = self
166 .env
167 .segment_limit_po2
168 .unwrap_or(DEFAULT_SEGMENT_LIMIT_PO2 as u32) as usize;
169
170 let mut refs = Vec::new();
171 let mut exec = Executor::new(
172 self.image.clone(),
173 self,
174 self.env.input_digest,
175 self.env.trace.clone(),
176 );
177
178 let max_insn_cycles = if segment_limit_po2 >= 15 {
179 MAX_INSN_CYCLES
180 } else {
181 MAX_INSN_CYCLES_LOWER_PO2
182 };
183
184 let start_time = Instant::now();
185 let result = exec.run(
186 segment_limit_po2,
187 max_insn_cycles,
188 self.env.session_limit,
189 |inner| {
190 let output = inner
191 .claim
192 .terminate_state
193 .is_some()
194 .then(|| -> Option<Result<_>> {
195 inner
196 .claim
197 .output
198 .and_then(|digest| {
199 (digest != Digest::ZERO)
200 .then(|| journal.buf.lock().unwrap().clone())
201 })
202 .map(|journal| {
203 Ok(Output {
204 journal: journal.into(),
205 assumptions: Assumptions(
206 self.syscall_table
207 .assumptions_used
208 .lock()
209 .unwrap()
210 .iter()
211 .map(|(a, _)| a.clone().into())
212 .collect::<Vec<_>>(),
213 )
214 .into(),
215 })
216 })
217 })
218 .flatten()
219 .transpose()?;
220
221 let segment = Segment {
222 index: inner.index as u32,
223 inner,
224 output,
225 };
226 let segment_ref = callback(segment)?;
227 refs.push(segment_ref);
228 Ok(())
229 },
230 )?;
231 let elapsed = start_time.elapsed();
232
233 tracing::debug!("output_digest: {:?}", result.claim.output);
234
235 let exit_code = exit_code_from_rv32im_v2_claim(&result.claim)?;
236
237 let session_journal = result.claim.output.and_then(|digest| {
239 (digest != Digest::ZERO).then(|| std::mem::take(&mut *journal.buf.lock().unwrap()))
240 });
241 if !exit_code.expects_output() && session_journal.is_some() {
242 tracing::debug!(
243 "dropping non-empty journal due to exit code {exit_code:?}: 0x{}",
244 hex::encode(journal.buf.lock().unwrap().as_slice())
245 );
246 };
247
248 let ecall_metrics = exec.take_ecall_metrics();
249
250 let assumptions = std::mem::take(&mut *self.syscall_table.assumptions_used.lock().unwrap());
253 let mmr_assumptions = self.syscall_table.mmr_assumptions.take();
254 let pending_zkrs = self.syscall_table.pending_zkrs.take();
255 let pending_keccaks = self.syscall_table.pending_keccaks.take();
256
257 if let Some(profiler) = self.profiler.take() {
258 let report = profiler.borrow_mut().finalize_to_vec();
259 std::fs::write(self.env.pprof_out.as_ref().unwrap(), report)?;
260 }
261
262 self.image = result.post_image.clone();
263 let syscall_metrics = self.syscall_table.metrics.borrow().clone();
264
265 let post_digest = match exit_code {
267 ExitCode::Halted(_) => Digest::ZERO,
268 _ => result.claim.post_state,
269 };
270
271 let session = Session {
272 segments: refs,
273 input: self.env.input_digest.unwrap_or_default(),
274 journal: session_journal.map(crate::Journal::new),
275 exit_code,
276 assumptions,
277 mmr_assumptions,
278 user_cycles: result.user_cycles,
279 paging_cycles: result.paging_cycles,
280 reserved_cycles: result.reserved_cycles,
281 total_cycles: result.total_cycles,
282 pre_state: SystemState {
283 pc: 0,
284 merkle_root: result.claim.pre_state,
285 },
286 post_state: SystemState {
287 pc: 0,
288 merkle_root: post_digest,
289 },
290 pending_zkrs,
291 pending_keccaks,
292 syscall_metrics,
293 hooks: vec![],
294 ecall_metrics: ecall_metrics.into(),
295 };
296
297 tracing::info!("execution time: {elapsed:?}");
298 session.log();
299
300 assert_eq!(
301 session.total_cycles,
302 session.user_cycles + session.paging_cycles + session.reserved_cycles
303 );
304
305 Ok(session)
306 }
307}
308
309struct ContextAdapter<'a, 'b> {
310 ctx: &'b mut dyn CircuitSyscallContext,
311 syscall_table: SyscallTable<'a>,
312}
313
314impl<'a> SyscallContext<'a> for ContextAdapter<'a, '_> {
315 fn get_pc(&self) -> u32 {
316 self.ctx.get_pc()
317 }
318
319 fn get_cycle(&self) -> u64 {
320 self.ctx.get_cycle()
321 }
322
323 fn load_register(&mut self, idx: usize) -> u32 {
324 self.ctx.peek_register(idx).unwrap()
325 }
326
327 fn load_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
328 self.ctx.peek_page(page_idx)
329 }
330
331 fn load_u8(&mut self, addr: ByteAddr) -> Result<u8> {
332 self.ctx.peek_u8(addr)
333 }
334
335 fn load_u32(&mut self, addr: ByteAddr) -> Result<u32> {
336 self.ctx.peek_u32(addr)
337 }
338
339 fn syscall_table(&self) -> &SyscallTable<'a> {
340 &self.syscall_table
341 }
342}
343
344impl CircuitSyscall for ExecutorImpl<'_> {
345 fn host_read(
346 &self,
347 ctx: &mut dyn CircuitSyscallContext,
348 fd: u32,
349 buf: &mut [u8],
350 ) -> Result<u32> {
351 if fd == 0 {
352 let (a0, a1) = self.return_cache.get();
353 tracing::trace!("host_read(buf: {}) -> ({a0:#010x}, {a1:#010x})", buf.len());
354 let buf: &mut [u32] = bytemuck::cast_slice_mut(buf);
355 (buf[0], buf[1]) = (a0, a1);
356 return Ok(2 * WORD_SIZE as u32);
357 }
358
359 let mut ctx = ContextAdapter {
360 ctx,
361 syscall_table: self.syscall_table.clone(),
362 };
363
364 let name_ptr = ByteAddr(fd);
365 let syscall = ctx.peek_string(name_ptr)?;
366 tracing::trace!("host_read({syscall}, into_guest: {})", buf.len());
367
368 let words = align_up(buf.len(), WORD_SIZE) / WORD_SIZE;
369 let mut to_guest = vec![0u32; words];
370
371 self.return_cache.set(
372 self.syscall_table
373 .get_syscall(&syscall)
374 .context(format!("Unknown syscall: {syscall:?}"))?
375 .borrow_mut()
376 .syscall(&syscall, &mut ctx, &mut to_guest)?,
377 );
378
379 let bytes = bytemuck::cast_slice(to_guest.as_slice());
380 let rlen = buf.len();
381 buf.copy_from_slice(&bytes[..rlen]);
382
383 Ok(rlen as u32)
384 }
385
386 fn host_write(&self, ctx: &mut dyn CircuitSyscallContext, _fd: u32, buf: &[u8]) -> Result<u32> {
387 let str = String::from_utf8(buf.to_vec())?;
388 tracing::debug!("R0VM[{}] {str}", ctx.get_cycle());
389 Ok(buf.len() as u32)
390 }
391}
392
393impl ContextAdapter<'_, '_> {
394 fn peek_string(&mut self, mut addr: ByteAddr) -> Result<String> {
395 tracing::trace!("peek_string: {addr:?}");
396 let mut buf = Vec::new();
397 loop {
398 let bytes = self.ctx.peek_u8(addr)?;
399 if bytes == 0 {
400 break;
401 }
402 buf.push(bytes);
403 addr += 1u32;
404 }
405 Ok(String::from_utf8(buf)?)
406 }
407}