1use core::ops::{Add, Mul, Sub};
18
19use paste::paste;
20use risc0_core::field::{Elem, RootsOfUnity};
21
22use super::log2_ceil;
23
24pub fn bit_rev_32(mut x: u32) -> u32 {
35 x = ((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1);
38 x = ((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2);
40 x = ((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4);
42 x = ((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8);
44 x.rotate_left(16)
45}
46
47pub fn bit_reverse<T: Copy>(io: &mut [T]) {
65 let n = log2_ceil(io.len());
66 assert_eq!(1 << n, io.len());
67 for i in 0..io.len() {
68 let rev_idx = (bit_rev_32(i as u32) >> (32 - n)) as usize;
69 if i < rev_idx {
70 io.swap(i, rev_idx);
71 }
72 }
73}
74
75#[inline]
76fn fwd_butterfly_0<B, T>(_: &mut [T], _: usize) {
77 }
79
80#[inline]
81fn rev_butterfly_0<B, T>(_: &mut [T]) {
82 }
84
85macro_rules! butterfly {
88 ($n:literal, $x:literal) => {
89 paste! {
90 #[inline]
91 fn [<fwd_butterfly_ $n>]<B, T>(io: &mut [T], expand_bits: usize)
92 where
93 B: Elem + RootsOfUnity,
95 T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
96 {
97 if $n == expand_bits {
98 return;
99 }
100 let half = 1 << ($n - 1);
101 [<fwd_butterfly_ $x>]::<B, T>(&mut io[..half], expand_bits);
102 [<fwd_butterfly_ $x>]::<B, T>(&mut io[half..], expand_bits);
103 let step = <B as RootsOfUnity>::ROU_FWD[$n];
104 let mut cur = B::ONE;
105 for i in 0..half {
106 let a = io[i];
107 let b = io[i + half] * cur;
108 io[i] = a + b;
109 io[i + half] = a - b;
110 cur *= step;
111 }
112 }
113
114 #[inline]
115 fn [<rev_butterfly_ $n>]<B, T>(io: &mut [T])
116 where
117 B: Elem + RootsOfUnity,
119 T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
120 {
121 let half = 1 << ($n - 1);
122 let step = <B as RootsOfUnity>::ROU_REV[$n];
123 let mut cur = B::ONE;
124 for i in 0..half {
125 let a = io[i];
126 let b = io[i + half];
127 io[i] = a + b;
128 io[i + half] = (a - b) * cur;
129 cur *= step;
130 }
131 [<rev_butterfly_ $x>]::<B, T>(&mut io[..half]);
132 [<rev_butterfly_ $x>]::<B, T>(&mut io[half..]);
133 }
134 }
135 };
136}
137
138butterfly!(32, 31);
139butterfly!(31, 30);
140butterfly!(30, 29);
141butterfly!(29, 28);
142butterfly!(28, 27);
143butterfly!(27, 26);
144butterfly!(26, 25);
145butterfly!(25, 24);
146butterfly!(24, 23);
147butterfly!(23, 22);
148butterfly!(22, 21);
149butterfly!(21, 20);
150butterfly!(20, 19);
151butterfly!(19, 18);
152butterfly!(18, 17);
153butterfly!(17, 16);
154butterfly!(16, 15);
155butterfly!(15, 14);
156butterfly!(14, 13);
157butterfly!(13, 12);
158butterfly!(12, 11);
159butterfly!(11, 10);
160butterfly!(10, 9);
161butterfly!(9, 8);
162butterfly!(8, 7);
163butterfly!(7, 6);
164butterfly!(6, 5);
165butterfly!(5, 4);
166butterfly!(4, 3);
167butterfly!(3, 2);
168butterfly!(2, 1);
169butterfly!(1, 0);
170
171pub fn interpolate_ntt<B, T>(io: &mut [T])
233where
234 B: Elem + RootsOfUnity,
236 T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
237{
238 let size = io.len();
239 let n = log2_ceil(size);
240 assert_eq!(1 << n, size);
241 match n {
242 0 => rev_butterfly_0::<B, T>(io),
243 1 => rev_butterfly_1(io),
244 2 => rev_butterfly_2(io),
245 3 => rev_butterfly_3(io),
246 4 => rev_butterfly_4(io),
247 5 => rev_butterfly_5(io),
248 6 => rev_butterfly_6(io),
249 7 => rev_butterfly_7(io),
250 8 => rev_butterfly_8(io),
251 9 => rev_butterfly_9(io),
252 10 => rev_butterfly_10(io),
253 11 => rev_butterfly_11(io),
254 12 => rev_butterfly_12(io),
255 13 => rev_butterfly_13(io),
256 14 => rev_butterfly_14(io),
257 15 => rev_butterfly_15(io),
258 16 => rev_butterfly_16(io),
259 17 => rev_butterfly_17(io),
260 18 => rev_butterfly_18(io),
261 19 => rev_butterfly_19(io),
262 20 => rev_butterfly_20(io),
263 21 => rev_butterfly_21(io),
264 22 => rev_butterfly_22(io),
265 23 => rev_butterfly_23(io),
266 24 => rev_butterfly_24(io),
267 25 => rev_butterfly_25(io),
268 26 => rev_butterfly_26(io),
269 27 => rev_butterfly_27(io),
270 28 => rev_butterfly_28(io),
271 29 => rev_butterfly_29(io),
272 30 => rev_butterfly_30(io),
273 31 => rev_butterfly_31(io),
274 32 => rev_butterfly_32(io),
275 _ => unreachable!(),
276 }
277 let norm = B::from_u64(size as u64).inv();
278 for x in io.iter_mut().take(size) {
279 *x = *x * norm;
280 }
281}
282
283pub fn evaluate_ntt<B, T>(io: &mut [T], expand_bits: usize)
285where
286 B: Elem + RootsOfUnity,
288 T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
289{
290 let size = io.len();
292 let n = log2_ceil(size);
293 assert_eq!(1 << n, size);
294 match n {
295 0 => fwd_butterfly_0::<B, T>(io, expand_bits),
296 1 => fwd_butterfly_1(io, expand_bits),
297 2 => fwd_butterfly_2(io, expand_bits),
298 3 => fwd_butterfly_3(io, expand_bits),
299 4 => fwd_butterfly_4(io, expand_bits),
300 5 => fwd_butterfly_5(io, expand_bits),
301 6 => fwd_butterfly_6(io, expand_bits),
302 7 => fwd_butterfly_7(io, expand_bits),
303 8 => fwd_butterfly_8(io, expand_bits),
304 9 => fwd_butterfly_9(io, expand_bits),
305 10 => fwd_butterfly_10(io, expand_bits),
306 11 => fwd_butterfly_11(io, expand_bits),
307 12 => fwd_butterfly_12(io, expand_bits),
308 13 => fwd_butterfly_13(io, expand_bits),
309 14 => fwd_butterfly_14(io, expand_bits),
310 15 => fwd_butterfly_15(io, expand_bits),
311 16 => fwd_butterfly_16(io, expand_bits),
312 17 => fwd_butterfly_17(io, expand_bits),
313 18 => fwd_butterfly_18(io, expand_bits),
314 19 => fwd_butterfly_19(io, expand_bits),
315 20 => fwd_butterfly_20(io, expand_bits),
316 21 => fwd_butterfly_21(io, expand_bits),
317 22 => fwd_butterfly_22(io, expand_bits),
318 23 => fwd_butterfly_23(io, expand_bits),
319 24 => fwd_butterfly_24(io, expand_bits),
320 25 => fwd_butterfly_25(io, expand_bits),
321 26 => fwd_butterfly_26(io, expand_bits),
322 27 => fwd_butterfly_27(io, expand_bits),
323 28 => fwd_butterfly_28(io, expand_bits),
324 29 => fwd_butterfly_29(io, expand_bits),
325 30 => fwd_butterfly_30(io, expand_bits),
326 31 => fwd_butterfly_31(io, expand_bits),
327 32 => fwd_butterfly_32(io, expand_bits),
328 _ => unreachable!(),
329 }
330}
331
332pub fn expand<T>(output: &mut [T], input: &[T], expand_bits: usize)
335where
336 T: Copy,
337{
338 let size_out = input.len() * (1 << expand_bits);
339 for i in 0..size_out {
340 output[i] = input[i >> expand_bits];
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use rand::thread_rng;
347 use risc0_core::field::{baby_bear::BabyBearElem, Elem, RootsOfUnity};
348
349 use crate::core::ntt::{bit_reverse, evaluate_ntt, interpolate_ntt};
350
351 #[test]
353 fn cmp_evaluate() {
354 const N: usize = 6;
355 const SIZE: usize = 1 << N;
356 let mut rng = thread_rng();
357 let mut buf = [BabyBearElem::random(&mut rng); SIZE];
359 let mut goal = [BabyBearElem::ZERO; SIZE];
361 let mut x = BabyBearElem::ONE;
363 for goal in goal.iter_mut() {
364 let mut tot = BabyBearElem::ZERO;
366 let mut xn = BabyBearElem::ONE;
367 for buf in buf.iter() {
368 tot += *buf * xn;
369 xn *= x;
370 }
371 *goal = tot;
372 x *= BabyBearElem::ROU_FWD[N];
373 }
374 bit_reverse(&mut buf);
376 evaluate_ntt::<BabyBearElem, BabyBearElem>(&mut buf, 0);
377 assert_eq!(goal, buf);
379 }
380
381 #[test]
383 fn roundtrip() {
384 const N: usize = 10;
385 const SIZE: usize = 1 << N;
386 let mut rng = thread_rng();
388 let mut buf = [BabyBearElem::random(&mut rng); SIZE];
389 let orig = buf;
391 interpolate_ntt::<BabyBearElem, BabyBearElem>(&mut buf);
393 assert_ne!(orig, buf);
395 evaluate_ntt::<BabyBearElem, BabyBearElem>(&mut buf, 0);
397 assert_eq!(orig, buf);
399 }
400
401 #[test]
402 fn expand() {
403 const N: usize = 6;
404 const L: usize = 2;
405 const SIZE_IN: usize = 1 << (N - L);
406 const SIZE_OUT: usize = 1 << N;
407 let mut rng = thread_rng();
408 let mut cmp = [BabyBearElem::random(&mut rng); SIZE_IN];
409 let mut buf = [BabyBearElem::ZERO; SIZE_OUT];
410 interpolate_ntt::<BabyBearElem, BabyBearElem>(&mut cmp);
412 super::expand(&mut buf, &cmp, L);
414 evaluate_ntt::<BabyBearElem, BabyBearElem>(&mut buf, L);
416 bit_reverse(&mut cmp);
418 let mut goal = [BabyBearElem::ZERO; SIZE_OUT];
420 let mut x = BabyBearElem::ONE;
422 for goal in goal.iter_mut() {
423 let mut tot = BabyBearElem::ZERO;
425 let mut xn = BabyBearElem::ONE;
426 for cmp in cmp.iter() {
427 tot += *cmp * xn;
428 xn *= x;
429 }
430 *goal = tot;
431 x *= BabyBearElem::ROU_FWD[N];
432 }
433 assert_eq!(goal, buf);
434 }
435}