1pub mod cpu;
18#[cfg(feature = "cuda")]
19pub mod cuda;
20pub mod dual;
21#[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))]
22pub mod metal;
23
24use std::{
25 fmt::Debug,
26 sync::{Mutex, OnceLock},
27};
28
29use risc0_core::{
30 field::{Elem, ExtElem, Field, RootsOfUnity},
31 scope,
32};
33
34use crate::{
35 core::{digest::Digest, hash::HashSuite, poly::poly_divide},
36 INV_RATE,
37};
38
39pub trait Buffer<T>: Clone {
40 fn name(&self) -> &'static str;
41
42 fn size(&self) -> usize;
43
44 fn slice(&self, offset: usize, size: usize) -> Self;
45
46 fn get_at(&self, idx: usize) -> T;
47
48 fn view<F: FnOnce(&[T])>(&self, f: F);
49
50 fn view_mut<F: FnOnce(&mut [T])>(&self, f: F);
51
52 fn to_vec(&self) -> Vec<T>;
53}
54
55pub trait Hal {
56 type Field: Field<Elem = Self::Elem, ExtElem = Self::ExtElem>;
57 type Elem: Elem + RootsOfUnity;
58 type ExtElem: ExtElem<SubElem = Self::Elem>;
59 type Buffer<T: Clone + Debug + PartialEq>: Buffer<T>;
60
61 const CHECK_SIZE: usize = INV_RATE * Self::ExtElem::EXT_SIZE;
62
63 fn has_unified_memory(&self) -> bool;
64
65 fn get_hash_suite(&self) -> &HashSuite<Self::Field>;
66
67 fn alloc_digest(&self, name: &'static str, size: usize) -> Self::Buffer<Digest>;
68 fn alloc_elem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::Elem>;
69 fn alloc_extelem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem>;
70 fn alloc_u32(&self, name: &'static str, size: usize) -> Self::Buffer<u32>;
71
72 fn alloc_elem_init(
73 &self,
74 name: &'static str,
75 size: usize,
76 value: Self::Elem,
77 ) -> Self::Buffer<Self::Elem> {
78 let buffer = self.alloc_elem(name, size);
79 buffer.view_mut(|slice| {
80 slice.fill(value);
81 });
82 buffer
83 }
84
85 fn alloc_extelem_zeroed(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem> {
86 let buffer = self.alloc_extelem(name, size);
87 buffer.view_mut(|slice| {
88 slice.fill(Self::ExtElem::ZERO);
89 });
90 buffer
91 }
92
93 fn copy_from_digest(&self, name: &'static str, slice: &[Digest]) -> Self::Buffer<Digest>;
94 fn copy_from_elem(&self, name: &'static str, slice: &[Self::Elem]) -> Self::Buffer<Self::Elem>;
95 fn copy_from_extelem(
96 &self,
97 name: &'static str,
98 slice: &[Self::ExtElem],
99 ) -> Self::Buffer<Self::ExtElem>;
100 fn copy_from_u32(&self, name: &'static str, slice: &[u32]) -> Self::Buffer<u32>;
101
102 fn batch_expand_into_evaluate_ntt(
103 &self,
104 output: &Self::Buffer<Self::Elem>,
105 input: &Self::Buffer<Self::Elem>,
106 count: usize,
107 expand_bits: usize,
108 );
109
110 fn batch_interpolate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize);
111
112 fn batch_bit_reverse(&self, io: &Self::Buffer<Self::Elem>, count: usize);
113
114 fn batch_evaluate_any(
115 &self,
116 coeffs: &Self::Buffer<Self::Elem>,
117 poly_count: usize,
118 which: &Self::Buffer<u32>,
119 xs: &Self::Buffer<Self::ExtElem>,
120 out: &Self::Buffer<Self::ExtElem>,
121 );
122
123 fn zk_shift(&self, io: &Self::Buffer<Self::Elem>, count: usize);
124
125 #[allow(clippy::too_many_arguments)]
126 fn mix_poly_coeffs(
127 &self,
128 out: &Self::Buffer<Self::ExtElem>,
129 mix_start: &Self::ExtElem,
130 mix: &Self::ExtElem,
131 input: &Self::Buffer<Self::Elem>,
132 combos: &Self::Buffer<u32>,
133 input_size: usize,
134 count: usize,
135 );
136
137 fn eltwise_add_elem(
138 &self,
139 output: &Self::Buffer<Self::Elem>,
140 input1: &Self::Buffer<Self::Elem>,
141 input2: &Self::Buffer<Self::Elem>,
142 );
143
144 fn eltwise_sum_extelem(
145 &self,
146 output: &Self::Buffer<Self::Elem>,
147 input: &Self::Buffer<Self::ExtElem>,
148 );
149
150 fn eltwise_copy_elem(
151 &self,
152 output: &Self::Buffer<Self::Elem>,
153 input: &Self::Buffer<Self::Elem>,
154 );
155
156 #[allow(clippy::too_many_arguments)]
157 fn eltwise_copy_elem_slice(
158 &self,
159 into: &Self::Buffer<Self::Elem>,
160 from: &[Self::Elem],
161 from_rows: usize,
162 from_cols: usize,
163 from_offset: usize,
164 from_stride: usize,
165 into_offset: usize,
166 into_stride: usize,
167 );
168
169 fn eltwise_zeroize_elem(&self, elems: &Self::Buffer<Self::Elem>);
170
171 fn fri_fold(
172 &self,
173 output: &Self::Buffer<Self::Elem>,
174 input: &Self::Buffer<Self::Elem>,
175 mix: &Self::ExtElem,
176 );
177
178 fn hash_rows(&self, output: &Self::Buffer<Digest>, matrix: &Self::Buffer<Self::Elem>);
179
180 fn hash_fold(&self, io: &Self::Buffer<Digest>, input_size: usize, output_size: usize);
181
182 fn gather_sample(
183 &self,
184 dst: &Self::Buffer<Self::Elem>,
185 src: &Self::Buffer<Self::Elem>,
186 idx: usize,
187 size: usize,
188 stride: usize,
189 );
190
191 fn scatter(
192 &self,
193 into: &Self::Buffer<Self::Elem>,
194 index: &[u32],
195 offsets: &[u32],
196 values: &[Self::Elem],
197 );
198
199 fn prefix_products(&self, io: &Self::Buffer<Self::ExtElem>);
200
201 #[allow(clippy::too_many_arguments)]
202 fn combos_prepare(
203 &self,
204 combos: &Self::Buffer<Self::ExtElem>,
205 coeff_u: &[Self::ExtElem],
206 combo_count: usize,
207 cycles: usize,
208 reg_sizes: &[u32],
209 reg_combo_ids: &[u32],
210 mix: &Self::ExtElem,
211 ) {
212 combos.view_mut(|combos| {
213 scope!("combos_prepare", {
214 let mut cur_pos = 0;
215 let mut cur = Self::ExtElem::ONE;
216 for (reg_size, reg_combo_id) in reg_sizes.iter().zip(reg_combo_ids) {
218 let reg_size = *reg_size as usize;
219 let reg_combo_id = *reg_combo_id as usize;
220 for i in 0..reg_size {
221 combos[cycles * reg_combo_id + i] -= cur * coeff_u[cur_pos + i];
222 }
223 cur *= *mix;
224 cur_pos += reg_size;
225 }
226 for _ in 0..Self::CHECK_SIZE {
228 combos[cycles * combo_count] -= cur * coeff_u[cur_pos];
229 cur_pos += 1;
230 cur *= *mix;
231 }
232 });
233 });
234 }
235
236 fn combos_divide(
237 &self,
238 combos: &Self::Buffer<Self::ExtElem>,
239 chunks: Vec<(usize, Vec<Self::ExtElem>)>,
240 cycles: usize,
241 ) {
242 use rayon::prelude::*;
243
244 scope!("combos_divide");
245
246 combos.view_mut(|combos| {
247 combos
248 .par_chunks_exact_mut(cycles)
249 .zip(chunks)
250 .for_each(|(combo_slice, (i, pows))| {
251 for pow in pows {
252 let remainder = poly_divide(combo_slice, pow);
253 assert_eq!(remainder, Self::ExtElem::ZERO, "i: {i}");
254 }
255 });
256 });
257 }
258}
259
260#[derive(Clone, Default)]
261pub struct AccumPreflight {
262 pub is_par_safe: Vec<u8>,
263}
264
265pub trait CircuitHal<H: Hal> {
266 #[allow(clippy::too_many_arguments)]
267 fn accumulate(
268 &self,
269 preflight: &AccumPreflight,
270 ctrl: &H::Buffer<H::Elem>,
271 io: &H::Buffer<H::Elem>,
272 data: &H::Buffer<H::Elem>,
273 mix: &H::Buffer<H::Elem>,
274 accum: &H::Buffer<H::Elem>,
275 steps: usize,
276 );
277
278 fn eval_check(
280 &self,
281 check: &H::Buffer<H::Elem>,
282 groups: &[&H::Buffer<H::Elem>],
284 globals: &[&H::Buffer<H::Elem>],
286 poly_mix: H::ExtElem,
287 po2: usize,
288 steps: usize,
289 );
290}
291
292pub fn tracker() -> &'static Mutex<MemoryTracker> {
293 static ONCE: OnceLock<Mutex<MemoryTracker>> = OnceLock::new();
294 ONCE.get_or_init(|| Mutex::new(MemoryTracker::default()))
295}
296
297#[derive(Debug, Default)]
298pub struct MemoryTracker {
299 pub total: isize,
300 pub peak: isize,
301}
302
303impl MemoryTracker {
304 pub fn reset(&mut self) {
305 self.total = 0;
306 self.peak = 0;
307 }
308
309 pub fn alloc(&mut self, size: usize) {
310 self.total += size as isize;
311 self.peak = self.peak.max(self.total);
312 }
313
314 pub fn free(&mut self, size: usize) {
315 self.total -= size as isize;
316 }
317}
318
319#[cfg(test)]
320#[allow(unused)]
321mod testutil {
322 use std::rc::Rc;
323
324 use rand::{thread_rng, RngCore};
325 use risc0_core::field::{
326 baby_bear::{BabyBearElem, BabyBearExtElem},
327 Elem, ExtElem,
328 };
329
330 use super::{dual::DualHal, Hal};
331 use crate::{
332 core::digest::Digest,
333 hal::{cpu::CpuHal, Buffer},
334 FRI_FOLD, INV_RATE,
335 };
336
337 const COUNTS: [usize; 7] = [1, 9, 12, 1001, 1024, 1025, 1024 * 1024];
338 const DATA_SIZE: usize = 224;
339
340 fn generate_elem<H: Hal, R: RngCore>(hal: &H, rng: &mut R, size: usize) -> H::Buffer<H::Elem> {
341 let values: Vec<H::Elem> = (0..size).map(|_| H::Elem::random(rng)).collect();
342 hal.copy_from_elem("values", &values)
343 }
344
345 fn generate_extelem<H: Hal, R: RngCore>(
346 hal: &H,
347 rng: &mut R,
348 size: usize,
349 ) -> H::Buffer<H::ExtElem> {
350 let values: Vec<H::ExtElem> = (0..size).map(|_| H::ExtElem::random(rng)).collect();
351 hal.copy_from_extelem("values", &values)
352 }
353
354 pub(crate) fn batch_bit_reverse<H: Hal>(hal_gpu: H) {
355 let mut rng = thread_rng();
356 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
357 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
358
359 let steps = 1 << 12;
360 let count = DATA_SIZE;
361 let domain = steps * INV_RATE;
362 let io_size = count * domain;
363
364 let io = generate_elem(&hal, &mut rng, io_size);
365 hal.batch_bit_reverse(&io, count);
366 }
367
368 pub(crate) fn batch_evaluate_any<H: Hal>(hal_gpu: H) {
369 let mut rng = thread_rng();
370 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
371 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
372
373 let eval_size = 865;
374 let poly_count = 223;
375 let steps = 1 << 16;
376 let coeffs_size = steps * poly_count;
377
378 let z = H::ExtElem::random(&mut rng);
379 let z_pow = z.pow(H::ExtElem::EXT_SIZE);
380
381 let coeffs = generate_elem(&hal, &mut rng, coeffs_size);
382 let which = hal.copy_from_u32("which", &vec![0; eval_size]);
383 let xs = hal.copy_from_extelem("xs", &vec![z_pow; eval_size]);
384 let out = hal.alloc_extelem("out", eval_size);
385
386 hal.batch_evaluate_any(&coeffs, poly_count, &which, &xs, &out);
387 }
388
389 pub(crate) fn batch_expand_into_evaluate_ntt<H: Hal>(hal_gpu: H) {
390 let mut rng = thread_rng();
391 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
392 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
393
394 let count = DATA_SIZE;
395 let expand_bits = 2;
396 let steps = 1 << 16;
397 let domain = steps * INV_RATE;
398 let input_size = count * steps;
399 let output_size = count * domain;
400
401 let input = generate_elem(&hal, &mut rng, input_size);
402 let output = hal.alloc_elem("output", output_size);
403 hal.batch_expand_into_evaluate_ntt(&output, &input, count, expand_bits);
404 }
405
406 pub(crate) fn batch_interpolate_ntt<H: Hal>(hal_gpu: H) {
407 let mut rng = thread_rng();
408 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
409 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
410
411 let count = DATA_SIZE;
412 let steps = 1 << 16;
413 let domain = steps * INV_RATE;
414 let io_size = count * domain;
415
416 let io = generate_elem(&hal, &mut rng, io_size);
417 hal.batch_interpolate_ntt(&io, count);
418 }
419
420 pub(crate) fn gather_sample<H: Hal>(hal: H) {
421 let mut rng = thread_rng();
422 let rows = 1000;
423 let cols = 900;
424 let idx = 400;
425 let src_size = rows * cols;
426 let src = hal.alloc_elem("src", src_size);
427 let dst = hal.alloc_elem("dst", rows);
428 src.view_mut(|buf| {
429 for x in 0..cols {
430 for y in 0..rows {
431 let value = H::Elem::random(&mut rng);
432 buf[y * cols + x] = value;
433 }
434 }
435 });
436 hal.gather_sample(&dst, &src, idx, rows, cols);
437 src.view(|src| {
438 dst.view(|dst| {
439 for y in 0..rows {
440 assert_eq!(src[y * cols + idx], dst[y]);
441 }
442 });
443 });
444 }
445
446 pub(crate) fn check_req<H: Hal>(hal: H) {
447 let a = hal.alloc_elem("a", 10);
448 let b = hal.alloc_elem("b", 20);
449 hal.eltwise_add_elem(&a, &b, &b);
450 }
451
452 pub(crate) fn eltwise_add_elem<H: Hal>(hal_gpu: H) {
453 for (x, count) in COUNTS.iter().enumerate() {
454 let a = hal_gpu.alloc_elem("a", *count);
455 let b = hal_gpu.alloc_elem("b", *count);
456 let o = hal_gpu.alloc_elem("o", *count);
457 let mut golden = Vec::with_capacity(*count);
458
459 let mut rng = thread_rng();
460 a.view_mut(|a| {
461 b.view_mut(|b| {
462 assert_eq!(a.len(), b.len());
463 for i in 0..a.len() {
464 a[i] = H::Elem::random(&mut rng);
465 b[i] = H::Elem::random(&mut rng);
466 }
467 for i in 0..a.len() {
468 golden.push(a[i] + b[i]);
469 }
470 });
471 });
472
473 hal_gpu.eltwise_add_elem(&o, &a, &b);
474
475 o.view(|o| {
476 for i in 0..o.len() {
477 assert_eq!(o[i], golden[i], "x: {x}, count: {count}, i: {i}");
478 }
479 });
480 }
481 }
482
483 pub(crate) fn eltwise_copy_elem<H: Hal>(hal_gpu: H) {
484 let mut rng = thread_rng();
485 for count in COUNTS {
486 let input = generate_elem(&hal_gpu, &mut rng, count);
487 let output = hal_gpu.alloc_elem("output", count);
488 hal_gpu.eltwise_copy_elem(&output, &input);
489 output.view(|output| {
490 input.view(|input| assert_eq!(output, input));
491 });
492 }
493 }
494
495 pub(crate) fn eltwise_sum_extelem<H: Hal>(hal_gpu: H) {
496 const COUNT: usize = 1024 * 1024;
497
498 let mut rng = thread_rng();
499 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
500 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
501
502 let input = generate_extelem(&hal, &mut rng, COUNT);
503 let output = hal.alloc_elem("output", COUNT);
504 hal.eltwise_sum_extelem(&output, &input);
505 }
506
507 pub(crate) fn fri_fold<H: Hal>(hal_gpu: H) {
508 let mut rng = thread_rng();
509 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
510 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
511 for count in COUNTS {
512 let output_size = count * H::ExtElem::EXT_SIZE;
513 let input_size = output_size * FRI_FOLD;
514
515 let output = hal.alloc_elem("output", output_size);
516 let mix = H::ExtElem::random(&mut rng);
517 let input = generate_elem(&hal, &mut rng, input_size);
518 hal.fri_fold(&output, &input, &mix);
519 }
520 }
521
522 pub(crate) fn mix_poly_coeffs<H: Hal>(hal_gpu: H) {
523 let mut rng = thread_rng();
524 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
525 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
526
527 let combo_count = 100;
528 let steps = 1 << 12;
529 let domain = steps * INV_RATE;
530 let input_size = H::CHECK_SIZE * domain;
531 let output_size = steps * (combo_count + 1);
532 let combos = vec![0; H::CHECK_SIZE];
533 let mix_start = H::ExtElem::random(&mut rng);
534 let mix = H::ExtElem::random(&mut rng);
535
536 let output = hal.alloc_extelem("output", output_size);
537 let combos = hal.copy_from_u32("combos", &combos);
538 let input = generate_elem(&hal, &mut rng, input_size);
539
540 hal.mix_poly_coeffs(
541 &output,
542 &mix_start,
543 &mix,
544 &input,
545 &combos,
546 H::CHECK_SIZE,
547 steps,
548 );
549 }
550
551 pub(crate) fn hash_fold<H: Hal>(hal_gpu: H) {
552 const INPUTS: usize = 1024;
553 const OUTPUTS: usize = INPUTS / 2;
554 let mut rng = thread_rng();
555 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
556 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
557 let io = hal.alloc_digest("io", INPUTS * 2);
558 io.view_mut(|g| {
559 for i in 0..INPUTS {
560 g[i + INPUTS] = Digest::from([
561 rng.next_u32() / 3,
562 rng.next_u32() / 3,
563 rng.next_u32() / 3,
564 rng.next_u32() / 3,
565 rng.next_u32() / 3,
566 rng.next_u32() / 3,
567 rng.next_u32() / 3,
568 rng.next_u32() / 3,
569 ]);
570 }
571 });
572 hal.hash_fold(&io, INPUTS, OUTPUTS);
573 }
574
575 pub(crate) fn hash_rows<H: Hal<Elem = BabyBearElem>>(hal_gpu: H) {
576 let mut rng = thread_rng();
577 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
578 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
579 let rows = [1, 2, 3, 4, 10];
580 let cols = [16, 32, 64, 128];
581 for row_count in rows {
582 for col_count in cols {
583 let matrix_size = row_count * col_count;
584 let matrix = generate_elem(&hal, &mut rng, matrix_size);
585 let output = hal.alloc_digest("output", row_count);
586 hal.hash_rows(&output, &matrix);
587 }
588 }
589 }
590
591 pub(crate) fn slice<H: Hal<Elem = BabyBearElem>>(hal_gpu: H) {
592 let mut rng = thread_rng();
593 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
594 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
595
596 let rows = 4096;
597 let cols = 256;
598 let matrix_size = rows * cols;
599
600 let nodes = hal.alloc_digest("nodes", rows * 2);
601 let matrix = generate_elem(&hal, &mut rng, matrix_size);
602 hal.hash_rows(&nodes.slice(rows, rows), &matrix);
603 }
604
605 pub(crate) fn zk_shift<H: Hal>(hal_gpu: H) {
606 let mut rng = thread_rng();
607 let hal_cpu = CpuHal::new(hal_gpu.get_hash_suite().clone());
608 let hal = DualHal::new(Rc::new(hal_cpu), Rc::new(hal_gpu));
609 let counts = [(1000, (1 << 8)), (900, (1 << 12))];
610 for (poly_count, steps) in counts {
611 let count = poly_count * steps;
612 let io = generate_elem(&hal, &mut rng, count);
613 hal.zk_shift(&io, poly_count);
614 }
615 }
616}