1use std::{fmt::Debug, ops::Range, sync::Arc};
18
19use ndarray::{ArrayView, ArrayViewMut, Axis};
20use parking_lot::{
21 MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
22};
23use rayon::prelude::*;
24use risc0_core::field::{Elem, ExtElem, Field};
25
26use super::{tracker, Buffer, Hal};
27use crate::{
28 core::{
29 digest::Digest,
30 hash::HashSuite,
31 log2_ceil,
32 ntt::{bit_rev_32, bit_reverse, evaluate_ntt, expand, interpolate_ntt},
33 },
34 FRI_FOLD,
35};
36
37pub struct CpuHal<F: Field> {
38 suite: HashSuite<F>,
39}
40
41impl<F: Field> CpuHal<F> {
42 pub fn new(suite: HashSuite<F>) -> Self {
43 Self { suite }
44 }
45}
46
47#[derive(Debug, Clone)]
48struct Region(usize, usize);
49
50impl Region {
51 fn offset(&self) -> usize {
52 self.0
53 }
54
55 fn size(&self) -> usize {
56 self.1
57 }
58
59 fn range(&self) -> Range<usize> {
60 Range {
61 start: self.offset(),
62 end: self.offset() + self.size(),
63 }
64 }
65}
66
67struct TrackedVec<T>(Vec<T>);
68
69impl<T> TrackedVec<T> {
70 pub fn new(vec: Vec<T>) -> Self {
71 tracker()
72 .lock()
73 .unwrap()
74 .alloc(vec.capacity() * std::mem::size_of::<T>());
75 Self(vec)
76 }
77}
78
79impl<T> Drop for TrackedVec<T> {
80 fn drop(&mut self) {
81 tracker()
82 .lock()
83 .unwrap()
84 .free(self.0.capacity() * std::mem::size_of::<T>());
85 }
86}
87
88#[derive(Clone)]
89pub struct CpuBuffer<T> {
90 name: &'static str,
91 buf: Arc<RwLock<TrackedVec<T>>>,
92 region: Region,
93}
94
95enum SyncSliceRef<'a, T: Default + Clone> {
96 FromBuf {
97 _inner: MappedRwLockWriteGuard<'a, [T]>,
98 },
99 FromSlice {
100 _inner: &'a SyncSlice<'a, T>,
101 },
102}
103
104pub struct SyncSlice<'a, T: Default + Clone> {
108 _buf: SyncSliceRef<'a, T>,
109 ptr: *mut T,
110 size: usize,
111}
112
113unsafe impl<T: Default + Clone> Sync for SyncSlice<'_, T> {}
119
120impl<'a, T: Default + Clone> SyncSlice<'a, T> {
121 pub fn new(mut buf: MappedRwLockWriteGuard<'a, [T]>) -> Self {
122 let ptr = buf.as_mut_ptr();
123 let size = buf.len();
124 SyncSlice {
125 ptr,
126 size,
127 _buf: SyncSliceRef::FromBuf { _inner: buf },
128 }
129 }
130
131 pub fn get_ptr(&self) -> *mut T {
132 self.ptr
133 }
134
135 pub fn get(&self, offset: usize) -> T {
136 assert!(offset < self.size);
137 unsafe { self.ptr.add(offset).read() }
138 }
139
140 pub fn set(&self, offset: usize, val: T) {
141 assert!(offset < self.size);
142 unsafe { self.ptr.add(offset).write(val) }
143 }
144
145 pub fn slice(&self, offset: usize, size: usize) -> SyncSlice<'_, T> {
146 assert!(
147 offset + size <= self.size,
148 "Attempting to slice [{offset}, {offset} + {size} = {}) from a slice of length {}",
149 offset + size,
150 self.size
151 );
152 SyncSlice {
153 _buf: SyncSliceRef::FromSlice { _inner: self },
154 ptr: unsafe { self.ptr.add(offset) },
155 size,
156 }
157 }
158
159 pub fn size(&self) -> usize {
160 self.size
161 }
162}
163
164impl<T: Default + Clone> CpuBuffer<T> {
165 fn new(name: &'static str, size: usize) -> Self {
166 let buf = vec![T::default(); size];
167 CpuBuffer {
168 name,
169 buf: Arc::new(RwLock::new(TrackedVec::new(buf))),
170 region: Region(0, size),
171 }
172 }
173
174 pub fn get_ptr(&self) -> *mut T {
175 self.as_slice_sync().get_ptr()
176 }
177
178 fn copy_from(name: &'static str, slice: &[T]) -> Self {
179 CpuBuffer {
180 name,
181 buf: Arc::new(RwLock::new(TrackedVec::new(slice.to_vec()))),
182 region: Region(0, slice.len()),
183 }
184 }
185
186 pub fn from_fn<F>(name: &'static str, size: usize, f: F) -> Self
187 where
188 F: FnMut(usize) -> T,
189 {
190 let vec = (0..size).map(f).collect();
191 CpuBuffer {
192 name,
193 buf: Arc::new(RwLock::new(TrackedVec::new(vec))),
194 region: Region(0, size),
195 }
196 }
197
198 pub fn as_slice(&self) -> MappedRwLockReadGuard<'_, [T]> {
199 let vec = self.buf.read();
200 RwLockReadGuard::map(vec, |vec| &vec.0[self.region.range()])
201 }
202
203 pub fn as_slice_mut(&self) -> MappedRwLockWriteGuard<'_, [T]> {
204 let vec = self.buf.write();
205 RwLockWriteGuard::map(vec, |vec| &mut vec.0[self.region.range()])
206 }
207
208 pub fn as_slice_sync(&self) -> SyncSlice<'_, T> {
209 SyncSlice::new(self.as_slice_mut())
210 }
211}
212
213impl<T: Default + Clone> From<Vec<T>> for CpuBuffer<T> {
214 fn from(vec: Vec<T>) -> CpuBuffer<T> {
215 let size = vec.len();
216 CpuBuffer {
217 name: "vec",
218 buf: Arc::new(RwLock::new(TrackedVec::new(vec))),
219 region: Region(0, size),
220 }
221 }
222}
223
224impl<T: Clone> Buffer<T> for CpuBuffer<T> {
225 fn name(&self) -> &'static str {
226 self.name
227 }
228
229 fn size(&self) -> usize {
230 self.region.size()
231 }
232
233 fn slice(&self, offset: usize, size: usize) -> CpuBuffer<T> {
234 assert!(offset + size <= self.size());
235 let region = Region(self.region.offset() + offset, size);
236 CpuBuffer {
237 name: self.name,
238 buf: Arc::clone(&self.buf),
239 region,
240 }
241 }
242
243 fn get_at(&self, idx: usize) -> T {
244 let buf = self.buf.read();
245 buf.0[idx].clone()
246 }
247
248 fn view<F: FnOnce(&[T])>(&self, f: F) {
249 let buf = self.buf.read();
250 f(&buf.0[self.region.range()]);
251 }
252
253 fn view_mut<F: FnOnce(&mut [T])>(&self, f: F) {
254 let mut buf = self.buf.write();
255 f(&mut buf.0[self.region.range()]);
256 }
257
258 fn to_vec(&self) -> Vec<T> {
259 self.buf.read().0.clone()
260 }
261}
262
263impl<F: Field> Hal for CpuHal<F> {
264 type Field = F;
265 type Elem = F::Elem;
266 type ExtElem = F::ExtElem;
267 type Buffer<T: Clone + Debug + PartialEq> = CpuBuffer<T>;
268
269 fn alloc_elem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::Elem> {
270 CpuBuffer::new(name, size)
271 }
272
273 fn copy_from_elem(&self, name: &'static str, slice: &[Self::Elem]) -> Self::Buffer<Self::Elem> {
274 CpuBuffer::copy_from(name, slice)
275 }
276
277 fn alloc_extelem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem> {
278 CpuBuffer::new(name, size)
279 }
280
281 fn copy_from_extelem(
282 &self,
283 name: &'static str,
284 slice: &[Self::ExtElem],
285 ) -> Self::Buffer<Self::ExtElem> {
286 CpuBuffer::copy_from(name, slice)
287 }
288
289 fn alloc_digest(&self, name: &'static str, size: usize) -> Self::Buffer<Digest> {
290 CpuBuffer::new(name, size)
291 }
292
293 fn copy_from_digest(&self, name: &'static str, slice: &[Digest]) -> Self::Buffer<Digest> {
294 CpuBuffer::copy_from(name, slice)
295 }
296
297 fn alloc_u32(&self, name: &'static str, size: usize) -> Self::Buffer<u32> {
298 CpuBuffer::new(name, size)
299 }
300
301 fn copy_from_u32(&self, name: &'static str, slice: &[u32]) -> Self::Buffer<u32> {
302 CpuBuffer::copy_from(name, slice)
303 }
304
305 fn batch_expand_into_evaluate_ntt(
306 &self,
307 output: &Self::Buffer<Self::Elem>,
308 input: &Self::Buffer<Self::Elem>,
309 count: usize,
310 expand_bits: usize,
311 ) {
312 {
314 let out_size = output.size() / count;
315 let in_size = input.size() / count;
316 let expand_bits = log2_ceil(out_size / in_size);
317 assert_eq!(out_size, in_size * (1 << expand_bits));
318 assert_eq!(out_size * count, output.size());
319 assert_eq!(in_size * count, input.size());
320 output
321 .as_slice_mut()
322 .par_chunks_exact_mut(out_size)
323 .zip(input.as_slice().par_chunks_exact(in_size))
324 .for_each(|(output, input)| {
325 expand(output, input, expand_bits);
326 });
327 }
328
329 {
331 let row_size = output.size() / count;
332 assert_eq!(row_size * count, output.size());
333 output
334 .as_slice_mut()
335 .par_chunks_exact_mut(row_size)
336 .for_each(|row| {
337 evaluate_ntt::<Self::Elem, Self::Elem>(row, expand_bits);
338 });
339 }
340 }
341
342 fn batch_interpolate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
343 let row_size = io.size() / count;
344 assert_eq!(row_size * count, io.size());
345 io.as_slice_mut()
346 .par_chunks_exact_mut(row_size)
347 .for_each(|row| {
348 interpolate_ntt::<Self::Elem, Self::Elem>(row);
349 });
350 }
351
352 fn batch_bit_reverse(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
353 let row_size = io.size() / count;
354 assert_eq!(row_size * count, io.size());
355 io.as_slice_mut()
356 .par_chunks_exact_mut(row_size)
357 .for_each(|row| {
358 bit_reverse(row);
359 });
360 }
361
362 fn batch_evaluate_any(
363 &self,
364 coeffs: &Self::Buffer<Self::Elem>,
365 poly_count: usize,
366 which: &Self::Buffer<u32>,
367 xs: &Self::Buffer<Self::ExtElem>,
368 out: &Self::Buffer<Self::ExtElem>,
369 ) {
370 let po2 = log2_ceil(coeffs.size() / poly_count);
371 assert_eq!(poly_count * (1 << po2), coeffs.size());
372 let eval_count = which.size();
373 assert_eq!(xs.size(), eval_count);
374 assert_eq!(out.size(), eval_count);
375 let coeffs = &*coeffs.as_slice();
376 let which = which.as_slice();
377 let xs = xs.as_slice();
378 let mut out = out.as_slice_mut();
379 (&which[..], &xs[..], &mut out[..])
380 .into_par_iter()
381 .for_each(|(id, x, out)| {
382 let mut tot = Self::ExtElem::ZERO;
383 let mut cur = Self::ExtElem::ONE;
384 let id = *id as usize;
385 let count = 1 << po2;
386 let local = &coeffs[count * id..count * id + count];
387 for coeff in local {
388 tot += cur * *coeff;
389 cur *= *x;
390 }
391 *out = tot;
392 });
393 }
394
395 fn zk_shift(&self, io: &Self::Buffer<Self::Elem>, poly_count: usize) {
396 let bits = log2_ceil(io.size() / poly_count);
397 let count = io.size();
398 assert_eq!(io.size(), poly_count * (1 << bits));
399 let mut io = io.as_slice_mut();
400 (&mut io[..], 0..count)
401 .into_par_iter()
402 .for_each(|(io, idx)| {
403 let pos = idx & ((1 << bits) - 1);
404 let rev = bit_rev_32(pos as u32) >> (32 - bits);
405 let pow3 = Self::Elem::from_u64(3).pow(rev as usize);
406 *io *= pow3;
407 });
408 }
409
410 fn mix_poly_coeffs(
411 &self,
412 output: &Self::Buffer<Self::ExtElem>,
413 mix_start: &Self::ExtElem,
414 mix: &Self::ExtElem,
415 input: &Self::Buffer<Self::Elem>,
416 combos: &Self::Buffer<u32>,
417 input_size: usize,
418 count: usize,
419 ) {
420 tracing::debug!(
421 "output: {}, input: {}, combos: {}, input_size: {input_size}, count: {count}",
422 output.size(),
423 input.size(),
424 combos.size()
425 );
426
427 let mut mix_cur = *mix_start;
428 let mix_pows: Vec<_> = (0..input_size)
429 .map(|_| {
430 let val = mix_cur;
431 mix_cur *= *mix;
432 val
433 })
434 .collect();
435
436 let combos: &[u32] = &combos.as_slice();
438 let mix_pows: &[Self::ExtElem] = mix_pows.as_slice();
439 let input: &[Self::Elem] = &input.as_slice();
440
441 output
442 .as_slice_mut()
443 .par_chunks_exact_mut(count)
444 .enumerate()
445 .for_each(|(id, out_chunk): (usize, &mut [Self::ExtElem])| {
446 for i in 0..input_size {
447 if combos[i] != id as u32 {
448 continue;
449 }
450 for idx in 0..count {
451 out_chunk[idx] += mix_pows[i] * input[count * i + idx];
452 }
453 }
454 });
455 }
456
457 fn eltwise_add_elem(
458 &self,
459 output: &Self::Buffer<Self::Elem>,
460 input1: &Self::Buffer<Self::Elem>,
461 input2: &Self::Buffer<Self::Elem>,
462 ) {
463 assert_eq!(output.size(), input1.size());
464 assert_eq!(output.size(), input2.size());
465 let mut output = output.as_slice_mut();
466 let input1 = input1.as_slice();
467 let input2 = input2.as_slice();
468 (&mut output[..], &input1[..], &input2[..])
469 .into_par_iter()
470 .for_each(|(o, a, b)| {
471 *o = *a + *b;
472 });
473 }
474
475 fn eltwise_sum_extelem(
476 &self,
477 output: &Self::Buffer<Self::Elem>,
478 input: &Self::Buffer<Self::ExtElem>,
479 ) {
480 let count = output.size() / Self::ExtElem::EXT_SIZE;
481 let to_add = input.size() / count;
482 assert_eq!(output.size(), count * Self::ExtElem::EXT_SIZE);
483 assert_eq!(input.size(), count * to_add);
484 let mut output = output.as_slice_mut();
485 let mut output =
486 ArrayViewMut::from_shape((Self::ExtElem::EXT_SIZE, count), &mut output).unwrap();
487 let output = output.axis_iter_mut(Axis(1)).into_par_iter();
488 let input = input.as_slice();
489 let input = ArrayView::from_shape((to_add, count), &input).unwrap();
490 let input = input.axis_iter(Axis(1)).into_par_iter();
491 output.zip(input).for_each(|(mut output, input)| {
492 let mut sum = Self::ExtElem::ZERO;
493 for i in input {
494 sum += *i;
495 }
496 for i in 0..Self::ExtElem::EXT_SIZE {
497 output[i] = sum.subelems()[i]
498 }
499 });
500 }
501
502 fn eltwise_copy_elem(
503 &self,
504 output: &Self::Buffer<Self::Elem>,
505 input: &Self::Buffer<Self::Elem>,
506 ) {
507 let count = output.size();
508 assert_eq!(count, input.size());
509 let mut output = output.as_slice_mut();
510 let input = input.as_slice();
511 (&mut output[..], &input[..])
512 .into_par_iter()
513 .for_each(|(output, input)| {
514 *output = *input;
515 });
516 }
517
518 fn eltwise_zeroize_elem(&self, elems: &Self::Buffer<Self::Elem>) {
519 elems.as_slice_mut().par_iter_mut().for_each(|elem| {
520 *elem = elem.valid_or_zero();
521 });
522 }
523
524 fn fri_fold(
525 &self,
526 output: &Self::Buffer<Self::Elem>,
527 input: &Self::Buffer<Self::Elem>,
528 mix: &Self::ExtElem,
529 ) {
530 let count = output.size() / Self::ExtElem::EXT_SIZE;
531 assert_eq!(output.size(), count * Self::ExtElem::EXT_SIZE);
532 assert_eq!(input.size(), output.size() * FRI_FOLD);
533 let mut output = output.as_slice_mut();
534 let input = input.as_slice();
535
536 for idx in 0..count {
538 let mut tot = Self::ExtElem::ZERO;
539 let mut cur_mix = Self::ExtElem::ONE;
540 for i in 0..FRI_FOLD {
541 let rev_i = bit_rev_32(i as u32) >> (32 - log2_ceil(FRI_FOLD));
542 let rev_idx = rev_i as usize * count + idx;
543 let factor = Self::ExtElem::from_subelems(
544 (0..Self::ExtElem::EXT_SIZE).map(|i| input[i * count * FRI_FOLD + rev_idx]),
545 );
546 tot += cur_mix * factor;
547 cur_mix *= *mix;
548 }
549 for i in 0..Self::ExtElem::EXT_SIZE {
550 output[count * i + idx] = tot.subelems()[i];
551 }
552 }
553 }
554
555 fn hash_rows(&self, output: &Self::Buffer<Digest>, matrix: &Self::Buffer<Self::Elem>) {
556 let row_size = output.size();
557 let col_size = matrix.size() / output.size();
558 assert_eq!(matrix.size(), col_size * row_size);
559 let mut output = output.as_slice_mut();
560 let matrix = &*matrix.as_slice();
561 let hashfn = self.suite.hashfn.as_ref();
562 output.par_iter_mut().enumerate().for_each(|(idx, output)| {
563 let column: Vec<Self::Elem> =
564 (0..col_size).map(|i| matrix[i * row_size + idx]).collect();
565 *output = *hashfn.hash_elem_slice(column.as_slice());
566 });
567 }
568
569 fn hash_fold(&self, io: &Self::Buffer<Digest>, input_size: usize, output_size: usize) {
570 assert!(io.size() >= 2 * input_size);
571 assert_eq!(input_size, 2 * output_size);
572 let io = io.as_slice_sync();
573 let output = io.slice(output_size, output_size);
574 let input = io.slice(input_size, input_size);
575 let hashfn = self.suite.hashfn.as_ref();
576 (0..output.size()).into_par_iter().for_each(|idx| {
577 let in1 = input.get(2 * idx);
578 let in2 = input.get(2 * idx + 1);
579 output.set(idx, *hashfn.hash_pair(&in1, &in2));
580 });
581 }
582
583 fn gather_sample(
584 &self,
585 dst: &Self::Buffer<Self::Elem>,
586 src: &Self::Buffer<Self::Elem>,
587 idx: usize,
588 size: usize,
589 stride: usize,
590 ) {
591 let src = src.as_slice();
592 let mut dst = dst.as_slice_mut();
593 for gid in 0..size {
594 dst[gid] = src[gid * stride + idx];
595 }
596 }
597
598 fn scatter(
599 &self,
600 into: &Self::Buffer<Self::Elem>,
601 index: &[u32],
602 offsets: &[u32],
603 values: &[Self::Elem],
604 ) {
605 if index.is_empty() {
606 return;
607 }
608
609 let mut into = into.as_slice_mut();
610 for cycle in 0..index.len() - 1 {
611 for idx in index[cycle]..index[cycle + 1] {
612 into[offsets[idx as usize] as usize] = values[idx as usize];
613 }
614 }
615 }
616
617 fn eltwise_copy_elem_slice(
618 &self,
619 into: &Self::Buffer<Self::Elem>,
620 from: &[Self::Elem],
621 from_rows: usize,
622 from_cols: usize,
623 from_offset: usize,
624 from_stride: usize,
625 into_offset: usize,
626 into_stride: usize,
627 ) {
628 let mut into = into.as_slice_mut();
629 for row in 0..from_rows {
630 for col in 0..from_cols {
631 into[into_offset + row * into_stride + col] =
632 from[from_offset + row * from_stride + col];
633 }
634 }
635 }
636
637 fn prefix_products(&self, io: &Self::Buffer<Self::ExtElem>) {
638 let mut io = io.as_slice_mut();
639 for i in 1..io.len() {
640 io[i] = io[i] * io[i - 1];
641 }
642 }
643
644 fn has_unified_memory(&self) -> bool {
645 true
646 }
647
648 fn get_hash_suite(&self) -> &HashSuite<Self::Field> {
649 &self.suite
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use hex::FromHex;
656 use rand::thread_rng;
657 use risc0_core::field::baby_bear::{BabyBear, BabyBearExtElem};
658
659 use super::*;
660 use crate::core::hash::sha::Sha256HashSuite;
661
662 #[test]
663 #[should_panic]
664 fn check_req() {
665 let hal: CpuHal<BabyBear> = CpuHal::new(Sha256HashSuite::new_suite());
666 let a = hal.alloc_elem("a", 10);
667 let b = hal.alloc_elem("b", 20);
668 hal.eltwise_add_elem(&a, &b, &b);
669 }
670
671 #[test]
672 fn fp() {
673 let hal: CpuHal<BabyBear> = CpuHal::new(Sha256HashSuite::new_suite());
674 const COUNT: usize = 1024 * 1024;
675 test_binary(
676 &hal,
677 |o, a, b| {
678 hal.eltwise_add_elem(o, a, b);
679 },
680 |a, b| *a + *b,
681 COUNT,
682 );
683 }
684
685 fn test_binary<H, HF, CF>(hal: &H, hal_fn: HF, cpu_fn: CF, count: usize)
686 where
687 H: Hal,
688 HF: Fn(&H::Buffer<H::Elem>, &H::Buffer<H::Elem>, &H::Buffer<H::Elem>),
689 CF: Fn(&H::Elem, &H::Elem) -> H::Elem,
690 {
691 let a = hal.alloc_elem("a", count);
692 let b = hal.alloc_elem("b", count);
693 let o = hal.alloc_elem("o", count);
694 let mut golden = Vec::with_capacity(count);
695 let mut rng = thread_rng();
696 a.view_mut(|a| {
697 b.view_mut(|b| {
698 for i in 0..count {
699 a[i] = H::Elem::random(&mut rng);
700 b[i] = H::Elem::random(&mut rng);
701 }
702 for i in 0..count {
703 golden.push(cpu_fn(&a[i], &b[i]));
704 }
705 });
706 });
707 hal_fn(&o, &a, &b);
708 o.view(|o| {
709 assert_eq!(o, &golden[..]);
710 });
711 }
712
713 fn do_hash_rows(rows: usize, cols: usize, expected: &[&str]) {
714 let hal: CpuHal<BabyBear> = CpuHal::new(Sha256HashSuite::new_suite());
715 let matrix_size = rows * cols;
716 let matrix = hal.alloc_elem("matrix", matrix_size);
717 let output = hal.alloc_digest("output", rows);
718 hal.hash_rows(&output, &matrix);
719 output.view(|view| {
720 assert_eq!(expected.len(), view.len());
721 for (expected, actual) in expected.iter().zip(view) {
722 assert_eq!(Digest::from_hex(expected).unwrap(), *actual);
723 }
724 });
725 }
726
727 #[test]
728 fn hash_rows() {
729 do_hash_rows(
730 1,
731 16,
732 &["da5698be17b9b46962335799779fbeca8ce5d491c0d26243bafef9ea1837a9d8"],
733 );
734 }
735
736 #[test]
737 fn prefix_products() {
738 let hal: CpuHal<BabyBear> = CpuHal::new(Sha256HashSuite::new_suite());
739 let io = vec![BabyBearExtElem::from_u32(2); 4];
740 let io = hal.copy_from_extelem("io", &io);
741 hal.prefix_products(&io);
742
743 let io: Vec<_> = io.as_slice().iter().cloned().collect();
744
745 assert_eq!(
746 &io,
747 &[
748 BabyBearExtElem::from_u32(2),
749 BabyBearExtElem::from_u32(4),
750 BabyBearExtElem::from_u32(8),
751 BabyBearExtElem::from_u32(16),
752 ]
753 );
754 }
755}