use crate::error::{FFTError, FFTResult};
pub const MOD998244353: u64 = 998_244_353;
pub const MOD1000000007: u64 = 1_000_000_007;
pub const MOD469762049: u64 = 469_762_049;
pub const KNOWN_NTT_PRIMES: &[(u64, u64)] = &[
(998_244_353, 3), (1_004_535_809, 3), (469_762_049, 3), (167_772_161, 3), (2_013_265_921, 31), (786_433, 10), ];
#[inline(always)]
pub fn mulmod(a: u64, b: u64, m: u64) -> u64 {
((a as u128 * b as u128) % m as u128) as u64
}
#[inline(always)]
pub fn addmod(a: u64, b: u64, m: u64) -> u64 {
let s = a + b;
if s >= m { s - m } else { s }
}
#[inline(always)]
pub fn submod(a: u64, b: u64, m: u64) -> u64 {
if a >= b { a - b } else { a + m - b }
}
pub fn powmod(mut base: u64, mut exp: u64, m: u64) -> u64 {
if m == 1 {
return 0;
}
let mut result = 1u64;
base %= m;
while exp > 0 {
if exp & 1 == 1 {
result = mulmod(result, base, m);
}
base = mulmod(base, base, m);
exp >>= 1;
}
result
}
pub fn modinv(a: u64, m: u64) -> u64 {
powmod(a, m - 2, m)
}
pub fn find_ntt_prime(n: usize, bits: u32) -> FFTResult<(u64, u64)> {
if !n.is_power_of_two() {
return Err(FFTError::ValueError(format!(
"NTT length {n} must be a power of two"
)));
}
let max_val = 1u64 << bits;
let n64 = n as u64;
for &(p, g) in KNOWN_NTT_PRIMES {
if p < max_val && (p - 1) % n64 == 0 {
return Ok((p, g));
}
}
let mut k = 1u64;
while k * n64 + 1 < max_val {
let candidate = k * n64 + 1;
if is_prime(candidate) {
let g = find_primitive_root_of(candidate)?;
return Ok((candidate, g));
}
k += 1;
}
Err(FFTError::ValueError(format!(
"no NTT-friendly prime found for n={n} within {bits} bits"
)))
}
fn is_prime(n: u64) -> bool {
if n < 2 {
return false;
}
if n == 2 || n == 3 || n == 5 || n == 7 {
return true;
}
if n % 2 == 0 || n % 3 == 0 || n % 5 == 0 {
return false;
}
let mut d = n - 1;
let mut r = 0u32;
while d % 2 == 0 {
d /= 2;
r += 1;
}
let witnesses: &[u64] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
'outer: for &a in witnesses {
if a >= n {
continue;
}
let mut x = powmod(a, d, n);
if x == 1 || x == n - 1 {
continue;
}
for _ in 0..r - 1 {
x = mulmod(x, x, n);
if x == n - 1 {
continue 'outer;
}
}
return false;
}
true
}
fn find_primitive_root_of(p: u64) -> FFTResult<u64> {
for &(known_p, known_g) in KNOWN_NTT_PRIMES {
if known_p == p {
return Ok(known_g);
}
}
let phi = p - 1;
let factors = factorize(phi);
for g in 2..p {
let is_root = factors.iter().all(|&f| powmod(g, phi / f, p) != 1);
if is_root {
return Ok(g);
}
}
Err(FFTError::ValueError(format!(
"no primitive root found for prime {p}"
)))
}
fn factorize(mut n: u64) -> Vec<u64> {
let mut factors = Vec::new();
let mut d = 2u64;
while d * d <= n {
if n % d == 0 {
factors.push(d);
while n % d == 0 {
n /= d;
}
}
d += 1;
}
if n > 1 {
factors.push(n);
}
factors
}
fn bit_reverse_permute(a: &mut [u64]) {
let n = a.len();
if n == 0 {
return;
}
let log_n = n.ilog2() as usize;
for i in 0..n {
let j = reverse_bits(i, log_n);
if i < j {
a.swap(i, j);
}
}
}
fn reverse_bits(mut x: usize, bits: usize) -> usize {
let mut result = 0usize;
for _ in 0..bits {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
fn validate(n: usize, p: u64) -> FFTResult<()> {
if n == 0 {
return Err(FFTError::ValueError("NTT length must be > 0".into()));
}
if !n.is_power_of_two() {
return Err(FFTError::ValueError(format!(
"NTT length {n} must be a power of two"
)));
}
if (p - 1) % n as u64 != 0 {
return Err(FFTError::ValueError(format!(
"modulus {p} does not support NTT of length {n} ((p-1) mod n ≠ 0)"
)));
}
Ok(())
}
pub fn ntt(a: &mut [u64], p: u64, g: u64, inverse: bool) -> FFTResult<()> {
let n = a.len();
validate(n, p)?;
bit_reverse_permute(a);
let mut len = 2_usize;
while len <= n {
let exp = (p - 1) / len as u64;
let root = if inverse {
powmod(modinv(g, p), exp, p)
} else {
powmod(g, exp, p)
};
let mut i = 0;
while i < n {
let mut w = 1u64;
for j in 0..len / 2 {
let u = a[i + j];
let v = mulmod(a[i + j + len / 2], w, p);
a[i + j] = addmod(u, v, p);
a[i + j + len / 2] = submod(u, v, p);
w = mulmod(w, root, p);
}
i += len;
}
len <<= 1;
}
if inverse {
let n_inv = modinv(n as u64, p);
for x in a.iter_mut() {
*x = mulmod(*x, n_inv, p);
}
}
Ok(())
}
pub fn ntt_mul(a: &[u64], b: &[u64], p: u64) -> FFTResult<Vec<u64>> {
if a.is_empty() || b.is_empty() {
return Err(FFTError::ValueError(
"polynomial inputs must not be empty".to_string(),
));
}
let result_len = a.len() + b.len() - 1;
let fft_len = result_len.next_power_of_two();
validate(fft_len, p)?;
let g = generator_for_prime(p)?;
let mut fa: Vec<u64> = a.iter().map(|&x| x % p).collect();
fa.resize(fft_len, 0);
let mut fb: Vec<u64> = b.iter().map(|&x| x % p).collect();
fb.resize(fft_len, 0);
ntt(&mut fa, p, g, false)?;
ntt(&mut fb, p, g, false)?;
for (x, y) in fa.iter_mut().zip(fb.iter()) {
*x = mulmod(*x, *y, p);
}
ntt(&mut fa, p, g, true)?;
fa.truncate(result_len);
Ok(fa)
}
fn generator_for_prime(p: u64) -> FFTResult<u64> {
for &(known_p, known_g) in KNOWN_NTT_PRIMES {
if known_p == p {
return Ok(known_g);
}
}
find_primitive_root_of(p)
}
pub fn ntt_998244353(a: &mut [u64], inverse: bool) -> FFTResult<()> {
const P: u64 = MOD998244353;
const G: u64 = 3;
ntt(a, P, G, inverse)
}
pub fn ntt_mul_998244353(a: &[u64], b: &[u64]) -> FFTResult<Vec<u64>> {
ntt_mul(a, b, MOD998244353)
}
pub fn convolve_exact(a: &[i64], b: &[i64]) -> FFTResult<Vec<i64>> {
if a.is_empty() || b.is_empty() {
return Err(FFTError::ValueError("inputs must not be empty".into()));
}
let m1 = MOD998244353;
let m2 = MOD469762049;
let to_u64 = |x: i64, m: u64| -> u64 {
if x >= 0 {
(x as u64) % m
} else {
let pos = ((-x) as u64) % m;
if pos == 0 { 0 } else { m - pos }
}
};
let a1: Vec<u64> = a.iter().map(|&x| to_u64(x, m1)).collect();
let b1: Vec<u64> = b.iter().map(|&x| to_u64(x, m1)).collect();
let a2: Vec<u64> = a.iter().map(|&x| to_u64(x, m2)).collect();
let b2: Vec<u64> = b.iter().map(|&x| to_u64(x, m2)).collect();
let c1 = ntt_mul(&a1, &b1, m1)?;
let c2 = ntt_mul(&a2, &b2, m2)?;
let m1_inv_m2 = modinv(m1 % m2, m2);
let m1_m2 = m1 as i128 * m2 as i128;
let half = m1_m2 / 2;
let result: Vec<i64> = c1
.iter()
.zip(c2.iter())
.map(|(&r1, &r2)| {
let r1_mod_m2 = r1 % m2;
let diff = submod(r2, r1_mod_m2, m2);
let t = mulmod(diff, m1_inv_m2, m2) as u128;
let x = r1 as i128 + m1 as i128 * t as i128;
if x > half { (x - m1_m2) as i64 } else { x as i64 }
})
.collect();
Ok(result)
}
pub fn poly_inv_mod_xn(f: &[u64], n: usize, p: u64) -> FFTResult<Vec<u64>> {
if f.is_empty() || f[0] == 0 {
return Err(FFTError::ValueError(
"constant term must be nonzero for polynomial inversion".into(),
));
}
let mut g = vec![modinv(f[0], p)];
let mut k = 1_usize;
while k < n {
k = (2 * k).min(n);
let fg = ntt_mul(f, &g, p)?;
let fg_trunc: Vec<u64> = fg.into_iter().take(k).collect();
let fg2 = ntt_mul(&fg_trunc, &g, p)?;
let two_g: Vec<u64> = g.iter().map(|&x| mulmod(x, 2, p)).collect();
let len = k;
let mut g_new = vec![0u64; len];
for (i, &x) in two_g.iter().take(len).enumerate() {
g_new[i] = addmod(g_new[i], x, p);
}
for (i, &y) in fg2.iter().take(len).enumerate() {
g_new[i] = submod(g_new[i], y, p);
}
g = g_new;
}
g.truncate(n);
Ok(g)
}
pub fn poly_deriv_mod(f: &[u64], p: u64) -> Vec<u64> {
if f.len() <= 1 {
return vec![0];
}
f[1..]
.iter()
.enumerate()
.map(|(i, &c)| mulmod(c, (i + 1) as u64 % p, p))
.collect()
}
pub fn poly_integral_mod(f: &[u64], p: u64) -> Vec<u64> {
let mut g = vec![0u64]; for (i, &c) in f.iter().enumerate() {
let inv = modinv((i + 1) as u64 % p, p);
g.push(mulmod(c, inv, p));
}
g
}
pub fn poly_gcd_mod(a: &[u64], b: &[u64], p: u64) -> FFTResult<Vec<u64>> {
let mut u: Vec<u64> = trim_zeros(a);
let mut v: Vec<u64> = trim_zeros(b);
while !v.iter().all(|&x| x == 0) {
let r = poly_rem_mod(&u, &v, p)?;
u = v;
v = trim_zeros(&r);
}
if u.is_empty() {
return Ok(vec![1]);
}
let lc = *u.last().unwrap_or(&1);
if lc == 0 {
return Ok(vec![1]);
}
let lc_inv = modinv(lc, p);
Ok(u.iter().map(|&c| mulmod(c, lc_inv, p)).collect())
}
fn poly_rem_mod(a: &[u64], b: &[u64], p: u64) -> FFTResult<Vec<u64>> {
if b.is_empty() {
return Err(FFTError::ValueError("division by zero polynomial".into()));
}
let mut r: Vec<u64> = a.to_vec();
let db = b.len() - 1;
let lc_b = *b.last().unwrap_or(&1);
if lc_b == 0 {
return Err(FFTError::ValueError(
"leading coefficient of divisor is zero".into(),
));
}
let lc_b_inv = modinv(lc_b, p);
while r.len() > db && !r.is_empty() {
let n = r.len();
let top = mulmod(*r.last().unwrap_or(&0), lc_b_inv, p);
let shift = n - 1 - db;
for (i, &bi) in b.iter().enumerate() {
let val = mulmod(top, bi, p);
r[shift + i] = submod(r[shift + i], val, p);
}
while r.last() == Some(&0) && r.len() > 1 {
r.pop();
}
if r.last() == Some(&0) && r.len() == 1 {
break;
}
}
Ok(r)
}
fn trim_zeros(a: &[u64]) -> Vec<u64> {
let mut v = a.to_vec();
while v.len() > 1 && v.last() == Some(&0) {
v.pop();
}
v
}
#[cfg(test)]
mod tests {
use super::*;
const P: u64 = MOD998244353;
#[test]
fn test_ntt_roundtrip() {
let orig = vec![1u64, 2, 3, 4];
let mut a = orig.clone();
ntt(&mut a, P, 3, false).expect("forward");
ntt(&mut a, P, 3, true).expect("inverse");
assert_eq!(a, orig);
}
#[test]
fn test_ntt_998244353_roundtrip() {
let orig = vec![0u64, 5, 3, 2, 7, 1, 0, 0];
let mut a = orig.clone();
ntt_998244353(&mut a, false).expect("forward");
ntt_998244353(&mut a, true).expect("inverse");
assert_eq!(a, orig);
}
#[test]
fn test_ntt_mul_basic() {
let c = ntt_mul(&[1, 2], &[3, 4], P).expect("mul");
assert_eq!(c, vec![3, 10, 8]);
}
#[test]
fn test_ntt_mul_identity() {
let a = vec![3u64, 1, 4, 1, 5, 9, 2, 6];
let one = vec![1u64];
let c = ntt_mul(&a, &one, P).expect("mul by 1");
assert_eq!(c, a);
}
#[test]
fn test_ntt_mul_all_ones() {
let a = vec![1u64; 4];
let c = ntt_mul(&a, &a, P).expect("mul");
assert_eq!(c, vec![1, 2, 3, 4, 3, 2, 1]);
}
#[test]
fn test_ntt_mul_matches_brute_force() {
let a = vec![1u64, 2, 3, 4, 5, 6, 7, 8];
let b = vec![8u64, 7, 6, 5, 4, 3, 2, 1];
let n = a.len() + b.len() - 1;
let mut expected = vec![0u64; n];
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
expected[i + j] = (expected[i + j] + ai * bj) % P;
}
}
let result = ntt_mul(&a, &b, P).expect("ntt mul");
assert_eq!(result, expected);
}
#[test]
fn test_find_ntt_prime_small() {
let (p, _g) = find_ntt_prime(8, 30).expect("prime found");
assert!(p < (1u64 << 30));
assert!((p - 1) % 8 == 0);
assert!(is_prime(p));
}
#[test]
fn test_find_ntt_prime_known() {
let (p, g) = find_ntt_prime(8, 32).expect("prime found");
assert!((p - 1) % 8 == 0);
let phi = p - 1;
let factors = factorize(phi);
for f in &factors {
assert_ne!(powmod(g, phi / f, p), 1, "g is not a primitive root");
}
}
#[test]
fn test_find_ntt_prime_non_power_of_two_errors() {
assert!(find_ntt_prime(3, 30).is_err());
}
#[test]
fn test_convolve_exact_basic() {
let a = vec![1i64, 2, 3];
let b = vec![4i64, 5, 6];
let c = convolve_exact(&a, &b).expect("ok");
assert_eq!(c, vec![4, 13, 28, 27, 18]);
}
#[test]
fn test_convolve_exact_negatives() {
let a = vec![-1i64, 2];
let b = vec![3i64, -4];
let c = convolve_exact(&a, &b).expect("ok");
assert_eq!(c, vec![-3, 10, -8]);
}
#[test]
fn test_poly_inv_mod_xn() {
let f = vec![1u64, P - 1]; let g = poly_inv_mod_xn(&f, 4, P).expect("inv");
assert_eq!(g.len(), 4);
let fg = ntt_mul(&f, &g, P).expect("mul");
assert_eq!(fg[0], 1);
for &c in &fg[1..4] {
assert_eq!(c, 0, "non-zero coefficient: {c}");
}
}
#[test]
fn test_poly_deriv_mod() {
let f = vec![1u64, 2, 3];
let df = poly_deriv_mod(&f, P);
assert_eq!(df[0], 2);
assert_eq!(df[1], 6);
}
#[test]
fn test_poly_gcd_mod() {
let a = vec![P - 1, 0, 1]; let b = vec![P - 1, 1]; let g = poly_gcd_mod(&a, &b, P).expect("gcd");
assert_eq!(g.last(), Some(&1u64));
}
#[test]
fn test_ntt_invalid_length_error() {
let mut a = vec![1u64, 2, 3];
assert!(ntt(&mut a, P, 3, false).is_err());
}
}