risc0_circuit_rv32im/prove/engine/
loader.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::BTreeMap,
17    fmt::Debug,
18    iter::Peekable,
19    ops::{Index, IndexMut},
20};
21
22use risc0_core::{
23    field::{baby_bear::BabyBearElem, Elem},
24    scope,
25};
26use risc0_zkp::{
27    adapter::TapsProvider,
28    core::{
29        digest::{Digest, DIGEST_WORDS},
30        hash::sha::SHA256_INIT,
31    },
32    hal::Hal,
33    prove::poly_group::PolyGroup,
34    MAX_CYCLES_PO2, MIN_CYCLES_PO2, ZK_CYCLES,
35};
36use risc0_zkvm_platform::{memory, WORD_SIZE};
37
38use crate::CIRCUIT;
39
40pub const SHA_K_OFFSET: usize = memory::PRE_LOAD.start();
41pub const SHA_K_SIZE: usize = 64;
42pub const SHA_INIT_OFFSET: usize = SHA_K_OFFSET + SHA_K_SIZE * WORD_SIZE;
43pub const ZEROS_OFFSET: usize = SHA_INIT_OFFSET + DIGEST_WORDS * WORD_SIZE;
44
45// TODO: generate from zirgen
46pub const SETUP_STEP_REGS: usize = 84;
47pub const SETUP_CYCLES: usize = setup_count(SETUP_STEP_REGS);
48pub const RAM_LOAD_CYCLES: usize = 27;
49
50// The number of cycles needed before the body phase.
51// BytesInit: 1
52// BytesSetup: 1561
53// RamInit: 1
54// RamLoad: 27
55// Reset(0): 2
56pub const INIT_CYCLES: usize = 1 + SETUP_CYCLES + 1 + RAM_LOAD_CYCLES + 2;
57
58// The number of cycles needed after the body phase.
59// Reset(1): 2
60// Reset(2): 2
61// RamFini: 1
62// BytesFini: 1
63pub const FINI_CYCLES: usize = 2 + 2 + 1 + 1;
64
65pub static SHA_K: [u32; SHA_K_SIZE] = [
66    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
67    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
68    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
69    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
70    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
71    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
72    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
73    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
74];
75
76const fn div_ceil(a: usize, b: usize) -> usize {
77    (a + b - 1) / b
78}
79
80const fn setup_count(regs: usize) -> usize {
81    let pairs = regs / 4;
82    div_ceil(32 * 1024, pairs)
83}
84
85/// These are the registers of the control group.
86#[derive(Copy, Clone)]
87enum CtrlReg {
88    _Cycle, // This register is handled by the loader
89    BytesInit,
90    BytesSetup,
91    RamInit,
92    RamLoad,
93    Reset,
94    Body,
95    RamFini,
96    BytesFini,
97    Info,
98    Data1Lo,
99    Data1Hi,
100    Data2Lo,
101    Data2Hi,
102    Data3Lo,
103    Data3Hi,
104    NumRegs,
105}
106
107#[derive(Clone)]
108pub struct CtrlCycle([BabyBearElem; CtrlReg::NumRegs as usize]);
109
110fn split_word16(value: u32) -> (BabyBearElem, BabyBearElem) {
111    (
112        BabyBearElem::new(value & 0xffff),
113        BabyBearElem::new(value >> 16),
114    )
115}
116
117impl CtrlCycle {
118    fn bytes_init() -> Self {
119        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
120        row[CtrlReg::BytesInit] = BabyBearElem::ONE;
121        row
122    }
123
124    fn bytes_setup(info: BabyBearElem) -> Self {
125        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
126        row[CtrlReg::BytesSetup] = BabyBearElem::ONE;
127        row[CtrlReg::Info] = info;
128        row
129    }
130
131    fn ram_init() -> Self {
132        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
133        row[CtrlReg::RamInit] = BabyBearElem::ONE;
134        row
135    }
136
137    fn ram_load(triple: TripleWord) -> Self {
138        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
139        row[CtrlReg::RamLoad] = BabyBearElem::ONE;
140        row[CtrlReg::Info] = BabyBearElem::new(triple.addr);
141        (row[CtrlReg::Data1Lo], row[CtrlReg::Data1Hi]) = split_word16(triple.data[0]);
142        (row[CtrlReg::Data2Lo], row[CtrlReg::Data2Hi]) = split_word16(triple.data[1]);
143        (row[CtrlReg::Data3Lo], row[CtrlReg::Data3Hi]) = split_word16(triple.data[2]);
144        row
145    }
146
147    fn reset(is_first: BabyBearElem, phase: CtrlReg) -> Self {
148        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
149        row[CtrlReg::Reset] = BabyBearElem::ONE;
150        row[CtrlReg::Info] = is_first;
151        row[phase] = BabyBearElem::ONE;
152        row
153    }
154
155    fn ram_fini() -> Self {
156        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
157        row[CtrlReg::RamFini] = BabyBearElem::ONE;
158        row
159    }
160
161    fn bytes_fini() -> Self {
162        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
163        row[CtrlReg::BytesFini] = BabyBearElem::ONE;
164        row
165    }
166
167    fn body() -> Self {
168        let mut row = Self([BabyBearElem::ZERO; CtrlReg::NumRegs as usize]);
169        row[CtrlReg::Body] = BabyBearElem::ONE;
170        row
171    }
172}
173
174impl Index<CtrlReg> for CtrlCycle {
175    type Output = BabyBearElem;
176
177    fn index(&self, index: CtrlReg) -> &Self::Output {
178        &self.0[index as usize]
179    }
180}
181
182impl IndexMut<CtrlReg> for CtrlCycle {
183    fn index_mut(&mut self, index: CtrlReg) -> &mut Self::Output {
184        &mut self.0[index as usize]
185    }
186}
187
188#[derive(Clone, Copy, PartialEq)]
189struct TripleWord {
190    addr: u32,
191    data: [u32; 3],
192}
193
194impl Debug for TripleWord {
195    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
196        f.write_fmt(format_args!(
197            "0x{:08X}[0x{:08X}, 0x{:08X}, 0x{:08X}]",
198            self.addr * 4,
199            self.data[0],
200            self.data[1],
201            self.data[2]
202        ))
203    }
204}
205
206struct TripleWordIter<'a> {
207    it: Peekable<std::collections::btree_map::Iter<'a, u32, u32>>,
208}
209
210impl<'a> TripleWordIter<'a> {
211    fn new(image: &'a BTreeMap<u32, u32>) -> Self {
212        Self {
213            it: image.iter().peekable(),
214        }
215    }
216}
217
218impl<'a> Iterator for TripleWordIter<'a> {
219    type Item = TripleWord;
220
221    fn next(&mut self) -> Option<Self::Item> {
222        let mut cur = TripleWord {
223            addr: 0,
224            data: [0, 0, 0],
225        };
226        for i in 0..3 {
227            match self.it.peek() {
228                Some((addr, data)) => {
229                    let addr = **addr / 4;
230                    if i == 0 {
231                        cur.addr = addr;
232                    } else if addr != cur.addr + i as u32 {
233                        continue;
234                    }
235                    cur.data[i] = **data;
236                    self.it.next();
237                }
238                None => {
239                    if i == 0 {
240                        return None;
241                    }
242                }
243            }
244        }
245        Some(cur)
246    }
247}
248
249pub struct Loader {
250    max_cycles: usize,
251    pub ctrl: Vec<BabyBearElem>,
252    cycle: usize,
253    ram_load_cycles: Vec<CtrlCycle>,
254}
255
256pub fn ram_load_cycles() -> Vec<CtrlCycle> {
257    let mut image: BTreeMap<u32, u32> = BTreeMap::new();
258
259    // Setup 'k' for SHA
260    for (i, word) in SHA_K.iter().enumerate() {
261        image.insert((SHA_K_OFFSET + i * WORD_SIZE) as u32, *word);
262    }
263
264    // Setup SHA-256 Init
265    for (i, word) in SHA256_INIT.as_words().iter().enumerate() {
266        image.insert((SHA_INIT_OFFSET + i * WORD_SIZE) as u32, *word);
267    }
268
269    // Setup ZEROS
270    for i in 0..DIGEST_WORDS {
271        image.insert((ZEROS_OFFSET + i * WORD_SIZE) as u32, 0);
272    }
273
274    TripleWordIter::new(&image)
275        .map(CtrlCycle::ram_load)
276        .collect()
277}
278
279impl Loader {
280    pub fn new(max_cycles: usize, ctrl_size: usize) -> Self {
281        let ram_load_cycles = ram_load_cycles();
282        assert_eq!(ram_load_cycles.len(), RAM_LOAD_CYCLES);
283        Self {
284            max_cycles,
285            ctrl: vec![BabyBearElem::ZERO; max_cycles * ctrl_size],
286            cycle: 0,
287            ram_load_cycles,
288        }
289    }
290
291    pub fn load(&mut self) -> usize {
292        scope!("load");
293
294        self.pre_steps();
295        self.body();
296        self.post_steps();
297        self.cycle
298    }
299
300    fn pre_steps(&mut self) {
301        self.bytes_init();
302        self.bytes_setup();
303        self.ram_init();
304        self.ram_load();
305        self.reset(0);
306    }
307
308    fn body(&mut self) {
309        let body_cycles = self.max_cycles - self.cycle - FINI_CYCLES - ZK_CYCLES;
310        tracing::debug!("[{}] BODY: {body_cycles}", self.cycle);
311        for _ in 0..body_cycles {
312            self.add_cycle(CtrlCycle::body());
313        }
314    }
315
316    fn post_steps(&mut self) {
317        self.reset(1);
318        self.reset(2);
319        self.fini();
320    }
321
322    fn bytes_init(&mut self) {
323        tracing::debug!("[{}] BYTES_INIT", self.cycle);
324        self.add_cycle(CtrlCycle::bytes_init());
325    }
326
327    fn bytes_setup(&mut self) {
328        tracing::debug!("[{}] BYTES_SETUP", self.cycle);
329        for _ in 0..SETUP_CYCLES - 1 {
330            self.add_cycle(CtrlCycle::bytes_setup(BabyBearElem::ZERO));
331        }
332        self.add_cycle(CtrlCycle::bytes_setup(BabyBearElem::ONE));
333    }
334
335    fn ram_init(&mut self) {
336        tracing::debug!("[{}] RAM_INIT", self.cycle);
337        self.add_cycle(CtrlCycle::ram_init());
338    }
339
340    fn ram_load(&mut self) {
341        for cycle in self.ram_load_cycles.clone() {
342            self.add_cycle(cycle);
343        }
344    }
345
346    fn reset(&mut self, phase: u32) {
347        tracing::debug!("[{}] RESET({phase})", self.cycle);
348        let phase = match phase {
349            0 => CtrlReg::Data1Lo,
350            1 => CtrlReg::Data1Hi,
351            2 => CtrlReg::Data2Lo,
352            _ => unimplemented!("Invalid phase"),
353        };
354        self.add_cycle(CtrlCycle::reset(BabyBearElem::ONE, phase));
355        self.add_cycle(CtrlCycle::reset(BabyBearElem::ZERO, phase));
356    }
357
358    fn fini(&mut self) {
359        tracing::debug!("[{}] RAM_FINI", self.cycle);
360        self.add_cycle(CtrlCycle::ram_fini());
361        tracing::debug!("[{}] BYTES_FINI", self.cycle);
362        self.add_cycle(CtrlCycle::bytes_fini());
363    }
364
365    fn add_cycle(&mut self, row: CtrlCycle) {
366        self.ctrl[self.cycle] = BabyBearElem::new(self.cycle as u32);
367        for i in 1..row.0.len() {
368            self.ctrl[self.max_cycles * i + self.cycle] = row.0[i];
369        }
370        self.cycle += 1;
371    }
372
373    // Compute the `ControlId` associated with the given HAL, along with a textual description.
374    pub fn compute_control_id_table<H: Hal<Elem = BabyBearElem>>(hal: &H) -> Vec<(String, Digest)> {
375        // Make the digest for each level
376        let mut table = Vec::new();
377        for po2 in MIN_CYCLES_PO2..=MAX_CYCLES_PO2 {
378            table.push((
379                format!("rv32im po2={po2}"),
380                Self::compute_control_id(hal, po2),
381            ));
382        }
383        table
384    }
385
386    pub fn compute_control_id<H: Hal<Elem = BabyBearElem>>(hal: &H, po2: usize) -> Digest {
387        tracing::debug!("po2: {po2}");
388        let cycles = 1 << po2;
389        let ctrl_size = CIRCUIT.ctrl_size();
390        let mut loader = Loader::new(cycles, ctrl_size);
391        // Make a vector & set it up with the elf data
392        loader.load();
393        // Copy into accel buffer
394        let coeffs = hal.copy_from_elem("coeffs", &loader.ctrl);
395        // Do interpolate & shift
396        hal.batch_interpolate_ntt(&coeffs, ctrl_size);
397        hal.zk_shift(&coeffs, ctrl_size);
398        // Make the poly-group & extract the root
399        let group = PolyGroup::new(hal, coeffs, ctrl_size, cycles, "ctrl");
400        *group.merkle.root()
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use std::collections::BTreeMap;
407
408    use test_log::test;
409
410    use super::{TripleWord, TripleWordIter};
411
412    fn triple_test(input: &[(u32, u32)], expected: &[TripleWord]) {
413        let mut map = BTreeMap::new();
414        for (addr, data) in input {
415            map.insert(*addr, *data);
416        }
417        let result: Vec<TripleWord> = TripleWordIter::new(&map).collect();
418        assert_eq!(result.as_slice(), expected);
419    }
420
421    #[test]
422    fn triple_word_iter() {
423        triple_test(&[], &[]);
424        triple_test(
425            &[(0, 1)],
426            &[TripleWord {
427                addr: 0,
428                data: [1, 0, 0],
429            }],
430        );
431        triple_test(
432            &[(0, 1), (4, 2)],
433            &[TripleWord {
434                addr: 0,
435                data: [1, 2, 0],
436            }],
437        );
438        triple_test(
439            &[(0, 1), (4, 2), (8, 3)],
440            &[TripleWord {
441                addr: 0,
442                data: [1, 2, 3],
443            }],
444        );
445        triple_test(
446            &[(0, 1), (8, 3)],
447            &[TripleWord {
448                addr: 0,
449                data: [1, 0, 3],
450            }],
451        );
452        triple_test(
453            &[(0, 1), (4, 2), (8, 3), (12, 4)],
454            &[
455                TripleWord {
456                    addr: 0,
457                    data: [1, 2, 3],
458                },
459                TripleWord {
460                    addr: 3,
461                    data: [4, 0, 0],
462                },
463            ],
464        );
465    }
466}