1use std::{collections::BTreeSet, fs, path::PathBuf};
19
20use anyhow::{ensure, Result};
21use enum_map::EnumMap;
22use risc0_binfmt::{MemoryImage, SystemState};
23use risc0_circuit_rv32im::prove::{emu::exec::EcallMetric, segment::Segment as CircuitSegment};
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 host::{
28 client::env::{ProveKeccakRequest, ProveZkrRequest, SegmentPath},
29 prove_info::SessionStats,
30 },
31 sha::Digest,
32 Assumption, AssumptionReceipt, Assumptions, ExitCode, Journal, MaybePruned, Output,
33 ReceiptClaim,
34};
35
36use super::exec::syscall::{SyscallKind, SyscallMetric};
37
38#[derive(Clone, Default, Serialize, Deserialize, Debug)]
39pub struct PageFaults {
40 pub(crate) reads: BTreeSet<u32>,
41 pub(crate) writes: BTreeSet<u32>,
42}
43
44#[non_exhaustive]
51pub struct Session {
52 pub segments: Vec<Box<dyn SegmentRef>>,
57
58 pub input: Digest,
60
61 pub journal: Option<Journal>,
63
64 pub exit_code: ExitCode,
66
67 pub post_image: MemoryImage,
69
70 pub assumptions: Vec<(Assumption, AssumptionReceipt)>,
72
73 pub hooks: Vec<Box<dyn SessionEvents>>,
75
76 pub user_cycles: u64,
79
80 pub paging_cycles: u64,
82
83 pub reserved_cycles: u64,
85
86 pub total_cycles: u64,
89
90 pub pre_state: SystemState,
92
93 pub post_state: SystemState,
95
96 pub(crate) pending_zkrs: Vec<ProveZkrRequest>,
99
100 pub(crate) pending_keccaks: Vec<ProveKeccakRequest>,
103
104 pub(crate) ecall_metrics: Vec<(String, EcallMetric)>,
106
107 pub(crate) syscall_metrics: EnumMap<SyscallKind, SyscallMetric>,
109}
110
111#[derive(Clone, Serialize, Deserialize)]
120pub struct Segment {
121 pub index: u32,
123
124 pub(crate) inner: CircuitSegment,
125 pub(crate) output: Option<Output>,
126}
127
128impl Segment {
129 pub fn po2(&self) -> usize {
133 self.inner.po2
134 }
135}
136
137pub trait SegmentRef: Send {
143 fn resolve(&self) -> Result<Segment>;
145}
146
147pub trait SessionEvents {
149 #[allow(unused)]
151 fn on_pre_prove_segment(&self, segment: &Segment) {}
152
153 #[allow(unused)]
155 fn on_post_prove_segment(&self, segment: &Segment) {}
156}
157
158impl Session {
159 #[allow(clippy::too_many_arguments)]
161 pub(crate) fn new(
162 segments: Vec<Box<dyn SegmentRef>>,
163 input: Digest,
164 journal: Option<Vec<u8>>,
165 exit_code: ExitCode,
166 post_image: MemoryImage,
167 assumptions: Vec<(Assumption, AssumptionReceipt)>,
168 user_cycles: u64,
169 paging_cycles: u64,
170 padding_cycles: u64,
171 total_cycles: u64,
172 pre_state: SystemState,
173 post_state: SystemState,
174 pending_zkrs: Vec<ProveZkrRequest>,
175 pending_keccaks: Vec<ProveKeccakRequest>,
176 ecall_metrics: Vec<(String, EcallMetric)>,
177 syscall_metrics: EnumMap<SyscallKind, SyscallMetric>,
178 ) -> Self {
179 Self {
180 segments,
181 input,
182 journal: journal.map(Journal::new),
183 exit_code,
184 post_image,
185 assumptions,
186 hooks: Vec::new(),
187 user_cycles,
188 paging_cycles,
189 reserved_cycles: padding_cycles,
190 total_cycles,
191 pre_state,
192 post_state,
193 pending_zkrs,
194 pending_keccaks,
195 ecall_metrics,
196 syscall_metrics,
197 }
198 }
199
200 pub fn add_hook<E: SessionEvents + 'static>(&mut self, hook: E) {
202 self.hooks.push(Box::new(hook));
203 }
204
205 pub fn claim(&self) -> Result<ReceiptClaim> {
209 self.claim_with_assumptions(self.assumptions.iter().map(|(_, x)| x))
213 }
214
215 pub(crate) fn claim_with_assumptions<'a>(
216 &self,
217 assumptions: impl Iterator<Item = &'a AssumptionReceipt>,
218 ) -> Result<ReceiptClaim> {
219 let output = if self.exit_code.expects_output() {
220 self.journal
221 .as_ref()
222 .map(|journal| -> Result<_> {
223 Ok(Output {
224 journal: journal.bytes.clone().into(),
225 assumptions: Assumptions(
226 assumptions
227 .filter_map(|x| match x {
228 AssumptionReceipt::Proven(_) => None,
229 AssumptionReceipt::Unresolved(a) => Some(a.clone().into()),
230 })
231 .collect::<Vec<_>>(),
232 )
233 .into(),
234 })
235 })
236 .transpose()?
237 } else {
238 ensure!(
239 self.journal.is_none(),
240 "Session with exit code {:?} has a journal",
241 self.exit_code
242 );
243 ensure!(
244 self.assumptions.is_empty(),
245 "Session with exit code {:?} has encoded assumptions",
246 self.exit_code
247 );
248 None
249 };
250
251 Ok(ReceiptClaim {
252 pre: self.pre_state.clone().into(),
253 post: self.post_state.clone().into(),
254 exit_code: self.exit_code,
255 input: MaybePruned::Pruned(self.input),
256 output: output.into(),
257 })
258 }
259
260 pub fn log(&self) {
264 if std::env::var_os("RISC0_INFO").is_none() {
265 return;
266 }
267
268 let pct = |cycles: u64| cycles as f64 / self.total_cycles as f64 * 100.0;
269
270 tracing::info!("number of segments: {}", self.segments.len());
271 tracing::info!("{} total cycles", self.total_cycles);
272 tracing::info!(
273 "{} user cycles ({:.2}%)",
274 self.user_cycles,
275 pct(self.user_cycles)
276 );
277 tracing::info!(
278 "{} paging cycles ({:.2}%)",
279 self.paging_cycles,
280 pct(self.paging_cycles)
281 );
282 tracing::info!(
283 "{} reserved cycles ({:.2}%)",
284 self.reserved_cycles,
285 pct(self.reserved_cycles)
286 );
287
288 tracing::info!("ecalls");
289 let mut ecall_metrics = self.ecall_metrics.clone();
290 ecall_metrics.sort_by(|a, b| a.1.cycles.cmp(&b.1.cycles));
291 for (name, metric) in ecall_metrics.iter().rev() {
292 tracing::info!(
293 "\t{} {name} calls, {} cycles, ({:.2}%)",
294 metric.count,
295 metric.cycles,
296 pct(metric.cycles)
297 );
298 }
299
300 tracing::info!("syscalls");
301 let mut syscall_metrics: Vec<_> = self.syscall_metrics.iter().collect();
302 syscall_metrics.sort_by(|a, b| a.1.count.cmp(&b.1.count));
303 for (name, metric) in syscall_metrics.iter().rev() {
304 tracing::info!("\t{} {name:?} calls", metric.count);
305 }
306
307 assert_eq!(
308 self.total_cycles,
309 self.user_cycles + self.paging_cycles + self.reserved_cycles
310 );
311 }
312
313 pub fn stats(&self) -> SessionStats {
317 SessionStats {
318 segments: self.segments.len(),
319 total_cycles: self.total_cycles,
320 user_cycles: self.user_cycles,
321 paging_cycles: self.paging_cycles,
322 reserved_cycles: self.reserved_cycles,
323 }
324 }
325}
326
327#[derive(Serialize, Deserialize)]
331pub struct NullSegmentRef;
332
333impl SegmentRef for NullSegmentRef {
334 fn resolve(&self) -> anyhow::Result<Segment> {
335 unimplemented!()
336 }
337}
338
339pub fn null_callback(_: Segment) -> Result<Box<dyn SegmentRef>> {
340 Ok(Box::new(NullSegmentRef))
341}
342
343#[derive(Clone, Serialize, Deserialize)]
347pub struct SimpleSegmentRef {
348 segment: Segment,
349}
350
351impl SegmentRef for SimpleSegmentRef {
352 fn resolve(&self) -> Result<Segment> {
353 Ok(self.segment.clone())
354 }
355}
356
357impl SimpleSegmentRef {
358 pub fn new(segment: Segment) -> Self {
360 Self { segment }
361 }
362}
363
364pub struct FileSegmentRef {
373 path: PathBuf,
374 _dir: SegmentPath,
375}
376
377impl SegmentRef for FileSegmentRef {
378 fn resolve(&self) -> Result<Segment> {
379 let contents = fs::read(&self.path)?;
380 let segment = bincode::deserialize(&contents)?;
381 Ok(segment)
382 }
383}
384
385impl FileSegmentRef {
386 pub fn new(segment: &Segment, dir: &SegmentPath) -> Result<Self> {
390 let path = dir.path().join(format!("{}.bincode", segment.index));
391 fs::write(&path, bincode::serialize(&segment)?)?;
392 Ok(Self {
393 path,
394 _dir: dir.clone(),
395 })
396 }
397}