use crate::fft::errors::FFTError;
use crate::{
field::{
element::FieldElement,
traits::{IsFFTField, RootsConfig},
},
polynomial::Polynomial,
};
#[cfg(feature = "cuda")]
use crate::fft::gpu::cuda::polynomial::{evaluate_fft_cuda, interpolate_fft_cuda};
#[cfg(feature = "metal")]
use crate::fft::gpu::metal::polynomial::{evaluate_fft_metal, interpolate_fft_metal};
use super::cpu::{ops, roots_of_unity};
pub trait FFTPoly<F: IsFFTField> {
fn evaluate_fft(
&self,
blowup_factor: usize,
domain_size: Option<usize>,
) -> Result<Vec<FieldElement<F>>, FFTError>;
fn evaluate_offset_fft(
&self,
blowup_factor: usize,
domain_size: Option<usize>,
offset: &FieldElement<F>,
) -> Result<Vec<FieldElement<F>>, FFTError>;
fn interpolate_fft(
fft_evals: &[FieldElement<F>],
) -> Result<Polynomial<FieldElement<F>>, FFTError>;
fn interpolate_offset_fft(
fft_evals: &[FieldElement<F>],
offset: &FieldElement<F>,
) -> Result<Polynomial<FieldElement<F>>, FFTError>;
}
impl<F: IsFFTField> FFTPoly<F> for Polynomial<FieldElement<F>> {
fn evaluate_fft(
&self,
blowup_factor: usize,
domain_size: Option<usize>,
) -> Result<Vec<FieldElement<F>>, FFTError> {
let domain_size = domain_size.unwrap_or(0);
let len = std::cmp::max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
if self.coefficients().is_empty() {
return Ok(vec![FieldElement::zero(); len]);
}
let mut coeffs = self.coefficients().to_vec();
coeffs.resize(len, FieldElement::zero());
#[cfg(feature = "metal")]
{
if !F::field_name().is_empty() {
Ok(evaluate_fft_metal(&coeffs)?)
} else {
println!(
"GPU evaluation failed for field {}. Program will fallback to CPU.",
std::any::type_name::<F>()
);
evaluate_fft_cpu(&coeffs)
}
}
#[cfg(feature = "cuda")]
{
if F::field_name() == "stark256" {
Ok(evaluate_fft_cuda(&coeffs)?)
} else {
evaluate_fft_cpu(&coeffs)
}
}
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
{
evaluate_fft_cpu(&coeffs)
}
}
fn evaluate_offset_fft(
&self,
blowup_factor: usize,
domain_size: Option<usize>,
offset: &FieldElement<F>,
) -> Result<Vec<FieldElement<F>>, FFTError> {
let scaled = self.scale(offset);
scaled.evaluate_fft(blowup_factor, domain_size)
}
fn interpolate_fft(fft_evals: &[FieldElement<F>]) -> Result<Self, FFTError> {
#[cfg(feature = "metal")]
{
if !F::field_name().is_empty() {
Ok(interpolate_fft_metal(fft_evals)?)
} else {
println!(
"GPU interpolation failed for field {}. Program will fallback to CPU.",
std::any::type_name::<F>()
);
interpolate_fft_cpu(fft_evals)
}
}
#[cfg(feature = "cuda")]
{
if !F::field_name().is_empty() {
Ok(interpolate_fft_cuda(fft_evals)?)
} else {
interpolate_fft_cpu(fft_evals)
}
}
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
{
interpolate_fft_cpu(fft_evals)
}
}
fn interpolate_offset_fft(
fft_evals: &[FieldElement<F>],
offset: &FieldElement<F>,
) -> Result<Polynomial<FieldElement<F>>, FFTError> {
let scaled = Polynomial::interpolate_fft(fft_evals)?;
Ok(scaled.scale(&offset.inv().unwrap()))
}
}
pub fn compose_fft<F>(
poly_1: &Polynomial<FieldElement<F>>,
poly_2: &Polynomial<FieldElement<F>>,
) -> Polynomial<FieldElement<F>>
where
F: IsFFTField,
{
let poly_2_evaluations = poly_2.evaluate_fft(1, None).unwrap();
let values: Vec<_> = poly_2_evaluations
.iter()
.map(|value| poly_1.evaluate(value))
.collect();
Polynomial::interpolate_fft(values.as_slice()).unwrap()
}
pub fn evaluate_fft_cpu<F>(coeffs: &[FieldElement<F>]) -> Result<Vec<FieldElement<F>>, FFTError>
where
F: IsFFTField,
{
let order = coeffs.len().trailing_zeros();
let twiddles = roots_of_unity::get_twiddles(order.into(), RootsConfig::BitReverse)?;
ops::fft(coeffs, &twiddles)
}
pub fn interpolate_fft_cpu<F>(
fft_evals: &[FieldElement<F>],
) -> Result<Polynomial<FieldElement<F>>, FFTError>
where
F: IsFFTField,
{
let order = fft_evals.len().trailing_zeros();
let twiddles = roots_of_unity::get_twiddles(order.into(), RootsConfig::BitReverseInversed)?;
let coeffs = ops::fft(fft_evals, &twiddles)?;
let scale_factor = FieldElement::from(fft_evals.len() as u64).inv().unwrap();
Ok(Polynomial::new(&coeffs).scale_coeffs(&scale_factor))
}
#[cfg(test)]
mod tests {
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
use crate::field::traits::IsField;
use crate::field::traits::RootsConfig;
use proptest::{collection, prelude::*};
use roots_of_unity::{get_powers_of_primitive_root, get_powers_of_primitive_root_coset};
use super::*;
fn gen_fft_and_naive_evaluation<F: IsFFTField>(
poly: Polynomial<FieldElement<F>>,
) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
let len = poly.coeff_len().next_power_of_two();
let order = len.trailing_zeros();
let twiddles =
get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap();
let fft_eval = poly.evaluate_fft(1, None).unwrap();
let naive_eval = poly.evaluate_slice(&twiddles);
(fft_eval, naive_eval)
}
fn gen_fft_coset_and_naive_evaluation<F: IsFFTField>(
poly: Polynomial<FieldElement<F>>,
offset: FieldElement<F>,
blowup_factor: usize,
) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
let len = poly.coeff_len().next_power_of_two();
let order = (len * blowup_factor).trailing_zeros();
let twiddles =
get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset).unwrap();
let fft_eval = poly
.evaluate_offset_fft(blowup_factor, None, &offset)
.unwrap();
let naive_eval = poly.evaluate_slice(&twiddles);
(fft_eval, naive_eval)
}
fn gen_fft_and_naive_interpolate<F: IsFFTField>(
fft_evals: &[FieldElement<F>],
) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
let order = fft_evals.len().trailing_zeros() as u64;
let twiddles =
get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap();
let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
let fft_poly = Polynomial::interpolate_fft(fft_evals).unwrap();
(fft_poly, naive_poly)
}
fn gen_fft_and_naive_coset_interpolate<F: IsFFTField>(
fft_evals: &[FieldElement<F>],
offset: &FieldElement<F>,
) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
let order = fft_evals.len().trailing_zeros() as u64;
let twiddles = get_powers_of_primitive_root_coset(order, 1 << order, offset).unwrap();
let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
let fft_poly = Polynomial::interpolate_offset_fft(fft_evals, offset).unwrap();
(fft_poly, naive_poly)
}
fn gen_fft_interpolate_and_evaluate<F: IsFFTField>(
poly: Polynomial<FieldElement<F>>,
) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
let eval = poly.evaluate_fft(1, None).unwrap();
let new_poly = Polynomial::interpolate_fft(&eval).unwrap();
(poly, new_poly)
}
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
mod u64_field_tests {
use super::*;
use crate::field::test_fields::u64_test_field::U64TestField;
type F = U64TestField;
type FE = FieldElement<F>;
prop_compose! {
fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
}
prop_compose! {
fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
FE::from(num)
}
}
prop_compose! {
fn offset()(num in 1..F::neg(&1)) -> FE { FE::from(num) }
}
prop_compose! {
fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
vec
}
}
prop_compose! {
fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
vec
}
}
prop_compose! {
fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
prop_compose! {
fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
proptest! {
#[test]
fn test_fft_matches_naive_evaluation(poly in poly(8)) {
let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_coset_matches_naive_evaluation(poly in poly(6), offset in offset(), blowup_factor in powers_of_two(4)) {
let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_is_inverse_of_evaluate(poly in poly(4)
.prop_filter("Avoid polynomials of size not power of two",
|poly| poly.coeff_len().is_power_of_two())) {
let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
prop_assert_eq!(poly, new_poly);
}
}
#[test]
fn composition_fft_works() {
let p = Polynomial::new(&[FE::new(0), FE::new(2)]);
let q = Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(1)]);
assert_eq!(
compose_fft(&p, &q),
Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(2)])
);
}
}
mod u256_field_tests {
use super::*;
use crate::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField;
prop_compose! {
fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
}
prop_compose! {
fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
FE::from(num)
}
}
prop_compose! {
fn offset()(num in any::<u64>(), factor in any::<u64>()) -> FE { FE::from(num).pow(factor) }
}
prop_compose! {
fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
vec
}
}
prop_compose! {
fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
vec
}
}
prop_compose! {
fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
prop_compose! {
fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
type F = Stark252PrimeField;
type FE = FieldElement<F>;
proptest! {
#[test]
fn test_fft_matches_naive_evaluation(poly in poly(8)) {
let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_coset_matches_naive_evaluation(poly in poly(4), offset in offset(), blowup_factor in powers_of_two(4)) {
let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_is_inverse_of_evaluate(
poly in poly(4).prop_filter("Avoid non pows of two", |poly| poly.coeff_len().is_power_of_two())) {
let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
prop_assert_eq!(poly, new_poly);
}
}
}
}