1use std::{collections::BTreeSet, fs, path::PathBuf};
19
20use anyhow::{ensure, Context, Result};
21use enum_map::EnumMap;
22use risc0_binfmt::{PovwJobId, SystemState};
23use risc0_circuit_keccak::{compute_keccak_digest, KECCAK_CONTROL_ROOT};
24use risc0_circuit_rv32im::{execute::EcallMetric, TerminateState};
25use serde::{Deserialize, Serialize};
26
27use crate::{
28 host::{
29 client::env::{ProveKeccakRequest, SegmentPath},
30 prove_info::SessionStats,
31 },
32 mmr::{GuestPeak, MerkleMountainAccumulator},
33 sha::Digest,
34 Assumption, AssumptionReceipt, Assumptions, ExitCode, Journal, MaybePruned, Output,
35 ReceiptClaim, Work,
36};
37
38use super::exec::syscall::{SyscallKind, SyscallMetric};
39
40#[derive(Clone, Default, Serialize, Deserialize, Debug)]
41pub struct PageFaults {
42 pub(crate) reads: BTreeSet<u32>,
43 pub(crate) writes: BTreeSet<u32>,
44}
45
46#[non_exhaustive]
53pub struct Session {
54 pub segments: Vec<Box<dyn SegmentRef>>,
59
60 pub input: Digest,
62
63 pub journal: Option<Journal>,
65
66 pub exit_code: ExitCode,
68
69 pub assumptions: Vec<(Assumption, AssumptionReceipt)>,
71
72 pub mmr_assumptions: Vec<AssumptionReceipt>,
74
75 pub hooks: Vec<Box<dyn SessionEvents>>,
77
78 pub user_cycles: u64,
81
82 pub paging_cycles: u64,
84
85 pub reserved_cycles: u64,
88
89 pub total_cycles: u64,
92
93 pub pre_state: SystemState,
95
96 pub post_state: SystemState,
98
99 pub(crate) pending_keccaks: Vec<ProveKeccakRequest>,
102
103 pub(crate) ecall_metrics: Vec<(String, EcallMetric)>,
105
106 pub(crate) syscall_metrics: EnumMap<SyscallKind, SyscallMetric>,
108
109 pub(crate) povw_job_id: Option<PovwJobId>,
111}
112
113#[derive(Clone, Serialize, Deserialize)]
122pub struct Segment {
123 pub index: u32,
125
126 pub(crate) inner: risc0_circuit_rv32im::execute::Segment,
127
128 pub(crate) output: Option<Output>,
129}
130
131impl Segment {
132 pub fn po2(&self) -> usize {
136 self.inner.po2 as usize
137 }
138
139 pub(crate) fn user_cycles(&self) -> u32 {
140 self.inner.suspend_cycle
141 }
142}
143
144pub struct PreflightResults {
146 pub(crate) inner: risc0_circuit_rv32im::prove::PreflightResults,
147
148 pub(crate) terminate_state: Option<TerminateState>,
149 pub(crate) output: Option<Output>,
150 pub(crate) segment_index: u32,
151}
152
153impl PreflightResults {
154 pub fn segment_index(&self) -> u32 {
156 self.segment_index
157 }
158}
159
160pub trait SegmentRef: Send {
166 fn resolve(&self) -> Result<Segment>;
168}
169
170pub trait SessionEvents {
172 #[allow(unused)]
174 fn on_pre_prove_segment(&self, segment: &Segment) {}
175
176 #[allow(unused)]
178 fn on_post_prove_segment(&self, segment: &Segment) {}
179}
180
181impl Session {
182 pub fn add_hook<E: SessionEvents + 'static>(&mut self, hook: E) {
184 self.hooks.push(Box::new(hook));
185 }
186
187 pub fn claim(&self) -> Result<ReceiptClaim> {
191 let output = if self.exit_code.expects_output() {
195 self.journal
196 .as_ref()
197 .map(|journal| -> Result<_> {
198 Ok(Output {
199 journal: journal.bytes.clone().into(),
200 assumptions: self
201 .unresolved_assumptions()
202 .context("failed to compute unresolved_assumptions")?
203 .into(),
204 })
205 })
206 .transpose()?
207 } else {
208 ensure!(
209 self.journal.is_none(),
210 "Session with exit code {:?} has a journal",
211 self.exit_code
212 );
213 ensure!(
214 self.assumptions.is_empty(),
215 "Session with exit code {:?} has encoded assumptions",
216 self.exit_code
217 );
218 None
219 };
220
221 Ok(ReceiptClaim {
222 pre: self.pre_state.clone().into(),
223 post: self.post_state.clone().into(),
224 exit_code: self.exit_code,
225 input: MaybePruned::Pruned(self.input),
226 output: output.into(),
227 })
228 }
229
230 fn keccak_root_assumption(&self) -> Result<Option<Assumption>> {
231 let mut keccak_receipts = MerkleMountainAccumulator::<GuestPeak>::new();
232 for proof_request in self.pending_keccaks.iter() {
233 let claim = compute_keccak_digest(bytemuck::cast_slice(proof_request.input.as_slice()));
234 tracing::debug!("adding keccak assumption: {}", claim);
235 keccak_receipts.insert(Assumption {
236 claim,
237 control_root: KECCAK_CONTROL_ROOT,
238 })?;
239 }
240
241 if !keccak_receipts.is_empty() {
242 let root_assumption = keccak_receipts.root()?;
243 tracing::debug!("keccak root assumption for session: {:?}", root_assumption);
244 return Ok(Some(root_assumption));
245 }
246 Ok(None)
247 }
248
249 fn unresolved_assumptions(&self) -> Result<Assumptions> {
250 let keccak_root_assumption = self
251 .keccak_root_assumption()
252 .context("failed to compute keccak root assumption")?;
253 Ok(self
254 .assumptions
255 .iter()
256 .filter_map(|(_, receipt)| match receipt {
257 AssumptionReceipt::Proven(_) => None,
258 AssumptionReceipt::Unresolved(assumption) => {
259 if let Some(ref keccak) = keccak_root_assumption {
260 if keccak == assumption {
261 return None;
262 }
263 }
264 Some(assumption.clone())
265 }
266 })
267 .collect::<Vec<_>>()
268 .into())
269 }
270
271 pub fn work(&self) -> Option<Work> {
273 self.povw_job_id.map(|povw_job_id| Work {
274 nonce_min: povw_job_id.nonce(0),
275 nonce_max: povw_job_id.nonce(self.segments.len() as u32),
276 value: self.total_cycles,
277 })
278 }
279
280 pub fn log(&self) {
284 if std::env::var_os("RISC0_INFO").is_none() {
285 return;
286 }
287
288 let pct = |cycles: u64| cycles as f64 / self.total_cycles as f64 * 100.0;
289
290 tracing::info!("number of segments: {}", self.segments.len());
291 tracing::info!("{} total cycles", self.total_cycles);
292 tracing::info!(
293 "{} user cycles ({:.2}%)",
294 self.user_cycles,
295 pct(self.user_cycles)
296 );
297 tracing::info!(
298 "{} paging cycles ({:.2}%)",
299 self.paging_cycles,
300 pct(self.paging_cycles)
301 );
302 tracing::info!(
303 "{} reserved cycles ({:.2}%)",
304 self.reserved_cycles,
305 pct(self.reserved_cycles)
306 );
307
308 tracing::info!("ecalls");
309 let mut ecall_metrics = self.ecall_metrics.clone();
310 ecall_metrics.sort_by(|a, b| a.1.cycles.cmp(&b.1.cycles));
311 for (name, metric) in ecall_metrics.iter().rev() {
312 tracing::info!(
313 "\t{} {name} calls, {} cycles, ({:.2}%)",
314 metric.count,
315 metric.cycles,
316 pct(metric.cycles)
317 );
318 }
319
320 tracing::info!("syscalls");
321 let mut syscall_metrics: Vec<_> = self.syscall_metrics.iter().collect();
322 syscall_metrics.sort_by(|a, b| a.1.count.cmp(&b.1.count));
323 for (name, metric) in syscall_metrics.iter().rev() {
324 tracing::info!("\t{} {name:?} calls", metric.count);
325 }
326 }
327
328 pub fn stats(&self) -> SessionStats {
332 SessionStats {
333 segments: self.segments.len(),
334 total_cycles: self.total_cycles,
335 user_cycles: self.user_cycles,
336 paging_cycles: self.paging_cycles,
337 reserved_cycles: self.reserved_cycles,
338 }
339 }
340}
341
342#[derive(Serialize, Deserialize)]
346pub struct NullSegmentRef;
347
348impl SegmentRef for NullSegmentRef {
349 fn resolve(&self) -> anyhow::Result<Segment> {
350 unimplemented!()
351 }
352}
353
354pub fn null_callback(_: Segment) -> Result<Box<dyn SegmentRef>> {
355 Ok(Box::new(NullSegmentRef))
356}
357
358#[derive(Clone, Serialize, Deserialize)]
362pub struct SimpleSegmentRef {
363 segment: Segment,
364}
365
366impl SegmentRef for SimpleSegmentRef {
367 fn resolve(&self) -> Result<Segment> {
368 Ok(self.segment.clone())
369 }
370}
371
372impl SimpleSegmentRef {
373 pub fn new(segment: Segment) -> Self {
375 Self { segment }
376 }
377}
378
379pub struct FileSegmentRef {
388 path: PathBuf,
389 _dir: SegmentPath,
390}
391
392impl SegmentRef for FileSegmentRef {
393 fn resolve(&self) -> Result<Segment> {
394 let contents = fs::read(&self.path)?;
395 let segment = bincode::deserialize(&contents)?;
396 Ok(segment)
397 }
398}
399
400impl FileSegmentRef {
401 pub fn new(segment: &Segment, dir: &SegmentPath) -> Result<Self> {
405 let path = dir.path().join(format!("{}.bincode", segment.index));
406 fs::write(&path, bincode::serialize(&segment)?)?;
407 Ok(Self {
408 path,
409 _dir: dir.clone(),
410 })
411 }
412}