1use crate::poly::Coefficients;
2use crate::util::batch_inversion;
3use crate::PointsValue;
4#[cfg(feature = "std")]
5use rayon::join;
6use zkstd::common::{vec, FftField, Vec};
7
8#[derive(Clone, Debug, Eq, PartialEq)]
10pub struct Fft<F: FftField> {
11 n: usize,
13 twiddle_factors: Vec<F>,
15 inv_twiddle_factors: Vec<F>,
17 cosets: Vec<F>,
19 inv_cosets: Vec<F>,
21 n_inv: F,
23 bit_reverse: Vec<(usize, usize)>,
25 pub elements: Vec<F>,
26}
27
28impl<F: FftField> Fft<F> {
30 pub fn new(k: usize) -> Self {
31 assert!(k >= 1);
32 let n = 1 << k;
33 let half_n = n >> 1;
34 let offset = 64 - k;
35
36 let g = (0..F::S - k).fold(F::ROOT_OF_UNITY, |acc, _| acc.square());
38 let twiddle_factors = (0..half_n)
39 .scan(F::one(), |w, _| {
40 let tw = *w;
41 *w *= g;
42 Some(tw)
43 })
44 .collect::<Vec<_>>();
45
46 let g_inv = g.invert().unwrap();
48 let inv_twiddle_factors = (0..half_n)
49 .scan(F::one(), |w, _| {
50 let tw = *w;
51 *w *= g_inv;
52 Some(tw)
53 })
54 .collect::<Vec<_>>();
55
56 let mul_g = F::MULTIPLICATIVE_GENERATOR;
58 let cosets = (0..n)
59 .scan(F::one(), |w, _| {
60 let tw = *w;
61 *w *= mul_g;
62 Some(tw)
63 })
64 .collect::<Vec<_>>();
65
66 let mul_g_inv = mul_g.invert().unwrap();
68 let inv_cosets = (0..n)
69 .scan(F::one(), |w, _| {
70 let tw = *w;
71 *w *= mul_g_inv;
72 Some(tw)
73 })
74 .collect::<Vec<_>>();
75
76 let elements = (0..n)
77 .scan(F::one(), |w, _| {
78 let tw = *w;
79 *w *= g;
80 Some(tw)
81 })
82 .collect::<Vec<_>>();
83
84 let bit_reverse = (0..n as u64)
85 .filter_map(|i| {
86 let r = i.reverse_bits() >> offset;
87 (i < r).then_some((i as usize, r as usize))
88 })
89 .collect::<Vec<_>>();
90
91 Self {
92 n,
93 twiddle_factors,
94 inv_twiddle_factors,
95 cosets,
96 inv_cosets,
97 n_inv: F::from(n as u64).invert().unwrap(),
98 bit_reverse,
99 elements,
100 }
101 }
102
103 pub fn size(&self) -> usize {
105 self.n
106 }
107
108 pub fn size_inv(&self) -> F {
110 self.n_inv
111 }
112
113 pub fn generator(&self) -> F {
115 self.twiddle_factors[1]
116 }
117
118 pub fn generator_inv(&self) -> F {
120 self.inv_twiddle_factors[1]
121 }
122
123 pub fn dft(&self, coeffs: Coefficients<F>) -> PointsValue<F> {
125 let mut evals = coeffs.0;
126 self.prepare_fft(&mut evals);
127 classic_fft_arithmetic(&mut evals, self.n, 1, &self.twiddle_factors);
128 PointsValue::new(evals.clone())
129 }
130
131 pub fn idft(&self, points: PointsValue<F>) -> Coefficients<F> {
133 let mut coeffs = points.0;
134 self.prepare_fft(&mut coeffs);
135 classic_fft_arithmetic(&mut coeffs, self.n, 1, &self.inv_twiddle_factors);
136 coeffs.iter_mut().for_each(|coeff| *coeff *= self.n_inv);
137 Coefficients::new(coeffs.clone())
138 }
139
140 pub fn coset_dft(&self, mut coeffs: Coefficients<F>) -> PointsValue<F> {
142 coeffs
143 .0
144 .iter_mut()
145 .zip(self.cosets.iter())
146 .for_each(|(coeff, coset)| *coeff *= *coset);
147 self.dft(coeffs)
148 }
149
150 pub fn coset_idft(&self, points: PointsValue<F>) -> Coefficients<F> {
152 let mut points = self.idft(points);
153 points
154 .0
155 .iter_mut()
156 .zip(self.inv_cosets.iter())
157 .for_each(|(coeff, inv_coset)| *coeff *= *inv_coset);
158 Coefficients::new(points.0)
159 }
160
161 pub fn z(&self, tau: &F) -> F {
164 let mut tmp = tau.pow(self.n as u64);
165 tmp.sub_assign(&F::one());
166
167 tmp
168 }
169
170 pub fn z_on_coset(&self) -> F {
173 let mut tmp = F::MULTIPLICATIVE_GENERATOR.pow(self.n as u64);
174 tmp.sub_assign(&F::one());
175
176 tmp
177 }
178
179 pub fn divide_by_z_on_coset(&self, points: PointsValue<F>) -> PointsValue<F> {
183 let i = self.z_on_coset().invert().unwrap();
184
185 PointsValue(points.0.into_iter().map(|v| v * i).collect())
186 }
187
188 fn prepare_fft(&self, coeffs: &mut Vec<F>) {
190 coeffs.resize(self.n, F::zero());
191 self.bit_reverse
192 .iter()
193 .for_each(|(i, ri)| coeffs.swap(*ri, *i));
194 }
195
196 pub fn poly_mul(&self, rhs: Coefficients<F>, lhs: Coefficients<F>) -> Coefficients<F> {
198 let rhs = self.dft(rhs);
199 let lhs = self.dft(lhs);
200 let mul_poly = PointsValue::new(
201 rhs.0
202 .iter()
203 .zip(lhs.0.iter())
204 .map(|(a, b)| *a * *b)
205 .collect(),
206 );
207 self.idft(mul_poly)
208 }
209
210 pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
213 let size = self.n;
215 let t_size = tau.pow(size as u64);
216 let one = F::one();
217 if t_size == F::one() {
218 let mut u = vec![F::zero(); size];
219 let mut omega_i = one;
220 for x in u.iter_mut().take(size) {
221 if omega_i == tau {
222 *x = one;
223 break;
224 }
225 omega_i *= &self.generator();
226 }
227 u
228 } else {
229 let mut l = (t_size - one) * self.n_inv;
230 let mut r = one;
231 let mut u = vec![F::zero(); size];
232 let mut ls = vec![F::zero(); size];
233 for i in 0..size {
234 u[i] = tau - r;
235 ls[i] = l;
236 l *= &self.generator();
237 r *= &self.generator();
238 }
239
240 batch_inversion(u.as_mut_slice());
241
242 u.iter()
243 .zip(ls)
244 .map(|(tau_minus_r, l)| l * *tau_minus_r)
245 .collect()
246 }
247 }
248
249 pub fn compute_vanishing_poly_over_coset(
253 &self, poly_degree: u64, ) -> PointsValue<F> {
256 assert!((self.size() as u64) > poly_degree);
257 let coset_gen = F::MULTIPLICATIVE_GENERATOR.pow(poly_degree);
258 let v_h: Vec<_> = (0..self.size())
259 .map(|i| (coset_gen * self.generator().pow(poly_degree * i as u64)) - F::one())
260 .collect();
261 PointsValue::new(v_h)
262 }
263}
264
265fn classic_fft_arithmetic<F: FftField>(
267 coeffs: &mut [F],
268 n: usize,
269 twiddle_chunk: usize,
270 twiddles: &[F],
271) {
272 if n == 2 {
273 let t = coeffs[1];
274 coeffs[1] = coeffs[0];
275 coeffs[0] += t;
276 coeffs[1] -= t;
277 } else {
278 let (left, right) = coeffs.split_at_mut(n / 2);
279 #[cfg(feature = "std")]
280 join(
281 || classic_fft_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
282 || classic_fft_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
283 );
284 #[cfg(not(feature = "std"))]
285 {
286 classic_fft_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles);
288 classic_fft_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles);
289 };
290 butterfly_arithmetic(left, right, twiddle_chunk, twiddles)
291 }
292}
293
294fn butterfly_arithmetic<F: FftField>(
296 left: &mut [F],
297 right: &mut [F],
298 twiddle_chunk: usize,
299 twiddles: &[F],
300) {
301 let t = right[0];
303 right[0] = left[0];
304 left[0] += t;
305 right[0] -= t;
306
307 left.iter_mut()
308 .zip(right.iter_mut())
309 .enumerate()
310 .skip(1)
311 .for_each(|(i, (a, b))| {
312 let mut t = *b;
313 t *= twiddles[i * twiddle_chunk];
314 *b = *a;
315 *a += t;
316 *b -= t;
317 });
318}
319
320#[cfg(test)]
321mod tests {
322 use crate::poly::Coefficients;
323
324 use super::Fft;
325 use bls_12_381::Fr;
326 use rand_core::OsRng;
327 use zkstd::common::Vec;
328 use zkstd::common::{Group, PrimeField};
329
330 fn arb_poly(k: u32) -> Vec<Fr> {
331 (0..(1 << k))
332 .map(|_| Fr::random(OsRng))
333 .collect::<Vec<Fr>>()
334 }
335
336 fn naive_multiply<F: PrimeField>(a: Vec<F>, b: Vec<F>) -> Vec<F> {
337 assert_eq!(a.len(), b.len());
338 let mut c = vec![F::zero(); a.len() + b.len()];
339 a.iter().enumerate().for_each(|(i_a, coeff_a)| {
340 b.iter().enumerate().for_each(|(i_b, coeff_b)| {
341 c[i_a + i_b] += *coeff_a * *coeff_b;
342 })
343 });
344 c
345 }
346
347 #[test]
348 fn fft_transformation_test() {
349 let coeffs = arb_poly(10);
350 let poly_a = Coefficients(coeffs);
351 let poly_b = poly_a.clone();
352 let classic_fft = Fft::new(10);
353
354 let evals_a = classic_fft.dft(poly_a);
355 let poly_a = classic_fft.idft(evals_a);
356
357 assert_eq!(poly_a, poly_b)
358 }
359
360 #[test]
361 fn fft_multiplication_test() {
362 let coeffs_a = arb_poly(4);
363 let coeffs_b = arb_poly(4);
364 let fft = Fft::new(5);
365 let poly_c = coeffs_a.clone();
366 let poly_d = coeffs_b.clone();
367 let poly_a = Coefficients(coeffs_a);
368 let poly_b = Coefficients(coeffs_b);
369 let poly_g = poly_a.clone();
370 let poly_h = poly_b.clone();
371
372 let poly_e = Coefficients::new(naive_multiply(poly_c, poly_d));
373
374 let evals_a = fft.dft(poly_a);
375 let evals_b = fft.dft(poly_b);
376 let poly_f = &evals_a * &evals_b;
377 let poly_f = fft.idft(poly_f);
378
379 let poly_i = fft.poly_mul(poly_g, poly_h);
380
381 assert_eq!(poly_e, poly_f);
382 assert_eq!(poly_e, poly_i)
383 }
384}