1use crate::fft::errors::FFTError;
2
3use crate::field::errors::FieldError;
4use crate::field::traits::{IsField, IsSubFieldOf};
5use crate::{
6 field::{
7 element::FieldElement,
8 traits::{IsFFTField, RootsConfig},
9 },
10 polynomial::Polynomial,
11};
12use alloc::{vec, vec::Vec};
13
14#[cfg(feature = "cuda")]
15use crate::fft::gpu::cuda::polynomial::{evaluate_fft_cuda, interpolate_fft_cuda};
16
17use super::cpu::{ops, roots_of_unity};
18
19impl<E: IsField> Polynomial<FieldElement<E>> {
20 pub fn evaluate_fft<F: IsFFTField + IsSubFieldOf<E>>(
25 poly: &Polynomial<FieldElement<E>>,
26 blowup_factor: usize,
27 domain_size: Option<usize>,
28 ) -> Result<Vec<FieldElement<E>>, FFTError> {
29 let domain_size = domain_size.unwrap_or(0);
30 let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
31 if len.trailing_zeros() as u64 > F::TWO_ADICITY {
32 return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize));
33 }
34 if poly.coefficients().is_empty() {
35 return Ok(vec![FieldElement::zero(); len]);
36 }
37
38 let mut coeffs = poly.coefficients().to_vec();
39 coeffs.resize(len, FieldElement::zero());
40 #[cfg(feature = "cuda")]
43 {
44 if F::field_name() == "stark256" {
46 Ok(evaluate_fft_cuda(&coeffs)?)
47 } else {
48 evaluate_fft_cpu::<F, E>(&coeffs)
49 }
50 }
51
52 #[cfg(not(feature = "cuda"))]
53 {
54 evaluate_fft_cpu::<F, E>(&coeffs)
55 }
56 }
57
58 pub fn evaluate_offset_fft<F: IsFFTField + IsSubFieldOf<E>>(
63 poly: &Polynomial<FieldElement<E>>,
64 blowup_factor: usize,
65 domain_size: Option<usize>,
66 offset: &FieldElement<F>,
67 ) -> Result<Vec<FieldElement<E>>, FFTError> {
68 let scaled = poly.scale(offset);
69 Polynomial::evaluate_fft::<F>(&scaled, blowup_factor, domain_size)
70 }
71
72 pub fn interpolate_fft<F: IsFFTField + IsSubFieldOf<E>>(
76 fft_evals: &[FieldElement<E>],
77 ) -> Result<Self, FFTError> {
78 #[cfg(feature = "cuda")]
79 {
80 if !F::field_name().is_empty() {
81 Ok(interpolate_fft_cuda(fft_evals)?)
82 } else {
83 interpolate_fft_cpu::<F, E>(fft_evals)
84 }
85 }
86
87 #[cfg(not(feature = "cuda"))]
88 {
89 interpolate_fft_cpu::<F, E>(fft_evals)
90 }
91 }
92
93 pub fn interpolate_offset_fft<F: IsFFTField + IsSubFieldOf<E>>(
97 fft_evals: &[FieldElement<E>],
98 offset: &FieldElement<F>,
99 ) -> Result<Polynomial<FieldElement<E>>, FFTError> {
100 let scaled = Polynomial::interpolate_fft::<F>(fft_evals)?;
101 Ok(scaled.scale(&offset.inv().unwrap()))
102 }
103
104 pub fn fast_fft_multiplication<F: IsFFTField + IsSubFieldOf<E>>(
113 &self,
114 other: &Self,
115 ) -> Result<Self, FFTError> {
116 let domain_size = self.degree() + other.degree() + 1;
117 let p = Polynomial::evaluate_fft::<F>(self, 1, Some(domain_size))?;
118 let q = Polynomial::evaluate_fft::<F>(other, 1, Some(domain_size))?;
119 let r = p.into_iter().zip(q).map(|(a, b)| a * b).collect::<Vec<_>>();
120
121 Polynomial::interpolate_fft::<F>(&r)
122 }
123
124 pub fn fast_division<F: IsSubFieldOf<E> + IsFFTField>(
128 &self,
129 divisor: &Self,
130 ) -> Result<(Self, Self), FFTError> {
131 let n = self.degree();
132 let m = divisor.degree();
133 if divisor.coefficients.is_empty()
134 || divisor
135 .coefficients
136 .iter()
137 .all(|c| c == &FieldElement::zero())
138 {
139 return Err(FieldError::DivisionByZero.into());
140 }
141 if n < m {
142 return Ok((Self::zero(), self.clone()));
143 }
144 let d = n - m; let a_rev = self.reverse(n);
146 let b_rev = divisor.reverse(m);
147 let inv_b_rev = b_rev.invert_polynomial_mod::<F>(d + 1)?;
148 let q = a_rev
149 .fast_fft_multiplication::<F>(&inv_b_rev)?
150 .truncate(d + 1)
151 .reverse(d);
152
153 let r = self - q.fast_fft_multiplication::<F>(divisor)?;
154 Ok((q, r))
155 }
156
157 pub fn invert_polynomial_mod<F: IsSubFieldOf<E> + IsFFTField>(
160 &self,
161 k: usize,
162 ) -> Result<Self, FFTError> {
163 if self.coefficients.is_empty()
164 || self.coefficients.iter().all(|c| c == &FieldElement::zero())
165 {
166 return Err(FieldError::DivisionByZero.into());
167 }
168 let mut q = Self::new(&[self.coefficients[0].inv()?]);
169 let mut current_precision = 1;
170
171 let two = Self::new(&[FieldElement::<F>::one() + FieldElement::one()]);
172 while current_precision < k {
173 current_precision *= 2;
174 let temp = self
175 .fast_fft_multiplication::<F>(&q)?
176 .truncate(current_precision);
177 let correction = &two - temp;
178 q = q
179 .fast_fft_multiplication::<F>(&correction)?
180 .truncate(current_precision);
181 }
182
183 Ok(q.truncate(k))
185 }
186}
187
188pub fn compose_fft<F, E>(
189 poly_1: &Polynomial<FieldElement<E>>,
190 poly_2: &Polynomial<FieldElement<E>>,
191) -> Polynomial<FieldElement<E>>
192where
193 F: IsFFTField + IsSubFieldOf<E>,
194 E: IsField,
195{
196 let poly_2_evaluations = Polynomial::evaluate_fft::<F>(poly_2, 1, None).unwrap();
197
198 let values: Vec<_> = poly_2_evaluations
199 .iter()
200 .map(|value| poly_1.evaluate(value))
201 .collect();
202
203 Polynomial::interpolate_fft::<F>(values.as_slice()).unwrap()
204}
205
206pub fn evaluate_fft_cpu<F, E>(coeffs: &[FieldElement<E>]) -> Result<Vec<FieldElement<E>>, FFTError>
207where
208 F: IsFFTField + IsSubFieldOf<E>,
209 E: IsField,
210{
211 let order = coeffs.len().trailing_zeros();
212 let twiddles = roots_of_unity::get_twiddles::<F>(order.into(), RootsConfig::BitReverse)?;
213 ops::fft(coeffs, &twiddles)
215}
216
217pub fn interpolate_fft_cpu<F, E>(
218 fft_evals: &[FieldElement<E>],
219) -> Result<Polynomial<FieldElement<E>>, FFTError>
220where
221 F: IsFFTField + IsSubFieldOf<E>,
222 E: IsField,
223{
224 let order = fft_evals.len().trailing_zeros();
225 let twiddles =
226 roots_of_unity::get_twiddles::<F>(order.into(), RootsConfig::BitReverseInversed)?;
227
228 let coeffs = ops::fft(fft_evals, &twiddles)?;
229
230 let scale_factor = FieldElement::from(fft_evals.len() as u64).inv().unwrap();
231 Ok(Polynomial::new(&coeffs).scale_coeffs(&scale_factor))
232}
233
234#[cfg(test)]
235mod tests {
236 #[cfg(not(feature = "cuda"))]
237 use crate::field::traits::IsField;
238
239 use crate::field::{
240 test_fields::u64_test_field::{U64TestField, U64TestFieldExtension},
241 traits::RootsConfig,
242 };
243 use proptest::{collection, prelude::*};
244
245 use roots_of_unity::{get_powers_of_primitive_root, get_powers_of_primitive_root_coset};
246
247 use super::*;
248
249 fn gen_fft_and_naive_evaluation<F: IsFFTField>(
250 poly: Polynomial<FieldElement<F>>,
251 ) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
252 let len = poly.coeff_len().next_power_of_two();
253 let order = len.trailing_zeros();
254 let twiddles =
255 get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap();
256
257 let fft_eval = Polynomial::evaluate_fft::<F>(&poly, 1, None).unwrap();
258 let naive_eval = poly.evaluate_slice(&twiddles);
259
260 (fft_eval, naive_eval)
261 }
262
263 fn gen_fft_coset_and_naive_evaluation<F: IsFFTField>(
264 poly: Polynomial<FieldElement<F>>,
265 offset: FieldElement<F>,
266 blowup_factor: usize,
267 ) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
268 let len = poly.coeff_len().next_power_of_two();
269 let order = (len * blowup_factor).trailing_zeros();
270 let twiddles =
271 get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset).unwrap();
272
273 let fft_eval =
274 Polynomial::evaluate_offset_fft::<F>(&poly, blowup_factor, None, &offset).unwrap();
275 let naive_eval = poly.evaluate_slice(&twiddles);
276
277 (fft_eval, naive_eval)
278 }
279
280 fn gen_fft_and_naive_interpolate<F: IsFFTField>(
281 fft_evals: &[FieldElement<F>],
282 ) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
283 let order = fft_evals.len().trailing_zeros() as u64;
284 let twiddles =
285 get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap();
286
287 let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
288 let fft_poly = Polynomial::interpolate_fft::<F>(fft_evals).unwrap();
289
290 (fft_poly, naive_poly)
291 }
292
293 fn gen_fft_and_naive_coset_interpolate<F: IsFFTField>(
294 fft_evals: &[FieldElement<F>],
295 offset: &FieldElement<F>,
296 ) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
297 let order = fft_evals.len().trailing_zeros() as u64;
298 let twiddles = get_powers_of_primitive_root_coset(order, 1 << order, offset).unwrap();
299
300 let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
301 let fft_poly = Polynomial::interpolate_offset_fft(fft_evals, offset).unwrap();
302
303 (fft_poly, naive_poly)
304 }
305
306 fn gen_fft_interpolate_and_evaluate<F: IsFFTField>(
307 poly: Polynomial<FieldElement<F>>,
308 ) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
309 let eval = Polynomial::evaluate_fft::<F>(&poly, 1, None).unwrap();
310 let new_poly = Polynomial::interpolate_fft::<F>(&eval).unwrap();
311
312 (poly, new_poly)
313 }
314
315 #[cfg(not(feature = "cuda"))]
316 mod u64_field_tests {
317 use super::*;
318 use crate::field::test_fields::u64_test_field::U64TestField;
319
320 type F = U64TestField;
322 type FE = FieldElement<F>;
323
324 prop_compose! {
325 fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
326 }
329 prop_compose! {
330 fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
331 FE::from(num)
332 }
333 }
334 prop_compose! {
335 fn offset()(num in 1..F::neg(&1)) -> FE { FE::from(num) }
336 }
337 prop_compose! {
338 fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
339 vec
340 }
341 }
342 prop_compose! {
343 fn non_empty_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 1 << max_exp)) -> Vec<FE> {
344 vec
345 }
346 }
347 prop_compose! {
348 fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
349 vec
350 }
351 }
352 prop_compose! {
353 fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
354 Polynomial::new(&coeffs)
355 }
356 }
357 prop_compose! {
358 fn non_zero_poly(max_exp: u8)(coeffs in non_empty_field_vec(max_exp)) -> Polynomial<FE> {
359 Polynomial::new(&coeffs)
360 }
361 }
362 prop_compose! {
363 fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
364 Polynomial::new(&coeffs)
365 }
366 }
367
368 proptest! {
369 #[test]
371 fn test_fft_matches_naive_evaluation(poly in poly(8)) {
372 let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
373 prop_assert_eq!(fft_eval, naive_eval);
374 }
375
376 #[test]
378 fn test_fft_coset_matches_naive_evaluation(poly in poly(6), offset in offset(), blowup_factor in powers_of_two(4)) {
379 let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
380 prop_assert_eq!(fft_eval, naive_eval);
381 }
382
383 #[test]
385 fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
386 .prop_filter("Avoid polynomials of size not power of two",
387 |evals| evals.len().is_power_of_two())) {
388 let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
389 prop_assert_eq!(fft_poly, naive_poly);
390 }
391
392 #[test]
394 fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
395 .prop_filter("Avoid polynomials of size not power of two",
396 |evals| evals.len().is_power_of_two())) {
397 let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
398 prop_assert_eq!(fft_poly, naive_poly);
399 }
400
401 #[test]
403 fn test_fft_interpolate_is_inverse_of_evaluate(poly in poly(4)
404 .prop_filter("Avoid polynomials of size not power of two",
405 |poly| poly.coeff_len().is_power_of_two())) {
406 let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
407
408 prop_assert_eq!(poly, new_poly);
409 }
410
411 #[test]
412 fn test_fft_multiplication_works(poly in poly(7), other in poly(7)) {
413 prop_assert_eq!(poly.fast_fft_multiplication::<F>(&other).unwrap(), poly * other);
414 }
415
416 #[test]
417 fn test_fft_division_works(poly in non_zero_poly(7), other in non_zero_poly(7)) {
418 prop_assert_eq!(poly.fast_division::<F>(&other).unwrap(), poly.long_division_with_remainder(&other));
419 }
420
421 #[test]
422 fn test_invert_polynomial_mod_works(poly in non_zero_poly(7), k in powers_of_two(4)) {
423 let inverted_poly = poly.invert_polynomial_mod::<F>(k).unwrap();
424 prop_assert_eq!((poly * inverted_poly).truncate(k), Polynomial::new(&[FE::one()]));
425 }
426 }
427
428 #[test]
429 fn composition_fft_works() {
430 let p = Polynomial::new(&[FE::new(0), FE::new(2)]);
431 let q = Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(1)]);
432 assert_eq!(
433 compose_fft::<F, F>(&p, &q),
434 Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(2)])
435 );
436 }
437 }
438
439 mod u256_field_tests {
440 use super::*;
441 use crate::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField;
442
443 prop_compose! {
444 fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
445 }
448 prop_compose! {
449 fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
450 FE::from(num)
451 }
452 }
453 prop_compose! {
454 fn offset()(num in any::<u64>(), factor in any::<u64>()) -> FE { FE::from(num).pow(factor) }
455 }
456 prop_compose! {
457 fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
458 vec
459 }
460 }
461 prop_compose! {
462 fn non_empty_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 1 << max_exp)) -> Vec<FE> {
463 vec
464 }
465 }
466 prop_compose! {
467 fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
468 vec
469 }
470 }
471 prop_compose! {
472 fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
473 Polynomial::new(&coeffs)
474 }
475 }
476 prop_compose! {
477 fn non_zero_poly(max_exp: u8)(coeffs in non_empty_field_vec(max_exp)) -> Polynomial<FE> {
478 Polynomial::new(&coeffs)
479 }
480 }
481 prop_compose! {
482 fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
483 Polynomial::new(&coeffs)
484 }
485 }
486
487 type F = Stark252PrimeField;
489 type FE = FieldElement<F>;
490
491 proptest! {
492 #[test]
494 fn test_fft_matches_naive_evaluation(poly in poly(8)) {
495 let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
496 prop_assert_eq!(fft_eval, naive_eval);
497 }
498
499 #[test]
501 fn test_fft_coset_matches_naive_evaluation(poly in poly(4), offset in offset(), blowup_factor in powers_of_two(4)) {
502 let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
503 prop_assert_eq!(fft_eval, naive_eval);
504 }
505
506 #[test]
508 fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
509 .prop_filter("Avoid polynomials of size not power of two",
510 |evals| evals.len().is_power_of_two())) {
511 let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
512 prop_assert_eq!(fft_poly, naive_poly);
513 }
514
515 #[test]
517 fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
518 .prop_filter("Avoid polynomials of size not power of two",
519 |evals| evals.len().is_power_of_two())) {
520 let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
521 prop_assert_eq!(fft_poly, naive_poly);
522 }
523
524 #[test]
526 fn test_fft_interpolate_is_inverse_of_evaluate(
527 poly in poly(4).prop_filter("Avoid non pows of two", |poly| poly.coeff_len().is_power_of_two())) {
528 let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
529 prop_assert_eq!(poly, new_poly);
530 }
531
532 #[test]
533 fn test_fft_multiplication_works(poly in poly(7), other in poly(7)) {
534 prop_assert_eq!(poly.fast_fft_multiplication::<F>(&other).unwrap(), poly * other);
535 }
536
537 #[test]
538 fn test_fft_division_works(poly in poly(7), other in non_zero_poly(7)) {
539 prop_assert_eq!(poly.fast_division::<F>(&other).unwrap(), poly.long_division_with_remainder(&other));
540 }
541
542 #[test]
543 fn test_invert_polynomial_mod_works(poly in non_zero_poly(7), k in powers_of_two(4)) {
544 let inverted_poly = poly.invert_polynomial_mod::<F>(k).unwrap();
545 prop_assert_eq!((poly * inverted_poly).truncate(k), Polynomial::new(&[FE::one()]));
546 }
547 }
548 }
549
550 #[test]
551 fn test_fft_with_values_in_field_extension_over_domain_in_prime_field() {
552 type TF = U64TestField;
553 type TL = U64TestFieldExtension;
554
555 let a = FieldElement::<TL>::from(&[FieldElement::one(), FieldElement::one()]);
556 let b = FieldElement::<TL>::from(&[-FieldElement::from(2), FieldElement::from(17)]);
557 let c = FieldElement::<TL>::one();
558 let poly = Polynomial::new(&[a, b, c]);
559
560 let eval = Polynomial::evaluate_offset_fft::<TF>(&poly, 8, Some(4), &FieldElement::from(2))
561 .unwrap();
562 let new_poly =
563 Polynomial::interpolate_offset_fft::<TF>(&eval, &FieldElement::from(2)).unwrap();
564 assert_eq!(poly, new_poly);
565 }
566}