use crate::field::FieldElement;
use crate::fp::{log2, MAX_ROOTS};
use std::convert::TryFrom;
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum FftError {
#[error("output slice is smaller than specified size")]
OutputTooSmall,
#[error("size is larger than than maximum permitted")]
SizeTooLarge,
#[error("size is not a power of 2")]
SizeInvalid,
}
#[allow(clippy::many_single_char_names)]
pub fn discrete_fourier_transform<F: FieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
) -> Result<(), FftError> {
let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?;
if size > outp.len() {
return Err(FftError::OutputTooSmall);
}
if size > 1 << MAX_ROOTS {
return Err(FftError::SizeTooLarge);
}
if size != 1 << d {
return Err(FftError::SizeInvalid);
}
#[allow(clippy::needless_range_loop)]
for i in 0..size {
let j = bitrev(d, i);
outp[i] = if j < inp.len() { inp[j] } else { F::zero() }
}
let mut w: F;
for l in 1..d + 1 {
w = F::one();
let r = F::root(l).unwrap();
let y = 1 << (l - 1);
for i in 0..y {
for j in 0..(size / y) >> 1 {
let x = (1 << l) * j + i;
let u = outp[x];
let v = w * outp[x + y];
outp[x] = u + v;
outp[x + y] = u - v;
}
w *= r;
}
}
Ok(())
}
#[cfg(test)]
pub(crate) fn discrete_fourier_transform_inv<F: FieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
) -> Result<(), FftError> {
let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv();
discrete_fourier_transform(outp, inp, size)?;
discrete_fourier_transform_inv_finish(outp, size, size_inv);
Ok(())
}
pub(crate) fn discrete_fourier_transform_inv_finish<F: FieldElement>(
outp: &mut [F],
size: usize,
size_inv: F,
) {
let mut tmp: F;
outp[0] *= size_inv;
outp[size >> 1] *= size_inv;
for i in 1..size >> 1 {
tmp = outp[i] * size_inv;
outp[i] = outp[size - i] * size_inv;
outp[size - i] = tmp;
}
}
fn bitrev(d: usize, x: usize) -> usize {
let mut y = 0;
for i in 0..d {
y += ((x >> i) & 1) << (d - i);
}
y >> 1
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{
random_vector, split_vector, Field128, Field32, Field64, Field96, FieldPrio2,
};
use crate::polynomial::{poly_fft, PolyAuxMemory};
fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];
for size in test_sizes.iter() {
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
let want = random_vector(*size).unwrap();
discrete_fourier_transform(&mut tmp, &want, want.len())?;
discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?;
assert_eq!(got, want);
}
Ok(())
}
#[test]
fn test_field32() {
discrete_fourier_transform_then_inv_test::<Field32>().expect("unexpected error");
}
#[test]
fn test_priov2_field32() {
discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error");
}
#[test]
fn test_field64() {
discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
}
#[test]
fn test_field96() {
discrete_fourier_transform_then_inv_test::<Field96>().expect("unexpected error");
}
#[test]
fn test_field128() {
discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error");
}
#[test]
fn test_recursive_fft() {
let size = 128;
let mut mem = PolyAuxMemory::new(size / 2);
let inp = random_vector(size).unwrap();
let mut want = vec![Field32::zero(); size];
let mut got = vec![Field32::zero(); size];
discrete_fourier_transform::<Field32>(&mut want, &inp, inp.len()).unwrap();
poly_fft(
&mut got,
&inp,
&mem.roots_2n,
size,
false,
&mut mem.fft_memory,
);
assert_eq!(got, want);
}
#[test]
fn test_fft_linearity() {
let len = 16;
let num_shares = 3;
let x: Vec<Field64> = random_vector(len).unwrap();
let mut x_shares = split_vector(&x, num_shares).unwrap();
#[allow(clippy::needless_range_loop)]
for i in 0..len {
if i % 2 != 0 {
x_shares[0][i] = x[i];
}
for j in 1..num_shares {
if i % 2 != 0 {
x_shares[j][i] = Field64::zero();
}
}
}
let mut got = vec![Field64::zero(); len];
let mut buf = vec![Field64::zero(); len];
for share in x_shares {
discrete_fourier_transform_inv(&mut buf, &share, len).unwrap();
for i in 0..len {
got[i] += buf[i];
}
}
let mut want = vec![Field64::zero(); len];
discrete_fourier_transform_inv(&mut want, &x, len).unwrap();
assert_eq!(got, want);
}
}