risc0_zkvm/host/server/exec/
executor.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    cell::{Cell, RefCell},
17    rc::Rc,
18    sync::Arc,
19    time::Instant,
20};
21
22use anyhow::{bail, Context as _, Result};
23use risc0_binfmt::{
24    AbiKind, ByteAddr, ExitCode, MemoryImage, Program, ProgramBinary, ProgramBinaryHeader,
25    SystemState,
26};
27use risc0_circuit_rv32im::{
28    execute::{
29        platform::WORD_SIZE, Executor, Syscall as CircuitSyscall,
30        SyscallContext as CircuitSyscallContext, DEFAULT_SEGMENT_LIMIT_PO2,
31    },
32    MAX_INSN_CYCLES,
33};
34use risc0_core::scope;
35use risc0_zkp::core::digest::Digest;
36use risc0_zkvm_platform::{align_up, fileno};
37use tempfile::tempdir;
38
39use crate::{
40    host::{client::env::SegmentPath, server::session::Session},
41    receipt_claim::exit_code_from_rv32im_v2_claim,
42    Assumptions, ExecutorEnv, FileSegmentRef, Output, Segment, SegmentRef,
43};
44
45use super::{
46    profiler::{self, Profiler},
47    syscall::{SyscallContext, SyscallTable},
48    Journal,
49};
50
51// The Executor provides an implementation for the execution phase.
52///
53/// The proving phase uses an execution trace generated by the Executor.
54pub struct ExecutorImpl<'a> {
55    env: ExecutorEnv<'a>,
56    image: MemoryImage,
57    pub(crate) syscall_table: SyscallTable<'a>,
58    profiler: Option<Rc<RefCell<Profiler>>>,
59    return_cache: Cell<(u32, u32)>,
60}
61
62/// Check to see if the executor is compatible with the given guest program.
63fn check_program_version(header: &ProgramBinaryHeader) -> Result<()> {
64    let abi_kind = header.abi_kind;
65    let abi_version = &header.abi_version;
66
67    if abi_kind != AbiKind::V1Compat {
68        bail!("ProgramBinary abi_kind mismatch {abi_kind:?} != AbiKind::V1Compat");
69    }
70    if !semver::VersionReq::parse("^1.0.0")
71        .unwrap()
72        .matches(abi_version)
73    {
74        bail!("ProgramBinary abi_version mismatch {abi_version} doesn't match ^1.0.0");
75    }
76
77    Ok(())
78}
79
80impl<'a> ExecutorImpl<'a> {
81    /// Construct a new [ExecutorImpl] from a [MemoryImage] and entry point.
82    ///
83    /// Before a guest program is proven, the [ExecutorImpl] is responsible for
84    /// deciding where a zkVM program should be split into [Segment]s and what
85    /// work will be done in each segment. This is the execution phase:
86    /// the guest program is executed to determine how its proof should be
87    /// divided into subparts.
88    #[allow(dead_code)]
89    pub fn new(env: ExecutorEnv<'a>, image: MemoryImage) -> Result<Self> {
90        Self::with_details(env, image, None)
91    }
92
93    /// Construct a new [ExecutorImpl] from the ELF binary of the guest program
94    /// you want to run and an [ExecutorEnv] containing relevant
95    /// environmental configuration details.
96    pub fn from_elf(mut env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
97        let binary = ProgramBinary::decode(elf)?;
98        check_program_version(&binary.header)?;
99
100        let image = binary.to_image()?;
101
102        let profiler = if env.pprof_out.is_some() {
103            let profiler = Rc::new(RefCell::new(Profiler::new(
104                &binary,
105                None,
106                profiler::read_enable_inline_functions_env_var(),
107            )?));
108            env.trace.push(profiler.clone());
109            Some(profiler)
110        } else {
111            None
112        };
113
114        Self::with_details(env, image, profiler)
115    }
116
117    /// TODO(flaub)
118    #[allow(dead_code)]
119    pub(crate) fn from_kernel_elf(env: ExecutorEnv<'a>, elf: &[u8]) -> Result<Self> {
120        let kernel = Program::load_elf(elf, u32::MAX)?;
121        let image = MemoryImage::new_kernel(kernel);
122        Self::with_details(env, image, None)
123    }
124
125    fn with_details(
126        env: ExecutorEnv<'a>,
127        image: MemoryImage,
128        profiler: Option<Rc<RefCell<Profiler>>>,
129    ) -> Result<Self> {
130        let syscall_table = SyscallTable::from_env(&env);
131        Ok(Self {
132            env,
133            image,
134            syscall_table,
135            profiler,
136            return_cache: Cell::new((0, 0)),
137        })
138    }
139
140    /// This will run the executor to get a [Session] which contain the results
141    /// of the execution.
142    pub fn run(&mut self) -> Result<Session> {
143        if self.env.segment_path.is_none() {
144            self.env.segment_path = Some(SegmentPath::TempDir(Arc::new(tempdir()?)));
145        }
146
147        let path = self.env.segment_path.clone().unwrap();
148        self.run_with_callback(|segment| Ok(Box::new(FileSegmentRef::new(&segment, &path)?)))
149    }
150
151    /// Run the executor until [crate::ExitCode::Halted] or
152    /// [crate::ExitCode::Paused] is reached, producing a [Session] as a result.
153    pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<Session>
154    where
155        F: FnMut(Segment) -> Result<Box<dyn SegmentRef>> + Send,
156    {
157        scope!("execute");
158
159        let journal = Journal::default();
160        self.env
161            .posix_io
162            .borrow_mut()
163            .with_write_fd(fileno::JOURNAL, journal.clone());
164
165        let segment_limit_po2 = self
166            .env
167            .segment_limit_po2
168            .unwrap_or(DEFAULT_SEGMENT_LIMIT_PO2 as u32) as usize;
169
170        let mut refs = Vec::new();
171        let mut exec = Executor::new(
172            self.image.clone(),
173            self,
174            self.env.input_digest,
175            self.env.trace.clone(),
176        );
177
178        let start_time = Instant::now();
179        let result = exec.run(
180            segment_limit_po2,
181            MAX_INSN_CYCLES,
182            self.env.session_limit,
183            |inner| {
184                let output = inner
185                    .claim
186                    .terminate_state
187                    .is_some()
188                    .then(|| -> Option<Result<_>> {
189                        inner
190                            .claim
191                            .output
192                            .and_then(|digest| {
193                                (digest != Digest::ZERO)
194                                    .then(|| journal.buf.lock().unwrap().clone())
195                            })
196                            .map(|journal| {
197                                Ok(Output {
198                                    journal: journal.into(),
199                                    assumptions: Assumptions(
200                                        self.syscall_table
201                                            .assumptions_used
202                                            .lock()
203                                            .unwrap()
204                                            .iter()
205                                            .map(|(a, _)| a.clone().into())
206                                            .collect::<Vec<_>>(),
207                                    )
208                                    .into(),
209                                })
210                            })
211                    })
212                    .flatten()
213                    .transpose()?;
214
215                let segment = Segment {
216                    index: inner.index as u32,
217                    inner,
218                    output,
219                };
220                let segment_ref = callback(segment)?;
221                refs.push(segment_ref);
222                Ok(())
223            },
224        )?;
225        let elapsed = start_time.elapsed();
226
227        tracing::debug!("output_digest: {:?}", result.claim.output);
228
229        let exit_code = exit_code_from_rv32im_v2_claim(&result.claim)?;
230
231        // Set the session_journal to the committed data iff the guest set a non-zero output.
232        let session_journal = result.claim.output.and_then(|digest| {
233            (digest != Digest::ZERO).then(|| std::mem::take(&mut *journal.buf.lock().unwrap()))
234        });
235        if !exit_code.expects_output() && session_journal.is_some() {
236            tracing::debug!(
237                "dropping non-empty journal due to exit code {exit_code:?}: 0x{}",
238                hex::encode(journal.buf.lock().unwrap().as_slice())
239            );
240        };
241
242        let ecall_metrics = exec.take_ecall_metrics();
243
244        // Take (clear out) the list of accessed assumptions.
245        // Leave the assumptions cache so it can be used if execution is resumed from pause.
246        let assumptions = std::mem::take(&mut *self.syscall_table.assumptions_used.lock().unwrap());
247        let mmr_assumptions = self.syscall_table.mmr_assumptions.take();
248        let pending_zkrs = self.syscall_table.pending_zkrs.take();
249        let pending_keccaks = self.syscall_table.pending_keccaks.take();
250
251        if let Some(profiler) = self.profiler.take() {
252            let report = profiler.borrow_mut().finalize_to_vec();
253            std::fs::write(self.env.pprof_out.as_ref().unwrap(), report)?;
254        }
255
256        self.image = result.post_image.clone();
257        let syscall_metrics = self.syscall_table.metrics.borrow().clone();
258
259        // NOTE: When a segment ends in a Halted(_) state, the post_digest will be null.
260        let post_digest = match exit_code {
261            ExitCode::Halted(_) => Digest::ZERO,
262            _ => result.claim.post_state,
263        };
264
265        let session = Session {
266            segments: refs,
267            input: self.env.input_digest.unwrap_or_default(),
268            journal: session_journal.map(crate::Journal::new),
269            exit_code,
270            assumptions,
271            mmr_assumptions,
272            user_cycles: result.user_cycles,
273            paging_cycles: result.paging_cycles,
274            reserved_cycles: result.reserved_cycles,
275            total_cycles: result.total_cycles,
276            pre_state: SystemState {
277                pc: 0,
278                merkle_root: result.claim.pre_state,
279            },
280            post_state: SystemState {
281                pc: 0,
282                merkle_root: post_digest,
283            },
284            pending_zkrs,
285            pending_keccaks,
286            syscall_metrics,
287            hooks: vec![],
288            ecall_metrics: ecall_metrics.into(),
289        };
290
291        tracing::info!("execution time: {elapsed:?}");
292        session.log();
293
294        assert_eq!(
295            session.total_cycles,
296            session.user_cycles + session.paging_cycles + session.reserved_cycles
297        );
298
299        Ok(session)
300    }
301}
302
303struct ContextAdapter<'a, 'b> {
304    ctx: &'b mut dyn CircuitSyscallContext,
305    syscall_table: SyscallTable<'a>,
306}
307
308impl<'a> SyscallContext<'a> for ContextAdapter<'a, '_> {
309    fn get_pc(&self) -> u32 {
310        self.ctx.get_pc()
311    }
312
313    fn get_cycle(&self) -> u64 {
314        self.ctx.get_cycle()
315    }
316
317    fn load_register(&mut self, idx: usize) -> u32 {
318        self.ctx.peek_register(idx).unwrap()
319    }
320
321    fn load_page(&mut self, page_idx: u32) -> Result<Vec<u8>> {
322        self.ctx.peek_page(page_idx)
323    }
324
325    fn load_u8(&mut self, addr: ByteAddr) -> Result<u8> {
326        self.ctx.peek_u8(addr)
327    }
328
329    fn load_u32(&mut self, addr: ByteAddr) -> Result<u32> {
330        self.ctx.peek_u32(addr)
331    }
332
333    fn syscall_table(&self) -> &SyscallTable<'a> {
334        &self.syscall_table
335    }
336}
337
338impl CircuitSyscall for ExecutorImpl<'_> {
339    fn host_read(
340        &self,
341        ctx: &mut dyn CircuitSyscallContext,
342        fd: u32,
343        buf: &mut [u8],
344    ) -> Result<u32> {
345        if fd == 0 {
346            let (a0, a1) = self.return_cache.get();
347            tracing::trace!("host_read(buf: {}) -> ({a0:#010x}, {a1:#010x})", buf.len());
348            let buf: &mut [u32] = bytemuck::cast_slice_mut(buf);
349            (buf[0], buf[1]) = (a0, a1);
350            return Ok(2 * WORD_SIZE as u32);
351        }
352
353        let mut ctx = ContextAdapter {
354            ctx,
355            syscall_table: self.syscall_table.clone(),
356        };
357
358        let name_ptr = ByteAddr(fd);
359        let syscall = ctx.peek_string(name_ptr)?;
360        tracing::trace!("host_read({syscall}, into_guest: {})", buf.len());
361
362        let words = align_up(buf.len(), WORD_SIZE) / WORD_SIZE;
363        let mut to_guest = vec![0u32; words];
364
365        self.return_cache.set(
366            self.syscall_table
367                .get_syscall(&syscall)
368                .context(format!("Unknown syscall: {syscall:?}"))?
369                .borrow_mut()
370                .syscall(&syscall, &mut ctx, &mut to_guest)?,
371        );
372
373        let bytes = bytemuck::cast_slice(to_guest.as_slice());
374        let rlen = buf.len();
375        buf.copy_from_slice(&bytes[..rlen]);
376
377        Ok(rlen as u32)
378    }
379
380    fn host_write(&self, ctx: &mut dyn CircuitSyscallContext, _fd: u32, buf: &[u8]) -> Result<u32> {
381        let str = String::from_utf8(buf.to_vec())?;
382        tracing::debug!("R0VM[{}] {str}", ctx.get_cycle());
383        Ok(buf.len() as u32)
384    }
385}
386
387impl ContextAdapter<'_, '_> {
388    fn peek_string(&mut self, mut addr: ByteAddr) -> Result<String> {
389        tracing::trace!("peek_string: {addr:?}");
390        let mut buf = Vec::new();
391        loop {
392            let bytes = self.ctx.peek_u8(addr)?;
393            if bytes == 0 {
394                break;
395            }
396            buf.push(bytes);
397            addr += 1u32;
398        }
399        Ok(String::from_utf8(buf)?)
400    }
401}