use super::arith::{mod_inv_32, mod_pow_32};
use alloc::vec;
use alloc::vec::Vec;
pub fn is_prime_32(n: u32) -> bool {
if n < 2 {
return false;
}
if n < 4 {
return true;
}
if n.is_multiple_of(2) || n.is_multiple_of(3) {
return false;
}
let mut d = 5u32;
while d.saturating_mul(d) <= n {
if n.is_multiple_of(d) || n.is_multiple_of(d + 2) {
return false;
}
d += 6;
}
true
}
pub fn generate_primes_28(poly_degree: usize, count: usize) -> Vec<u32> {
let two_n = (2 * poly_degree) as u64;
let mut primes = Vec::with_capacity(count);
let k_max = ((1u64 << 28) - 1) / two_n;
for k in (1..=k_max).rev() {
let candidate = k * two_n + 1;
debug_assert!(candidate < (1u64 << 28));
if is_prime_32(candidate as u32) {
primes.push(candidate as u32);
if primes.len() == count {
break;
}
}
}
assert!(
primes.len() == count,
"Cannot find {count} 28-bit primes for N={poly_degree}, found only {}",
primes.len()
);
primes
}
pub(crate) fn small_factor_32(mut n: u32) -> Vec<u32> {
let mut factors = Vec::new();
if n.is_multiple_of(2) {
factors.push(2);
while n.is_multiple_of(2) {
n /= 2;
}
}
let mut d = 3u32;
while d.saturating_mul(d) <= n {
if n.is_multiple_of(d) {
factors.push(d);
while n.is_multiple_of(d) {
n /= d;
}
}
d += 2;
}
if n > 1 {
factors.push(n);
}
factors
}
fn find_generator_32(q: u32, prime_factors: &[u32]) -> u32 {
let q_minus_1 = q - 1;
for g in 2..q {
let mut is_generator = true;
for &p in prime_factors {
let exp = q_minus_1 / p;
if mod_pow_32(g, exp, q) == 1 {
is_generator = false;
break;
}
}
if is_generator {
return g;
}
}
panic!("No generator found for q={q} — this should never happen");
}
pub fn find_primitive_root(n: usize, q: u32) -> u32 {
let two_n = 2 * n as u32;
assert!(
(q - 1).is_multiple_of(two_n),
"find_primitive_root: q={q} does not satisfy q ≡ 1 (mod 2N={})",
two_n
);
let q_minus_1 = q - 1;
let prime_factors = small_factor_32(q_minus_1);
let g = find_generator_32(q, &prime_factors);
let exp = q_minus_1 / two_n;
let psi = mod_pow_32(g, exp, q);
debug_assert_eq!(mod_pow_32(psi, two_n, q), 1, "ψ^(2N) ≠ 1: not a 2N-th root");
debug_assert_eq!(
mod_pow_32(psi, n as u32, q),
q - 1,
"ψ^N ≠ -1: not a PRIMITIVE 2N-th root"
);
psi
}
#[inline]
pub(crate) fn bit_reverse(x: u32, bits: u32) -> u32 {
x.reverse_bits() >> (32 - bits)
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct NttRootTable {
pub n: usize,
pub log_n: u32,
pub q: u32,
pub root_powers: Vec<u32>,
pub inv_root_powers: Vec<u32>,
pub n_inv: u32,
}
impl NttRootTable {
pub fn new(n: usize, q: u32) -> Self {
assert!(n >= 2 && n.is_power_of_two(), "N must be a power of 2 >= 2");
assert!(q < (1u32 << 28), "q={q} must be < 2^28");
let log_n = n.trailing_zeros();
assert!(
(q - 1).is_multiple_of(2 * n as u32),
"q={q} does not satisfy q ≡ 1 (mod 2N={})",
2 * n
);
let psi = find_primitive_root(n, q);
let psi_inv = mod_inv_32(psi, q);
let n_inv = mod_inv_32(n as u32, q);
let mut root_powers = vec![0u32; n];
let mut inv_root_powers = vec![0u32; n];
for i in 0..n {
let exp = bit_reverse(i as u32, log_n);
root_powers[i] = mod_pow_32(psi, exp, q);
inv_root_powers[i] = mod_pow_32(psi_inv, exp, q);
}
Self {
n,
log_n,
q,
root_powers,
inv_root_powers,
n_inv,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_prime_32() {
assert!(!is_prime_32(0));
assert!(!is_prime_32(1));
assert!(is_prime_32(2));
assert!(is_prime_32(3));
assert!(!is_prime_32(4));
assert!(is_prime_32(5));
assert!(is_prime_32(7));
assert!(!is_prime_32(9));
assert!(is_prime_32(13));
assert!(is_prime_32(268_435_399));
}
#[test]
fn test_generate_primes_28() {
for &n in &[16, 64, 1024, 4096] {
let primes = generate_primes_28(n, 5);
assert_eq!(primes.len(), 5);
for &p in &primes {
assert!(p < (1u32 << 28), "Prime {p} >= 2^28");
assert!(is_prime_32(p), "{p} is not prime");
assert_eq!(
(p - 1) % (2 * n as u32),
0,
"Prime {p} is not NTT-friendly for N={n}"
);
}
for i in 0..primes.len() {
for j in (i + 1)..primes.len() {
assert_ne!(primes[i], primes[j], "Duplicate primes");
}
}
}
}
#[test]
fn test_find_primitive_root() {
for &n in &[16, 64, 1024] {
let q = generate_primes_28(n, 1)[0];
let psi = find_primitive_root(n, q);
assert_eq!(
mod_pow_32(psi, 2 * n as u32, q),
1,
"ψ^(2N) ≠ 1 for N={n}, q={q}"
);
assert_eq!(
mod_pow_32(psi, n as u32, q),
q - 1,
"ψ^N ≠ -1 for N={n}, q={q}"
);
}
}
}