risc0_zkvm/host/server/
session.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module defines [Session] and [Segment] which provides a way to share
16//! execution traces between the execution phase and the proving phase.
17
18use std::{collections::BTreeSet, fs, path::PathBuf};
19
20use anyhow::{ensure, Result};
21use enum_map::EnumMap;
22use risc0_binfmt::{MemoryImage, SystemState};
23use risc0_circuit_rv32im::prove::{emu::exec::EcallMetric, segment::Segment as CircuitSegment};
24use serde::{Deserialize, Serialize};
25
26use crate::{
27    host::{
28        client::env::{ProveKeccakRequest, ProveZkrRequest, SegmentPath},
29        prove_info::SessionStats,
30    },
31    sha::Digest,
32    Assumption, AssumptionReceipt, Assumptions, ExitCode, Journal, MaybePruned, Output,
33    ReceiptClaim,
34};
35
36use super::exec::syscall::{SyscallKind, SyscallMetric};
37
38#[derive(Clone, Default, Serialize, Deserialize, Debug)]
39pub struct PageFaults {
40    pub(crate) reads: BTreeSet<u32>,
41    pub(crate) writes: BTreeSet<u32>,
42}
43
44/// The execution trace of a program.
45///
46/// The record of memory transactions of an execution that starts from an
47/// initial memory image (which includes the starting PC) and proceeds until
48/// either a sys_halt or a sys_pause syscall is encountered. This record is
49/// stored as a vector of [Segment]s.
50#[non_exhaustive]
51pub struct Session {
52    /// The constituent [Segment]s of the Session. The final [Segment] will have
53    /// an [ExitCode] of [Halted](ExitCode::Halted), [Paused](ExitCode::Paused),
54    /// or [SessionLimit](ExitCode::SessionLimit), and all other [Segment]s (if
55    /// any) will have [ExitCode::SystemSplit].
56    pub segments: Vec<Box<dyn SegmentRef>>,
57
58    /// The input digest.
59    pub input: Digest,
60
61    /// The data publicly committed by the guest program.
62    pub journal: Option<Journal>,
63
64    /// The [ExitCode] of the session.
65    pub exit_code: ExitCode,
66
67    /// The final [MemoryImage] at the end of execution.
68    pub post_image: MemoryImage,
69
70    /// The list of assumptions made by the guest and resolved by the host.
71    pub assumptions: Vec<(Assumption, AssumptionReceipt)>,
72
73    /// The hooks to be called during the proving phase.
74    pub hooks: Vec<Box<dyn SessionEvents>>,
75
76    /// The number of user cycles without any overhead for continuations or po2
77    /// padding.
78    pub user_cycles: u64,
79
80    /// The number of cycles needed for paging operations.
81    pub paging_cycles: u64,
82
83    /// The number of cycles needed for the proof system which includes padding up to the nearest power of 2.
84    pub reserved_cycles: u64,
85
86    /// Total number of cycles that a prover experiences. This includes overhead
87    /// associated with continuations and padding up to the nearest power of 2.
88    pub total_cycles: u64,
89
90    /// The system state of the initial [MemoryImage].
91    pub pre_state: SystemState,
92
93    /// The system state of the final [MemoryImage] at the end of execution.
94    pub post_state: SystemState,
95
96    /// A list of pending ZKR proof requests.
97    // TODO: make this scalable so we don't OOM
98    pub(crate) pending_zkrs: Vec<ProveZkrRequest>,
99
100    /// A list of pending keccak proof requests.
101    // TODO: make this scalable so we don't OOM
102    pub(crate) pending_keccaks: Vec<ProveKeccakRequest>,
103
104    /// ecall metrics grouped by name.
105    pub(crate) ecall_metrics: Vec<(String, EcallMetric)>,
106
107    /// syscall metrics grouped by kind.
108    pub(crate) syscall_metrics: EnumMap<SyscallKind, SyscallMetric>,
109}
110
111/// The execution trace of a portion of a program.
112///
113/// The record of memory transactions of an execution that starts from an
114/// initial memory image, and proceeds until terminated by the system or user.
115/// This represents a chunk of execution work that will be proven in a single
116/// call to the ZKP system. It does not necessarily represent an entire program;
117/// see [Session] for tracking memory transactions until a user-requested
118/// termination.
119#[derive(Clone, Serialize, Deserialize)]
120pub struct Segment {
121    /// The index of this [Segment] within the [Session]
122    pub index: u32,
123
124    pub(crate) inner: CircuitSegment,
125    pub(crate) output: Option<Output>,
126}
127
128impl Segment {
129    /// Give the power of two length of this [Segment]
130    ///
131    /// If the [Segment]'s execution trace had 2^20 rows, this would return 20.
132    pub fn po2(&self) -> usize {
133        self.inner.po2
134    }
135}
136
137/// A reference to a [Segment].
138///
139/// This allows implementers to determine the best way to represent this in an
140/// pluggable manner. See the [SimpleSegmentRef] for a very basic
141/// implementation.
142pub trait SegmentRef: Send {
143    /// Resolve this reference into an actual [Segment].
144    fn resolve(&self) -> Result<Segment>;
145}
146
147/// The Events of [Session]
148pub trait SessionEvents {
149    /// Fired before the proving of a segment starts.
150    #[allow(unused)]
151    fn on_pre_prove_segment(&self, segment: &Segment) {}
152
153    /// Fired after the proving of a segment ends.
154    #[allow(unused)]
155    fn on_post_prove_segment(&self, segment: &Segment) {}
156}
157
158impl Session {
159    /// Construct a new [Session] from its constituent components.
160    #[allow(clippy::too_many_arguments)]
161    pub(crate) fn new(
162        segments: Vec<Box<dyn SegmentRef>>,
163        input: Digest,
164        journal: Option<Vec<u8>>,
165        exit_code: ExitCode,
166        post_image: MemoryImage,
167        assumptions: Vec<(Assumption, AssumptionReceipt)>,
168        user_cycles: u64,
169        paging_cycles: u64,
170        padding_cycles: u64,
171        total_cycles: u64,
172        pre_state: SystemState,
173        post_state: SystemState,
174        pending_zkrs: Vec<ProveZkrRequest>,
175        pending_keccaks: Vec<ProveKeccakRequest>,
176        ecall_metrics: Vec<(String, EcallMetric)>,
177        syscall_metrics: EnumMap<SyscallKind, SyscallMetric>,
178    ) -> Self {
179        Self {
180            segments,
181            input,
182            journal: journal.map(Journal::new),
183            exit_code,
184            post_image,
185            assumptions,
186            hooks: Vec::new(),
187            user_cycles,
188            paging_cycles,
189            reserved_cycles: padding_cycles,
190            total_cycles,
191            pre_state,
192            post_state,
193            pending_zkrs,
194            pending_keccaks,
195            ecall_metrics,
196            syscall_metrics,
197        }
198    }
199
200    /// Add a hook to be called during the proving phase.
201    pub fn add_hook<E: SessionEvents + 'static>(&mut self, hook: E) {
202        self.hooks.push(Box::new(hook));
203    }
204
205    /// Calculate for the [ReceiptClaim] associated with this [Session]. The
206    /// [ReceiptClaim] is the claim that will be proven if this [Session]
207    /// is passed to the [crate::Prover].
208    pub fn claim(&self) -> Result<ReceiptClaim> {
209        // Construct the Output struct for the session, checking internal consistency.
210        // NOTE: The Session output is distinct from the final Segment output because in the
211        // Session output any proven assumptions are not included.
212        self.claim_with_assumptions(self.assumptions.iter().map(|(_, x)| x))
213    }
214
215    pub(crate) fn claim_with_assumptions<'a>(
216        &self,
217        assumptions: impl Iterator<Item = &'a AssumptionReceipt>,
218    ) -> Result<ReceiptClaim> {
219        let output = if self.exit_code.expects_output() {
220            self.journal
221                .as_ref()
222                .map(|journal| -> Result<_> {
223                    Ok(Output {
224                        journal: journal.bytes.clone().into(),
225                        assumptions: Assumptions(
226                            assumptions
227                                .filter_map(|x| match x {
228                                    AssumptionReceipt::Proven(_) => None,
229                                    AssumptionReceipt::Unresolved(a) => Some(a.clone().into()),
230                                })
231                                .collect::<Vec<_>>(),
232                        )
233                        .into(),
234                    })
235                })
236                .transpose()?
237        } else {
238            ensure!(
239                self.journal.is_none(),
240                "Session with exit code {:?} has a journal",
241                self.exit_code
242            );
243            ensure!(
244                self.assumptions.is_empty(),
245                "Session with exit code {:?} has encoded assumptions",
246                self.exit_code
247            );
248            None
249        };
250
251        Ok(ReceiptClaim {
252            pre: self.pre_state.clone().into(),
253            post: self.post_state.clone().into(),
254            exit_code: self.exit_code,
255            input: MaybePruned::Pruned(self.input),
256            output: output.into(),
257        })
258    }
259
260    /// Log cycle information for this [Session].
261    ///
262    /// This logs the total and user cycles for this [Session] at the INFO level.
263    pub fn log(&self) {
264        if std::env::var_os("RISC0_INFO").is_none() {
265            return;
266        }
267
268        let pct = |cycles: u64| cycles as f64 / self.total_cycles as f64 * 100.0;
269
270        tracing::info!("number of segments: {}", self.segments.len());
271        tracing::info!("{} total cycles", self.total_cycles);
272        tracing::info!(
273            "{} user cycles ({:.2}%)",
274            self.user_cycles,
275            pct(self.user_cycles)
276        );
277        tracing::info!(
278            "{} paging cycles ({:.2}%)",
279            self.paging_cycles,
280            pct(self.paging_cycles)
281        );
282        tracing::info!(
283            "{} reserved cycles ({:.2}%)",
284            self.reserved_cycles,
285            pct(self.reserved_cycles)
286        );
287
288        tracing::info!("ecalls");
289        let mut ecall_metrics = self.ecall_metrics.clone();
290        ecall_metrics.sort_by(|a, b| a.1.cycles.cmp(&b.1.cycles));
291        for (name, metric) in ecall_metrics.iter().rev() {
292            tracing::info!(
293                "\t{} {name} calls, {} cycles, ({:.2}%)",
294                metric.count,
295                metric.cycles,
296                pct(metric.cycles)
297            );
298        }
299
300        tracing::info!("syscalls");
301        let mut syscall_metrics: Vec<_> = self.syscall_metrics.iter().collect();
302        syscall_metrics.sort_by(|a, b| a.1.count.cmp(&b.1.count));
303        for (name, metric) in syscall_metrics.iter().rev() {
304            tracing::info!("\t{} {name:?} calls", metric.count);
305        }
306
307        assert_eq!(
308            self.total_cycles,
309            self.user_cycles + self.paging_cycles + self.reserved_cycles
310        );
311    }
312
313    /// Returns stats for the session
314    ///
315    /// This contains cycle and segment information about the session useful for debugging and measuring performance.
316    pub fn stats(&self) -> SessionStats {
317        SessionStats {
318            segments: self.segments.len(),
319            total_cycles: self.total_cycles,
320            user_cycles: self.user_cycles,
321            paging_cycles: self.paging_cycles,
322            reserved_cycles: self.reserved_cycles,
323        }
324    }
325}
326
327/// Implementation of a [SegmentRef] that does not save the segment.
328///
329/// This is useful for DevMode where the segments aren't needed.
330#[derive(Serialize, Deserialize)]
331pub struct NullSegmentRef;
332
333impl SegmentRef for NullSegmentRef {
334    fn resolve(&self) -> anyhow::Result<Segment> {
335        unimplemented!()
336    }
337}
338
339pub fn null_callback(_: Segment) -> Result<Box<dyn SegmentRef>> {
340    Ok(Box::new(NullSegmentRef))
341}
342
343/// A very basic implementation of a [SegmentRef].
344///
345/// The [Segment] itself is stored in this implementation.
346#[derive(Clone, Serialize, Deserialize)]
347pub struct SimpleSegmentRef {
348    segment: Segment,
349}
350
351impl SegmentRef for SimpleSegmentRef {
352    fn resolve(&self) -> Result<Segment> {
353        Ok(self.segment.clone())
354    }
355}
356
357impl SimpleSegmentRef {
358    /// Construct a [SimpleSegmentRef] with the specified [Segment].
359    pub fn new(segment: Segment) -> Self {
360        Self { segment }
361    }
362}
363
364/// A basic implementation of a [SegmentRef] that saves the segment to a file
365///
366/// The [Segment] is stored in a user-specified file in this implementation,
367/// and the SegmentRef holds the filename.
368///
369/// There is an example of using [FileSegmentRef] in our [EVM example][1]
370///
371/// [1]: https://github.com/risc0/risc0/blob/main/examples/zkevm-demo/src/main.rs
372pub struct FileSegmentRef {
373    path: PathBuf,
374    _dir: SegmentPath,
375}
376
377impl SegmentRef for FileSegmentRef {
378    fn resolve(&self) -> Result<Segment> {
379        let contents = fs::read(&self.path)?;
380        let segment = bincode::deserialize(&contents)?;
381        Ok(segment)
382    }
383}
384
385impl FileSegmentRef {
386    /// Construct a [FileSegmentRef]
387    ///
388    /// This builds a FileSegmentRef that stores `segment` in a file at `path`.
389    pub fn new(segment: &Segment, dir: &SegmentPath) -> Result<Self> {
390        let path = dir.path().join(format!("{}.bincode", segment.index));
391        fs::write(&path, bincode::serialize(&segment)?)?;
392        Ok(Self {
393            path,
394            _dir: dir.clone(),
395        })
396    }
397}