extern crate alloc;
#[cfg(feature = "alloc")]
use super::{
cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive},
cosets::Coset,
twiddles::{get_twiddles, TwiddlesConfig},
};
use crate::{
fft::cpu::bit_reversing::in_place_bit_reverse_permute,
field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field},
};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
pub fn evaluate_cfft(
coeff: Vec<FieldElement<Mersenne31Field>>,
) -> Vec<FieldElement<Mersenne31Field>> {
let mut coeff = coeff;
let domain_log_2_size: u32 = coeff.len().trailing_zeros();
let coset = Coset::new_standard(domain_log_2_size);
let config = TwiddlesConfig::Evaluation;
let twiddles = get_twiddles(coset, config);
in_place_bit_reverse_permute::<FieldElement<Mersenne31Field>>(&mut coeff);
cfft(&mut coeff, twiddles);
order_cfft_result_naive(&coeff)
}
#[cfg(feature = "alloc")]
pub fn interpolate_cfft(
eval: Vec<FieldElement<Mersenne31Field>>,
) -> Vec<FieldElement<Mersenne31Field>> {
let mut eval = eval;
if eval.is_empty() {
let poly: Vec<FieldElement<Mersenne31Field>> = Vec::new();
return poly;
}
let domain_log_2_size: u32 = eval.len().trailing_zeros();
let coset = Coset::new_standard(domain_log_2_size);
let config = TwiddlesConfig::Interpolation;
let twiddles = get_twiddles(coset, config);
let mut eval_ordered = order_icfft_input_naive(&mut eval);
icfft(&mut eval_ordered, twiddles);
in_place_bit_reverse_permute::<FieldElement<Mersenne31Field>>(&mut eval_ordered);
let factor = (FieldElement::<Mersenne31Field>::from(eval.len() as u64))
.inv()
.unwrap();
eval_ordered.iter().map(|coef| coef * factor).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::circle::cosets::Coset;
type FE = FieldElement<Mersenne31Field>;
use alloc::vec;
fn evaluate_poly_4(coef: &[FE; 4], x: FE, y: FE) -> FE {
coef[0] + coef[1] * y + coef[2] * x + coef[3] * x * y
}
fn evaluate_poly_8(coef: &[FE; 8], x: FE, y: FE) -> FE {
coef[0]
+ coef[1] * y
+ coef[2] * x
+ coef[3] * x * y
+ coef[4] * (x.square().double() - FE::one())
+ coef[5] * (x.square().double() - FE::one()) * y
+ coef[6] * ((x.square() * x).double() - x)
+ coef[7] * ((x.square() * x).double() - x) * y
}
fn evaluate_poly_16(coef: &[FE; 16], x: FE, y: FE) -> FE {
let mut a = x;
let mut v = Vec::new();
v.push(FE::one());
v.push(x);
for _ in 2..4 {
a = a.square().double() - FE::one();
v.push(a);
}
coef[0] * v[0]
+ coef[1] * y * v[0]
+ coef[2] * v[1]
+ coef[3] * y * v[1]
+ coef[4] * v[2]
+ coef[5] * y * v[2]
+ coef[6] * v[1] * v[2]
+ coef[7] * y * v[1] * v[2]
+ coef[8] * v[3]
+ coef[9] * y * v[3]
+ coef[10] * v[1] * v[3]
+ coef[11] * y * v[1] * v[3]
+ coef[12] * v[2] * v[3]
+ coef[13] * y * v[2] * v[3]
+ coef[14] * v[1] * v[2] * v[3]
+ coef[15] * y * v[1] * v[2] * v[3]
}
#[test]
fn cfft_evaluation_4_points() {
let input = [FE::from(1), FE::from(2), FE::from(3), FE::from(4)];
let coset = Coset::new_standard(2);
let points = Coset::get_coset_points(&coset);
let mut expected_result: Vec<FE> = Vec::new();
for point in points {
let point_eval = evaluate_poly_4(&input, point.x, point.y);
expected_result.push(point_eval);
}
let input_vec = input.to_vec();
let result = evaluate_cfft(input_vec);
let slice_result: &[FE] = &result;
assert_eq!(slice_result, expected_result);
}
#[test]
fn cfft_evaluation_8_points() {
let input = [
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
];
let coset = Coset::new_standard(3);
let points = Coset::get_coset_points(&coset);
let mut expected_result: Vec<FE> = Vec::new();
for point in points {
let point_eval = evaluate_poly_8(&input, point.x, point.y);
expected_result.push(point_eval);
}
let result = evaluate_cfft(input.to_vec());
let slice_result: &[FE] = &result;
assert_eq!(slice_result, expected_result);
}
#[test]
fn cfft_evaluation_16_points() {
let input = [
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
FE::from(9),
FE::from(10),
FE::from(11),
FE::from(12),
FE::from(13),
FE::from(14),
FE::from(15),
FE::from(16),
];
let coset = Coset::new_standard(4);
let points = Coset::get_coset_points(&coset);
let mut expected_result: Vec<FE> = Vec::new();
for point in points {
let point_eval = evaluate_poly_16(&input, point.x, point.y);
expected_result.push(point_eval);
}
let result = evaluate_cfft(input.to_vec());
let slice_result: &[FE] = &result;
assert_eq!(slice_result, expected_result);
}
#[test]
fn evaluate_and_interpolate_8_points_is_identity() {
let coeff = vec![
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
];
let evals = evaluate_cfft(coeff.clone());
let new_coeff = interpolate_cfft(evals);
assert_eq!(coeff, new_coeff);
}
#[test]
fn evaluate_and_interpolate_8_other_points() {
let coeff = vec![
FE::from(2147483650),
FE::from(147483647),
FE::from(2147483700),
FE::from(2147483647),
FE::from(3147483647),
FE::from(4147483647),
FE::from(2147483640),
FE::from(5147483647),
];
let evals = evaluate_cfft(coeff.clone());
let new_coeff = interpolate_cfft(evals);
assert_eq!(coeff, new_coeff);
}
#[test]
fn evaluate_and_interpolate_32_points() {
let coeff = vec![
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
FE::from(9),
FE::from(10),
FE::from(11),
FE::from(12),
FE::from(13),
FE::from(14),
FE::from(15),
FE::from(16),
FE::from(17),
FE::from(18),
FE::from(19),
FE::from(20),
FE::from(21),
FE::from(22),
FE::from(23),
FE::from(24),
FE::from(25),
FE::from(26),
FE::from(27),
FE::from(28),
FE::from(29),
FE::from(30),
FE::from(31),
FE::from(32),
];
let evals = evaluate_cfft(coeff.clone());
let new_coeff = interpolate_cfft(evals);
assert_eq!(coeff, new_coeff);
}
#[test]
fn evaluate_and_interpolate_2_pow_20_other_points() {
let coeff: Vec<FieldElement<Mersenne31Field>> =
(0..2_u32.pow(20)).map(|i| FE::from(&i)).collect();
let evals = evaluate_cfft(coeff.clone());
let new_coeff = interpolate_cfft(evals);
assert_eq!(coeff, new_coeff);
}
}