mod cooley_tukey;
mod matrix;
#[cfg(test)]
pub(crate) mod test_utils;
mod transpose;
mod utils;
mod wavelet;
use std::{
fmt::Debug,
marker::PhantomData,
sync::{Arc, LazyLock},
};
use ark_ff::{FftField, Field};
use static_assertions::assert_obj_safe;
#[cfg(feature = "tracing")]
use tracing::instrument;
use self::matrix::MatrixMut;
pub use self::{
cooley_tukey::{generator, intt, intt_batch, ntt, ntt_batch},
transpose::transpose,
wavelet::{inverse_wavelet_transform, wavelet_transform},
};
use crate::{
algebra::fields,
type_map::{self, TypeMap},
};
pub static NTT: LazyLock<TypeMap<NttFamily>> = LazyLock::new(|| {
let map = TypeMap::new();
map.insert(Arc::new(ArkNtt::<fields::Field64>::default()) as Arc<dyn ReedSolomon<_>>);
map.insert(Arc::new(ArkNtt::<fields::Field128>::default()) as Arc<dyn ReedSolomon<_>>);
map.insert(Arc::new(ArkNtt::<fields::Field192>::default()) as Arc<dyn ReedSolomon<_>>);
map.insert(Arc::new(ArkNtt::<fields::Field256>::default()) as Arc<dyn ReedSolomon<_>>);
map.insert(Arc::new(ArkNtt::<fields::Field64_2>::default()) as Arc<dyn ReedSolomon<_>>);
map.insert(Arc::new(ArkNtt::<fields::Field64_3>::default()) as Arc<dyn ReedSolomon<_>>);
map.insert(
Arc::new(ArkNtt::<<fields::Field64_2 as Field>::BasePrimeField>::default())
as Arc<dyn ReedSolomon<_>>,
);
map.insert(
Arc::new(ArkNtt::<<fields::Field64_3 as Field>::BasePrimeField>::default())
as Arc<dyn ReedSolomon<_>>,
);
map
});
#[derive(Default)]
pub struct NttFamily;
impl type_map::Family for NttFamily {
type Dyn<F: 'static> = dyn ReedSolomon<F>;
}
pub trait ReedSolomon<F>: Debug + Send + Sync {
fn interleaved_encode(
&self,
interleaved_coeffs: &[&[F]],
codeword_length: usize,
interleaving_depth: usize,
) -> Vec<F>;
}
assert_obj_safe!(ReedSolomon<crate::algebra::fields::Field256>);
#[derive(Clone, Copy, Debug, Default)]
pub struct ArkNtt<F: FftField>(PhantomData<F>);
impl<F: FftField> ReedSolomon<F> for ArkNtt<F> {
fn interleaved_encode(
&self,
interleaved_coeffs: &[&[F]],
codeword_length: usize,
interleaving_depth: usize,
) -> Vec<F> {
ark_ntt(interleaved_coeffs, codeword_length, interleaving_depth)
}
}
pub fn interleaved_rs_encode<F: 'static>(
interleaved_coeffs: &[&[F]],
codeword_length: usize,
interleaving_depth: usize,
) -> Vec<F> {
let engine = NTT.get::<F>().expect("Unsupported field");
engine.interleaved_encode(interleaved_coeffs, codeword_length, interleaving_depth)
}
#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(coeffs), fields(size = coeffs.len())))]
fn ark_ntt<F: FftField>(
coeffs: &[&[F]],
codeword_length: usize,
interleaving_depth: usize,
) -> Vec<F> {
if coeffs.is_empty() {
return Vec::new();
}
let poly_size = coeffs[0].len();
for poly in coeffs {
assert_eq!(poly.len(), poly_size);
}
assert!(poly_size.is_multiple_of(interleaving_depth));
let message_length = poly_size / interleaving_depth;
let per_poly_size = codeword_length * interleaving_depth;
let total_size = per_poly_size * coeffs.len();
assert!(codeword_length.is_multiple_of(message_length));
let mut result = vec![F::ZERO; total_size];
for (poly_index, poly) in coeffs.iter().enumerate() {
for (block_index, block) in poly.chunks_exact(message_length).enumerate() {
let dst = poly_index * per_poly_size + block_index * codeword_length;
result[dst..dst + message_length].copy_from_slice(block);
}
}
ntt_batch(&mut result, codeword_length);
transpose(
&mut result,
coeffs.len() * interleaving_depth,
codeword_length,
);
result
}
#[cfg(test)]
mod tests {
use ark_ff::Field;
use super::*;
use crate::algebra::{fields::Field64, ntt::cooley_tukey::NttEngine};
#[test]
fn test_expand_from_coeff_size_2() {
let engine = NttEngine::<Field64>::new_from_fftfield();
let c0 = Field64::from(1);
let c1 = Field64::from(2);
let coeffs = vec![c0, c1];
let expansion = 2;
let omega = engine.root(4);
let f0 = c0;
let f1 = c1;
let f2 = c0 * omega.pow([0]);
let f3 = c1 * omega.pow([1]);
let expected_f0 = f0 + f1;
let expected_f1 = f0 - f1;
let expected_f2 = f2 + f3;
let expected_f3 = f2 - f3;
let expected_values_transposed = vec![expected_f0, expected_f2, expected_f1, expected_f3];
let computed_values = test_utils::expand_from_coeff(&coeffs, expansion);
assert_eq!(computed_values, expected_values_transposed);
}
#[test]
fn test_expand_from_coeff_size_4() {
let engine = NttEngine::<Field64>::new_from_fftfield();
let c0 = Field64::from(1);
let c1 = Field64::from(2);
let c2 = Field64::from(3);
let c3 = Field64::from(4);
let coeffs = vec![c0, c1, c2, c3];
let expansion = 4;
let omega = engine.root(16);
let f0 = c0;
let f1 = c1;
let f2 = c2;
let f3 = c3;
let f4 = c0 * omega.pow([1]).pow([0]);
let f5 = c1 * omega.pow([1]).pow([1]);
let f6 = c2 * omega.pow([1]).pow([2]);
let f7 = c3 * omega.pow([1]).pow([3]);
let f8 = c0 * omega.pow([2]).pow([0]);
let f9 = c1 * omega.pow([2]).pow([1]);
let f10 = c2 * omega.pow([2]).pow([2]);
let f11 = c3 * omega.pow([2]).pow([3]);
let f12 = c0 * omega.pow([3]).pow([0]);
let f13 = c1 * omega.pow([3]).pow([1]);
let f14 = c2 * omega.pow([3]).pow([2]);
let f15 = c3 * omega.pow([3]).pow([3]);
let omega = engine.root(4);
let omega1 = omega; let omega2 = omega * omega; let omega3 = omega * omega2; let omega4 = omega * omega3;
let expected_f0 = f0 + f1 + f2 + f3;
let expected_f1 = f0 + f1 * omega1 + f2 * omega2 + f3 * omega3;
let expected_f2 = f0 + f1 * omega2 + f2 * omega4 + f3 * omega2;
let expected_f3 = f0 + f1 * omega3 + f2 * omega2 + f3 * omega1;
let expected_f4 = f4 + f5 + f6 + f7;
let expected_f5 = f4 + f5 * omega1 + f6 * omega2 + f7 * omega3;
let expected_f6 = f4 + f5 * omega2 + f6 * omega4 + f7 * omega2;
let expected_f7 = f4 + f5 * omega3 + f6 * omega2 + f7 * omega1;
let expected_f8 = f8 + f9 + f10 + f11;
let expected_f9 = f8 + f9 * omega1 + f10 * omega2 + f11 * omega3;
let expected_f10 = f8 + f9 * omega2 + f10 * omega4 + f11 * omega2;
let expected_f11 = f8 + f9 * omega3 + f10 * omega2 + f11 * omega1;
let expected_f12 = f12 + f13 + f14 + f15;
let expected_f13 = f12 + f13 * omega1 + f14 * omega2 + f15 * omega3;
let expected_f14 = f12 + f13 * omega2 + f14 * omega4 + f15 * omega2;
let expected_f15 = f12 + f13 * omega3 + f14 * omega2 + f15 * omega1;
let expected_values_transposed = vec![
expected_f0,
expected_f4,
expected_f8,
expected_f12,
expected_f1,
expected_f5,
expected_f9,
expected_f13,
expected_f2,
expected_f6,
expected_f10,
expected_f14,
expected_f3,
expected_f7,
expected_f11,
expected_f15,
];
let computed_values = test_utils::expand_from_coeff(&coeffs, expansion);
assert_eq!(computed_values, expected_values_transposed);
}
#[test]
fn test_interleaved_rs_encode() {
use ark_std::UniformRand;
let mut rng = ark_std::test_rng();
let count = 1 << 20;
let expansion = 4;
let folding_factor = 6;
let codeword_length = (count >> folding_factor) * expansion;
let poly: Vec<_> = (0..count).map(|_| Field64::rand(&mut rng)).collect();
let block_size = count / (1 << folding_factor);
let mut blocks = Vec::with_capacity(1 << folding_factor);
for block in poly.chunks_exact(block_size) {
blocks.push(test_utils::expand_from_coeff(block, expansion));
}
let evals_len = block_size * expansion;
let mut expected = Vec::with_capacity(count * expansion);
for i in 0..evals_len {
for block in &blocks {
expected.push(block[i]);
}
}
let interleaved_ntt =
interleaved_rs_encode(&[poly.as_slice()], codeword_length, 1 << folding_factor);
assert_eq!(expected, interleaved_ntt);
}
}