1use std::{
16 collections::BTreeMap,
17 fmt::Debug,
18 iter::Peekable,
19 ops::{Index, IndexMut},
20};
21
22use risc0_core::{
23 field::{baby_bear::BabyBearElem, Elem},
24 scope,
25};
26use risc0_zkp::{
27 adapter::TapsProvider,
28 core::{
29 digest::{Digest, DIGEST_WORDS},
30 hash::sha::SHA256_INIT,
31 },
32 hal::Hal,
33 prove::poly_group::PolyGroup,
34 MAX_CYCLES_PO2, MIN_CYCLES_PO2, ZK_CYCLES,
35};
36use risc0_zkvm_platform::{memory, WORD_SIZE};
37
38use crate::CIRCUIT;
39
40pub const SHA_K_OFFSET: usize = memory::PRE_LOAD.start();
41pub const SHA_K_SIZE: usize = 64;
42pub const SHA_INIT_OFFSET: usize = SHA_K_OFFSET + SHA_K_SIZE * WORD_SIZE;
43pub const ZEROS_OFFSET: usize = SHA_INIT_OFFSET + DIGEST_WORDS * WORD_SIZE;
44
45pub const SETUP_STEP_REGS: usize = 84;
47pub const SETUP_CYCLES: usize = setup_count(SETUP_STEP_REGS);
48pub const RAM_LOAD_CYCLES: usize = 27;
49
50pub const INIT_CYCLES: usize = 1 + SETUP_CYCLES + 1 + RAM_LOAD_CYCLES + 2;
57
58pub const FINI_CYCLES: usize = 2 + 2 + 1 + 1;
64
65pub static SHA_K: [u32; SHA_K_SIZE] = [
66 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
67 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
68 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
69 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
70 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
71 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
72 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
73 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
74];
75
76const fn div_ceil(a: usize, b: usize) -> usize {
77 (a + b - 1) / b
78}
79
80const fn setup_count(regs: usize) -> usize {
81 let pairs = regs / 4;
82 div_ceil(32 * 1024, pairs)
83}
84
85#[derive(Copy, Clone)]
87enum CtrlReg {
88 _Cycle, BytesInit,
90 BytesSetup,
91 RamInit,
92 RamLoad,
93 Reset,
94 Body,
95 RamFini,
96 BytesFini,
97 Info,
98 Data1Lo,
99 Data1Hi,
100 Data2Lo,
101 Data2Hi,
102 Data3Lo,
103 Data3Hi,
104 NumRegs,
105}
106
107#[derive(Clone)]
108pub struct CtrlCycle([BabyBearElem; CtrlReg::NumRegs as usize]);
109
110fn split_word16(value: u32) -> (BabyBearElem, BabyBearElem) {
111 (
112 BabyBearElem::new(value & 0xffff),
113 BabyBearElem::new(value >> 16),
114 )
115}
116
117impl CtrlCycle {
118 fn bytes_init() -> Self {
119 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
120 row[CtrlReg::BytesInit] = BabyBearElem::ONE;
121 row
122 }
123
124 fn bytes_setup(info: BabyBearElem) -> Self {
125 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
126 row[CtrlReg::BytesSetup] = BabyBearElem::ONE;
127 row[CtrlReg::Info] = info;
128 row
129 }
130
131 fn ram_init() -> Self {
132 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
133 row[CtrlReg::RamInit] = BabyBearElem::ONE;
134 row
135 }
136
137 fn ram_load(triple: TripleWord) -> Self {
138 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
139 row[CtrlReg::RamLoad] = BabyBearElem::ONE;
140 row[CtrlReg::Info] = BabyBearElem::new(triple.addr);
141 (row[CtrlReg::Data1Lo], row[CtrlReg::Data1Hi]) = split_word16(triple.data[0]);
142 (row[CtrlReg::Data2Lo], row[CtrlReg::Data2Hi]) = split_word16(triple.data[1]);
143 (row[CtrlReg::Data3Lo], row[CtrlReg::Data3Hi]) = split_word16(triple.data[2]);
144 row
145 }
146
147 fn reset(is_first: BabyBearElem, phase: CtrlReg) -> Self {
148 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
149 row[CtrlReg::Reset] = BabyBearElem::ONE;
150 row[CtrlReg::Info] = is_first;
151 row[phase] = BabyBearElem::ONE;
152 row
153 }
154
155 fn ram_fini() -> Self {
156 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
157 row[CtrlReg::RamFini] = BabyBearElem::ONE;
158 row
159 }
160
161 fn bytes_fini() -> Self {
162 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
163 row[CtrlReg::BytesFini] = BabyBearElem::ONE;
164 row
165 }
166
167 fn body() -> Self {
168 let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
169 row[CtrlReg::Body] = BabyBearElem::ONE;
170 row
171 }
172}
173
174impl Index<CtrlReg> for CtrlCycle {
175 type Output = BabyBearElem;
176
177 fn index(&self, index: CtrlReg) -> &Self::Output {
178 &self.0[index as usize]
179 }
180}
181
182impl IndexMut<CtrlReg> for CtrlCycle {
183 fn index_mut(&mut self, index: CtrlReg) -> &mut Self::Output {
184 &mut self.0[index as usize]
185 }
186}
187
188#[derive(Clone, Copy, PartialEq)]
189struct TripleWord {
190 addr: u32,
191 data: [u32; 3],
192}
193
194impl Debug for TripleWord {
195 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
196 f.write_fmt(format_args!(
197 "0x{:08X}[0x{:08X}, 0x{:08X}, 0x{:08X}]",
198 self.addr * 4,
199 self.data[0],
200 self.data[1],
201 self.data[2]
202 ))
203 }
204}
205
206struct TripleWordIter<'a> {
207 it: Peekable<std::collections::btree_map::Iter<'a, u32, u32>>,
208}
209
210impl<'a> TripleWordIter<'a> {
211 fn new(image: &'a BTreeMap<u32, u32>) -> Self {
212 Self {
213 it: image.iter().peekable(),
214 }
215 }
216}
217
218impl<'a> Iterator for TripleWordIter<'a> {
219 type Item = TripleWord;
220
221 fn next(&mut self) -> Option<Self::Item> {
222 let mut cur = TripleWord {
223 addr: 0,
224 data: [0, 0, 0],
225 };
226 for i in 0..3 {
227 match self.it.peek() {
228 Some((addr, data)) => {
229 let addr = **addr / 4;
230 if i == 0 {
231 cur.addr = addr;
232 } else if addr != cur.addr + i as u32 {
233 continue;
234 }
235 cur.data[i] = **data;
236 self.it.next();
237 }
238 None => {
239 if i == 0 {
240 return None;
241 }
242 }
243 }
244 }
245 Some(cur)
246 }
247}
248
249pub struct Loader {
250 max_cycles: usize,
251 pub ctrl: Vec<BabyBearElem>,
252 cycle: usize,
253 ram_load_cycles: Vec<CtrlCycle>,
254}
255
256pub fn ram_load_cycles() -> Vec<CtrlCycle> {
257 let mut image: BTreeMap<u32, u32> = BTreeMap::new();
258
259 for (i, word) in SHA_K.iter().enumerate() {
261 image.insert((SHA_K_OFFSET + i * WORD_SIZE) as u32, *word);
262 }
263
264 for (i, word) in SHA256_INIT.as_words().iter().enumerate() {
266 image.insert((SHA_INIT_OFFSET + i * WORD_SIZE) as u32, *word);
267 }
268
269 for i in 0..DIGEST_WORDS {
271 image.insert((ZEROS_OFFSET + i * WORD_SIZE) as u32, 0);
272 }
273
274 TripleWordIter::new(&image)
275 .map(CtrlCycle::ram_load)
276 .collect()
277}
278
279impl Loader {
280 pub fn new(max_cycles: usize, ctrl_size: usize) -> Self {
281 let ram_load_cycles = ram_load_cycles();
282 assert_eq!(ram_load_cycles.len(), RAM_LOAD_CYCLES);
283 Self {
284 max_cycles,
285 ctrl: vec![BabyBearElem::ZERO; max_cycles * ctrl_size],
286 cycle: 0,
287 ram_load_cycles,
288 }
289 }
290
291 pub fn load(&mut self) -> usize {
292 scope!("load");
293
294 self.pre_steps();
295 self.body();
296 self.post_steps();
297 self.cycle
298 }
299
300 fn pre_steps(&mut self) {
301 self.bytes_init();
302 self.bytes_setup();
303 self.ram_init();
304 self.ram_load();
305 self.reset(0);
306 }
307
308 fn body(&mut self) {
309 let body_cycles = self.max_cycles - self.cycle - FINI_CYCLES - ZK_CYCLES;
310 tracing::debug!("[{}] BODY: {body_cycles}", self.cycle);
311 for _ in 0..body_cycles {
312 self.add_cycle(CtrlCycle::body());
313 }
314 }
315
316 fn post_steps(&mut self) {
317 self.reset(1);
318 self.reset(2);
319 self.fini();
320 }
321
322 fn bytes_init(&mut self) {
323 tracing::debug!("[{}] BYTES_INIT", self.cycle);
324 self.add_cycle(CtrlCycle::bytes_init());
325 }
326
327 fn bytes_setup(&mut self) {
328 tracing::debug!("[{}] BYTES_SETUP", self.cycle);
329 for _ in 0..SETUP_CYCLES - 1 {
330 self.add_cycle(CtrlCycle::bytes_setup(BabyBearElem::ZERO));
331 }
332 self.add_cycle(CtrlCycle::bytes_setup(BabyBearElem::ONE));
333 }
334
335 fn ram_init(&mut self) {
336 tracing::debug!("[{}] RAM_INIT", self.cycle);
337 self.add_cycle(CtrlCycle::ram_init());
338 }
339
340 fn ram_load(&mut self) {
341 for cycle in self.ram_load_cycles.clone() {
342 self.add_cycle(cycle);
343 }
344 }
345
346 fn reset(&mut self, phase: u32) {
347 tracing::debug!("[{}] RESET({phase})", self.cycle);
348 let phase = match phase {
349 0 => CtrlReg::Data1Lo,
350 1 => CtrlReg::Data1Hi,
351 2 => CtrlReg::Data2Lo,
352 _ => unimplemented!("Invalid phase"),
353 };
354 self.add_cycle(CtrlCycle::reset(BabyBearElem::ONE, phase));
355 self.add_cycle(CtrlCycle::reset(BabyBearElem::ZERO, phase));
356 }
357
358 fn fini(&mut self) {
359 tracing::debug!("[{}] RAM_FINI", self.cycle);
360 self.add_cycle(CtrlCycle::ram_fini());
361 tracing::debug!("[{}] BYTES_FINI", self.cycle);
362 self.add_cycle(CtrlCycle::bytes_fini());
363 }
364
365 fn add_cycle(&mut self, row: CtrlCycle) {
366 self.ctrl[self.cycle] = BabyBearElem::new(self.cycle as u32);
367 for i in 1..row.0.len() {
368 self.ctrl[self.max_cycles * i + self.cycle] = row.0[i];
369 }
370 self.cycle += 1;
371 }
372
373 pub fn compute_control_id_table<H: Hal<Elem = BabyBearElem>>(hal: &H) -> Vec<(String, Digest)> {
375 let mut table = Vec::new();
377 for po2 in MIN_CYCLES_PO2..=MAX_CYCLES_PO2 {
378 table.push((
379 format!("rv32im po2={po2}"),
380 Self::compute_control_id(hal, po2),
381 ));
382 }
383 table
384 }
385
386 pub fn compute_control_id<H: Hal<Elem = BabyBearElem>>(hal: &H, po2: usize) -> Digest {
387 tracing::debug!("po2: {po2}");
388 let cycles = 1 << po2;
389 let ctrl_size = CIRCUIT.ctrl_size();
390 let mut loader = Loader::new(cycles, ctrl_size);
391 loader.load();
393 let coeffs = hal.copy_from_elem("coeffs", &loader.ctrl);
395 hal.batch_interpolate_ntt(&coeffs, ctrl_size);
397 hal.zk_shift(&coeffs, ctrl_size);
398 let group = PolyGroup::new(hal, coeffs, ctrl_size, cycles, "ctrl");
400 *group.merkle.root()
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use std::collections::BTreeMap;
407
408 use test_log::test;
409
410 use super::{TripleWord, TripleWordIter};
411
412 fn triple_test(input: &[(u32, u32)], expected: &[TripleWord]) {
413 let mut map = BTreeMap::new();
414 for (addr, data) in input {
415 map.insert(*addr, *data);
416 }
417 let result: Vec<TripleWord> = TripleWordIter::new(&map).collect();
418 assert_eq!(result.as_slice(), expected);
419 }
420
421 #[test]
422 fn triple_word_iter() {
423 triple_test(&[], &[]);
424 triple_test(
425 &[(0, 1)],
426 &[TripleWord {
427 addr: 0,
428 data: [1, 0, 0],
429 }],
430 );
431 triple_test(
432 &[(0, 1), (4, 2)],
433 &[TripleWord {
434 addr: 0,
435 data: [1, 2, 0],
436 }],
437 );
438 triple_test(
439 &[(0, 1), (4, 2), (8, 3)],
440 &[TripleWord {
441 addr: 0,
442 data: [1, 2, 3],
443 }],
444 );
445 triple_test(
446 &[(0, 1), (8, 3)],
447 &[TripleWord {
448 addr: 0,
449 data: [1, 0, 3],
450 }],
451 );
452 triple_test(
453 &[(0, 1), (4, 2), (8, 3), (12, 4)],
454 &[
455 TripleWord {
456 addr: 0,
457 data: [1, 2, 3],
458 },
459 TripleWord {
460 addr: 3,
461 data: [4, 0, 0],
462 },
463 ],
464 );
465 }
466}