risc0_zkp/core/
ntt.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! An implementation of a number-theoretic transform (NTT).
16
17use core::ops::{Add, Mul, Sub};
18
19use paste::paste;
20use risc0_core::field::{Elem, RootsOfUnity};
21
22use super::log2_ceil;
23
24/// Reverses the bits in a 32-bit number.
25/// # Example
26/// ```rust
27/// # use risc0_zkp::core::ntt::bit_rev_32;
28/// #
29/// let a: u32 = 2^8 + 2^4 + 1;
30///
31/// assert_eq!(format!("{:b}", a), "1101");
32/// assert_eq!(format!("{:b}", bit_rev_32(a)), "10110000000000000000000000000000");
33/// ```
34pub fn bit_rev_32(mut x: u32) -> u32 {
35    // The values used here are, in binary:
36    // 10101010101010101010101010101010, 01010101010101010101010101010101
37    x = ((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1);
38    // 110011001100110011001100, 001100110011001100110011
39    x = ((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2);
40    // 111100001111000011110000, 000011110000111100001111
41    x = ((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4);
42    // 11111111000000001111111100000000, 00000000111111110000000011111111
43    x = ((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8);
44    x.rotate_left(16)
45}
46
47/// Bit-reverses the indices in an array of (1 << n) numbers.
48/// This permutes the values in the array so that a value which is previously
49/// in index i will now go in the index i', given by reversing the bits of i.
50///
51/// # Example
52/// For example, with the array given below of size n=4,
53/// the indices are `0, 1, 2, 3`; bitwise, they're `0, 01, 10, 11`.
54///
55/// Reversed, these give `0, 10, 01, 11`, permuting the second and third
56/// values.
57/// ```rust
58/// # use risc0_zkp::core::ntt::bit_reverse;
59/// #
60/// let mut some_values = [1, 2, 3, 4];
61/// bit_reverse(&mut some_values);
62/// assert_eq!(some_values, [1, 3, 2, 4]);
63/// ```
64pub fn bit_reverse<T: Copy>(io: &mut [T]) {
65    let n = log2_ceil(io.len());
66    assert_eq!(1 << n, io.len());
67    for i in 0..io.len() {
68        let rev_idx = (bit_rev_32(i as u32) >> (32 - n)) as usize;
69        if i < rev_idx {
70            io.swap(i, rev_idx);
71        }
72    }
73}
74
75#[inline]
76fn fwd_butterfly_0<B, T>(_: &mut [T], _: usize) {
77    // no-op base case
78}
79
80#[inline]
81fn rev_butterfly_0<B, T>(_: &mut [T]) {
82    // no-op base case
83}
84
85// TODO: This generates butterfly functions up to $n = 32, but will panic if $n
86// is bigger than <F as RootsOfUnity>::MAX_ROU_PO2 -- is this the best approach?
87macro_rules! butterfly {
88    ($n:literal, $x:literal) => {
89        paste! {
90            #[inline]
91            fn [<fwd_butterfly_ $n>]<B, T>(io: &mut [T], expand_bits: usize)
92            where
93                // B is a base field element, T may be either base or extension
94                B: Elem + RootsOfUnity,
95                T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
96            {
97                if $n == expand_bits {
98                    return;
99                }
100                let half = 1 << ($n - 1);
101                [<fwd_butterfly_ $x>]::<B, T>(&mut io[..half], expand_bits);
102                [<fwd_butterfly_ $x>]::<B, T>(&mut io[half..], expand_bits);
103                let step = <B as RootsOfUnity>::ROU_FWD[$n];
104                let mut cur = B::ONE;
105                for i in 0..half {
106                    let a = io[i];
107                    let b = io[i + half] * cur;
108                    io[i] = a + b;
109                    io[i + half] = a - b;
110                    cur *= step;
111                }
112            }
113
114            #[inline]
115            fn [<rev_butterfly_ $n>]<B, T>(io: &mut [T])
116            where
117                // B is a base field element, T may be either base or extension
118                B: Elem + RootsOfUnity,
119                T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
120            {
121                let half = 1 << ($n - 1);
122                let step = <B as RootsOfUnity>::ROU_REV[$n];
123                let mut cur = B::ONE;
124                for i in 0..half {
125                    let a = io[i];
126                    let b = io[i + half];
127                    io[i] = a + b;
128                    io[i + half] = (a - b) * cur;
129                    cur *= step;
130                }
131                [<rev_butterfly_ $x>]::<B, T>(&mut io[..half]);
132                [<rev_butterfly_ $x>]::<B, T>(&mut io[half..]);
133            }
134        }
135    };
136}
137
138butterfly!(32, 31);
139butterfly!(31, 30);
140butterfly!(30, 29);
141butterfly!(29, 28);
142butterfly!(28, 27);
143butterfly!(27, 26);
144butterfly!(26, 25);
145butterfly!(25, 24);
146butterfly!(24, 23);
147butterfly!(23, 22);
148butterfly!(22, 21);
149butterfly!(21, 20);
150butterfly!(20, 19);
151butterfly!(19, 18);
152butterfly!(18, 17);
153butterfly!(17, 16);
154butterfly!(16, 15);
155butterfly!(15, 14);
156butterfly!(14, 13);
157butterfly!(13, 12);
158butterfly!(12, 11);
159butterfly!(11, 10);
160butterfly!(10, 9);
161butterfly!(9, 8);
162butterfly!(8, 7);
163butterfly!(7, 6);
164butterfly!(6, 5);
165butterfly!(5, 4);
166butterfly!(4, 3);
167butterfly!(3, 2);
168butterfly!(2, 1);
169butterfly!(1, 0);
170
171/// Perform a reverse butterfly transform of a buffer of (1 << n) numbers.
172/// The result of this computation is a discrete Fourier transform, but with
173/// changed indices. This is described [here](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#Data_reordering,_bit_reversal,_and_in-place_algorithms).
174/// The output of `rev_butterfly(io, n)` at index i is the sum over k from 0 to
175/// 2^n-1 of io\[k\] ROU_REV\[n\]^(k i'), where i' is i bit-reversed as an
176/// n-bit number and ROU_REV are the 'reverse' roots of unity.
177///
178/// As an example, we'll work through a trace of the rev_butterfly algorithm
179/// with n = 3 on a list of length 8. Let w = ROU_REV\[3\] be the eighth root of
180/// unity. We start with
181///
182///   \[a0, a1, a2, a3, a4, a5, a6, a7\]
183///
184/// After the loop, before the first round of recursive calls, we have
185///
186///   [a0+a4, a1+a5,     a2+a6,         a3+a7,
187///
188///    a0-a4, a1w-a5w, a2w^2-a6w^2, a3w^3-a7w^3]
189///
190/// After first round of recursive calls, we have
191///
192///   [a0+a4+a2+a6,         a1+a5+a3+a7,
193///
194///    a0+a4-a2-a6,         a1w^2+a5w^2-a3w^2-a7w^2,
195///
196///    a0-a4+a2w^2-a6w^2, a1w-a5w+a3w^3-a7w^3,
197///
198///    a0-a4-a2w^2+a6w^2, a1w^3-a5w^3-a3w^5+a7w^5]
199///
200/// And after the second round of recursive calls, we have
201///
202///   [a0+a4+a2+a6+a1+a5+a3+a7,
203///
204///    a0+a4+a2+a6-a1-a5-a3-a7,
205///
206///    a0+a4-a2-a6+a1w^2+a5w^2-a3w^2-a7w^2,
207///
208///    a0+a4-a2-a6-a1w^2-a5w^2+a3w^2+a7w^2,
209///
210///    a0-a4+a2w^2-a6w^2+a1w-a5w+a3w^3-a7w^3,
211///
212///    a0-a4+a2w^2-a6w^2-a1w+a5w-a3w^3+a7w^3,
213///
214///    a0-a4-a2w^2+a6w^2+a1w^3-a5w^3+a3w^5-a7w^5,
215///
216///    a0-a4-a2w^2+a6w^2-a1w^3+a5w^3-a3w^5+a7w^5]
217///
218/// Rewriting this, we get
219///
220///   \[sum_k ak w^0,
221///    sum_k ak w^4k,
222///    sum_k ak w^2k,
223///    sum_k ak w^6k,
224///    sum_k ak w^1k,
225///    sum_k ak w^5k,
226///    sum_k ak w^3k,
227///    sum_k ak w^7k\]
228///
229/// The exponent multiplicands in the sum arise from reversing the indices as
230/// three-bit numbers. For example, 3 is 011 in binary, which reversed is 110,
231/// which is 6. So i' in the exponent of the index-3 value is 6.
232pub fn interpolate_ntt<B, T>(io: &mut [T])
233where
234    // B is a base field element, T may be either base or extension
235    B: Elem + RootsOfUnity,
236    T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
237{
238    let size = io.len();
239    let n = log2_ceil(size);
240    assert_eq!(1 << n, size);
241    match n {
242        0 => rev_butterfly_0::<B, T>(io),
243        1 => rev_butterfly_1(io),
244        2 => rev_butterfly_2(io),
245        3 => rev_butterfly_3(io),
246        4 => rev_butterfly_4(io),
247        5 => rev_butterfly_5(io),
248        6 => rev_butterfly_6(io),
249        7 => rev_butterfly_7(io),
250        8 => rev_butterfly_8(io),
251        9 => rev_butterfly_9(io),
252        10 => rev_butterfly_10(io),
253        11 => rev_butterfly_11(io),
254        12 => rev_butterfly_12(io),
255        13 => rev_butterfly_13(io),
256        14 => rev_butterfly_14(io),
257        15 => rev_butterfly_15(io),
258        16 => rev_butterfly_16(io),
259        17 => rev_butterfly_17(io),
260        18 => rev_butterfly_18(io),
261        19 => rev_butterfly_19(io),
262        20 => rev_butterfly_20(io),
263        21 => rev_butterfly_21(io),
264        22 => rev_butterfly_22(io),
265        23 => rev_butterfly_23(io),
266        24 => rev_butterfly_24(io),
267        25 => rev_butterfly_25(io),
268        26 => rev_butterfly_26(io),
269        27 => rev_butterfly_27(io),
270        28 => rev_butterfly_28(io),
271        29 => rev_butterfly_29(io),
272        30 => rev_butterfly_30(io),
273        31 => rev_butterfly_31(io),
274        32 => rev_butterfly_32(io),
275        _ => unreachable!(),
276    }
277    let norm = B::from_u64(size as u64).inv();
278    for x in io.iter_mut().take(size) {
279        *x = *x * norm;
280    }
281}
282
283/// Perform a forward butterfly transform of a buffer of (1 << n) numbers.
284pub fn evaluate_ntt<B, T>(io: &mut [T], expand_bits: usize)
285where
286    // B is a base field element, T may be either base or extension
287    B: Elem + RootsOfUnity,
288    T: Copy + Mul<B, Output = T> + Add<Output = T> + Sub<Output = T>,
289{
290    // do_ntt::<T, false>(io, expand_bits);
291    let size = io.len();
292    let n = log2_ceil(size);
293    assert_eq!(1 << n, size);
294    match n {
295        0 => fwd_butterfly_0::<B, T>(io, expand_bits),
296        1 => fwd_butterfly_1(io, expand_bits),
297        2 => fwd_butterfly_2(io, expand_bits),
298        3 => fwd_butterfly_3(io, expand_bits),
299        4 => fwd_butterfly_4(io, expand_bits),
300        5 => fwd_butterfly_5(io, expand_bits),
301        6 => fwd_butterfly_6(io, expand_bits),
302        7 => fwd_butterfly_7(io, expand_bits),
303        8 => fwd_butterfly_8(io, expand_bits),
304        9 => fwd_butterfly_9(io, expand_bits),
305        10 => fwd_butterfly_10(io, expand_bits),
306        11 => fwd_butterfly_11(io, expand_bits),
307        12 => fwd_butterfly_12(io, expand_bits),
308        13 => fwd_butterfly_13(io, expand_bits),
309        14 => fwd_butterfly_14(io, expand_bits),
310        15 => fwd_butterfly_15(io, expand_bits),
311        16 => fwd_butterfly_16(io, expand_bits),
312        17 => fwd_butterfly_17(io, expand_bits),
313        18 => fwd_butterfly_18(io, expand_bits),
314        19 => fwd_butterfly_19(io, expand_bits),
315        20 => fwd_butterfly_20(io, expand_bits),
316        21 => fwd_butterfly_21(io, expand_bits),
317        22 => fwd_butterfly_22(io, expand_bits),
318        23 => fwd_butterfly_23(io, expand_bits),
319        24 => fwd_butterfly_24(io, expand_bits),
320        25 => fwd_butterfly_25(io, expand_bits),
321        26 => fwd_butterfly_26(io, expand_bits),
322        27 => fwd_butterfly_27(io, expand_bits),
323        28 => fwd_butterfly_28(io, expand_bits),
324        29 => fwd_butterfly_29(io, expand_bits),
325        30 => fwd_butterfly_30(io, expand_bits),
326        31 => fwd_butterfly_31(io, expand_bits),
327        32 => fwd_butterfly_32(io, expand_bits),
328        _ => unreachable!(),
329    }
330}
331
332/// Expand the `input` into `output` to support polynomial evaluation on
333/// `input.len() * (1 << expand_bits)` points.
334pub fn expand<T>(output: &mut [T], input: &[T], expand_bits: usize)
335where
336    T: Copy,
337{
338    let size_out = input.len() * (1 << expand_bits);
339    for i in 0..size_out {
340        output[i] = input[i >> expand_bits];
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use rand::thread_rng;
347    use risc0_core::field::{baby_bear::BabyBearElem, Elem, RootsOfUnity};
348
349    use crate::core::ntt::{bit_reverse, evaluate_ntt, interpolate_ntt};
350
351    // Compare the complex version to the naive version
352    #[test]
353    fn cmp_evaluate() {
354        const N: usize = 6;
355        const SIZE: usize = 1 << N;
356        let mut rng = thread_rng();
357        // Randomly fill input
358        let mut buf = [BabyBearElem::random(&mut rng); SIZE];
359        // Compute the hard way
360        let mut goal = [BabyBearElem::ZERO; SIZE];
361        // Compute polynomial at each ROU power (starting at 0, i.e. x = 1)
362        let mut x = BabyBearElem::ONE;
363        for goal in goal.iter_mut() {
364            // Compute the polynomial
365            let mut tot = BabyBearElem::ZERO;
366            let mut xn = BabyBearElem::ONE;
367            for buf in buf.iter() {
368                tot += *buf * xn;
369                xn *= x;
370            }
371            *goal = tot;
372            x *= BabyBearElem::ROU_FWD[N];
373        }
374        // Now compute multiEvaluate in place
375        bit_reverse(&mut buf);
376        evaluate_ntt::<BabyBearElem, BabyBearElem>(&mut buf, 0);
377        // Compare
378        assert_eq!(goal, buf);
379    }
380
381    // Make sure fwd + rev is identity
382    #[test]
383    fn roundtrip() {
384        const N: usize = 10;
385        const SIZE: usize = 1 << N;
386        // Randomly fill buffer
387        let mut rng = thread_rng();
388        let mut buf = [BabyBearElem::random(&mut rng); SIZE];
389        // Copy it
390        let orig = buf;
391        // Now go backwards
392        interpolate_ntt::<BabyBearElem, BabyBearElem>(&mut buf);
393        // Make sure something changed
394        assert_ne!(orig, buf);
395        // Now go forward
396        evaluate_ntt::<BabyBearElem, BabyBearElem>(&mut buf, 0);
397        // It should be back to identical
398        assert_eq!(orig, buf);
399    }
400
401    #[test]
402    fn expand() {
403        const N: usize = 6;
404        const L: usize = 2;
405        const SIZE_IN: usize = 1 << (N - L);
406        const SIZE_OUT: usize = 1 << N;
407        let mut rng = thread_rng();
408        let mut cmp = [BabyBearElem::random(&mut rng); SIZE_IN];
409        let mut buf = [BabyBearElem::ZERO; SIZE_OUT];
410        // Do plain interpolate on cmp
411        interpolate_ntt::<BabyBearElem, BabyBearElem>(&mut cmp);
412        // Expand to buf
413        super::expand(&mut buf, &cmp, L);
414        // Evaluate over the larger space
415        evaluate_ntt::<BabyBearElem, BabyBearElem>(&mut buf, L);
416        // Order cmp nicely for the check
417        bit_reverse(&mut cmp);
418        // Now verify by comparing with the slow way
419        let mut goal = [BabyBearElem::ZERO; SIZE_OUT];
420        // Compute polynomial at each ROU power (starting at 0, i.e. x = 1)
421        let mut x = BabyBearElem::ONE;
422        for goal in goal.iter_mut() {
423            // Compute the polynomial
424            let mut tot = BabyBearElem::ZERO;
425            let mut xn = BabyBearElem::ONE;
426            for cmp in cmp.iter() {
427                tot += *cmp * xn;
428                xn *= x;
429            }
430            *goal = tot;
431            x *= BabyBearElem::ROU_FWD[N];
432        }
433        assert_eq!(goal, buf);
434    }
435}