1extern crate alloc;
16
17use alloc::{
18 collections::{BTreeMap, BTreeSet},
19 format, vec,
20 vec::Vec,
21};
22use core::mem;
23use lazy_static::lazy_static;
24
25#[cfg(feature = "std")]
26use std::sync::Arc;
27
28use anyhow::{anyhow, bail, Result};
29use derive_more::Debug;
30use risc0_zkp::{
31 core::{
32 digest::{Digest, DIGEST_WORDS},
33 hash::poseidon2::{poseidon2_mix, CELLS},
34 },
35 field::{baby_bear::BabyBearElem, Elem as _},
36};
37use serde::{Deserialize, Serialize};
38
39use crate::{
40 addr::{ByteAddr, WordAddr},
41 Program, PAGE_BYTES, PAGE_WORDS, WORD_SIZE,
42};
43
44const MEMORY_BYTES: u64 = 1 << 32;
45const MEMORY_PAGES: usize = (MEMORY_BYTES / PAGE_BYTES as u64) as usize;
46const MERKLE_TREE_DEPTH: usize = MEMORY_PAGES.ilog2() as usize;
47
48pub const USER_START_ADDR: ByteAddr = ByteAddr(0x0001_0000);
50
51pub const KERNEL_START_ADDR: ByteAddr = ByteAddr(0xc000_0000);
53
54const SUSPEND_PC_ADDR: ByteAddr = ByteAddr(0xffff_0210);
55const SUSPEND_MODE_ADDR: ByteAddr = ByteAddr(0xffff_0214);
56
57lazy_static! {
58 static ref ZERO_CACHE: ZeroCache = ZeroCache::new();
59}
60
61struct ZeroCache {
62 pub page: Page,
63 pub digests: Vec<Digest>,
64}
65
66impl ZeroCache {
67 fn new() -> Self {
68 let page = Page::default();
69 let mut digest = page.digest();
70 let mut digests = vec![Digest::ZERO; MERKLE_TREE_DEPTH + 1];
71 for depth in (0..MERKLE_TREE_DEPTH + 1).rev() {
72 digests[depth] = digest;
73 digest = DigestPair {
74 lhs: digest,
75 rhs: digest,
76 }
77 .digest();
78 }
79 Self { page, digests }
80 }
81}
82
83#[cfg(feature = "std")]
89#[derive(Clone)]
90pub struct Page(Arc<Vec<u8>>);
91
92#[cfg(not(feature = "std"))]
98#[derive(Clone)]
99pub struct Page(Vec<u8>);
100
101#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct MemoryImage {
108 #[debug("{}", pages.len())]
110 pages: BTreeMap<u32, Page>,
112
113 #[debug("{}", digests.len())]
115 digests: BTreeMap<u32, Digest>,
117
118 #[debug("{}", dirty.len())]
119 dirty: BTreeSet<u32>,
120}
121
122impl Default for MemoryImage {
123 fn default() -> Self {
124 Self {
125 pages: Default::default(),
126 digests: BTreeMap::from([(1, ZERO_CACHE.digests[0])]),
127 dirty: Default::default(),
128 }
129 }
130}
131
132impl MemoryImage {
133 fn new(image: BTreeMap<u32, u32>) -> Self {
134 let mut this = Self::default();
135 let mut cur_page_idx = u32::MAX;
136 let mut cur_page: Option<Page> = None;
137
138 for (&addr, &word) in image.iter() {
139 let addr = ByteAddr(addr).waddr();
140 let page_idx = addr.page_idx();
141 if page_idx != cur_page_idx {
142 if let Some(page) = cur_page.take() {
143 this.set_page(cur_page_idx, page);
144 }
145 cur_page = Some(Page::default());
146 cur_page_idx = page_idx;
147 }
148
149 cur_page.as_mut().unwrap().store(addr, word);
150 }
151
152 if let Some(page) = cur_page.take() {
153 this.set_page(cur_page_idx, page);
154 }
155
156 this.update_digests();
157
158 this
159 }
160
161 pub fn new_user(program: Program) -> Self {
163 let mut image = program.image;
164 image.insert(USER_START_ADDR.0, program.entry);
165 Self::new(image)
166 }
167
168 pub fn new_kernel(program: Program) -> Self {
170 let mut image = program.image;
171 image.insert(SUSPEND_PC_ADDR.0, program.entry);
172 image.insert(SUSPEND_MODE_ADDR.0, 1);
173 Self::new(image)
174 }
175
176 pub fn with_kernel(mut user: Program, mut kernel: Program) -> Self {
179 user.image.insert(USER_START_ADDR.0, user.entry);
180 kernel.image.append(&mut user.image);
181 kernel.image.insert(SUSPEND_PC_ADDR.0, kernel.entry);
182 kernel.image.insert(SUSPEND_MODE_ADDR.0, 1);
183 Self::new(kernel.image)
184 }
185
186 pub fn get_page_indexes(&self) -> BTreeSet<u32> {
188 self.pages.keys().copied().collect()
189 }
190
191 pub fn digests(&self) -> impl Iterator<Item = (&'_ u32, &'_ Digest)> + '_ {
193 self.digests.iter()
194 }
195
196 pub fn get_page(&mut self, page_idx: u32) -> Result<Page> {
198 if let Some(page) = self.pages.get(&page_idx) {
200 return Ok(page.clone());
201 }
202
203 let digest_idx = MEMORY_PAGES as u32 + page_idx;
205 if self.expand_if_zero(digest_idx) {
206 let zero_page = &ZERO_CACHE.page;
207 self.pages.insert(page_idx, zero_page.clone());
208 return Ok(zero_page.clone());
209 }
210
211 bail!("Unavailable page: {page_idx}")
213 }
214
215 pub fn get_existing_page(&self, page_idx: u32) -> Page {
217 self.pages.get(&page_idx).unwrap().clone()
218 }
219
220 pub fn set_page(&mut self, page_idx: u32, page: Page) {
222 let digest_idx = MEMORY_PAGES as u32 + page_idx;
224 self.expand_if_zero(digest_idx);
225 self.digests.insert(digest_idx, page.digest());
226 self.pages.insert(page_idx, page);
227 self.mark_dirty(digest_idx);
228 }
229
230 pub fn get_digest(&mut self, digest_idx: u32) -> Result<&Digest> {
232 self.expand_if_zero(digest_idx);
234 self.digests
235 .get(&digest_idx)
236 .ok_or_else(|| anyhow!("Unavailable digest: {digest_idx}"))
237 }
238
239 pub fn get_existing_digest(&self, digest_idx: u32) -> &Digest {
241 self.digests.get(&digest_idx).unwrap()
242 }
243
244 pub fn set_digest(&mut self, digest_idx: u32, digest: Digest) {
246 self.expand_if_zero(digest_idx);
248 self.digests.insert(digest_idx, digest);
250 self.mark_dirty(digest_idx);
251 }
252
253 pub fn image_id(&mut self) -> Digest {
255 *self.get_digest(1).unwrap()
256 }
257
258 pub fn user_id(&mut self) -> Digest {
260 *self.get_digest(2).unwrap()
261 }
262
263 pub fn kernel_id(&mut self) -> Digest {
265 *self.get_digest(3).unwrap()
266 }
267
268 fn expand_if_zero(&mut self, digest_idx: u32) -> bool {
270 self.is_zero(digest_idx)
271 .then(|| {
272 self.expand_zero(digest_idx);
273 })
274 .is_some()
275 }
276
277 fn is_zero(&self, mut digest_idx: u32) -> bool {
279 let mut depth = digest_idx.ilog2() as usize;
281 while !self.digests.contains_key(&digest_idx) && digest_idx > 0 {
283 digest_idx /= 2;
284 depth -= 1;
285 }
286 if digest_idx == 0 {
287 false
288 } else {
289 self.digests[&digest_idx] == ZERO_CACHE.digests[depth]
290 }
291 }
292
293 fn expand_zero(&mut self, mut digest_idx: u32) {
297 let mut depth = digest_idx.ilog2() as usize;
299 while !self.digests.contains_key(&digest_idx) {
301 let parent_idx = digest_idx / 2;
302 let lhs_idx = parent_idx * 2;
303 let rhs_idx = parent_idx * 2 + 1;
304 self.digests.insert(lhs_idx, ZERO_CACHE.digests[depth]);
305 self.digests.insert(rhs_idx, ZERO_CACHE.digests[depth]);
306 digest_idx = parent_idx;
307 depth -= 1;
308 }
309 }
310
311 fn mark_dirty(&mut self, mut digest_idx: u32) {
313 while digest_idx != 1 {
314 let parent_idx = digest_idx / 2;
315 let lhs_idx = parent_idx * 2;
316 let rhs_idx = parent_idx * 2 + 1;
317 let lhs = self.digests.get(&lhs_idx);
318 let rhs = self.digests.get(&rhs_idx);
319 if let (Some(_), Some(_)) = (lhs, rhs) {
320 self.dirty.insert(parent_idx);
321 digest_idx = parent_idx;
322 } else {
323 break;
324 };
325 }
326 }
327
328 pub fn update_digests(&mut self) {
331 let dirty: Vec<_> = mem::take(&mut self.dirty).into_iter().collect();
332 for idx in dirty.into_iter().rev() {
333 let lhs_idx = idx * 2;
334 let rhs_idx = idx * 2 + 1;
335 let lhs = *self.digests.get(&lhs_idx).unwrap();
336 let rhs = *self.digests.get(&rhs_idx).unwrap();
337
338 let parent_digest = DigestPair { lhs, rhs }.digest();
339 self.digests.insert(idx, parent_digest);
340 }
341 }
342}
343
344impl Default for Page {
345 fn default() -> Self {
346 Self::from_vec(vec![0; PAGE_BYTES])
347 }
348}
349
350impl Page {
351 fn from_vec(v: Vec<u8>) -> Self {
353 #[cfg(not(feature = "std"))]
354 return Self(v);
355 #[cfg(feature = "std")]
356 return Self(Arc::new(v));
357 }
358
359 pub fn digest(&self) -> Digest {
364 let mut cells = [BabyBearElem::ZERO; CELLS];
365 for i in 0..PAGE_WORDS / DIGEST_WORDS {
366 for j in 0..DIGEST_WORDS {
367 let addr = WordAddr((i * DIGEST_WORDS + j) as u32);
368 let word = self.load(addr);
369 cells[2 * j] = BabyBearElem::new(word & 0xffff);
370 cells[2 * j + 1] = BabyBearElem::new(word >> 16);
371 }
372 poseidon2_mix(&mut cells);
373 }
374 cells_to_digest(&cells)
375 }
376
377 pub fn load(&self, addr: WordAddr) -> u32 {
385 let byte_addr = addr.page_subaddr().baddr().0 as usize;
386 let mut bytes = [0u8; WORD_SIZE];
387 bytes.clone_from_slice(&self.0[byte_addr..byte_addr + WORD_SIZE]);
388 #[allow(clippy::let_and_return)] let word = u32::from_le_bytes(bytes);
390 word
392 }
393
394 #[cfg(feature = "std")]
395 fn ensure_writable(&mut self) -> &mut [u8] {
396 &mut Arc::make_mut(&mut self.0)[..]
397 }
398
399 #[cfg(not(feature = "std"))]
400 fn ensure_writable(&mut self) -> &mut [u8] {
401 &mut self.0
402 }
403
404 pub fn store(&mut self, addr: WordAddr, word: u32) {
412 let writable_ref = self.ensure_writable();
413
414 let byte_addr = addr.page_subaddr().baddr().0 as usize;
415 writable_ref[byte_addr..byte_addr + WORD_SIZE].clone_from_slice(&word.to_le_bytes());
417 }
418
419 pub fn data(&self) -> &Vec<u8> {
421 &self.0
422 }
423}
424
425impl Serialize for Page {
426 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
427 where
428 S: serde::ser::Serializer,
429 {
430 self.0.serialize(serializer)
431 }
432}
433
434impl<'de> Deserialize<'de> for Page {
435 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
436 where
437 D: serde::de::Deserializer<'de>,
438 {
439 use serde::de::Error as _;
440
441 let vec = <Vec<u8> as Deserialize>::deserialize(deserializer)?;
442 if vec.len() != PAGE_BYTES {
443 return Err(D::Error::custom(format!(
444 "serialized page has wrong length {} != {}",
445 vec.len(),
446 PAGE_BYTES
447 )));
448 }
449 Ok(Self::from_vec(vec))
450 }
451}
452
453pub(crate) struct DigestPair {
454 pub(crate) lhs: Digest,
455 pub(crate) rhs: Digest,
456}
457
458impl DigestPair {
459 pub fn digest(&self) -> Digest {
460 let mut cells = [BabyBearElem::ZERO; CELLS];
461 for i in 0..DIGEST_WORDS {
462 cells[i] = BabyBearElem::new(self.rhs.as_words()[i]);
463 cells[DIGEST_WORDS + i] = BabyBearElem::new(self.lhs.as_words()[i]);
464 }
465 poseidon2_mix(&mut cells);
466 cells_to_digest(&cells)
467 }
468}
469
470fn cells_to_digest(cells: &[BabyBearElem; CELLS]) -> Digest {
471 Digest::new([
472 cells[0].as_u32(),
473 cells[1].as_u32(),
474 cells[2].as_u32(),
475 cells[3].as_u32(),
476 cells[4].as_u32(),
477 cells[5].as_u32(),
478 cells[6].as_u32(),
479 cells[7].as_u32(),
480 ])
481}
482
483#[cfg(test)]
484mod tests {
485 use std::collections::BTreeMap;
486
487 use risc0_zkp::digest;
488 use test_log::test;
489
490 use super::{MemoryImage, Program, ZERO_CACHE};
491
492 #[test]
493 fn poseidon2_zeros() {
494 let expected = [
495 digest!("f85c5a32ccc45c22f9686b08d710d4597d7ce256cdcd63146426270d9432c644"),
496 digest!("2ce7714c40af126c2e86f320b10de417eddd8f51d2b9133d3105c3541a154812"),
497 digest!("889c443e0c55734c0212fe6c400f00423c421f2070b1340351e77826e4918274"),
498 digest!("53ea92273a7dfb7622de685c49f4ce1bd69db1696cd6846e9f5de56c89098b01"),
499 digest!("82db13229831cb2ad63df0476dc1f217c702503d46770c283b6ecc1520fff074"),
500 digest!("45cba5321f90c34b780d5d1790f23612fb834b3d21dc1e53594826470719ba34"),
501 digest!("132689262568ae5ac27a4b65018aef0b2e4345578a16453acd874973a61c6350"),
502 digest!("9fc9626e87aa3614eb38b44d9d832712fb2ea32427c6fd49281ca225f1fefd0d"),
503 digest!("70947164fe9a4353fa33fb024f09ea0df24be40d88b6025278a3472ac49e6715"),
504 digest!("4b707f15d9941c0168d630618cdcc05ccae5d84ab9674a6666123a0039915173"),
505 digest!("97fb1325724ddb74b1446b5bfa13f02c2ecb1b2b2a2f5b1334a04c5c76335d12"),
506 digest!("adba743a459eb5357487a1238a0c4c238b8313458283900447e9b8540adfb042"),
507 digest!("a16e68725fe981434dcca548e972214b2dd85e017c3a4e03909a0f4c31a08741"),
508 digest!("fb94f356397279703f12c24da7aa371e192294347af15d46f10ab512708cdb68"),
509 digest!("30a2fe1aa5c2ae0e10b91074e34b06742be91e450a9bc10f28ab082263c48750"),
510 digest!("2347f636d9a0ea45bbe8bf519f39d3127f72b625e2e5495f26a6dd583eb2965d"),
511 digest!("e43d140e71e366521152d932e846c73535674921576711023deaee06de3b091e"),
512 digest!("35500a740d3a8b4e5a0ca06a8362f3444456e3206826102dd9e9bc3e5a1a5a18"),
513 digest!("7c650c1a2000ef1a9baf4f56c2d66e76a3a0b4510175b171268d156a25d8dd45"),
514 digest!("d73a1e0997a00543afd8de5261f316704215ce384e3ea13df3f87e000f04fb5f"),
515 digest!("5b77f60275cb272fa0a3d267bdf1fc15021dbe7185ed6a3c94e45d70bbd70148"),
516 digest!("e053c93b359c8905c5d8523139988b0ed4ef3426864a80498dfcb91d9b813364"),
517 digest!("242ce034cc4e9326f8b7071124454b2be1a1cd5d21b6483c7ff81d4ba5ac9566"),
518 ];
519 assert_eq!(ZERO_CACHE.digests, expected);
520 }
521
522 #[test]
523 fn image_circuit_match() {
524 let entry = 0x10000;
525 let program = Program {
526 entry,
527 image: BTreeMap::from([(entry, 0x1234b337)]),
528 };
529 let mut image = MemoryImage::new_kernel(program);
530 assert_eq!(
531 *image.get_digest(0x0040_0100).unwrap(),
532 digest!("242ce034cc4e9326f8b7071124454b2be1a1cd5d21b6483c7ff81d4ba5ac9566")
533 );
534 assert_eq!(
535 image.image_id(),
536 digest!("9d41290fa400705127c0240cb646586cc6ea8a23d560aa57cfa86c1369d9d53f")
537 );
538 }
539}