poly_commit/
fft.rs

1use crate::poly::Coefficients;
2use crate::util::batch_inversion;
3use crate::PointsValue;
4#[cfg(feature = "std")]
5use rayon::join;
6use zkstd::common::{vec, FftField, Vec};
7
8/// fft construction using n th root of unity supports polynomial operation less than n degree
9#[derive(Clone, Debug, Eq, PartialEq)]
10pub struct Fft<F: FftField> {
11    // polynomial degree 2^k
12    n: usize,
13    // generator of order 2^{k - 1} multiplicative group used as twiddle factors
14    twiddle_factors: Vec<F>,
15    // multiplicative group generator inverse
16    inv_twiddle_factors: Vec<F>,
17    // coset domain
18    cosets: Vec<F>,
19    // inverse coset domain
20    inv_cosets: Vec<F>,
21    // n inverse for inverse discrete fourier transform
22    n_inv: F,
23    // bit reverse index
24    bit_reverse: Vec<(usize, usize)>,
25    pub elements: Vec<F>,
26}
27
28// SBP-M1 review: use safe math operations
29impl<F: FftField> Fft<F> {
30    pub fn new(k: usize) -> Self {
31        assert!(k >= 1);
32        let n = 1 << k;
33        let half_n = n >> 1;
34        let offset = 64 - k;
35
36        // precompute twiddle factors
37        let g = (0..F::S - k).fold(F::ROOT_OF_UNITY, |acc, _| acc.square());
38        let twiddle_factors = (0..half_n)
39            .scan(F::one(), |w, _| {
40                let tw = *w;
41                *w *= g;
42                Some(tw)
43            })
44            .collect::<Vec<_>>();
45
46        // precompute inverse twiddle factors
47        let g_inv = g.invert().unwrap();
48        let inv_twiddle_factors = (0..half_n)
49            .scan(F::one(), |w, _| {
50                let tw = *w;
51                *w *= g_inv;
52                Some(tw)
53            })
54            .collect::<Vec<_>>();
55
56        // precompute cosets
57        let mul_g = F::MULTIPLICATIVE_GENERATOR;
58        let cosets = (0..n)
59            .scan(F::one(), |w, _| {
60                let tw = *w;
61                *w *= mul_g;
62                Some(tw)
63            })
64            .collect::<Vec<_>>();
65
66        // precompute inverse cosets
67        let mul_g_inv = mul_g.invert().unwrap();
68        let inv_cosets = (0..n)
69            .scan(F::one(), |w, _| {
70                let tw = *w;
71                *w *= mul_g_inv;
72                Some(tw)
73            })
74            .collect::<Vec<_>>();
75
76        let elements = (0..n)
77            .scan(F::one(), |w, _| {
78                let tw = *w;
79                *w *= g;
80                Some(tw)
81            })
82            .collect::<Vec<_>>();
83
84        let bit_reverse = (0..n as u64)
85            .filter_map(|i| {
86                let r = i.reverse_bits() >> offset;
87                (i < r).then_some((i as usize, r as usize))
88            })
89            .collect::<Vec<_>>();
90
91        Self {
92            n,
93            twiddle_factors,
94            inv_twiddle_factors,
95            cosets,
96            inv_cosets,
97            n_inv: F::from(n as u64).invert().unwrap(),
98            bit_reverse,
99            elements,
100        }
101    }
102
103    /// polynomial degree
104    pub fn size(&self) -> usize {
105        self.n
106    }
107
108    /// size inverse
109    pub fn size_inv(&self) -> F {
110        self.n_inv
111    }
112
113    /// nth unity of root
114    pub fn generator(&self) -> F {
115        self.twiddle_factors[1]
116    }
117
118    /// nth unity of root
119    pub fn generator_inv(&self) -> F {
120        self.inv_twiddle_factors[1]
121    }
122
123    /// perform discrete fourier transform
124    pub fn dft(&self, coeffs: Coefficients<F>) -> PointsValue<F> {
125        let mut evals = coeffs.0;
126        self.prepare_fft(&mut evals);
127        classic_fft_arithmetic(&mut evals, self.n, 1, &self.twiddle_factors);
128        PointsValue::new(evals.clone())
129    }
130
131    /// perform classic inverse discrete fourier transform
132    pub fn idft(&self, points: PointsValue<F>) -> Coefficients<F> {
133        let mut coeffs = points.0;
134        self.prepare_fft(&mut coeffs);
135        classic_fft_arithmetic(&mut coeffs, self.n, 1, &self.inv_twiddle_factors);
136        coeffs.iter_mut().for_each(|coeff| *coeff *= self.n_inv);
137        Coefficients::new(coeffs.clone())
138    }
139
140    /// perform discrete fourier transform on coset
141    pub fn coset_dft(&self, mut coeffs: Coefficients<F>) -> PointsValue<F> {
142        coeffs
143            .0
144            .iter_mut()
145            .zip(self.cosets.iter())
146            .for_each(|(coeff, coset)| *coeff *= *coset);
147        self.dft(coeffs)
148    }
149
150    /// perform discrete fourier transform on coset
151    pub fn coset_idft(&self, points: PointsValue<F>) -> Coefficients<F> {
152        let mut points = self.idft(points);
153        points
154            .0
155            .iter_mut()
156            .zip(self.inv_cosets.iter())
157            .for_each(|(coeff, inv_coset)| *coeff *= *inv_coset);
158        Coefficients::new(points.0)
159    }
160
161    /// This evaluates t(tau) for this domain, which is
162    /// tau^m - 1 for these radix-2 domains.
163    pub fn z(&self, tau: &F) -> F {
164        let mut tmp = tau.pow(self.n as u64);
165        tmp.sub_assign(&F::one());
166
167        tmp
168    }
169
170    /// This evaluates t(tau) for this domain, which is
171    /// tau^m - 1 for these radix-2 domains.
172    pub fn z_on_coset(&self) -> F {
173        let mut tmp = F::MULTIPLICATIVE_GENERATOR.pow(self.n as u64);
174        tmp.sub_assign(&F::one());
175
176        tmp
177    }
178
179    /// The target polynomial is the zero polynomial in our
180    /// evaluation domain, so we must perform division over
181    /// a coset.
182    pub fn divide_by_z_on_coset(&self, points: PointsValue<F>) -> PointsValue<F> {
183        let i = self.z_on_coset().invert().unwrap();
184
185        PointsValue(points.0.into_iter().map(|v| v * i).collect())
186    }
187
188    /// resize polynomial and bit reverse swap
189    fn prepare_fft(&self, coeffs: &mut Vec<F>) {
190        coeffs.resize(self.n, F::zero());
191        self.bit_reverse
192            .iter()
193            .for_each(|(i, ri)| coeffs.swap(*ri, *i));
194    }
195
196    /// polynomial multiplication
197    pub fn poly_mul(&self, rhs: Coefficients<F>, lhs: Coefficients<F>) -> Coefficients<F> {
198        let rhs = self.dft(rhs);
199        let lhs = self.dft(lhs);
200        let mul_poly = PointsValue::new(
201            rhs.0
202                .iter()
203                .zip(lhs.0.iter())
204                .map(|(a, b)| *a * *b)
205                .collect(),
206        );
207        self.idft(mul_poly)
208    }
209
210    /// Evaluate all the lagrange polynomials defined by this domain at the
211    /// point `tau`.
212    pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
213        // Evaluate all Lagrange polynomials
214        let size = self.n;
215        let t_size = tau.pow(size as u64);
216        let one = F::one();
217        if t_size == F::one() {
218            let mut u = vec![F::zero(); size];
219            let mut omega_i = one;
220            for x in u.iter_mut().take(size) {
221                if omega_i == tau {
222                    *x = one;
223                    break;
224                }
225                omega_i *= &self.generator();
226            }
227            u
228        } else {
229            let mut l = (t_size - one) * self.n_inv;
230            let mut r = one;
231            let mut u = vec![F::zero(); size];
232            let mut ls = vec![F::zero(); size];
233            for i in 0..size {
234                u[i] = tau - r;
235                ls[i] = l;
236                l *= &self.generator();
237                r *= &self.generator();
238            }
239
240            batch_inversion(u.as_mut_slice());
241
242            u.iter()
243                .zip(ls)
244                .map(|(tau_minus_r, l)| l * *tau_minus_r)
245                .collect()
246        }
247    }
248
249    /// Given that the domain size is `D`
250    /// This function computes the `D` evaluation points for
251    /// the vanishing polynomial of degree `n` over a coset
252    pub fn compute_vanishing_poly_over_coset(
253        &self,            // domain to evaluate over
254        poly_degree: u64, // degree of the vanishing polynomial
255    ) -> PointsValue<F> {
256        assert!((self.size() as u64) > poly_degree);
257        let coset_gen = F::MULTIPLICATIVE_GENERATOR.pow(poly_degree);
258        let v_h: Vec<_> = (0..self.size())
259            .map(|i| (coset_gen * self.generator().pow(poly_degree * i as u64)) - F::one())
260            .collect();
261        PointsValue::new(v_h)
262    }
263}
264
265// classic fft using divide and conquer algorithm
266fn classic_fft_arithmetic<F: FftField>(
267    coeffs: &mut [F],
268    n: usize,
269    twiddle_chunk: usize,
270    twiddles: &[F],
271) {
272    if n == 2 {
273        let t = coeffs[1];
274        coeffs[1] = coeffs[0];
275        coeffs[0] += t;
276        coeffs[1] -= t;
277    } else {
278        let (left, right) = coeffs.split_at_mut(n / 2);
279        #[cfg(feature = "std")]
280        join(
281            || classic_fft_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
282            || classic_fft_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
283        );
284        #[cfg(not(feature = "std"))]
285        {
286            // TODO: recursion is quite inefficient when not parallel
287            classic_fft_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles);
288            classic_fft_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles);
289        };
290        butterfly_arithmetic(left, right, twiddle_chunk, twiddles)
291    }
292}
293
294// butterfly arithmetic polynomial evaluation
295fn butterfly_arithmetic<F: FftField>(
296    left: &mut [F],
297    right: &mut [F],
298    twiddle_chunk: usize,
299    twiddles: &[F],
300) {
301    // case when twiddle factor is one
302    let t = right[0];
303    right[0] = left[0];
304    left[0] += t;
305    right[0] -= t;
306
307    left.iter_mut()
308        .zip(right.iter_mut())
309        .enumerate()
310        .skip(1)
311        .for_each(|(i, (a, b))| {
312            let mut t = *b;
313            t *= twiddles[i * twiddle_chunk];
314            *b = *a;
315            *a += t;
316            *b -= t;
317        });
318}
319
320#[cfg(test)]
321mod tests {
322    use crate::poly::Coefficients;
323
324    use super::Fft;
325    use bls_12_381::Fr;
326    use rand_core::OsRng;
327    use zkstd::common::Vec;
328    use zkstd::common::{Group, PrimeField};
329
330    fn arb_poly(k: u32) -> Vec<Fr> {
331        (0..(1 << k))
332            .map(|_| Fr::random(OsRng))
333            .collect::<Vec<Fr>>()
334    }
335
336    fn naive_multiply<F: PrimeField>(a: Vec<F>, b: Vec<F>) -> Vec<F> {
337        assert_eq!(a.len(), b.len());
338        let mut c = vec![F::zero(); a.len() + b.len()];
339        a.iter().enumerate().for_each(|(i_a, coeff_a)| {
340            b.iter().enumerate().for_each(|(i_b, coeff_b)| {
341                c[i_a + i_b] += *coeff_a * *coeff_b;
342            })
343        });
344        c
345    }
346
347    #[test]
348    fn fft_transformation_test() {
349        let coeffs = arb_poly(10);
350        let poly_a = Coefficients(coeffs);
351        let poly_b = poly_a.clone();
352        let classic_fft = Fft::new(10);
353
354        let evals_a = classic_fft.dft(poly_a);
355        let poly_a = classic_fft.idft(evals_a);
356
357        assert_eq!(poly_a, poly_b)
358    }
359
360    #[test]
361    fn fft_multiplication_test() {
362        let coeffs_a = arb_poly(4);
363        let coeffs_b = arb_poly(4);
364        let fft = Fft::new(5);
365        let poly_c = coeffs_a.clone();
366        let poly_d = coeffs_b.clone();
367        let poly_a = Coefficients(coeffs_a);
368        let poly_b = Coefficients(coeffs_b);
369        let poly_g = poly_a.clone();
370        let poly_h = poly_b.clone();
371
372        let poly_e = Coefficients::new(naive_multiply(poly_c, poly_d));
373
374        let evals_a = fft.dft(poly_a);
375        let evals_b = fft.dft(poly_b);
376        let poly_f = &evals_a * &evals_b;
377        let poly_f = fft.idft(poly_f);
378
379        let poly_i = fft.poly_mul(poly_g, poly_h);
380
381        assert_eq!(poly_e, poly_f);
382        assert_eq!(poly_e, poly_i)
383    }
384}