risc0_binfmt/
image.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate alloc;
16
17use alloc::{
18    collections::{BTreeMap, BTreeSet},
19    format, vec,
20    vec::Vec,
21};
22use core::mem;
23use lazy_static::lazy_static;
24
25#[cfg(feature = "std")]
26use std::sync::Arc;
27
28use anyhow::{anyhow, bail, Result};
29use derive_more::Debug;
30use risc0_zkp::{
31    core::{
32        digest::{Digest, DIGEST_WORDS},
33        hash::poseidon2::{poseidon2_mix, CELLS},
34    },
35    field::{baby_bear::BabyBearElem, Elem as _},
36};
37use serde::{Deserialize, Serialize};
38
39use crate::{
40    addr::{ByteAddr, WordAddr},
41    Program, PAGE_BYTES, PAGE_WORDS, WORD_SIZE,
42};
43
44const MEMORY_BYTES: u64 = 1 << 32;
45const MEMORY_PAGES: usize = (MEMORY_BYTES / PAGE_BYTES as u64) as usize;
46const MERKLE_TREE_DEPTH: usize = MEMORY_PAGES.ilog2() as usize;
47
48/// Start address for user-mode memory.
49pub const USER_START_ADDR: ByteAddr = ByteAddr(0x0001_0000);
50
51/// Start address for kernel-mode memory.
52pub const KERNEL_START_ADDR: ByteAddr = ByteAddr(0xc000_0000);
53
54const SUSPEND_PC_ADDR: ByteAddr = ByteAddr(0xffff_0210);
55const SUSPEND_MODE_ADDR: ByteAddr = ByteAddr(0xffff_0214);
56
57lazy_static! {
58    static ref ZERO_CACHE: ZeroCache = ZeroCache::new();
59}
60
61struct ZeroCache {
62    pub page: Page,
63    pub digests: Vec<Digest>,
64}
65
66impl ZeroCache {
67    fn new() -> Self {
68        let page = Page::default();
69        let mut digest = page.digest();
70        let mut digests = vec![Digest::ZERO; MERKLE_TREE_DEPTH + 1];
71        for depth in (0..MERKLE_TREE_DEPTH + 1).rev() {
72            digests[depth] = digest;
73            digest = DigestPair {
74                lhs: digest,
75                rhs: digest,
76            }
77            .digest();
78        }
79        Self { page, digests }
80    }
81}
82
83/// A page of memory
84///
85/// This represents a single page of memory. When accessing memory, all the
86/// memory in the page is paged in and then accessible for the rest of the
87/// segment, at which point it is paged out.
88#[cfg(feature = "std")]
89#[derive(Clone)]
90pub struct Page(Arc<Vec<u8>>);
91
92/// A page of memory
93///
94/// This represents a single page of memory. When accessing memory, all the
95/// memory in the page is paged in and then accessible for the rest of the
96/// segment, at which point it is paged out.
97#[cfg(not(feature = "std"))]
98#[derive(Clone)]
99pub struct Page(Vec<u8>);
100
101/// A memory image
102///
103/// A full memory image of a zkVM guest. Includes functio∑nality for accessing
104/// memory and associated digests, and for initializing the memory state for a
105/// [Program].
106#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct MemoryImage {
108    /// TODO(flaub)
109    #[debug("{}", pages.len())]
110    // #[debug("{:#010x?}", pages.keys())]
111    pages: BTreeMap<u32, Page>,
112
113    /// TODO(flaub)
114    #[debug("{}", digests.len())]
115    // #[debug("{:#010x?}", digests.keys())]
116    digests: BTreeMap<u32, Digest>,
117
118    #[debug("{}", dirty.len())]
119    dirty: BTreeSet<u32>,
120}
121
122impl Default for MemoryImage {
123    fn default() -> Self {
124        Self {
125            pages: Default::default(),
126            digests: BTreeMap::from([(1, ZERO_CACHE.digests[0])]),
127            dirty: Default::default(),
128        }
129    }
130}
131
132impl MemoryImage {
133    fn new(image: BTreeMap<u32, u32>) -> Self {
134        let mut this = Self::default();
135        let mut cur_page_idx = u32::MAX;
136        let mut cur_page: Option<Page> = None;
137
138        for (&addr, &word) in image.iter() {
139            let addr = ByteAddr(addr).waddr();
140            let page_idx = addr.page_idx();
141            if page_idx != cur_page_idx {
142                if let Some(page) = cur_page.take() {
143                    this.set_page(cur_page_idx, page);
144                }
145                cur_page = Some(Page::default());
146                cur_page_idx = page_idx;
147            }
148
149            cur_page.as_mut().unwrap().store(addr, word);
150        }
151
152        if let Some(page) = cur_page.take() {
153            this.set_page(cur_page_idx, page);
154        }
155
156        this.update_digests();
157
158        this
159    }
160
161    /// Creates the initial memory state for a user-mode `program`.
162    pub fn new_user(program: Program) -> Self {
163        let mut image = program.image;
164        image.insert(USER_START_ADDR.0, program.entry);
165        Self::new(image)
166    }
167
168    /// Creates the initial memory state for a kernel-mode `program`.
169    pub fn new_kernel(program: Program) -> Self {
170        let mut image = program.image;
171        image.insert(SUSPEND_PC_ADDR.0, program.entry);
172        image.insert(SUSPEND_MODE_ADDR.0, 1);
173        Self::new(image)
174    }
175
176    /// Creates the initial memory state for a user-mode `user` [Program] with a
177    /// kernel-mode `kernel` [Program].
178    pub fn with_kernel(mut user: Program, mut kernel: Program) -> Self {
179        user.image.insert(USER_START_ADDR.0, user.entry);
180        kernel.image.append(&mut user.image);
181        kernel.image.insert(SUSPEND_PC_ADDR.0, kernel.entry);
182        kernel.image.insert(SUSPEND_MODE_ADDR.0, 1);
183        Self::new(kernel.image)
184    }
185
186    /// Returns a set of the page indexes that are loaded.
187    pub fn get_page_indexes(&self) -> BTreeSet<u32> {
188        self.pages.keys().copied().collect()
189    }
190
191    /// Sorted iterator over page digests (page_idx -> Digest)
192    pub fn digests(&self) -> impl Iterator<Item = (&'_ u32, &'_ Digest)> + '_ {
193        self.digests.iter()
194    }
195
196    /// Return the page data, fails if unavailable
197    pub fn get_page(&mut self, page_idx: u32) -> Result<Page> {
198        // If page exists, return it
199        if let Some(page) = self.pages.get(&page_idx) {
200            return Ok(page.clone());
201        }
202
203        // Otherwise try an expand
204        let digest_idx = MEMORY_PAGES as u32 + page_idx;
205        if self.expand_if_zero(digest_idx) {
206            let zero_page = &ZERO_CACHE.page;
207            self.pages.insert(page_idx, zero_page.clone());
208            return Ok(zero_page.clone());
209        }
210
211        // Otherwise fail
212        bail!("Unavailable page: {page_idx}")
213    }
214
215    /// Return the page data, panics if not available
216    pub fn get_existing_page(&self, page_idx: u32) -> Page {
217        self.pages.get(&page_idx).unwrap().clone()
218    }
219
220    /// Set the data for a page
221    pub fn set_page(&mut self, page_idx: u32, page: Page) {
222        // tracing::trace!("set_page({page_idx:#08x})");
223        let digest_idx = MEMORY_PAGES as u32 + page_idx;
224        self.expand_if_zero(digest_idx);
225        self.digests.insert(digest_idx, page.digest());
226        self.pages.insert(page_idx, page);
227        self.mark_dirty(digest_idx);
228    }
229
230    /// Get a digest, fails if unavailable
231    pub fn get_digest(&mut self, digest_idx: u32) -> Result<&Digest> {
232        // Expand if needed
233        self.expand_if_zero(digest_idx);
234        self.digests
235            .get(&digest_idx)
236            .ok_or_else(|| anyhow!("Unavailable digest: {digest_idx}"))
237    }
238
239    /// Get a digest, panics if not available
240    pub fn get_existing_digest(&self, digest_idx: u32) -> &Digest {
241        self.digests.get(&digest_idx).unwrap()
242    }
243
244    /// Set a digest
245    pub fn set_digest(&mut self, digest_idx: u32, digest: Digest) {
246        // If digest is in a zero region, reify for proper uncles
247        self.expand_if_zero(digest_idx);
248        // Set the digest value
249        self.digests.insert(digest_idx, digest);
250        self.mark_dirty(digest_idx);
251    }
252
253    /// Return the root digest
254    pub fn image_id(&mut self) -> Digest {
255        *self.get_digest(1).unwrap()
256    }
257
258    /// Return the user portion of the MT
259    pub fn user_id(&mut self) -> Digest {
260        *self.get_digest(2).unwrap()
261    }
262
263    /// Return the kernel portion of the MT
264    pub fn kernel_id(&mut self) -> Digest {
265        *self.get_digest(3).unwrap()
266    }
267
268    /// Expand if digest at `digest_idx` is a zero, return if expanded
269    fn expand_if_zero(&mut self, digest_idx: u32) -> bool {
270        self.is_zero(digest_idx)
271            .then(|| {
272                self.expand_zero(digest_idx);
273            })
274            .is_some()
275    }
276
277    /// Check if given MT node is a zero
278    fn is_zero(&self, mut digest_idx: u32) -> bool {
279        // Compute the depth in the tree of this node
280        let mut depth = digest_idx.ilog2() as usize;
281        // Go up until we hit a valid node or get past the root
282        while !self.digests.contains_key(&digest_idx) && digest_idx > 0 {
283            digest_idx /= 2;
284            depth -= 1;
285        }
286        if digest_idx == 0 {
287            false
288        } else {
289            self.digests[&digest_idx] == ZERO_CACHE.digests[depth]
290        }
291    }
292
293    /// Expand zero MT node.
294    ///
295    /// Presumes `is_zero(digest_idx)` returned true.
296    fn expand_zero(&mut self, mut digest_idx: u32) {
297        // Compute the depth in the tree of this node
298        let mut depth = digest_idx.ilog2() as usize;
299        // Go up until we hit the valid zero node
300        while !self.digests.contains_key(&digest_idx) {
301            let parent_idx = digest_idx / 2;
302            let lhs_idx = parent_idx * 2;
303            let rhs_idx = parent_idx * 2 + 1;
304            self.digests.insert(lhs_idx, ZERO_CACHE.digests[depth]);
305            self.digests.insert(rhs_idx, ZERO_CACHE.digests[depth]);
306            digest_idx = parent_idx;
307            depth -= 1;
308        }
309    }
310
311    /// Mark inner digests as dirty after a change
312    fn mark_dirty(&mut self, mut digest_idx: u32) {
313        while digest_idx != 1 {
314            let parent_idx = digest_idx / 2;
315            let lhs_idx = parent_idx * 2;
316            let rhs_idx = parent_idx * 2 + 1;
317            let lhs = self.digests.get(&lhs_idx);
318            let rhs = self.digests.get(&rhs_idx);
319            if let (Some(_), Some(_)) = (lhs, rhs) {
320                self.dirty.insert(parent_idx);
321                digest_idx = parent_idx;
322            } else {
323                break;
324            };
325        }
326    }
327
328    /// After making changes to the image, call this to update all the digests
329    /// that need to be updated.
330    pub fn update_digests(&mut self) {
331        let dirty: Vec<_> = mem::take(&mut self.dirty).into_iter().collect();
332        for idx in dirty.into_iter().rev() {
333            let lhs_idx = idx * 2;
334            let rhs_idx = idx * 2 + 1;
335            let lhs = *self.digests.get(&lhs_idx).unwrap();
336            let rhs = *self.digests.get(&rhs_idx).unwrap();
337
338            let parent_digest = DigestPair { lhs, rhs }.digest();
339            self.digests.insert(idx, parent_digest);
340        }
341    }
342}
343
344impl Default for Page {
345    fn default() -> Self {
346        Self::from_vec(vec![0; PAGE_BYTES])
347    }
348}
349
350impl Page {
351    /// Caller must ensure given Vec is of length `PAGE_BYTES`
352    fn from_vec(v: Vec<u8>) -> Self {
353        #[cfg(not(feature = "std"))]
354        return Self(v);
355        #[cfg(feature = "std")]
356        return Self(Arc::new(v));
357    }
358
359    /// Produce the digest of this page
360    ///
361    /// Hashes the data in this page to produce a digest wh∑ich can be used for
362    /// verifying memory integrity.
363    pub fn digest(&self) -> Digest {
364        let mut cells = [BabyBearElem::ZERO; CELLS];
365        for i in 0..PAGE_WORDS / DIGEST_WORDS {
366            for j in 0..DIGEST_WORDS {
367                let addr = WordAddr((i * DIGEST_WORDS + j) as u32);
368                let word = self.load(addr);
369                cells[2 * j] = BabyBearElem::new(word & 0xffff);
370                cells[2 * j + 1] = BabyBearElem::new(word >> 16);
371            }
372            poseidon2_mix(&mut cells);
373        }
374        cells_to_digest(&cells)
375    }
376
377    /// Read a word from a page
378    ///
379    /// Loads the data at `addr` from this page. This only looks at the
380    /// subaddress, and does not check if the address belongs to this page.
381    /// Thus, if you pass a [WordAddr] belonging to a different page,
382    /// [Page::load] will load from the address in _this_ page with the same
383    /// [WordAddr::page_subaddr].
384    pub fn load(&self, addr: WordAddr) -> u32 {
385        let byte_addr = addr.page_subaddr().baddr().0 as usize;
386        let mut bytes = [0u8; WORD_SIZE];
387        bytes.clone_from_slice(&self.0[byte_addr..byte_addr + WORD_SIZE]);
388        #[allow(clippy::let_and_return)] // easier to toggle optional tracing
389        let word = u32::from_le_bytes(bytes);
390        // tracing::trace!("load({addr:?}) -> {word:#010x}");
391        word
392    }
393
394    #[cfg(feature = "std")]
395    fn ensure_writable(&mut self) -> &mut [u8] {
396        &mut Arc::make_mut(&mut self.0)[..]
397    }
398
399    #[cfg(not(feature = "std"))]
400    fn ensure_writable(&mut self) -> &mut [u8] {
401        &mut self.0
402    }
403
404    /// Store a word to this page
405    ///
406    /// Stores the data `word` to the address `addr` in this page. This only
407    /// looks at the subaddress, and does not check if the address belongs to
408    /// this page. Thus, if you pass a [WordAddr] belonging to a different page,
409    /// [Page::store] will store to the address in _this_ page with the same
410    /// [WordAddr::page_subaddr].
411    pub fn store(&mut self, addr: WordAddr, word: u32) {
412        let writable_ref = self.ensure_writable();
413
414        let byte_addr = addr.page_subaddr().baddr().0 as usize;
415        // tracing::trace!("store({addr:?}, {byte_addr:#05x}, {word:#010x})");
416        writable_ref[byte_addr..byte_addr + WORD_SIZE].clone_from_slice(&word.to_le_bytes());
417    }
418
419    /// Get a shared reference to the underlying data in the page
420    pub fn data(&self) -> &Vec<u8> {
421        &self.0
422    }
423}
424
425impl Serialize for Page {
426    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
427    where
428        S: serde::ser::Serializer,
429    {
430        self.0.serialize(serializer)
431    }
432}
433
434impl<'de> Deserialize<'de> for Page {
435    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
436    where
437        D: serde::de::Deserializer<'de>,
438    {
439        use serde::de::Error as _;
440
441        let vec = <Vec<u8> as Deserialize>::deserialize(deserializer)?;
442        if vec.len() != PAGE_BYTES {
443            return Err(D::Error::custom(format!(
444                "serialized page has wrong length {} != {}",
445                vec.len(),
446                PAGE_BYTES
447            )));
448        }
449        Ok(Self::from_vec(vec))
450    }
451}
452
453pub(crate) struct DigestPair {
454    pub(crate) lhs: Digest,
455    pub(crate) rhs: Digest,
456}
457
458impl DigestPair {
459    pub fn digest(&self) -> Digest {
460        let mut cells = [BabyBearElem::ZERO; CELLS];
461        for i in 0..DIGEST_WORDS {
462            cells[i] = BabyBearElem::new(self.rhs.as_words()[i]);
463            cells[DIGEST_WORDS + i] = BabyBearElem::new(self.lhs.as_words()[i]);
464        }
465        poseidon2_mix(&mut cells);
466        cells_to_digest(&cells)
467    }
468}
469
470fn cells_to_digest(cells: &[BabyBearElem; CELLS]) -> Digest {
471    Digest::new([
472        cells[0].as_u32(),
473        cells[1].as_u32(),
474        cells[2].as_u32(),
475        cells[3].as_u32(),
476        cells[4].as_u32(),
477        cells[5].as_u32(),
478        cells[6].as_u32(),
479        cells[7].as_u32(),
480    ])
481}
482
483#[cfg(test)]
484mod tests {
485    use std::collections::BTreeMap;
486
487    use risc0_zkp::digest;
488    use test_log::test;
489
490    use super::{MemoryImage, Program, ZERO_CACHE};
491
492    #[test]
493    fn poseidon2_zeros() {
494        let expected = [
495            digest!("f85c5a32ccc45c22f9686b08d710d4597d7ce256cdcd63146426270d9432c644"),
496            digest!("2ce7714c40af126c2e86f320b10de417eddd8f51d2b9133d3105c3541a154812"),
497            digest!("889c443e0c55734c0212fe6c400f00423c421f2070b1340351e77826e4918274"),
498            digest!("53ea92273a7dfb7622de685c49f4ce1bd69db1696cd6846e9f5de56c89098b01"),
499            digest!("82db13229831cb2ad63df0476dc1f217c702503d46770c283b6ecc1520fff074"),
500            digest!("45cba5321f90c34b780d5d1790f23612fb834b3d21dc1e53594826470719ba34"),
501            digest!("132689262568ae5ac27a4b65018aef0b2e4345578a16453acd874973a61c6350"),
502            digest!("9fc9626e87aa3614eb38b44d9d832712fb2ea32427c6fd49281ca225f1fefd0d"),
503            digest!("70947164fe9a4353fa33fb024f09ea0df24be40d88b6025278a3472ac49e6715"),
504            digest!("4b707f15d9941c0168d630618cdcc05ccae5d84ab9674a6666123a0039915173"),
505            digest!("97fb1325724ddb74b1446b5bfa13f02c2ecb1b2b2a2f5b1334a04c5c76335d12"),
506            digest!("adba743a459eb5357487a1238a0c4c238b8313458283900447e9b8540adfb042"),
507            digest!("a16e68725fe981434dcca548e972214b2dd85e017c3a4e03909a0f4c31a08741"),
508            digest!("fb94f356397279703f12c24da7aa371e192294347af15d46f10ab512708cdb68"),
509            digest!("30a2fe1aa5c2ae0e10b91074e34b06742be91e450a9bc10f28ab082263c48750"),
510            digest!("2347f636d9a0ea45bbe8bf519f39d3127f72b625e2e5495f26a6dd583eb2965d"),
511            digest!("e43d140e71e366521152d932e846c73535674921576711023deaee06de3b091e"),
512            digest!("35500a740d3a8b4e5a0ca06a8362f3444456e3206826102dd9e9bc3e5a1a5a18"),
513            digest!("7c650c1a2000ef1a9baf4f56c2d66e76a3a0b4510175b171268d156a25d8dd45"),
514            digest!("d73a1e0997a00543afd8de5261f316704215ce384e3ea13df3f87e000f04fb5f"),
515            digest!("5b77f60275cb272fa0a3d267bdf1fc15021dbe7185ed6a3c94e45d70bbd70148"),
516            digest!("e053c93b359c8905c5d8523139988b0ed4ef3426864a80498dfcb91d9b813364"),
517            digest!("242ce034cc4e9326f8b7071124454b2be1a1cd5d21b6483c7ff81d4ba5ac9566"),
518        ];
519        assert_eq!(ZERO_CACHE.digests, expected);
520    }
521
522    #[test]
523    fn image_circuit_match() {
524        let entry = 0x10000;
525        let program = Program {
526            entry,
527            image: BTreeMap::from([(entry, 0x1234b337)]),
528        };
529        let mut image = MemoryImage::new_kernel(program);
530        assert_eq!(
531            *image.get_digest(0x0040_0100).unwrap(),
532            digest!("242ce034cc4e9326f8b7071124454b2be1a1cd5d21b6483c7ff81d4ba5ac9566")
533        );
534        assert_eq!(
535            image.image_id(),
536            digest!("9d41290fa400705127c0240cb646586cc6ea8a23d560aa57cfa86c1369d9d53f")
537        );
538    }
539}