ntt/lib.rs
1use reikna::totient::totient;
2use reikna::factor::quick_factorize;
3use std::collections::HashMap;
4
5/// Modular arithmetic functions using i64
6fn mod_add(a: i64, b: i64, p: i64) -> i64 {
7 (a + b) % p
8}
9
10/// Modular multiplication
11fn mod_mul(a: i64, b: i64, p: i64) -> i64 {
12 (a * b) % p
13}
14
15/// Modular exponentiation
16/// # Arguments
17///
18/// * `base` - Base of the exponentiation.
19/// * `exp` - Exponent.
20/// * `p` - Prime modulus for the operations.
21///
22/// # Returns
23/// The result of the exponentiation modulo `p`.
24pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
25 let mut result = 1;
26 base %= p;
27 while exp > 0 {
28 if exp % 2 == 1 {
29 result = mod_mul(result, base, p);
30 }
31 base = mod_mul(base, base, p);
32 exp /= 2;
33 }
34 result
35}
36
37/// Extended Euclidean algorithm
38/// # Arguments
39///
40/// * `a` - First number.
41/// * `b` - Second number.
42///
43/// # Returns
44/// A tuple with the greatest common divisor and the Bézout coefficients.
45fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
46 if b == 0 {
47 (a, 1, 0) // gcd, x, y
48 } else {
49 let (gcd, x1, y1) = extended_gcd(b, a % b);
50 (gcd, y1, x1 - (a / b) * y1)
51 }
52}
53
54/// Compute the modular inverse of a modulo modulus
55fn mod_inv(a: i64, modulus: i64) -> i64 {
56 let (gcd, x, _) = extended_gcd(a, modulus);
57 if gcd != 1 {
58 panic!("{} and {} are not coprime, no inverse exists", a, modulus);
59 }
60 (x % modulus + modulus) % modulus // Ensure a positive result
61}
62
63/// Compute n-th root of unity (omega) for p not necessarily prime
64/// # Arguments
65///
66/// * `modulus` - Modulus. n must divide each prime power factor.
67/// * `n` - Order of the root of unity.
68///
69/// # Returns
70/// The n-th root of unity modulo `modulus`.
71///
72/// # Examples
73///
74/// ```
75/// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity.
76/// let modulus = 17 * 17;
77/// let n = 8;
78/// let omega = ntt::omega(modulus, n);
79/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
80///
81/// // For modulus = 17*41*73, we compute and verify an 8th root of unity.
82/// let modulus = 17*41*73;
83/// let omega = ntt::omega(modulus, n);
84/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
85/// ```
86pub fn omega(modulus: i64, n: usize) -> i64 {
87 let factors = factorize(modulus as i64);
88 if factors.len() == 1 {
89 let (p, e) = factors.into_iter().next().unwrap();
90 let root = primitive_root(p, e); // primitive root mod p
91 let grp_size = totient(modulus as u64) as i64;
92 assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size);
93 return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function
94 }
95 else {
96 return root_of_unity(modulus, n as i64)
97 }
98}
99
100/// Forward transform using NTT, output bit-reversed
101/// # Arguments
102///
103/// * `a` - Input vector.
104/// * `omega` - Primitive root of unity modulo `p`.
105/// * `n` - Length of the input vector and the result.
106/// * `p` - Prime modulus for the operations.
107///
108/// # Returns
109/// A vector representing the NTT of the input vector.
110///
111/// # Examples
112///
113/// ```
114/// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p
115/// let n: usize = 8; // Length of the NTT (must be a power of 2)
116/// let omega = ntt::omega(modulus, n); // n-th root of unity
117/// let mut a = vec![1, 2, 3, 4];
118/// a.resize(n, 0);
119/// // Perform the forward NTT
120/// let a_ntt = ntt::ntt(&a, omega, n, modulus);
121/// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15];
122/// assert_eq!(a_ntt, a_ntt_expected);
123pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
124 let mut result = a.to_vec();
125 let mut step = n/2;
126 while step > 0 {
127 let w_i = mod_exp(omega, (n/(2*step)).try_into().unwrap(), p);
128 for i in (0..n).step_by(2*step) {
129 let mut w = 1;
130 for j in 0..step {
131 let u = result[i+j];
132 let v = result[i+j+step];
133 result[i+j] = mod_add(u,v,p);
134 result[i+j+step] = mod_mul(mod_add(u,p-v,p),w,p);
135 w = mod_mul(w,w_i,p);
136 }
137 }
138 step/=2;
139 }
140 result
141}
142
143/// Inverse transform using INTT, input bit-reversed
144/// # Arguments
145///
146/// * `a` - Input vector (bit-reversed).
147/// * `omega` - Primitive root of unity modulo `p`.
148/// * `n` - Length of the input vector and the result.
149/// * `p` - Prime modulus for the operations.
150///
151/// # Returns
152/// A vector representing the inverse NTT of the input vector.
153pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
154 let omega_inv = mod_inv(omega, p);
155 let n_inv = mod_inv(n as i64, p);
156 let mut result = a.to_vec();
157 let mut step = 1;
158 while step < n {
159 let w_i = mod_exp(omega_inv, (n/(2*step)).try_into().unwrap(), p);
160 for i in (0..n).step_by(2*step) {
161 let mut w = 1;
162 for j in 0..step {
163 let u = result[i+j];
164 let v = mod_mul(result[i+j+step],w,p);
165 result[i+j] = mod_add(u,v,p);
166 result[i+j+step] = mod_add(u,p-v,p);
167 w = mod_mul(w,w_i,p);
168 }
169 }
170 step*=2;
171 }
172 result
173 .iter()
174 .map(|x| mod_mul(*x,n_inv,p))
175 .collect()
176}
177
178/// Naive polynomial multiplication
179/// # Arguments
180///
181/// * `a` - First polynomial (as a vector of coefficients).
182/// * `b` - Second polynomial (as a vector of coefficients).
183/// * `n` - Length of the polynomials and the result.
184/// * `p` - Prime modulus for the operations.
185///
186/// # Returns
187/// A vector representing the polynomial product modulo `p`.
188pub fn polymul(a: &Vec<i64>, b: &Vec<i64>, n: i64, p: i64) -> Vec<i64> {
189 let mut result = vec![0; n as usize];
190 for i in 0..a.len() {
191 for j in 0..b.len() {
192 result[(i + j) % n as usize] = mod_add(result[(i + j) % n as usize], mod_mul(a[i], b[j], p), p);
193 }
194 }
195 result
196}
197
198/// Multiply two polynomials using NTT (Number Theoretic Transform)
199///
200/// # Arguments
201///
202/// * `a` - First polynomial (as a vector of coefficients).
203/// * `b` - Second polynomial (as a vector of coefficients).
204/// * `n` - Length of the polynomials and the NTT (must be a power of 2).
205/// * `p` - Prime modulus for the operations.
206/// * `root` - Primitive root of unity modulo `p`.
207///
208/// # Returns
209/// A vector representing the polynomial product modulo `p`.
210pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i64> {
211
212 // Step 1: Perform the NTT (forward transform) on both polynomials
213 let a_ntt = ntt(a, omega, n, p);
214 let b_ntt = ntt(b, omega, n, p);
215
216 // Step 2: Perform pointwise multiplication in the NTT domain
217 let c_ntt: Vec<i64> = a_ntt
218 .iter()
219 .zip(b_ntt.iter())
220 .map(|(x, y)| mod_mul(*x, *y, p)) // pointwise multiplication
221 .collect();
222
223 // Step 3: Apply the inverse NTT to get the result
224 let c = intt(&c_ntt, omega, n, p);
225
226 c
227}
228
229/// Compute the prime factorization of `n` (with multiplicities)
230/// Uses reikna::quick_factorize internally
231/// # Arguments
232///
233/// * `n` - Number to factorize.
234///
235/// # Returns
236/// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
237fn factorize(n: i64) -> HashMap<i64, u32> {
238 let mut factors = HashMap::new();
239 for factor in quick_factorize(n as u64) {
240 *factors.entry(factor as i64).or_insert(0) += 1;
241 }
242 factors
243}
244
245/// Fast computation of a primitive root mod p^e
246/// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p
247/// # Arguments
248///
249/// * `p` - Prime modulus.
250/// * `e` - Exponent.
251///
252/// # Returns
253/// A primitive root modulo `p^e`.
254///
255/// # Examples
256///
257/// ```
258/// // For p = 17 and e = 2, we compute a primitive root modulo 289.
259/// let p = 17;
260/// let e = 2;
261/// let g = ntt::primitive_root(p, e);
262/// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1);
263pub fn primitive_root(p: i64, e: u32) -> i64 {
264 let g = primitive_root_mod_p(p);
265 let mut g_lifted = g; // Lift it to p^e
266 for _ in 1..e {
267 if mod_exp(g_lifted, p-1, p.pow(e)) == 1 {
268 g_lifted += p.pow(e - 1);
269 }
270 }
271 g_lifted
272}
273
274/// Finds a primitive root modulo a prime p
275/// # Arguments
276///
277/// * `p` - Prime modulus.
278///
279/// # Returns
280/// A primitive root modulo `p`.
281fn primitive_root_mod_p(p: i64) -> i64 {
282 let phi = p - 1;
283 let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities
284 for g in 2..p {
285 // Check if g is a primitive root by checking mod_exp conditions with all prime factors of phi
286 if factors.iter().all(|(&q, _)| mod_exp(g, phi / q, p) != 1) {
287 return g;
288 }
289 }
290 0 // Should never happen
291}
292
293/// the Chinese remainder theorem for two moduli
294/// # Arguments
295///
296/// * `a1` - First residue.
297/// * `n1` - First modulus.
298/// * `a2` - Second residue.
299/// * `n2` - Second modulus.
300///
301/// # Returns
302/// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
303pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
304 let n = n1 * n2;
305 let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2
306 let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1
307 let x = (a1 * m2 * n2 + a2 * m1 * n1) % n;
308 if x < 0 { x + n } else { x }
309}
310
311/// computes an n^th root of unity modulo a composite modulus
312/// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
313/// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus
314/// for the NTT, we require than a 2n^th root of unity exists
315/// # Arguments
316///
317/// * `modulus` - Modulus. n must divide each prime power factor.
318/// * `n` - Order of the root of unity.
319///
320/// # Returns
321/// The n-th root of unity modulo `modulus`.
322pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
323 let factors = factorize(modulus);
324 let mut result = 1;
325 for (&p, &e) in factors.iter() {
326 let omega = omega(p.pow(e), n.try_into().unwrap()); // Find primitive nth root of unity mod p^e
327 result = crt(result, modulus / p.pow(e), omega, p.pow(e)); // Combine with the running result using CRT
328 }
329 result
330}
331
332/// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
333/// # Arguments
334///
335/// * `omega` - n-th root of unity.
336/// * `n` - Order of the root of unity.
337/// * `modulus` - Modulus.
338///
339/// # Returns
340/// True if the root of unity satisfies the condition.
341pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool {
342 assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity");
343 assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)");
344 true
345}