risc0_zkvm/host/server/exec/
executor.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cell::{Cell, RefCell},
17    rc::Rc,
18    sync::Arc,
19    time::Instant,
20};
21
22use anyhow::{bail, Context as _, Result};
23use risc0_binfmt::{
24    AbiKind, ByteAddr, ExitCode, MemoryImage, Program, ProgramBinary, ProgramBinaryHeader,
25    SystemState,
26};
27use risc0_circuit_rv32im::{
28    execute::{
29        platform::WORD_SIZE, CycleLimit, Executor, Syscall as CircuitSyscall,
30        SyscallContext as CircuitSyscallContext, DEFAULT_SEGMENT_LIMIT_PO2,
31    },
32    MAX_INSN_CYCLES, MAX_INSN_CYCLES_LOWER_PO2,
33};
34use risc0_core::scope;
35use risc0_zkp::core::digest::Digest;
36use risc0_zkvm_platform::{align_up, fileno};
37use tempfile::tempdir;
38
39use crate::{
40    claim::receipt::exit_code_from_terminate_state,
41    host::{client::env::SegmentPath, server::session::Session},
42    Assumptions, ExecutorEnv, FileSegmentRef, Output, Segment, SegmentRef,
43};
44
45use super::{
46    profiler::{self, Profiler},
47    syscall::{SyscallContext, SyscallTable},
48    Journal,
49};
50
51// The Executor provides an implementation for the execution phase.
52///
53/// The proving phase uses an execution trace generated by the Executor.
54pub struct ExecutorImpl<'a> {
55    pub(crate) env: ExecutorEnv<'a>,
56    pub(crate) image: MemoryImage,
57    pub(crate) syscall_table: SyscallTable<'a>,
58    pub(crate) elf: Option<Vec<u8>>,
59    profiler: Option<Rc<RefCell<Profiler>>>,
60    return_cache: Cell<(u32, u32)>,
61}
62
63/// Check to see if the executor is compatible with the given guest program.
64fn check_program_version(header: &ProgramBinaryHeader) -> Result<()> {
65    let abi_kind = header.abi_kind;
66    let abi_version = &header.abi_version;
67
68    if abi_kind != AbiKind::V1Compat {
69        bail!("ProgramBinary abi_kind mismatch {abi_kind:?} != AbiKind::V1Compat");
70    }
71    if !semver::VersionReq::parse("^1.0.0")
72        .unwrap()
73        .matches(abi_version)
74    {
75        bail!("ProgramBinary abi_version mismatch {abi_version} doesn't match ^1.0.0");
76    }
77
78    Ok(())
79}
80
81impl<'a> ExecutorImpl<'a> {
82    /// Construct a new [ExecutorImpl] from a [MemoryImage] and entry point.
83    ///
84    /// Before a guest program is proven, the [ExecutorImpl] is responsible for
85    /// deciding where a zkVM program should be split into [Segment]s and what
86    /// work will be done in each segment. This is the execution phase:
87    /// the guest program is executed to determine how its proof should be
88    /// divided into subparts.
89    pub fn new(env: ExecutorEnv<'a>, image: MemoryImage) -> Result<Self> {
90        Self::with_details(env, None, image, None)
91    }
92
93    /// Construct a new [ExecutorImpl] from the ELF binary of the guest program
94    /// you want to run and an [ExecutorEnv] containing relevant
95    /// environmental configuration details.
96    pub fn from_elf(mut env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
97        let binary = ProgramBinary::decode(elf)?;
98        check_program_version(&binary.header)?;
99
100        let image = binary.to_image()?;
101
102        let profiler = if env.pprof_out.is_some() {
103            let profiler = Rc::new(RefCell::new(Profiler::new(
104                &binary,
105                None,
106                profiler::read_enable_inline_functions_env_var(),
107            )?));
108            env.trace.push(profiler.clone());
109            Some(profiler)
110        } else {
111            None
112        };
113
114        Self::with_details(env, Some(binary.user_elf), image, profiler)
115    }
116
117    /// TODO(flaub)
118    #[allow(dead_code)]
119    pub(crate) fn from_kernel_elf(env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
120        let kernel = Program::load_elf(elf, u32::MAX)?;
121        let image = MemoryImage::new_kernel(kernel);
122        Self::with_details(env, Some(elf), image, None)
123    }
124
125    fn with_details(
126        env: ExecutorEnv<'a>,
127        elf: Option<&[u8]>,
128        image: MemoryImage,
129        profiler: Option<Rc<RefCell<Profiler>>>,
130    ) -> Result<Self> {
131        let syscall_table = SyscallTable::from_env(&env);
132        Ok(Self {
133            env,
134            elf: elf.map(|e| e.to_owned()),
135            image,
136            syscall_table,
137            profiler,
138            return_cache: Cell::new((0, 0)),
139        })
140    }
141
142    /// This will run the executor to get a [Session] which contain the results
143    /// of the execution.
144    pub fn run(&mut self) -> Result<Session> {
145        if self.env.segment_path.is_none() {
146            self.env.segment_path = Some(SegmentPath::TempDir(Arc::new(tempdir()?)));
147        }
148
149        let path = self.env.segment_path.clone().unwrap();
150        self.run_with_callback(|segment| Ok(Box::new(FileSegmentRef::new(&segment, &path)?)))
151    }
152
153    /// This will run the executor with a gdb server so gdb can be attached.
154    pub fn run_with_debugger(&mut self) -> Result<()> {
155        let debugger = super::gdb::GdbExecutor::new(self)?;
156        eprintln!(
157            "connect gdb by running `riscv32im-gdb -ex \"target remote {}\" {}`",
158            debugger.local_addr()?,
159            debugger.elf_path().display()
160        );
161
162        debugger.run()
163    }
164
165    /// Run the executor until [crate::ExitCode::Halted] or
166    /// [crate::ExitCode::Paused] is reached, producing a [Session] as a result.
167    pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<Session>
168    where
169        F: FnMut(Segment) -> Result<Box<dyn SegmentRef>> + Send,
170    {
171        scope!("execute");
172
173        let journal = Journal::default();
174        self.env
175            .posix_io
176            .borrow_mut()
177            .with_write_fd(fileno::JOURNAL, journal.clone());
178
179        let segment_limit_po2 = self
180            .env
181            .segment_limit_po2
182            .unwrap_or(DEFAULT_SEGMENT_LIMIT_PO2 as u32) as usize;
183
184        let session_limit = match self.env.session_limit {
185            Some(limit) => CycleLimit::Hard(limit),
186            None => CycleLimit::None,
187        };
188
189        let mut refs = Vec::new();
190        let mut exec = Executor::new(
191            self.image.clone(),
192            self,
193            self.env.input_digest,
194            self.env.trace.clone(),
195            self.env.povw_job_id,
196        );
197
198        let max_insn_cycles = if segment_limit_po2 >= 15 {
199            MAX_INSN_CYCLES
200        } else {
201            MAX_INSN_CYCLES_LOWER_PO2
202        };
203
204        let start_time = Instant::now();
205        let result = exec.run(segment_limit_po2, max_insn_cycles, session_limit, |inner| {
206            let output = inner
207                .claim
208                .terminate_state
209                .is_some()
210                .then(|| -> Option<Result<_>> {
211                    inner
212                        .claim
213                        .output
214                        .and_then(|digest| {
215                            (digest != Digest::ZERO).then(|| journal.buf.lock().unwrap().clone())
216                        })
217                        .map(|journal| {
218                            Ok(Output {
219                                journal: journal.into(),
220                                assumptions: Assumptions(
221                                    self.syscall_table
222                                        .assumptions_used
223                                        .lock()
224                                        .unwrap()
225                                        .iter()
226                                        .map(|(a, _)| a.clone().into())
227                                        .collect::<Vec<_>>(),
228                                )
229                                .into(),
230                            })
231                        })
232                })
233                .flatten()
234                .transpose()?;
235
236            let segment = Segment {
237                index: inner.index as u32,
238                inner,
239                output,
240            };
241            let segment_ref = callback(segment)?;
242            refs.push(segment_ref);
243            Ok(())
244        })?;
245        let elapsed = start_time.elapsed();
246
247        tracing::debug!("output_digest: {:?}", result.claim.output);
248
249        let exit_code = exit_code_from_terminate_state(&result.claim.terminate_state)?;
250
251        // Set the session_journal to the committed data iff the guest set a non-zero output.
252        let session_journal = result.claim.output.and_then(|digest| {
253            (digest != Digest::ZERO).then(|| std::mem::take(&mut *journal.buf.lock().unwrap()))
254        });
255        if !exit_code.expects_output() && session_journal.is_some() {
256            tracing::debug!(
257                "dropping non-empty journal due to exit code {exit_code:?}: 0x{}",
258                hex::encode(journal.buf.lock().unwrap().as_slice())
259            );
260        };
261
262        let ecall_metrics = exec.take_ecall_metrics();
263
264        // Take (clear out) the list of accessed assumptions.
265        // Leave the assumptions cache so it can be used if execution is resumed from pause.
266        let assumptions = std::mem::take(&mut *self.syscall_table.assumptions_used.lock().unwrap());
267        let mmr_assumptions = self.syscall_table.mmr_assumptions.take();
268        let pending_keccaks = self.syscall_table.pending_keccaks.take();
269
270        if let Some(profiler) = self.profiler.take() {
271            let report = profiler.borrow_mut().finalize_to_vec();
272            std::fs::write(self.env.pprof_out.as_ref().unwrap(), report)?;
273        }
274
275        self.image = result.post_image.clone();
276        let syscall_metrics = self.syscall_table.metrics.borrow().clone();
277
278        // NOTE: When a segment ends in a Halted(_) state, the post_digest will be null.
279        let post_digest = match exit_code {
280            ExitCode::Halted(_) => Digest::ZERO,
281            _ => result.claim.post_state,
282        };
283
284        let session = Session {
285            segments: refs,
286            input: self.env.input_digest.unwrap_or_default(),
287            journal: session_journal.map(crate::Journal::new),
288            exit_code,
289            assumptions,
290            mmr_assumptions,
291            user_cycles: result.user_cycles,
292            paging_cycles: result.paging_cycles,
293            reserved_cycles: result.reserved_cycles,
294            total_cycles: result.total_cycles,
295            pre_state: SystemState {
296                pc: 0,
297                merkle_root: result.claim.pre_state,
298            },
299            post_state: SystemState {
300                pc: 0,
301                merkle_root: post_digest,
302            },
303            pending_keccaks,
304            syscall_metrics,
305            hooks: vec![],
306            ecall_metrics: ecall_metrics.into(),
307            povw_job_id: self.env.povw_job_id,
308        };
309
310        tracing::info!("execution time: {elapsed:?}");
311        session.log();
312
313        assert_eq!(
314            session.total_cycles,
315            session.user_cycles + session.paging_cycles + session.reserved_cycles
316        );
317
318        Ok(session)
319    }
320}
321
322struct ContextAdapter<'a, 'b> {
323    ctx: &'b mut dyn CircuitSyscallContext,
324    syscall_table: SyscallTable<'a>,
325}
326
327impl<'a> SyscallContext<'a> for ContextAdapter<'a, '_> {
328    fn get_pc(&self) -> u32 {
329        self.ctx.get_pc()
330    }
331
332    fn get_cycle(&self) -> u64 {
333        self.ctx.get_cycle()
334    }
335
336    fn load_register(&mut self, idx: usize) -> u32 {
337        self.ctx.peek_register(idx).unwrap()
338    }
339
340    fn load_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
341        self.ctx.peek_page(page_idx)
342    }
343
344    fn load_u8(&mut self, addr: ByteAddr) -> Result<u8> {
345        self.ctx.peek_u8(addr)
346    }
347
348    fn load_u32(&mut self, addr: ByteAddr) -> Result<u32> {
349        self.ctx.peek_u32(addr)
350    }
351
352    fn syscall_table(&self) -> &SyscallTable<'a> {
353        &self.syscall_table
354    }
355}
356
357impl CircuitSyscall for ExecutorImpl<'_> {
358    fn host_read(
359        &self,
360        ctx: &mut dyn CircuitSyscallContext,
361        fd: u32,
362        buf: &mut [u8],
363    ) -> Result<u32> {
364        if fd == 0 {
365            let (a0, a1) = self.return_cache.get();
366            tracing::trace!("host_read(buf: {}) -> ({a0:#010x}, {a1:#010x})", buf.len());
367            let buf: &mut [u32] = bytemuck::cast_slice_mut(buf);
368            (buf[0], buf[1]) = (a0, a1);
369            return Ok(2 * WORD_SIZE as u32);
370        }
371
372        let mut ctx = ContextAdapter {
373            ctx,
374            syscall_table: self.syscall_table.clone(),
375        };
376
377        let name_ptr = ByteAddr(fd);
378        let syscall = ctx.peek_string(name_ptr)?;
379        tracing::trace!("host_read({syscall}, into_guest: {})", buf.len());
380
381        let words = align_up(buf.len(), WORD_SIZE) / WORD_SIZE;
382        let mut to_guest = vec![0u32; words];
383
384        self.return_cache.set(
385            self.syscall_table
386                .get_syscall(&syscall)
387                .context(format!("Unknown syscall: {syscall:?}"))?
388                .borrow_mut()
389                .syscall(&syscall, &mut ctx, &mut to_guest)?,
390        );
391
392        let bytes = bytemuck::cast_slice(to_guest.as_slice());
393        let rlen = buf.len();
394        buf.copy_from_slice(&bytes[..rlen]);
395
396        Ok(rlen as u32)
397    }
398
399    fn host_write(&self, ctx: &mut dyn CircuitSyscallContext, _fd: u32, buf: &[u8]) -> Result<u32> {
400        let str = String::from_utf8(buf.to_vec())?;
401        tracing::debug!("R0VM[{}] {str}", ctx.get_cycle());
402        Ok(buf.len() as u32)
403    }
404}
405
406impl ContextAdapter<'_, '_> {
407    fn peek_string(&mut self, mut addr: ByteAddr) -> Result<String> {
408        tracing::trace!("peek_string: {addr:?}");
409        let mut buf = Vec::new();
410        loop {
411            let bytes = self.ctx.peek_u8(addr)?;
412            if bytes == 0 {
413                break;
414            }
415            buf.push(bytes);
416            addr += 1u32;
417        }
418        Ok(String::from_utf8(buf)?)
419    }
420}