1use crate::fft::errors::FFTError;
2
3use crate::field::traits::{IsField, IsSubFieldOf};
4use crate::{
5 field::{
6 element::FieldElement,
7 traits::{IsFFTField, RootsConfig},
8 },
9 polynomial::Polynomial,
10};
11use alloc::{vec, vec::Vec};
12
13#[cfg(feature = "cuda")]
14use crate::fft::gpu::cuda::polynomial::{evaluate_fft_cuda, interpolate_fft_cuda};
15#[cfg(feature = "metal")]
16use crate::fft::gpu::metal::polynomial::{evaluate_fft_metal, interpolate_fft_metal};
17
18use super::cpu::{ops, roots_of_unity};
19
20impl<E: IsField> Polynomial<FieldElement<E>> {
21 pub fn evaluate_fft<F: IsFFTField + IsSubFieldOf<E>>(
26 poly: &Polynomial<FieldElement<E>>,
27 blowup_factor: usize,
28 domain_size: Option<usize>,
29 ) -> Result<Vec<FieldElement<E>>, FFTError> {
30 let domain_size = domain_size.unwrap_or(0);
31 let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
32
33 if poly.coefficients().is_empty() {
34 return Ok(vec![FieldElement::zero(); len]);
35 }
36
37 let mut coeffs = poly.coefficients().to_vec();
38 coeffs.resize(len, FieldElement::zero());
39 #[cfg(feature = "metal")]
42 {
43 if !F::field_name().is_empty() {
44 Ok(evaluate_fft_metal::<F, E>(&coeffs)?)
45 } else {
46 println!(
47 "GPU evaluation failed for field {}. Program will fallback to CPU.",
48 core::any::type_name::<F>()
49 );
50 evaluate_fft_cpu::<F, E>(&coeffs)
51 }
52 }
53
54 #[cfg(feature = "cuda")]
55 {
56 if F::field_name() == "stark256" {
58 Ok(evaluate_fft_cuda(&coeffs)?)
59 } else {
60 evaluate_fft_cpu::<F, E>(&coeffs)
61 }
62 }
63
64 #[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
65 {
66 evaluate_fft_cpu::<F, E>(&coeffs)
67 }
68 }
69
70 pub fn evaluate_offset_fft<F: IsFFTField + IsSubFieldOf<E>>(
75 poly: &Polynomial<FieldElement<E>>,
76 blowup_factor: usize,
77 domain_size: Option<usize>,
78 offset: &FieldElement<F>,
79 ) -> Result<Vec<FieldElement<E>>, FFTError> {
80 let scaled = poly.scale(offset);
81 Polynomial::evaluate_fft::<F>(&scaled, blowup_factor, domain_size)
82 }
83
84 pub fn interpolate_fft<F: IsFFTField + IsSubFieldOf<E>>(
88 fft_evals: &[FieldElement<E>],
89 ) -> Result<Self, FFTError> {
90 #[cfg(feature = "metal")]
91 {
92 if !F::field_name().is_empty() {
93 Ok(interpolate_fft_metal::<F, E>(fft_evals)?)
94 } else {
95 println!(
96 "GPU interpolation failed for field {}. Program will fallback to CPU.",
97 core::any::type_name::<F>()
98 );
99 interpolate_fft_cpu::<F, E>(fft_evals)
100 }
101 }
102
103 #[cfg(feature = "cuda")]
104 {
105 if !F::field_name().is_empty() {
106 Ok(interpolate_fft_cuda(fft_evals)?)
107 } else {
108 interpolate_fft_cpu::<F, E>(fft_evals)
109 }
110 }
111
112 #[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
113 {
114 interpolate_fft_cpu::<F, E>(fft_evals)
115 }
116 }
117
118 pub fn interpolate_offset_fft<F: IsFFTField + IsSubFieldOf<E>>(
122 fft_evals: &[FieldElement<E>],
123 offset: &FieldElement<F>,
124 ) -> Result<Polynomial<FieldElement<E>>, FFTError> {
125 let scaled = Polynomial::interpolate_fft::<F>(fft_evals)?;
126 Ok(scaled.scale(&offset.inv().unwrap()))
127 }
128}
129
130pub fn compose_fft<F, E>(
131 poly_1: &Polynomial<FieldElement<E>>,
132 poly_2: &Polynomial<FieldElement<E>>,
133) -> Polynomial<FieldElement<E>>
134where
135 F: IsFFTField + IsSubFieldOf<E>,
136 E: IsField,
137{
138 let poly_2_evaluations = Polynomial::evaluate_fft::<F>(poly_2, 1, None).unwrap();
139
140 let values: Vec<_> = poly_2_evaluations
141 .iter()
142 .map(|value| poly_1.evaluate(value))
143 .collect();
144
145 Polynomial::interpolate_fft::<F>(values.as_slice()).unwrap()
146}
147
148pub fn evaluate_fft_cpu<F, E>(coeffs: &[FieldElement<E>]) -> Result<Vec<FieldElement<E>>, FFTError>
149where
150 F: IsFFTField + IsSubFieldOf<E>,
151 E: IsField,
152{
153 let order = coeffs.len().trailing_zeros();
154 let twiddles = roots_of_unity::get_twiddles::<F>(order.into(), RootsConfig::BitReverse)?;
155 ops::fft(coeffs, &twiddles)
157}
158
159pub fn interpolate_fft_cpu<F, E>(
160 fft_evals: &[FieldElement<E>],
161) -> Result<Polynomial<FieldElement<E>>, FFTError>
162where
163 F: IsFFTField + IsSubFieldOf<E>,
164 E: IsField,
165{
166 let order = fft_evals.len().trailing_zeros();
167 let twiddles =
168 roots_of_unity::get_twiddles::<F>(order.into(), RootsConfig::BitReverseInversed)?;
169
170 let coeffs = ops::fft(fft_evals, &twiddles)?;
171
172 let scale_factor = FieldElement::from(fft_evals.len() as u64).inv().unwrap();
173 Ok(Polynomial::new(&coeffs).scale_coeffs(&scale_factor))
174}
175
176#[cfg(test)]
177mod tests {
178 #[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
179 use crate::field::traits::IsField;
180
181 use alloc::format;
182
183 use crate::field::{
184 test_fields::u64_test_field::{U64TestField, U64TestFieldExtension},
185 traits::RootsConfig,
186 };
187 use proptest::{collection, prelude::*};
188
189 use roots_of_unity::{get_powers_of_primitive_root, get_powers_of_primitive_root_coset};
190
191 use super::*;
192
193 fn gen_fft_and_naive_evaluation<F: IsFFTField>(
194 poly: Polynomial<FieldElement<F>>,
195 ) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
196 let len = poly.coeff_len().next_power_of_two();
197 let order = len.trailing_zeros();
198 let twiddles =
199 get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap();
200
201 let fft_eval = Polynomial::evaluate_fft::<F>(&poly, 1, None).unwrap();
202 let naive_eval = poly.evaluate_slice(&twiddles);
203
204 (fft_eval, naive_eval)
205 }
206
207 fn gen_fft_coset_and_naive_evaluation<F: IsFFTField>(
208 poly: Polynomial<FieldElement<F>>,
209 offset: FieldElement<F>,
210 blowup_factor: usize,
211 ) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
212 let len = poly.coeff_len().next_power_of_two();
213 let order = (len * blowup_factor).trailing_zeros();
214 let twiddles =
215 get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset).unwrap();
216
217 let fft_eval =
218 Polynomial::evaluate_offset_fft::<F>(&poly, blowup_factor, None, &offset).unwrap();
219 let naive_eval = poly.evaluate_slice(&twiddles);
220
221 (fft_eval, naive_eval)
222 }
223
224 fn gen_fft_and_naive_interpolate<F: IsFFTField>(
225 fft_evals: &[FieldElement<F>],
226 ) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
227 let order = fft_evals.len().trailing_zeros() as u64;
228 let twiddles =
229 get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap();
230
231 let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
232 let fft_poly = Polynomial::interpolate_fft::<F>(fft_evals).unwrap();
233
234 (fft_poly, naive_poly)
235 }
236
237 fn gen_fft_and_naive_coset_interpolate<F: IsFFTField>(
238 fft_evals: &[FieldElement<F>],
239 offset: &FieldElement<F>,
240 ) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
241 let order = fft_evals.len().trailing_zeros() as u64;
242 let twiddles = get_powers_of_primitive_root_coset(order, 1 << order, offset).unwrap();
243
244 let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
245 let fft_poly = Polynomial::interpolate_offset_fft(fft_evals, offset).unwrap();
246
247 (fft_poly, naive_poly)
248 }
249
250 fn gen_fft_interpolate_and_evaluate<F: IsFFTField>(
251 poly: Polynomial<FieldElement<F>>,
252 ) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
253 let eval = Polynomial::evaluate_fft::<F>(&poly, 1, None).unwrap();
254 let new_poly = Polynomial::interpolate_fft::<F>(&eval).unwrap();
255
256 (poly, new_poly)
257 }
258
259 #[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
260 mod u64_field_tests {
261 use super::*;
262 use crate::field::test_fields::u64_test_field::U64TestField;
263
264 type F = U64TestField;
266 type FE = FieldElement<F>;
267
268 prop_compose! {
269 fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
270 }
273 prop_compose! {
274 fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
275 FE::from(num)
276 }
277 }
278 prop_compose! {
279 fn offset()(num in 1..F::neg(&1)) -> FE { FE::from(num) }
280 }
281 prop_compose! {
282 fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
283 vec
284 }
285 }
286 prop_compose! {
287 fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
288 vec
289 }
290 }
291 prop_compose! {
292 fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
293 Polynomial::new(&coeffs)
294 }
295 }
296 prop_compose! {
297 fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
298 Polynomial::new(&coeffs)
299 }
300 }
301
302 proptest! {
303 #[test]
305 fn test_fft_matches_naive_evaluation(poly in poly(8)) {
306 let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
307 prop_assert_eq!(fft_eval, naive_eval);
308 }
309
310 #[test]
312 fn test_fft_coset_matches_naive_evaluation(poly in poly(6), offset in offset(), blowup_factor in powers_of_two(4)) {
313 let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
314 prop_assert_eq!(fft_eval, naive_eval);
315 }
316
317 #[test]
319 fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
320 .prop_filter("Avoid polynomials of size not power of two",
321 |evals| evals.len().is_power_of_two())) {
322 let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
323 prop_assert_eq!(fft_poly, naive_poly);
324 }
325
326 #[test]
328 fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
329 .prop_filter("Avoid polynomials of size not power of two",
330 |evals| evals.len().is_power_of_two())) {
331 let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
332 prop_assert_eq!(fft_poly, naive_poly);
333 }
334
335 #[test]
337 fn test_fft_interpolate_is_inverse_of_evaluate(poly in poly(4)
338 .prop_filter("Avoid polynomials of size not power of two",
339 |poly| poly.coeff_len().is_power_of_two())) {
340 let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
341
342 prop_assert_eq!(poly, new_poly);
343 }
344 }
345
346 #[test]
347 fn composition_fft_works() {
348 let p = Polynomial::new(&[FE::new(0), FE::new(2)]);
349 let q = Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(1)]);
350 assert_eq!(
351 compose_fft::<F, F>(&p, &q),
352 Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(2)])
353 );
354 }
355 }
356
357 mod u256_field_tests {
358 use super::*;
359 use crate::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField;
360
361 prop_compose! {
362 fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
363 }
366 prop_compose! {
367 fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
368 FE::from(num)
369 }
370 }
371 prop_compose! {
372 fn offset()(num in any::<u64>(), factor in any::<u64>()) -> FE { FE::from(num).pow(factor) }
373 }
374 prop_compose! {
375 fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
376 vec
377 }
378 }
379 prop_compose! {
380 fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
381 vec
382 }
383 }
384 prop_compose! {
385 fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
386 Polynomial::new(&coeffs)
387 }
388 }
389 prop_compose! {
390 fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
391 Polynomial::new(&coeffs)
392 }
393 }
394
395 type F = Stark252PrimeField;
397 type FE = FieldElement<F>;
398
399 proptest! {
400 #[test]
402 fn test_fft_matches_naive_evaluation(poly in poly(8)) {
403 let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
404 prop_assert_eq!(fft_eval, naive_eval);
405 }
406
407 #[test]
409 fn test_fft_coset_matches_naive_evaluation(poly in poly(4), offset in offset(), blowup_factor in powers_of_two(4)) {
410 let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
411 prop_assert_eq!(fft_eval, naive_eval);
412 }
413
414 #[test]
416 fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
417 .prop_filter("Avoid polynomials of size not power of two",
418 |evals| evals.len().is_power_of_two())) {
419 let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
420 prop_assert_eq!(fft_poly, naive_poly);
421 }
422
423 #[test]
425 fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
426 .prop_filter("Avoid polynomials of size not power of two",
427 |evals| evals.len().is_power_of_two())) {
428 let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
429 prop_assert_eq!(fft_poly, naive_poly);
430 }
431
432 #[test]
434 fn test_fft_interpolate_is_inverse_of_evaluate(
435 poly in poly(4).prop_filter("Avoid non pows of two", |poly| poly.coeff_len().is_power_of_two())) {
436 let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
437 prop_assert_eq!(poly, new_poly);
438 }
439 }
440 }
441
442 #[test]
443 fn test_fft_with_values_in_field_extension_over_domain_in_prime_field() {
444 type TF = U64TestField;
445 type TL = U64TestFieldExtension;
446
447 let a = FieldElement::<TL>::from(&[FieldElement::one(), FieldElement::one()]);
448 let b = FieldElement::<TL>::from(&[-FieldElement::from(2), FieldElement::from(17)]);
449 let c = FieldElement::<TL>::one();
450 let poly = Polynomial::new(&[a, b, c]);
451
452 let eval = Polynomial::evaluate_offset_fft::<TF>(&poly, 8, Some(4), &FieldElement::from(2))
453 .unwrap();
454 let new_poly =
455 Polynomial::interpolate_offset_fft::<TF>(&eval, &FieldElement::from(2)).unwrap();
456 assert_eq!(poly, new_poly);
457 }
458}