use crate::prelude::*;
use super::arith::is_prime;
use super::arith::{mod_inv, mod_mul, mod_pow, primitive_root, two_adic_valuation};
use super::error::NttError;
#[derive(Debug, Clone)]
pub struct NttPlan {
n: usize,
modulus: u64,
root: u64,
inv_n: u64,
twiddles: Vec<u64>,
inv_twiddles: Vec<u64>,
}
impl NttPlan {
pub fn new(n: usize, modulus: u64) -> Result<Self, NttError> {
if n == 0 || (n & (n - 1)) != 0 {
return Err(NttError::NotPowerOfTwo(n));
}
if modulus < 2 || !is_prime(modulus) {
return Err(NttError::NotPrime(modulus));
}
let p_minus_1 = modulus - 1;
let k = two_adic_valuation(p_minus_1).unwrap_or(0);
let max_n = 1usize << k;
if n > max_n {
return Err(NttError::SizeTooLarge { n, max: max_n });
}
if n == 1 {
return Ok(Self {
n: 1,
modulus,
root: 1,
inv_n: 1,
twiddles: vec![1],
inv_twiddles: vec![1],
});
}
let g = primitive_root(modulus).ok_or(NttError::NoRootOfUnity { n, modulus })?;
let root = mod_pow(g, p_minus_1 / n as u64, modulus);
let inv_root = mod_inv(root, modulus).ok_or(NttError::NoRootOfUnity { n, modulus })?;
let inv_n = mod_inv(n as u64, modulus).ok_or(NttError::NoRootOfUnity { n, modulus })?;
let half = n / 2;
let mut twiddles = Vec::with_capacity(half);
let mut tw = 1u64;
for _ in 0..half {
twiddles.push(tw);
tw = mod_mul(tw, root, modulus);
}
let mut inv_twiddles = Vec::with_capacity(half);
let mut itw = 1u64;
for _ in 0..half {
inv_twiddles.push(itw);
itw = mod_mul(itw, inv_root, modulus);
}
Ok(Self {
n,
modulus,
root,
inv_n,
twiddles,
inv_twiddles,
})
}
#[inline]
pub fn size(&self) -> usize {
self.n
}
#[inline]
pub fn modulus(&self) -> u64 {
self.modulus
}
#[inline]
pub fn root(&self) -> u64 {
self.root
}
pub fn forward(&self, data: &mut [u64]) {
assert_eq!(data.len(), self.n, "data length must equal plan size");
if self.n <= 1 {
return;
}
bit_reverse_permutation(data);
self.butterfly_dit(data, &self.twiddles);
}
pub fn inverse(&self, data: &mut [u64]) {
assert_eq!(data.len(), self.n, "data length must equal plan size");
if self.n <= 1 {
return;
}
bit_reverse_permutation(data);
self.butterfly_dit(data, &self.inv_twiddles);
let m = self.modulus;
let inv_n = self.inv_n;
for x in data.iter_mut() {
*x = mod_mul(*x, inv_n, m);
}
}
pub fn forward_into(&self, input: &[u64], output: &mut [u64]) {
assert_eq!(input.len(), self.n, "input length must equal plan size");
assert_eq!(output.len(), self.n, "output length must equal plan size");
output.copy_from_slice(input);
self.forward(output);
}
fn butterfly_dit(&self, data: &mut [u64], tw: &[u64]) {
let n = self.n;
let m = self.modulus;
let mut len = 2; while len <= n {
let half = len / 2;
let step = n / len;
for start in (0..n).step_by(len) {
for j in 0..half {
let w = tw[j * step];
let u = data[start + j];
let v = mod_mul(data[start + j + half], w, m);
data[start + j] = if u + v >= m { u + v - m } else { u + v };
data[start + j + half] = if u >= v { u - v } else { u + m - v };
}
}
len <<= 1;
}
}
}
fn bit_reverse_permutation(data: &mut [u64]) {
let n = data.len();
if n <= 2 {
if n == 2 {
}
return;
}
let log_n = n.trailing_zeros();
for i in 0..n {
#[allow(clippy::cast_possible_truncation)]
let i_u32 = i as u32;
let j = reverse_bits(i_u32, log_n) as usize;
if i < j {
data.swap(i, j);
}
}
}
#[inline]
fn reverse_bits(mut x: u32, bits: u32) -> u32 {
let mut result = 0u32;
for _ in 0..bits {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reverse_bits() {
assert_eq!(reverse_bits(0b000, 3), 0b000);
assert_eq!(reverse_bits(0b001, 3), 0b100);
assert_eq!(reverse_bits(0b011, 3), 0b110);
assert_eq!(reverse_bits(0b101, 3), 0b101);
}
#[test]
fn test_bit_reverse_permutation() {
let mut data = vec![0, 1, 2, 3, 4, 5, 6, 7];
bit_reverse_permutation(&mut data);
assert_eq!(data, vec![0, 4, 2, 6, 1, 5, 3, 7]);
}
#[test]
fn test_plan_n1() {
let plan = NttPlan::new(1, 998_244_353);
assert!(plan.is_ok());
let plan = plan.expect("n=1 plan");
let mut data = vec![42u64];
plan.forward(&mut data);
assert_eq!(data, vec![42]);
plan.inverse(&mut data);
assert_eq!(data, vec![42]);
}
#[test]
fn test_plan_errors() {
assert!(matches!(
NttPlan::new(3, 998_244_353),
Err(NttError::NotPowerOfTwo(3))
));
assert!(matches!(
NttPlan::new(0, 998_244_353),
Err(NttError::NotPowerOfTwo(0))
));
assert!(matches!(NttPlan::new(4, 15), Err(NttError::NotPrime(15))));
assert!(matches!(
NttPlan::new(1 << 24, 998_244_353),
Err(NttError::SizeTooLarge { .. })
));
}
}