risc0_zkp/hal/
mod.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Hardware Abstraction Layer (HAL) for accelerating the ZKP system.
16
17pub mod cpu;
18#[cfg(feature = "cuda")]
19pub mod cuda;
20pub mod dual;
21#[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))]
22pub mod metal;
23
24use std::{
25    fmt::Debug,
26    sync::{Mutex, OnceLock},
27};
28
29use risc0_core::{
30    field::{Elem, ExtElem, Field, RootsOfUnity},
31    scope,
32};
33
34use crate::{
35    core::{digest::Digest, hash::HashSuite, poly::poly_divide},
36    INV_RATE,
37};
38
39pub trait Buffer<T>: Clone {
40    fn name(&self) -> &'static str;
41
42    fn size(&self) -> usize;
43
44    fn slice(&self, offset: usize, size: usize) -> Self;
45
46    fn get_at(&self, idx: usize) -> T;
47
48    fn view<F: FnOnce(&[T])>(&self, f: F);
49
50    fn view_mut<F: FnOnce(&mut [T])>(&self, f: F);
51
52    fn to_vec(&self) -> Vec<T>;
53}
54
55pub trait Hal {
56    type Field: Field<Elem = Self::Elem, ExtElem = Self::ExtElem>;
57    type Elem: Elem + RootsOfUnity;
58    type ExtElem: ExtElem<SubElem = Self::Elem>;
59    type Buffer<T: Clone + Debug + PartialEq>: Buffer<T>;
60
61    const CHECK_SIZE: usize = INV_RATE * Self::ExtElem::EXT_SIZE;
62
63    fn has_unified_memory(&self) -> bool;
64
65    fn get_hash_suite(&self) -> &HashSuite<Self::Field>;
66
67    fn alloc_digest(&self, name: &'static str, size: usize) -> Self::Buffer<Digest>;
68    fn alloc_elem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::Elem>;
69    fn alloc_extelem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem>;
70    fn alloc_u32(&self, name: &'static str, size: usize) -> Self::Buffer<u32>;
71
72    fn alloc_elem_init(
73        &self,
74        name: &'static str,
75        size: usize,
76        value: Self::Elem,
77    ) -> Self::Buffer<Self::Elem> {
78        let buffer = self.alloc_elem(name, size);
79        buffer.view_mut(|slice| {
80            slice.fill(value);
81        });
82        buffer
83    }
84
85    fn alloc_extelem_zeroed(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem> {
86        let buffer = self.alloc_extelem(name, size);
87        buffer.view_mut(|slice| {
88            slice.fill(Self::ExtElem::ZERO);
89        });
90        buffer
91    }
92
93    fn copy_from_digest(&self, name: &'static str, slice: &[Digest]) -> Self::Buffer<Digest>;
94    fn copy_from_elem(&self, name: &'static str, slice: &[Self::Elem]) -> Self::Buffer<Self::Elem>;
95    fn copy_from_extelem(
96        &self,
97        name: &'static str,
98        slice: &[Self::ExtElem],
99    ) -> Self::Buffer<Self::ExtElem>;
100    fn copy_from_u32(&self, name: &'static str, slice: &[u32]) -> Self::Buffer<u32>;
101
102    fn batch_expand_into_evaluate_ntt(
103        &self,
104        output: &Self::Buffer<Self::Elem>,
105        input: &Self::Buffer<Self::Elem>,
106        count: usize,
107        expand_bits: usize,
108    );
109
110    fn batch_interpolate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize);
111
112    fn batch_bit_reverse(&self, io: &Self::Buffer<Self::Elem>, count: usize);
113
114    fn batch_evaluate_any(
115        &self,
116        coeffs: &Self::Buffer<Self::Elem>,
117        poly_count: usize,
118        which: &Self::Buffer<u32>,
119        xs: &Self::Buffer<Self::ExtElem>,
120        out: &Self::Buffer<Self::ExtElem>,
121    );
122
123    fn zk_shift(&self, io: &Self::Buffer<Self::Elem>, count: usize);
124
125    #[allow(clippy::too_many_arguments)]
126    fn mix_poly_coeffs(
127        &self,
128        out: &Self::Buffer<Self::ExtElem>,
129        mix_start: &Self::ExtElem,
130        mix: &Self::ExtElem,
131        input: &Self::Buffer<Self::Elem>,
132        combos: &Self::Buffer<u32>,
133        input_size: usize,
134        count: usize,
135    );
136
137    fn eltwise_add_elem(
138        &self,
139        output: &Self::Buffer<Self::Elem>,
140        input1: &Self::Buffer<Self::Elem>,
141        input2: &Self::Buffer<Self::Elem>,
142    );
143
144    fn eltwise_sum_extelem(
145        &self,
146        output: &Self::Buffer<Self::Elem>,
147        input: &Self::Buffer<Self::ExtElem>,
148    );
149
150    fn eltwise_copy_elem(
151        &self,
152        output: &Self::Buffer<Self::Elem>,
153        input: &Self::Buffer<Self::Elem>,
154    );
155
156    #[allow(clippy::too_many_arguments)]
157    fn eltwise_copy_elem_slice(
158        &self,
159        into: &Self::Buffer<Self::Elem>,
160        from: &[Self::Elem],
161        from_rows: usize,
162        from_cols: usize,
163        from_offset: usize,
164        from_stride: usize,
165        into_offset: usize,
166        into_stride: usize,
167    );
168
169    fn eltwise_zeroize_elem(&self, elems: &Self::Buffer<Self::Elem>);
170
171    fn fri_fold(
172        &self,
173        output: &Self::Buffer<Self::Elem>,
174        input: &Self::Buffer<Self::Elem>,
175        mix: &Self::ExtElem,
176    );
177
178    fn hash_rows(&self, output: &Self::Buffer<Digest>, matrix: &Self::Buffer<Self::Elem>);
179
180    fn hash_fold(&self, io: &Self::Buffer<Digest>, input_size: usize, output_size: usize);
181
182    fn gather_sample(
183        &self,
184        dst: &Self::Buffer<Self::Elem>,
185        src: &Self::Buffer<Self::Elem>,
186        idx: usize,
187        size: usize,
188        stride: usize,
189    );
190
191    fn scatter(
192        &self,
193        into: &Self::Buffer<Self::Elem>,
194        index: &[u32],
195        offsets: &[u32],
196        values: &[Self::Elem],
197    );
198
199    fn prefix_products(&self, io: &Self::Buffer<Self::ExtElem>);
200
201    #[allow(clippy::too_many_arguments)]
202    fn combos_prepare(
203        &self,
204        combos: &Self::Buffer<Self::ExtElem>,
205        coeff_u: &[Self::ExtElem],
206        combo_count: usize,
207        cycles: usize,
208        reg_sizes: &[u32],
209        reg_combo_ids: &[u32],
210        mix: &Self::ExtElem,
211    ) {
212        combos.view_mut(|combos| {
213            scope!("combos_prepare", {
214                let mut cur_pos = 0;
215                let mut cur = Self::ExtElem::ONE;
216                // Subtract the U coeffs from the combos
217                for (reg_size, reg_combo_id) in reg_sizes.iter().zip(reg_combo_ids) {
218                    let reg_size = *reg_size as usize;
219                    let reg_combo_id = *reg_combo_id as usize;
220                    for i in 0..reg_size {
221                        combos[cycles * reg_combo_id + i] -= cur * coeff_u[cur_pos + i];
222                    }
223                    cur *= *mix;
224                    cur_pos += reg_size;
225                }
226                // Subtract the final 'check' coefficients
227                for _ in 0..Self::CHECK_SIZE {
228                    combos[cycles * combo_count] -= cur * coeff_u[cur_pos];
229                    cur_pos += 1;
230                    cur *= *mix;
231                }
232            });
233        });
234    }
235
236    fn combos_divide(
237        &self,
238        combos: &Self::Buffer<Self::ExtElem>,
239        chunks: Vec<(usize, Vec<Self::ExtElem>)>,
240        cycles: usize,
241    ) {
242        use rayon::prelude::*;
243
244        scope!("combos_divide");
245
246        combos.view_mut(|combos| {
247            combos
248                .par_chunks_exact_mut(cycles)
249                .zip(chunks)
250                .for_each(|(combo_slice, (i, pows))| {
251                    for pow in pows {
252                        let remainder = poly_divide(combo_slice, pow);
253                        assert_eq!(remainder, Self::ExtElem::ZERO, "i: {i}");
254                    }
255                });
256        });
257    }
258}
259
260#[derive(Clone, Default)]
261pub struct AccumPreflight {
262    pub is_par_safe: Vec<u8>,
263}
264
265pub trait CircuitHal<H: Hal> {
266    #[allow(clippy::too_many_arguments)]
267    fn accumulate(
268        &self,
269        preflight: &AccumPreflight,
270        ctrl: &H::Buffer<H::Elem>,
271        io: &H::Buffer<H::Elem>,
272        data: &H::Buffer<H::Elem>,
273        mix: &H::Buffer<H::Elem>,
274        accum: &H::Buffer<H::Elem>,
275        steps: usize,
276    );
277
278    /// Compute check polynomial.
279    fn eval_check(
280        &self,
281        check: &H::Buffer<H::Elem>,
282        // Register groups, e.g. accum, code, data.  These should have one row for each cycle.
283        groups: &[&H::Buffer<H::Elem>],
284        // Globals.  These should have one row total.
285        globals: &[&H::Buffer<H::Elem>],
286        poly_mix: H::ExtElem,
287        po2: usize,
288        steps: usize,
289    );
290}
291
292pub fn tracker() -> &'static Mutex<MemoryTracker> {
293    static ONCE: OnceLock<Mutex<MemoryTracker>> = OnceLock::new();
294    ONCE.get_or_init(|| Mutex::new(MemoryTracker::default()))
295}
296
297#[derive(Debug, Default)]
298pub struct MemoryTracker {
299    pub total: isize,
300    pub peak: isize,
301}
302
303impl MemoryTracker {
304    pub fn reset(&mut self) {
305        self.total = 0;
306        self.peak = 0;
307    }
308
309    pub fn alloc(&mut self, size: usize) {
310        self.total += size as isize;
311        self.peak = self.peak.max(self.total);
312    }
313
314    pub fn free(&mut self, size: usize) {
315        self.total -= size as isize;
316    }
317}
318
319#[cfg(test)]
320#[allow(unused)]
321mod testutil {
322    use std::rc::Rc;
323
324    use rand::{thread_rng, RngCore};
325    use risc0_core::field::{
326        baby_bear::{BabyBearElem, BabyBearExtElem},
327        Elem, ExtElem,
328    };
329
330    use super::{dual::DualHal, Hal};
331    use crate::{
332        core::digest::Digest,
333        hal::{cpu::CpuHal, Buffer},
334        FRI_FOLD, INV_RATE,
335    };
336
337    const COUNTS: [usize; 7] = [1, 9, 12, 1001, 1024, 1025, 1024 * 1024];
338    const DATA_SIZE: usize = 224;
339
340    fn generate_elem<H: Hal, R: RngCore>(hal: &H, rng: &mut R, size: usize) -> H::Buffer<H::Elem> {
341        let values: Vec<H::Elem> = (0..size).map(|_| H::Elem::random(rng)).collect();
342        hal.copy_from_elem("values", &values)
343    }
344
345    fn generate_extelem<H: Hal, R: RngCore>(
346        hal: &H,
347        rng: &mut R,
348        size: usize,
349    ) -> H::Buffer<H::ExtElem> {
350        let values: Vec<H::ExtElem> = (0..size).map(|_| H::ExtElem::random(rng)).collect();
351        hal.copy_from_extelem("values", &values)
352    }
353
354    pub(crate) fn batch_bit_reverse<H: Hal>(hal_gpu: H) {
355        let mut rng = thread_rng();
356        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
357        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
358
359        let steps = 1 << 12;
360        let count = DATA_SIZE;
361        let domain = steps * INV_RATE;
362        let io_size = count * domain;
363
364        let io = generate_elem(&hal, &mut rng, io_size);
365        hal.batch_bit_reverse(&io, count);
366    }
367
368    pub(crate) fn batch_evaluate_any<H: Hal>(hal_gpu: H) {
369        let mut rng = thread_rng();
370        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
371        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
372
373        let eval_size = 865;
374        let poly_count = 223;
375        let steps = 1 << 16;
376        let coeffs_size = steps * poly_count;
377
378        let z = H::ExtElem::random(&mut rng);
379        let z_pow = z.pow(H::ExtElem::EXT_SIZE);
380
381        let coeffs = generate_elem(&hal, &mut rng, coeffs_size);
382        let which = hal.copy_from_u32("which", &vec![0; eval_size]);
383        let xs = hal.copy_from_extelem("xs", &vec![z_pow; eval_size]);
384        let out = hal.alloc_extelem("out", eval_size);
385
386        hal.batch_evaluate_any(&coeffs, poly_count, &which, &xs, &out);
387    }
388
389    pub(crate) fn batch_expand_into_evaluate_ntt<H: Hal>(hal_gpu: H) {
390        let mut rng = thread_rng();
391        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
392        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
393
394        let count = DATA_SIZE;
395        let expand_bits = 2;
396        let steps = 1 << 16;
397        let domain = steps * INV_RATE;
398        let input_size = count * steps;
399        let output_size = count * domain;
400
401        let input = generate_elem(&hal, &mut rng, input_size);
402        let output = hal.alloc_elem("output", output_size);
403        hal.batch_expand_into_evaluate_ntt(&output, &input, count, expand_bits);
404    }
405
406    pub(crate) fn batch_interpolate_ntt<H: Hal>(hal_gpu: H) {
407        let mut rng = thread_rng();
408        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
409        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
410
411        let count = DATA_SIZE;
412        let steps = 1 << 16;
413        let domain = steps * INV_RATE;
414        let io_size = count * domain;
415
416        let io = generate_elem(&hal, &mut rng, io_size);
417        hal.batch_interpolate_ntt(&io, count);
418    }
419
420    pub(crate) fn gather_sample<H: Hal>(hal: H) {
421        let mut rng = thread_rng();
422        let rows = 1000;
423        let cols = 900;
424        let idx = 400;
425        let src_size = rows * cols;
426        let src = hal.alloc_elem("src", src_size);
427        let dst = hal.alloc_elem("dst", rows);
428        src.view_mut(|buf| {
429            for x in 0..cols {
430                for y in 0..rows {
431                    let value = H::Elem::random(&mut rng);
432                    buf[y * cols + x] = value;
433                }
434            }
435        });
436        hal.gather_sample(&dst, &src, idx, rows, cols);
437        src.view(|src| {
438            dst.view(|dst| {
439                for y in 0..rows {
440                    assert_eq!(src[y * cols + idx], dst[y]);
441                }
442            });
443        });
444    }
445
446    pub(crate) fn check_req<H: Hal>(hal: H) {
447        let a = hal.alloc_elem("a", 10);
448        let b = hal.alloc_elem("b", 20);
449        hal.eltwise_add_elem(&a, &b, &b);
450    }
451
452    pub(crate) fn eltwise_add_elem<H: Hal>(hal_gpu: H) {
453        for (x, count) in COUNTS.iter().enumerate() {
454            let a = hal_gpu.alloc_elem("a", *count);
455            let b = hal_gpu.alloc_elem("b", *count);
456            let o = hal_gpu.alloc_elem("o", *count);
457            let mut golden = Vec::with_capacity(*count);
458
459            let mut rng = thread_rng();
460            a.view_mut(|a| {
461                b.view_mut(|b| {
462                    assert_eq!(a.len(), b.len());
463                    for i in 0..a.len() {
464                        a[i] = H::Elem::random(&mut rng);
465                        b[i] = H::Elem::random(&mut rng);
466                    }
467                    for i in 0..a.len() {
468                        golden.push(a[i] + b[i]);
469                    }
470                });
471            });
472
473            hal_gpu.eltwise_add_elem(&o, &a, &b);
474
475            o.view(|o| {
476                for i in 0..o.len() {
477                    assert_eq!(o[i], golden[i], "x: {x}, count: {count}, i: {i}");
478                }
479            });
480        }
481    }
482
483    pub(crate) fn eltwise_copy_elem<H: Hal>(hal_gpu: H) {
484        let mut rng = thread_rng();
485        for count in COUNTS {
486            let input = generate_elem(&hal_gpu, &mut rng, count);
487            let output = hal_gpu.alloc_elem("output", count);
488            hal_gpu.eltwise_copy_elem(&output, &input);
489            output.view(|output| {
490                input.view(|input| assert_eq!(output, input));
491            });
492        }
493    }
494
495    pub(crate) fn eltwise_sum_extelem<H: Hal>(hal_gpu: H) {
496        const COUNT: usize = 1024 * 1024;
497
498        let mut rng = thread_rng();
499        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
500        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
501
502        let input = generate_extelem(&hal, &mut rng, COUNT);
503        let output = hal.alloc_elem("output", COUNT);
504        hal.eltwise_sum_extelem(&output, &input);
505    }
506
507    pub(crate) fn fri_fold<H: Hal>(hal_gpu: H) {
508        let mut rng = thread_rng();
509        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
510        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
511        for count in COUNTS {
512            let output_size = count * H::ExtElem::EXT_SIZE;
513            let input_size = output_size * FRI_FOLD;
514
515            let output = hal.alloc_elem("output", output_size);
516            let mix = H::ExtElem::random(&mut rng);
517            let input = generate_elem(&hal, &mut rng, input_size);
518            hal.fri_fold(&output, &input, &mix);
519        }
520    }
521
522    pub(crate) fn mix_poly_coeffs<H: Hal>(hal_gpu: H) {
523        let mut rng = thread_rng();
524        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
525        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
526
527        let combo_count = 100;
528        let steps = 1 << 12;
529        let domain = steps * INV_RATE;
530        let input_size = H::CHECK_SIZE * domain;
531        let output_size = steps * (combo_count + 1);
532        let combos = vec![0; H::CHECK_SIZE];
533        let mix_start = H::ExtElem::random(&mut rng);
534        let mix = H::ExtElem::random(&mut rng);
535
536        let output = hal.alloc_extelem("output", output_size);
537        let combos = hal.copy_from_u32("combos", &combos);
538        let input = generate_elem(&hal, &mut rng, input_size);
539
540        hal.mix_poly_coeffs(
541            &output,
542            &mix_start,
543            &mix,
544            &input,
545            &combos,
546            H::CHECK_SIZE,
547            steps,
548        );
549    }
550
551    pub(crate) fn hash_fold<H: Hal>(hal_gpu: H) {
552        const INPUTS: usize = 1024;
553        const OUTPUTS: usize = INPUTS / 2;
554        let mut rng = thread_rng();
555        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
556        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
557        let io = hal.alloc_digest("io", INPUTS * 2);
558        io.view_mut(|g| {
559            for i in 0..INPUTS {
560                g[i + INPUTS] = Digest::from([
561                    rng.next_u32() / 3,
562                    rng.next_u32() / 3,
563                    rng.next_u32() / 3,
564                    rng.next_u32() / 3,
565                    rng.next_u32() / 3,
566                    rng.next_u32() / 3,
567                    rng.next_u32() / 3,
568                    rng.next_u32() / 3,
569                ]);
570            }
571        });
572        hal.hash_fold(&io, INPUTS, OUTPUTS);
573    }
574
575    pub(crate) fn hash_rows<H: Hal<Elem = BabyBearElem>>(hal_gpu: H) {
576        let mut rng = thread_rng();
577        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
578        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
579        let rows = [1, 2, 3, 4, 10];
580        let cols = [16, 32, 64, 128];
581        for row_count in rows {
582            for col_count in cols {
583                let matrix_size = row_count * col_count;
584                let matrix = generate_elem(&hal, &mut rng, matrix_size);
585                let output = hal.alloc_digest("output", row_count);
586                hal.hash_rows(&output, &matrix);
587            }
588        }
589    }
590
591    pub(crate) fn slice<H: Hal<Elem = BabyBearElem>>(hal_gpu: H) {
592        let mut rng = thread_rng();
593        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
594        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
595
596        let rows = 4096;
597        let cols = 256;
598        let matrix_size = rows * cols;
599
600        let nodes = hal.alloc_digest("nodes", rows * 2);
601        let matrix = generate_elem(&hal, &mut rng, matrix_size);
602        hal.hash_rows(&nodes.slice(rows, rows), &matrix);
603    }
604
605    pub(crate) fn zk_shift<H: Hal>(hal_gpu: H) {
606        let mut rng = thread_rng();
607        let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
608        let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
609        let counts = [(1000, (1 << 8)), (900, (1 << 12))];
610        for (poly_count, steps) in counts {
611            let count = poly_count * steps;
612            let io = generate_elem(&hal, &mut rng, count);
613            hal.zk_shift(&io, poly_count);
614        }
615    }
616}