ntt/
lib.rs

1use reikna::totient::totient;
2use reikna::factor::quick_factorize;
3use std::collections::HashMap;
4
5/// Modular arithmetic functions using i64
6fn mod_add(a: i64, b: i64, p: i64) -> i64 {
7    (a + b) % p
8}
9
10/// Modular multiplication
11fn mod_mul(a: i64, b: i64, p: i64) -> i64 {
12    (a * b) % p
13}
14
15/// Modular exponentiation
16/// # Arguments
17///
18/// * `base` - Base of the exponentiation.
19/// * `exp` - Exponent.
20/// * `p` - Prime modulus for the operations.
21///
22/// # Returns
23/// The result of the exponentiation modulo `p`.
24pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
25    let mut result = 1;
26    base %= p;
27    while exp > 0 {
28        if exp % 2 == 1 {
29            result = mod_mul(result, base, p);
30        }
31        base = mod_mul(base, base, p);
32        exp /= 2;
33    }
34    result
35}
36
37/// Extended Euclidean algorithm
38/// # Arguments
39///
40/// * `a` - First number.
41/// * `b` - Second number.
42///
43/// # Returns
44/// A tuple with the greatest common divisor and the Bézout coefficients.
45fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
46    if b == 0 {
47        (a, 1, 0)  // gcd, x, y
48    } else {
49        let (gcd, x1, y1) = extended_gcd(b, a % b);
50        (gcd, y1, x1 - (a / b) * y1)
51    }
52}
53
54/// Compute the modular inverse of a modulo modulus
55fn mod_inv(a: i64, modulus: i64) -> i64 {
56    let (gcd, x, _) = extended_gcd(a, modulus);
57    if gcd != 1 {
58        panic!("{} and {} are not coprime, no inverse exists", a, modulus);
59    }
60    (x % modulus + modulus) % modulus  // Ensure a positive result
61}
62
63/// Compute n-th root of unity (omega) for p not necessarily prime
64/// # Arguments
65///
66/// * `modulus` - Modulus. n must divide each prime power factor.
67/// * `n` - Order of the root of unity.
68/// 
69/// # Returns
70/// The n-th root of unity modulo `modulus`.
71///
72/// # Examples
73///
74/// ```
75/// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity.
76/// let modulus = 17 * 17;
77/// let n = 8;
78/// let omega = ntt::omega(modulus, n);
79/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
80/// 
81/// // For modulus = 17*41*73, we compute and verify an 8th root of unity.
82/// let modulus = 17*41*73;
83/// let omega = ntt::omega(modulus, n);
84/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
85/// ```
86pub fn omega(modulus: i64, n: usize) -> i64 {
87    let factors = factorize(modulus as i64);
88    if factors.len() == 1 {
89        let (p, e) = factors.into_iter().next().unwrap();
90        let root = primitive_root(p, e); // primitive root mod p
91        let grp_size = totient(modulus as u64) as i64;
92        assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size);
93        return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function
94    }
95    else {
96        return root_of_unity(modulus, n as i64)
97    }
98}
99
100/// Forward transform using NTT, output bit-reversed
101/// # Arguments
102///
103/// * `a` - Input vector.
104/// * `omega` - Primitive root of unity modulo `p`.
105/// * `n` - Length of the input vector and the result.
106/// * `p` - Prime modulus for the operations.
107///
108/// # Returns
109/// A vector representing the NTT of the input vector.
110///
111/// # Examples
112///
113/// ```
114/// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p
115/// let n: usize = 8;  // Length of the NTT (must be a power of 2)
116/// let omega = ntt::omega(modulus, n); // n-th root of unity
117/// let mut a = vec![1, 2, 3, 4];
118/// a.resize(n, 0);
119/// // Perform the forward NTT
120/// let a_ntt = ntt::ntt(&a, omega, n, modulus);
121/// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15];
122/// assert_eq!(a_ntt, a_ntt_expected);
123pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
124    let mut result = a.to_vec();
125    let mut step = n/2;
126	while step > 0 {
127		let w_i  = mod_exp(omega, (n/(2*step)).try_into().unwrap(), p);
128		for i in (0..n).step_by(2*step) { 
129			let mut w = 1;
130			for j in 0..step {
131				let u = result[i+j];
132				let v = result[i+j+step];
133				result[i+j] = mod_add(u,v,p);
134				result[i+j+step] = mod_mul(mod_add(u,p-v,p),w,p);
135				w = mod_mul(w,w_i,p);
136			}
137		}
138		step/=2;
139	}
140	result
141}
142
143/// Inverse transform using INTT, input bit-reversed
144/// # Arguments
145/// 
146/// * `a` - Input vector (bit-reversed).
147/// * `omega` - Primitive root of unity modulo `p`.
148/// * `n` - Length of the input vector and the result.
149/// * `p` - Prime modulus for the operations.
150///
151/// # Returns
152/// A vector representing the inverse NTT of the input vector.
153pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
154    let omega_inv = mod_inv(omega, p);
155    let n_inv = mod_inv(n as i64, p);
156    let mut result = a.to_vec();
157    let mut step = 1;
158	while step < n {  
159		let w_i = mod_exp(omega_inv, (n/(2*step)).try_into().unwrap(), p);
160		for i in (0..n).step_by(2*step) { 
161			let mut w = 1;
162			for j in 0..step {
163				let u = result[i+j];
164				let v = mod_mul(result[i+j+step],w,p);
165				result[i+j] = mod_add(u,v,p);
166				result[i+j+step] = mod_add(u,p-v,p);
167				w = mod_mul(w,w_i,p);
168			}
169		}
170		step*=2;
171	}
172	result
173		.iter()
174        .map(|x| mod_mul(*x,n_inv,p))
175        .collect()
176}
177
178/// Naive polynomial multiplication
179/// # Arguments
180///
181/// * `a` - First polynomial (as a vector of coefficients).
182/// * `b` - Second polynomial (as a vector of coefficients).
183/// * `n` - Length of the polynomials and the result.
184/// * `p` - Prime modulus for the operations.
185///
186/// # Returns
187/// A vector representing the polynomial product modulo `p`.
188pub fn polymul(a: &Vec<i64>, b: &Vec<i64>, n: i64, p: i64) -> Vec<i64> {
189    let mut result = vec![0; n as usize];
190    for i in 0..a.len() {
191        for j in 0..b.len() {
192            result[(i + j) % n as usize] = mod_add(result[(i + j) % n as usize], mod_mul(a[i], b[j], p), p);
193        }
194    }
195    result
196}
197
198/// Multiply two polynomials using NTT (Number Theoretic Transform)
199/// 
200/// # Arguments
201/// 
202/// * `a` - First polynomial (as a vector of coefficients).
203/// * `b` - Second polynomial (as a vector of coefficients).
204/// * `n` - Length of the polynomials and the NTT (must be a power of 2).
205/// * `p` - Prime modulus for the operations.
206/// * `root` - Primitive root of unity modulo `p`.
207///
208/// # Returns
209/// A vector representing the polynomial product modulo `p`.
210pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i64> {
211
212    // Step 1: Perform the NTT (forward transform) on both polynomials
213    let a_ntt = ntt(a, omega, n, p);
214    let b_ntt = ntt(b, omega, n, p);
215    
216    // Step 2: Perform pointwise multiplication in the NTT domain
217    let c_ntt: Vec<i64> = a_ntt
218        .iter()
219        .zip(b_ntt.iter())
220        .map(|(x, y)| mod_mul(*x, *y, p)) // pointwise multiplication
221        .collect();
222
223    // Step 3: Apply the inverse NTT to get the result
224    let c = intt(&c_ntt, omega, n, p);
225    
226    c
227}
228
229/// Compute the prime factorization of `n` (with multiplicities)
230/// Uses reikna::quick_factorize internally
231/// # Arguments
232/// 
233/// * `n` - Number to factorize.
234/// 
235/// # Returns
236/// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
237fn factorize(n: i64) -> HashMap<i64, u32> {
238    let mut factors = HashMap::new();
239    for factor in quick_factorize(n as u64) {
240        *factors.entry(factor as i64).or_insert(0) += 1;
241    }
242    factors
243}
244
245/// Fast computation of a primitive root mod p^e
246/// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p
247/// # Arguments
248///
249/// * `p` - Prime modulus.
250/// * `e` - Exponent.
251///
252/// # Returns
253/// A primitive root modulo `p^e`.
254/// 
255/// # Examples
256///
257/// ```
258/// // For p = 17 and e = 2, we compute a primitive root modulo 289.
259/// let p = 17;
260/// let e = 2;
261/// let g = ntt::primitive_root(p, e);
262/// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1);
263pub fn primitive_root(p: i64, e: u32) -> i64 {
264    let g = primitive_root_mod_p(p);
265    let mut g_lifted = g; // Lift it to p^e
266    for _ in 1..e {
267        if mod_exp(g_lifted, p-1, p.pow(e)) == 1 {
268            g_lifted += p.pow(e - 1);
269        }
270    }
271    g_lifted
272}
273
274/// Finds a primitive root modulo a prime p
275/// # Arguments
276///
277/// * `p` - Prime modulus.
278///
279/// # Returns
280/// A primitive root modulo `p`.
281fn primitive_root_mod_p(p: i64) -> i64 {
282    let phi = p - 1;
283    let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities
284    for g in 2..p {
285        // Check if g is a primitive root by checking mod_exp conditions with all prime factors of phi
286        if factors.iter().all(|(&q, _)| mod_exp(g, phi / q, p) != 1) {
287            return g;
288        }
289    }
290    0 // Should never happen
291}
292
293/// the Chinese remainder theorem for two moduli
294/// # Arguments
295///
296/// * `a1` - First residue.
297/// * `n1` - First modulus.
298/// * `a2` - Second residue.
299/// * `n2` - Second modulus.
300///
301/// # Returns
302/// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
303pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
304    let n = n1 * n2;
305    let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2
306    let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1
307    let x = (a1 * m2 * n2 + a2 * m1 * n1) % n;
308    if x < 0 { x + n } else { x }
309}
310
311/// computes an n^th root of unity modulo a composite modulus
312/// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
313/// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus
314/// for the NTT, we require than a 2n^th root of unity exists
315/// # Arguments
316///
317/// * `modulus` - Modulus. n must divide each prime power factor.
318/// * `n` - Order of the root of unity.
319///
320/// # Returns
321/// The n-th root of unity modulo `modulus`.
322pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
323    let factors = factorize(modulus);
324    let mut result = 1;
325    for (&p, &e) in factors.iter() {
326		let omega = omega(p.pow(e), n.try_into().unwrap()); // Find primitive nth root of unity mod p^e
327        result = crt(result, modulus / p.pow(e), omega, p.pow(e)); // Combine with the running result using CRT
328	}
329	result
330}
331
332/// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
333/// # Arguments
334///
335/// * `omega` - n-th root of unity.
336/// * `n` - Order of the root of unity.
337/// * `modulus` - Modulus.
338///
339/// # Returns
340/// True if the root of unity satisfies the condition.
341pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool {
342    assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity");
343    assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)");
344    true
345}