risc0_zkvm/host/server/exec/
executor.rs1use std::{
16 cell::{Cell, RefCell},
17 rc::Rc,
18 sync::Arc,
19 time::Instant,
20};
21
22use anyhow::{bail, Context as _, Result};
23use risc0_binfmt::{
24 AbiKind, ByteAddr, ExitCode, MemoryImage, Program, ProgramBinary, ProgramBinaryHeader,
25 SystemState,
26};
27use risc0_circuit_rv32im::{
28 execute::{
29 platform::WORD_SIZE, Executor, Syscall as CircuitSyscall,
30 SyscallContext as CircuitSyscallContext, DEFAULT_SEGMENT_LIMIT_PO2,
31 },
32 MAX_INSN_CYCLES,
33};
34use risc0_core::scope;
35use risc0_zkp::core::digest::Digest;
36use risc0_zkvm_platform::{align_up, fileno};
37use tempfile::tempdir;
38
39use crate::{
40 host::{client::env::SegmentPath, server::session::Session},
41 receipt_claim::exit_code_from_rv32im_v2_claim,
42 Assumptions, ExecutorEnv, FileSegmentRef, Output, Segment, SegmentRef,
43};
44
45use super::{
46 profiler::{self, Profiler},
47 syscall::{SyscallContext, SyscallTable},
48 Journal,
49};
50
51pub struct ExecutorImpl<'a> {
55 env: ExecutorEnv<'a>,
56 image: MemoryImage,
57 pub(crate) syscall_table: SyscallTable<'a>,
58 profiler: Option<Rc<RefCell<Profiler>>>,
59 return_cache: Cell<(u32, u32)>,
60}
61
62fn check_program_version(header: &ProgramBinaryHeader) -> Result<()> {
64 let abi_kind = header.abi_kind;
65 let abi_version = &header.abi_version;
66
67 if abi_kind != AbiKind::V1Compat {
68 bail!("ProgramBinary abi_kind mismatch {abi_kind:?} != AbiKind::V1Compat");
69 }
70 if !semver::VersionReq::parse("^1.0.0")
71 .unwrap()
72 .matches(abi_version)
73 {
74 bail!("ProgramBinary abi_version mismatch {abi_version} doesn't match ^1.0.0");
75 }
76
77 Ok(())
78}
79
80impl<'a> ExecutorImpl<'a> {
81 #[allow(dead_code)]
89 pub fn new(env: ExecutorEnv<'a>, image: MemoryImage) -> Result<Self> {
90 Self::with_details(env, image, None)
91 }
92
93 pub fn from_elf(mut env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
97 let binary = ProgramBinary::decode(elf)?;
98 check_program_version(&binary.header)?;
99
100 let image = binary.to_image()?;
101
102 let profiler = if env.pprof_out.is_some() {
103 let profiler = Rc::new(RefCell::new(Profiler::new(
104 &binary,
105 None,
106 profiler::read_enable_inline_functions_env_var(),
107 )?));
108 env.trace.push(profiler.clone());
109 Some(profiler)
110 } else {
111 None
112 };
113
114 Self::with_details(env, image, profiler)
115 }
116
117 #[allow(dead_code)]
119 pub(crate) fn from_kernel_elf(env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
120 let kernel = Program::load_elf(elf, u32::MAX)?;
121 let image = MemoryImage::new_kernel(kernel);
122 Self::with_details(env, image, None)
123 }
124
125 fn with_details(
126 env: ExecutorEnv<'a>,
127 image: MemoryImage,
128 profiler: Option<Rc<RefCell<Profiler>>>,
129 ) -> Result<Self> {
130 let syscall_table = SyscallTable::from_env(&env);
131 Ok(Self {
132 env,
133 image,
134 syscall_table,
135 profiler,
136 return_cache: Cell::new((0, 0)),
137 })
138 }
139
140 pub fn run(&mut self) -> Result<Session> {
143 if self.env.segment_path.is_none() {
144 self.env.segment_path = Some(SegmentPath::TempDir(Arc::new(tempdir()?)));
145 }
146
147 let path = self.env.segment_path.clone().unwrap();
148 self.run_with_callback(|segment| Ok(Box::new(FileSegmentRef::new(&segment, &path)?)))
149 }
150
151 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<Session>
154 where
155 F: FnMut(Segment) -> Result<Box<dyn SegmentRef>> + Send,
156 {
157 scope!("execute");
158
159 let journal = Journal::default();
160 self.env
161 .posix_io
162 .borrow_mut()
163 .with_write_fd(fileno::JOURNAL, journal.clone());
164
165 let segment_limit_po2 = self
166 .env
167 .segment_limit_po2
168 .unwrap_or(DEFAULT_SEGMENT_LIMIT_PO2 as u32) as usize;
169
170 let mut refs = Vec::new();
171 let mut exec = Executor::new(
172 self.image.clone(),
173 self,
174 self.env.input_digest,
175 self.env.trace.clone(),
176 );
177
178 let start_time = Instant::now();
179 let result = exec.run(
180 segment_limit_po2,
181 MAX_INSN_CYCLES,
182 self.env.session_limit,
183 |inner| {
184 let output = inner
185 .claim
186 .terminate_state
187 .is_some()
188 .then(|| -> Option<Result<_>> {
189 inner
190 .claim
191 .output
192 .and_then(|digest| {
193 (digest != Digest::ZERO)
194 .then(|| journal.buf.lock().unwrap().clone())
195 })
196 .map(|journal| {
197 Ok(Output {
198 journal: journal.into(),
199 assumptions: Assumptions(
200 self.syscall_table
201 .assumptions_used
202 .lock()
203 .unwrap()
204 .iter()
205 .map(|(a, _)| a.clone().into())
206 .collect::<Vec<_>>(),
207 )
208 .into(),
209 })
210 })
211 })
212 .flatten()
213 .transpose()?;
214
215 let segment = Segment {
216 index: inner.index as u32,
217 inner,
218 output,
219 };
220 let segment_ref = callback(segment)?;
221 refs.push(segment_ref);
222 Ok(())
223 },
224 )?;
225 let elapsed = start_time.elapsed();
226
227 tracing::debug!("output_digest: {:?}", result.claim.output);
228
229 let exit_code = exit_code_from_rv32im_v2_claim(&result.claim)?;
230
231 let session_journal = result.claim.output.and_then(|digest| {
233 (digest != Digest::ZERO).then(|| std::mem::take(&mut *journal.buf.lock().unwrap()))
234 });
235 if !exit_code.expects_output() && session_journal.is_some() {
236 tracing::debug!(
237 "dropping non-empty journal due to exit code {exit_code:?}: 0x{}",
238 hex::encode(journal.buf.lock().unwrap().as_slice())
239 );
240 };
241
242 let ecall_metrics = exec.take_ecall_metrics();
243
244 let assumptions = std::mem::take(&mut *self.syscall_table.assumptions_used.lock().unwrap());
247 let mmr_assumptions = self.syscall_table.mmr_assumptions.take();
248 let pending_zkrs = self.syscall_table.pending_zkrs.take();
249 let pending_keccaks = self.syscall_table.pending_keccaks.take();
250
251 if let Some(profiler) = self.profiler.take() {
252 let report = profiler.borrow_mut().finalize_to_vec();
253 std::fs::write(self.env.pprof_out.as_ref().unwrap(), report)?;
254 }
255
256 self.image = result.post_image.clone();
257 let syscall_metrics = self.syscall_table.metrics.borrow().clone();
258
259 let post_digest = match exit_code {
261 ExitCode::Halted(_) => Digest::ZERO,
262 _ => result.claim.post_state,
263 };
264
265 let session = Session {
266 segments: refs,
267 input: self.env.input_digest.unwrap_or_default(),
268 journal: session_journal.map(crate::Journal::new),
269 exit_code,
270 assumptions,
271 mmr_assumptions,
272 user_cycles: result.user_cycles,
273 paging_cycles: result.paging_cycles,
274 reserved_cycles: result.reserved_cycles,
275 total_cycles: result.total_cycles,
276 pre_state: SystemState {
277 pc: 0,
278 merkle_root: result.claim.pre_state,
279 },
280 post_state: SystemState {
281 pc: 0,
282 merkle_root: post_digest,
283 },
284 pending_zkrs,
285 pending_keccaks,
286 syscall_metrics,
287 hooks: vec![],
288 ecall_metrics: ecall_metrics.into(),
289 };
290
291 tracing::info!("execution time: {elapsed:?}");
292 session.log();
293
294 assert_eq!(
295 session.total_cycles,
296 session.user_cycles + session.paging_cycles + session.reserved_cycles
297 );
298
299 Ok(session)
300 }
301}
302
303struct ContextAdapter<'a, 'b> {
304 ctx: &'b mut dyn CircuitSyscallContext,
305 syscall_table: SyscallTable<'a>,
306}
307
308impl<'a> SyscallContext<'a> for ContextAdapter<'a, '_> {
309 fn get_pc(&self) -> u32 {
310 self.ctx.get_pc()
311 }
312
313 fn get_cycle(&self) -> u64 {
314 self.ctx.get_cycle()
315 }
316
317 fn load_register(&mut self, idx: usize) -> u32 {
318 self.ctx.peek_register(idx).unwrap()
319 }
320
321 fn load_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
322 self.ctx.peek_page(page_idx)
323 }
324
325 fn load_u8(&mut self, addr: ByteAddr) -> Result<u8> {
326 self.ctx.peek_u8(addr)
327 }
328
329 fn load_u32(&mut self, addr: ByteAddr) -> Result<u32> {
330 self.ctx.peek_u32(addr)
331 }
332
333 fn syscall_table(&self) -> &SyscallTable<'a> {
334 &self.syscall_table
335 }
336}
337
338impl CircuitSyscall for ExecutorImpl<'_> {
339 fn host_read(
340 &self,
341 ctx: &mut dyn CircuitSyscallContext,
342 fd: u32,
343 buf: &mut [u8],
344 ) -> Result<u32> {
345 if fd == 0 {
346 let (a0, a1) = self.return_cache.get();
347 tracing::trace!("host_read(buf: {}) -> ({a0:#010x}, {a1:#010x})", buf.len());
348 let buf: &mut [u32] = bytemuck::cast_slice_mut(buf);
349 (buf[0], buf[1]) = (a0, a1);
350 return Ok(2 * WORD_SIZE as u32);
351 }
352
353 let mut ctx = ContextAdapter {
354 ctx,
355 syscall_table: self.syscall_table.clone(),
356 };
357
358 let name_ptr = ByteAddr(fd);
359 let syscall = ctx.peek_string(name_ptr)?;
360 tracing::trace!("host_read({syscall}, into_guest: {})", buf.len());
361
362 let words = align_up(buf.len(), WORD_SIZE) / WORD_SIZE;
363 let mut to_guest = vec![0u32; words];
364
365 self.return_cache.set(
366 self.syscall_table
367 .get_syscall(&syscall)
368 .context(format!("Unknown syscall: {syscall:?}"))?
369 .borrow_mut()
370 .syscall(&syscall, &mut ctx, &mut to_guest)?,
371 );
372
373 let bytes = bytemuck::cast_slice(to_guest.as_slice());
374 let rlen = buf.len();
375 buf.copy_from_slice(&bytes[..rlen]);
376
377 Ok(rlen as u32)
378 }
379
380 fn host_write(&self, ctx: &mut dyn CircuitSyscallContext, _fd: u32, buf: &[u8]) -> Result<u32> {
381 let str = String::from_utf8(buf.to_vec())?;
382 tracing::debug!("R0VM[{}] {str}", ctx.get_cycle());
383 Ok(buf.len() as u32)
384 }
385}
386
387impl ContextAdapter<'_, '_> {
388 fn peek_string(&mut self, mut addr: ByteAddr) -> Result<String> {
389 tracing::trace!("peek_string: {addr:?}");
390 let mut buf = Vec::new();
391 loop {
392 let bytes = self.ctx.peek_u8(addr)?;
393 if bytes == 0 {
394 break;
395 }
396 buf.push(bytes);
397 addr += 1u32;
398 }
399 Ok(String::from_utf8(buf)?)
400 }
401}