1extern crate alloc;
16
17use alloc::{
18 collections::{BTreeMap, BTreeSet},
19 format, vec,
20 vec::Vec,
21};
22use core::mem;
23use lazy_static::lazy_static;
24
25#[cfg(feature = "std")]
26use std::sync::Arc;
27
28use anyhow::{anyhow, bail, Result};
29use derive_more::Debug;
30use risc0_zkp::{
31 core::{
32 digest::{Digest, DIGEST_WORDS},
33 hash::poseidon2::{poseidon2_mix, CELLS},
34 },
35 field::{baby_bear::BabyBearElem, Elem as _},
36};
37use serde::{Deserialize, Serialize};
38
39use crate::{
40 addr::{ByteAddr, WordAddr},
41 Program, PAGE_BYTES, PAGE_WORDS, WORD_SIZE,
42};
43
44const MEMORY_BYTES: u64 = 1 << 32;
45const MEMORY_PAGES: usize = (MEMORY_BYTES / PAGE_BYTES as u64) as usize;
46const MERKLE_TREE_DEPTH: usize = MEMORY_PAGES.ilog2() as usize;
47
48pub const USER_START_ADDR: ByteAddr = ByteAddr(0x0001_0000);
50
51pub const KERNEL_START_ADDR: ByteAddr = ByteAddr(0xc000_0000);
53
54const SUSPEND_PC_ADDR: ByteAddr = ByteAddr(0xffff_0210);
55const SUSPEND_MODE_ADDR: ByteAddr = ByteAddr(0xffff_0214);
56
57lazy_static! {
58 static ref ZERO_CACHE: ZeroCache = ZeroCache::new();
59}
60
61struct ZeroCache {
62 pub page: Page,
63 pub digests: Vec<Digest>,
64}
65
66impl ZeroCache {
67 fn new() -> Self {
68 let page = Page::default();
69 let mut digest = page.digest();
70 let mut digests = vec![Digest::ZERO; MERKLE_TREE_DEPTH + 1];
71 for depth in (0..MERKLE_TREE_DEPTH + 1).rev() {
72 digests[depth] = digest;
73 digest = DigestPair {
74 lhs: digest,
75 rhs: digest,
76 }
77 .digest();
78 }
79 Self { page, digests }
80 }
81}
82
83#[cfg(feature = "std")]
89#[derive(Clone)]
90pub struct Page(Arc<Vec<u8>>);
91
92#[cfg(not(feature = "std"))]
98#[derive(Clone)]
99pub struct Page(Vec<u8>);
100
101#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct MemoryImage {
108 #[debug("{}", pages.len())]
110 pages: BTreeMap<u32, Page>,
112
113 #[debug("{}", digests.len())]
115 digests: BTreeMap<u32, Digest>,
117
118 #[debug("{}", dirty.len())]
119 dirty: BTreeSet<u32>,
120}
121
122impl Default for MemoryImage {
123 fn default() -> Self {
124 Self {
125 pages: Default::default(),
126 digests: BTreeMap::from([(1, ZERO_CACHE.digests[0])]),
127 dirty: Default::default(),
128 }
129 }
130}
131
132impl MemoryImage {
133 fn new(image: BTreeMap<u32, u32>) -> Self {
134 let mut this = Self::default();
135 let mut cur_page_idx = u32::MAX;
136 let mut cur_page: Option<Page> = None;
137
138 for (&addr, &word) in image.iter() {
139 let addr = ByteAddr(addr).waddr();
140 let page_idx = addr.page_idx();
141 if page_idx != cur_page_idx {
142 if let Some(page) = cur_page.take() {
143 this.set_page(cur_page_idx, page);
144 }
145 cur_page = Some(Page::default());
146 cur_page_idx = page_idx;
147 }
148
149 cur_page.as_mut().unwrap().store(addr, word);
150 }
151
152 if let Some(page) = cur_page.take() {
153 this.set_page(cur_page_idx, page);
154 }
155
156 this.update_digests();
157
158 this
159 }
160
161 pub fn new_user(program: Program) -> Self {
163 let mut image = program.image;
164 image.insert(USER_START_ADDR.0, program.entry);
165 Self::new(image)
166 }
167
168 pub fn new_kernel(program: Program) -> Self {
170 let mut image = program.image;
171 image.insert(SUSPEND_PC_ADDR.0, program.entry);
172 image.insert(SUSPEND_MODE_ADDR.0, 1);
173 Self::new(image)
174 }
175
176 pub fn with_kernel(mut user: Program, mut kernel: Program) -> Self {
179 user.image.insert(USER_START_ADDR.0, user.entry);
180 kernel.image.append(&mut user.image);
181 kernel.image.insert(SUSPEND_PC_ADDR.0, kernel.entry);
182 kernel.image.insert(SUSPEND_MODE_ADDR.0, 1);
183 Self::new(kernel.image)
184 }
185
186 pub fn get_page_indexes(&self) -> BTreeSet<u32> {
188 self.pages.keys().copied().collect()
189 }
190
191 pub fn digests(&self) -> impl Iterator<Item = (&'_ u32, &'_ Digest)> + '_ {
193 self.digests.iter()
194 }
195
196 pub fn get_page(&mut self, page_idx: u32) -> Result<Page> {
198 if let Some(page) = self.pages.get(&page_idx) {
200 return Ok(page.clone());
201 }
202
203 let digest_idx = MEMORY_PAGES as u32 + page_idx;
205 if self.expand_if_zero(digest_idx) {
206 let zero_page = &ZERO_CACHE.page;
207 self.pages.insert(page_idx, zero_page.clone());
208 return Ok(zero_page.clone());
209 }
210
211 bail!("Unavailable page: {page_idx}")
213 }
214
215 pub fn set_page(&mut self, page_idx: u32, page: Page) {
217 let digest_idx = MEMORY_PAGES as u32 + page_idx;
219 self.expand_if_zero(digest_idx);
220 self.digests.insert(digest_idx, page.digest());
221 self.pages.insert(page_idx, page);
222 self.mark_dirty(digest_idx);
223 }
224
225 pub fn set_page_with_digest(&mut self, page_idx: u32, page: Page, digest: Digest) {
227 let digest_idx = MEMORY_PAGES as u32 + page_idx;
228 self.expand_if_zero(digest_idx);
229 self.digests.insert(digest_idx, digest);
230 self.pages.insert(page_idx, page);
231 self.mark_dirty(digest_idx);
232 }
233
234 pub fn get_digest(&mut self, digest_idx: u32) -> Result<&Digest> {
236 self.expand_if_zero(digest_idx);
238 self.digests
239 .get(&digest_idx)
240 .ok_or_else(|| anyhow!("Unavailable digest: {digest_idx}"))
241 }
242
243 pub fn set_digest(&mut self, digest_idx: u32, digest: Digest) {
245 self.expand_if_zero(digest_idx);
247 self.digests.insert(digest_idx, digest);
249 self.mark_dirty(digest_idx);
250 }
251
252 pub fn image_id(&mut self) -> Digest {
254 *self.get_digest(1).unwrap()
255 }
256
257 pub fn user_id(&mut self) -> Digest {
259 *self.get_digest(2).unwrap()
260 }
261
262 pub fn kernel_id(&mut self) -> Digest {
264 *self.get_digest(3).unwrap()
265 }
266
267 fn expand_if_zero(&mut self, digest_idx: u32) -> bool {
269 self.is_zero(digest_idx)
270 .then(|| {
271 self.expand_zero(digest_idx);
272 })
273 .is_some()
274 }
275
276 fn is_zero(&self, mut digest_idx: u32) -> bool {
278 let mut depth = digest_idx.ilog2() as usize;
280 while !self.digests.contains_key(&digest_idx) && digest_idx > 0 {
282 digest_idx /= 2;
283 depth -= 1;
284 }
285 if digest_idx == 0 {
286 false
287 } else {
288 self.digests[&digest_idx] == ZERO_CACHE.digests[depth]
289 }
290 }
291
292 fn expand_zero(&mut self, mut digest_idx: u32) {
296 let mut depth = digest_idx.ilog2() as usize;
298 while !self.digests.contains_key(&digest_idx) {
300 let parent_idx = digest_idx / 2;
301 let lhs_idx = parent_idx * 2;
302 let rhs_idx = parent_idx * 2 + 1;
303 self.digests.insert(lhs_idx, ZERO_CACHE.digests[depth]);
304 self.digests.insert(rhs_idx, ZERO_CACHE.digests[depth]);
305 digest_idx = parent_idx;
306 depth -= 1;
307 }
308 }
309
310 fn mark_dirty(&mut self, mut digest_idx: u32) {
312 while digest_idx != 1 {
313 let parent_idx = digest_idx / 2;
314 let lhs_idx = parent_idx * 2;
315 let rhs_idx = parent_idx * 2 + 1;
316 let lhs = self.digests.get(&lhs_idx);
317 let rhs = self.digests.get(&rhs_idx);
318 if let (Some(_), Some(_)) = (lhs, rhs) {
319 self.dirty.insert(parent_idx);
320 digest_idx = parent_idx;
321 } else {
322 break;
323 };
324 }
325 }
326
327 pub fn update_digests(&mut self) {
330 let dirty: Vec<_> = mem::take(&mut self.dirty).into_iter().collect();
331 for idx in dirty.into_iter().rev() {
332 let lhs_idx = idx * 2;
333 let rhs_idx = idx * 2 + 1;
334 let lhs = *self.digests.get(&lhs_idx).unwrap();
335 let rhs = *self.digests.get(&rhs_idx).unwrap();
336
337 let parent_digest = DigestPair { lhs, rhs }.digest();
338 self.digests.insert(idx, parent_digest);
339 }
340 }
341
342 pub fn into_pages(self) -> BTreeMap<u32, Page> {
344 self.pages
345 }
346}
347
348impl Default for Page {
349 fn default() -> Self {
350 Self::from_vec(vec![0; PAGE_BYTES])
351 }
352}
353
354impl Page {
355 fn from_vec(v: Vec<u8>) -> Self {
357 #[cfg(not(feature = "std"))]
358 return Self(v);
359 #[cfg(feature = "std")]
360 return Self(Arc::new(v));
361 }
362
363 pub fn digest(&self) -> Digest {
368 let mut cells = [BabyBearElem::ZERO; CELLS];
369 for i in 0..PAGE_WORDS / DIGEST_WORDS {
370 for j in 0..DIGEST_WORDS {
371 let addr = WordAddr((i * DIGEST_WORDS + j) as u32);
372 let word = self.load(addr);
373 cells[2 * j] = BabyBearElem::new(word & 0xffff);
374 cells[2 * j + 1] = BabyBearElem::new(word >> 16);
375 }
376 poseidon2_mix(&mut cells);
377 }
378 cells_to_digest(&cells)
379 }
380
381 #[inline(always)]
389 pub fn load(&self, addr: WordAddr) -> u32 {
390 let byte_addr = addr.page_subaddr().baddr().0 as usize;
391 let mut bytes = [0u8; WORD_SIZE];
392 bytes.clone_from_slice(&self.0[byte_addr..byte_addr + WORD_SIZE]);
393 #[allow(clippy::let_and_return)] let word = u32::from_le_bytes(bytes);
395 word
397 }
398
399 #[cfg(feature = "std")]
400 #[inline(always)]
401 fn ensure_writable(&mut self) -> &mut [u8] {
402 &mut Arc::make_mut(&mut self.0)[..]
403 }
404
405 #[cfg(not(feature = "std"))]
406 #[inline(always)]
407 fn ensure_writable(&mut self) -> &mut [u8] {
408 &mut self.0
409 }
410
411 #[inline(always)]
419 pub fn store(&mut self, addr: WordAddr, word: u32) {
420 let writable_ref = self.ensure_writable();
421
422 let byte_addr = addr.page_subaddr().baddr().0 as usize;
423 writable_ref[byte_addr..byte_addr + WORD_SIZE].clone_from_slice(&word.to_le_bytes());
425 }
426
427 #[inline(always)]
429 pub fn data(&self) -> &Vec<u8> {
430 &self.0
431 }
432}
433
434impl Serialize for Page {
435 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
436 where
437 S: serde::ser::Serializer,
438 {
439 self.0.serialize(serializer)
440 }
441}
442
443impl<'de> Deserialize<'de> for Page {
444 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
445 where
446 D: serde::de::Deserializer<'de>,
447 {
448 use serde::de::Error as _;
449
450 let vec = <Vec<u8> as Deserialize>::deserialize(deserializer)?;
451 if vec.len() != PAGE_BYTES {
452 return Err(D::Error::custom(format!(
453 "serialized page has wrong length {} != {}",
454 vec.len(),
455 PAGE_BYTES
456 )));
457 }
458 Ok(Self::from_vec(vec))
459 }
460}
461
462pub(crate) struct DigestPair {
463 pub(crate) lhs: Digest,
464 pub(crate) rhs: Digest,
465}
466
467impl DigestPair {
468 pub fn digest(&self) -> Digest {
469 let mut cells = [BabyBearElem::ZERO; CELLS];
470 for i in 0..DIGEST_WORDS {
471 cells[i] = BabyBearElem::new(self.rhs.as_words()[i]);
472 cells[DIGEST_WORDS + i] = BabyBearElem::new(self.lhs.as_words()[i]);
473 }
474 poseidon2_mix(&mut cells);
475 cells_to_digest(&cells)
476 }
477}
478
479fn cells_to_digest(cells: &[BabyBearElem; CELLS]) -> Digest {
480 Digest::new([
481 cells[0].as_u32(),
482 cells[1].as_u32(),
483 cells[2].as_u32(),
484 cells[3].as_u32(),
485 cells[4].as_u32(),
486 cells[5].as_u32(),
487 cells[6].as_u32(),
488 cells[7].as_u32(),
489 ])
490}
491
492#[cfg(test)]
493mod tests {
494 use std::collections::BTreeMap;
495
496 use risc0_zkp::digest;
497 use test_log::test;
498
499 use super::{MemoryImage, Program, ZERO_CACHE};
500
501 #[test]
502 fn poseidon2_zeros() {
503 let expected = [
504 digest!("f85c5a32ccc45c22f9686b08d710d4597d7ce256cdcd63146426270d9432c644"),
505 digest!("2ce7714c40af126c2e86f320b10de417eddd8f51d2b9133d3105c3541a154812"),
506 digest!("889c443e0c55734c0212fe6c400f00423c421f2070b1340351e77826e4918274"),
507 digest!("53ea92273a7dfb7622de685c49f4ce1bd69db1696cd6846e9f5de56c89098b01"),
508 digest!("82db13229831cb2ad63df0476dc1f217c702503d46770c283b6ecc1520fff074"),
509 digest!("45cba5321f90c34b780d5d1790f23612fb834b3d21dc1e53594826470719ba34"),
510 digest!("132689262568ae5ac27a4b65018aef0b2e4345578a16453acd874973a61c6350"),
511 digest!("9fc9626e87aa3614eb38b44d9d832712fb2ea32427c6fd49281ca225f1fefd0d"),
512 digest!("70947164fe9a4353fa33fb024f09ea0df24be40d88b6025278a3472ac49e6715"),
513 digest!("4b707f15d9941c0168d630618cdcc05ccae5d84ab9674a6666123a0039915173"),
514 digest!("97fb1325724ddb74b1446b5bfa13f02c2ecb1b2b2a2f5b1334a04c5c76335d12"),
515 digest!("adba743a459eb5357487a1238a0c4c238b8313458283900447e9b8540adfb042"),
516 digest!("a16e68725fe981434dcca548e972214b2dd85e017c3a4e03909a0f4c31a08741"),
517 digest!("fb94f356397279703f12c24da7aa371e192294347af15d46f10ab512708cdb68"),
518 digest!("30a2fe1aa5c2ae0e10b91074e34b06742be91e450a9bc10f28ab082263c48750"),
519 digest!("2347f636d9a0ea45bbe8bf519f39d3127f72b625e2e5495f26a6dd583eb2965d"),
520 digest!("e43d140e71e366521152d932e846c73535674921576711023deaee06de3b091e"),
521 digest!("35500a740d3a8b4e5a0ca06a8362f3444456e3206826102dd9e9bc3e5a1a5a18"),
522 digest!("7c650c1a2000ef1a9baf4f56c2d66e76a3a0b4510175b171268d156a25d8dd45"),
523 digest!("d73a1e0997a00543afd8de5261f316704215ce384e3ea13df3f87e000f04fb5f"),
524 digest!("5b77f60275cb272fa0a3d267bdf1fc15021dbe7185ed6a3c94e45d70bbd70148"),
525 digest!("e053c93b359c8905c5d8523139988b0ed4ef3426864a80498dfcb91d9b813364"),
526 digest!("242ce034cc4e9326f8b7071124454b2be1a1cd5d21b6483c7ff81d4ba5ac9566"),
527 ];
528 assert_eq!(ZERO_CACHE.digests, expected);
529 }
530
531 #[test]
532 fn image_circuit_match() {
533 let entry = 0x10000;
534 let program = Program {
535 entry,
536 image: BTreeMap::from([(entry, 0x1234b337)]),
537 };
538 let mut image = MemoryImage::new_kernel(program);
539 assert_eq!(
540 *image.get_digest(0x0040_0100).unwrap(),
541 digest!("242ce034cc4e9326f8b7071124454b2be1a1cd5d21b6483c7ff81d4ba5ac9566")
542 );
543 assert_eq!(
544 image.image_id(),
545 digest!("9d41290fa400705127c0240cb646586cc6ea8a23d560aa57cfa86c1369d9d53f")
546 );
547 }
548}