dcrypt_algorithms/poly/ntt/mod.rs
1//! Number Theoretic Transform Implementation
2//!
3//! Generic NTT/iNTT for polynomials over finite fields, with full
4//! FIPS-204 compliance for Dilithium and support for Kyber variants.
5//!
6//! ## Dilithium (FIPS-204)
7//! - Forward NTT: Algorithm 41 (DIF with standard domain I/O)
8//! - Inverse NTT: Algorithm 42 (GS with standard domain I/O)
9//! - Twiddle factors: Precomputed in Montgomery form (ζ·R mod q)
10//! - Butterfly differences: Kept in [0, 2Q) range as per spec
11//! - Pointwise multiplication: Standard domain multiplication
12//!
13//! ## Kyber
14//! - Cooley-Tukey NTT with on-the-fly twiddle computation
15//! - Full Montgomery domain processing
16//! - Pointwise multiplication: Montgomery domain multiplication
17
18#![cfg_attr(not(feature = "std"), no_std)]
19
20use super::params::{Modulus, NttModulus, PostInvNtt};
21use super::polynomial::Polynomial;
22use crate::error::{Error, Result};
23
24/// Modular exponentiation in standard domain
25#[inline(always)]
26fn pow_mod<M: Modulus>(mut base: u32, mut exp: u32) -> u32 {
27 let mut acc: u32 = 1;
28 while exp != 0 {
29 if (exp & 1) == 1 {
30 acc = ((acc as u64 * base as u64) % M::Q as u64) as u32;
31 }
32 base = ((base as u64 * base as u64) % M::Q as u64) as u32;
33 exp >>= 1;
34 }
35 acc
36}
37
38/// Forward Number Theoretic Transform
39pub trait NttOperator<M: NttModulus> {
40 /// Performs forward NTT on polynomial in-place
41 ///
42 /// # Dilithium (FIPS-204)
43 /// - Implements Algorithm 41 (DIF)
44 /// - Input: coefficients in standard domain
45 /// - Output: coefficients in standard domain
46 ///
47 /// # Kyber
48 /// - Implements Cooley-Tukey NTT
49 /// - Converts to Montgomery domain internally
50 fn ntt(poly: &mut Polynomial<M>) -> Result<()>;
51}
52
53/// Inverse Number Theoretic Transform
54pub trait InverseNttOperator<M: NttModulus> {
55 /// Performs inverse NTT on polynomial in-place
56 ///
57 /// # Dilithium (FIPS-204)
58 /// - Implements Algorithm 42 (GS)
59 /// - Input: coefficients in standard domain
60 /// - Output: standard or Montgomery domain based on POST_INVNTT_MODE
61 ///
62 /// # Kyber
63 /// - Implements Cooley-Tukey inverse NTT
64 /// - Scales by N^(-1) and converts back to standard domain
65 fn inv_ntt(poly: &mut Polynomial<M>) -> Result<()>;
66}
67
68/// Cooley-Tukey NTT implementation
69pub struct CooleyTukeyNtt;
70
71/// Montgomery reduction: computes a * R^-1 mod Q
72///
73/// For a ∈ [0, Q·R), returns a·R^(-1) mod Q in [0, Q)
74#[inline(always)]
75pub fn montgomery_reduce<M: NttModulus>(a: u64) -> u32 {
76 let q = M::Q as u64;
77 let neg_qinv = M::NEG_QINV as u64;
78
79 // Compute m = (a * NEG_QINV) mod 2^32
80 let m = ((a as u32) as u64).wrapping_mul(neg_qinv) & 0xFFFFFFFF;
81 // Compute t = (a + m * q) >> 32
82 let t = a.wrapping_add(m.wrapping_mul(q)) >> 32;
83
84 // Conditional reduction
85 let result = t as u32;
86 let mask = ((result >= M::Q) as u32).wrapping_neg();
87 result.wrapping_sub(M::Q & mask)
88}
89
90/// Reduce any u32 to [0, Q)
91/// Handles both normal range and wrapped values from underflow
92#[inline]
93fn reduce_to_q<M: Modulus>(x: u32) -> u32 {
94 // Fast path for common case (x < 4Q)
95 let mut y = x;
96 y -= M::Q & ((y >= M::Q) as u32).wrapping_neg();
97 y -= M::Q & ((y >= M::Q) as u32).wrapping_neg();
98
99 if y < M::Q {
100 return y;
101 }
102
103 // Barrett reduction for large/wrapped values
104 let (mu, k) = if M::BARRETT_MU != 0 {
105 (M::BARRETT_MU, M::BARRETT_K)
106 } else {
107 // Dynamic computation for moduli without precomputed constants
108 let log_q = 64 - (M::Q as u64).leading_zeros(); // FIXED: Removed unnecessary cast
109 let k = log_q + 32;
110 let mu = (1u128 << k) / M::Q as u128; // FIXED: Removed unnecessary cast
111 (mu, k)
112 };
113
114 let x_wide = y as u128;
115 let q = ((x_wide * mu) >> k) as u32;
116 let mut r = y.wrapping_sub(q.wrapping_mul(M::Q));
117
118 r = r.wrapping_sub(M::Q & ((r >= M::Q) as u32).wrapping_neg());
119 r
120}
121
122/// Montgomery multiplication: a * b * R^-1 mod Q
123/// Accepts extended range inputs (e.g., [0, 9Q)) to preserve sign encoding
124#[inline(always)]
125fn montgomery_mul<M: NttModulus>(a: u32, b: u32) -> u32 {
126 montgomery_reduce::<M>((a as u64) * (b as u64))
127}
128
129/// Modular addition with full reduction
130#[inline(always)]
131fn add_mod<M: Modulus>(a: u32, b: u32) -> u32 {
132 ((a as u64 + b as u64) % M::Q as u64) as u32
133}
134
135/// Fast modular addition for inputs < Q
136#[inline(always)]
137fn add_mod_fast<M: Modulus>(a: u32, b: u32) -> u32 {
138 let s = a + b;
139 let mask = ((s >= M::Q) as u32).wrapping_neg();
140 s - (M::Q & mask)
141}
142
143/// Fast modular subtraction for inputs < Q
144#[inline(always)]
145fn sub_mod_fast<M: Modulus>(a: u32, b: u32) -> u32 {
146 let t = a.wrapping_add(M::Q).wrapping_sub(b);
147 let mask = ((t >= M::Q) as u32).wrapping_neg();
148 t - (M::Q & mask)
149}
150
151/// Modular subtraction returning [0, 2Q)
152/// Used in FIPS-204 butterflies to preserve sign information
153#[inline(always)]
154fn sub_mod_upto_2q<M: Modulus>(a: u32, b: u32) -> u32 {
155 a.wrapping_add(M::Q).wrapping_sub(b)
156}
157
158/// Convert standard domain to Montgomery domain
159#[inline(always)]
160fn to_montgomery<M: NttModulus>(val: u32) -> u32 {
161 ((val as u64 * M::MONT_R as u64) % M::Q as u64) as u32
162}
163
164impl<M: NttModulus> NttOperator<M> for CooleyTukeyNtt {
165 fn ntt(poly: &mut Polynomial<M>) -> Result<()> {
166 let n = M::N;
167 if n & (n - 1) != 0 {
168 return Err(Error::Parameter {
169 name: "NTT".into(),
170 reason: "Polynomial degree must be a power of 2".into(),
171 });
172 }
173
174 let coeffs = poly.as_mut_coeffs_slice();
175 let is_dilithium = !M::ZETAS.is_empty(); // FIXED: Use is_empty()
176
177 if is_dilithium {
178 // FIPS-204 Algorithm 41: Forward NTT
179 // Decimation-in-Frequency (DIF) with row-major twiddle traversal
180 // Input: standard domain, Output: standard domain
181 let mut k = 0;
182 let mut len = n / 2; // Start at 128 for N=256
183
184 while len >= 1 {
185 // Row-major (block-first) iteration matches twiddle table order
186 for start in (0..n).step_by(2 * len) {
187 let zeta = M::ZETAS[k]; // ζ·R mod q (Montgomery form)
188 k += 1;
189
190 for j in start..start + len {
191 let a = coeffs[j];
192 let b = coeffs[j + len];
193
194 // FIPS-204 DIF butterfly:
195 // t = ζ * b (Montgomery mul with ζ·R gives standard domain)
196 let t = montgomery_mul::<M>(b, zeta);
197 // a' = a + t mod q
198 coeffs[j] = add_mod::<M>(a, t);
199 // b' = a - t + Q (kept in [0, 2Q) per Algorithm 41)
200 coeffs[j + len] = sub_mod_upto_2q::<M>(a, t);
201 }
202 }
203
204 len >>= 1;
205 }
206
207 // Reduce all coefficients to [0, Q) for Dilithium compatibility
208 for c in coeffs.iter_mut() {
209 *c = reduce_to_q::<M>(*c);
210 }
211 } else {
212 // Kyber NTT
213 for c in coeffs.iter_mut() {
214 *c = to_montgomery::<M>(*c);
215 }
216
217 let mut len = 1_usize;
218 while len < n {
219 let exp = n / (len << 1);
220 let root_std = pow_mod::<M>(M::ZETA, exp as u32);
221 let root_mont = to_montgomery::<M>(root_std);
222
223 for start in (0..n).step_by(len << 1) {
224 let mut w_mont = M::MONT_R;
225
226 for j in 0..len {
227 let u = coeffs[start + j];
228 let v = montgomery_mul::<M>(coeffs[start + j + len], w_mont);
229
230 coeffs[start + j] = add_mod_fast::<M>(u, v);
231 coeffs[start + j + len] = sub_mod_fast::<M>(u, v);
232
233 w_mont = montgomery_mul::<M>(w_mont, root_mont);
234 }
235 }
236 len <<= 1;
237 }
238 // Kyber: Do NOT reduce here - coefficients must stay in Montgomery form!
239 }
240
241 Ok(())
242 }
243}
244
245impl<M: NttModulus> InverseNttOperator<M> for CooleyTukeyNtt {
246 fn inv_ntt(poly: &mut Polynomial<M>) -> Result<()> {
247 let n = M::N;
248 if n & (n - 1) != 0 {
249 return Err(Error::Parameter {
250 name: "Inverse NTT".into(),
251 reason: "Polynomial degree must be a power of 2".into(),
252 });
253 }
254
255 let coeffs = poly.as_mut_coeffs_slice();
256 let is_dilithium = !M::ZETAS.is_empty(); // FIXED: Use is_empty()
257
258 if is_dilithium {
259 // FIPS-204 Algorithm 42: Inverse NTT
260 // Gentleman-Sande (GS) with row-major traversal
261
262 // Pre-condition: ensure coefficients < Q for GS butterflies
263 for c in coeffs.iter_mut() {
264 *c = reduce_to_q::<M>(*c);
265 }
266
267 let mut k = M::ZETAS.len(); // Start after last entry
268 let mut len = 1;
269
270 while len < n {
271 // Row-major iteration matching forward NTT structure
272 for start in (0..n).step_by(2 * len) {
273 k -= 1; // Traverse ZETAS in reverse
274
275 // Use negated forward twiddle for inverse
276 let zeta_fwd = M::ZETAS[k];
277 let zeta = if zeta_fwd == 0 { 0 } else { M::Q - zeta_fwd };
278
279 for j in start..start + len {
280 let t = coeffs[j];
281 let u = coeffs[j + len];
282
283 // FIPS-204 GS butterfly:
284 // Line 13: w_j ← w_j + w_{j+len}
285 coeffs[j] = add_mod::<M>(t, u);
286 // Line 14: w_{j+len} ← ζ^(-1) * (w_j - w_{j+len})
287 let diff = sub_mod_upto_2q::<M>(t, u);
288 coeffs[j + len] = montgomery_mul::<M>(diff, zeta);
289 }
290 }
291
292 len <<= 1;
293 }
294
295 // Final reduction before N^(-1) scaling
296 for c in coeffs.iter_mut() {
297 *c = reduce_to_q::<M>(*c);
298 }
299
300 // Scale by N^(-1) in standard domain
301 let n_inv_std = pow_mod::<M>(M::N as u32, M::Q - 2);
302 for c in coeffs.iter_mut() {
303 *c = ((*c as u64 * n_inv_std as u64) % M::Q as u64) as u32;
304 }
305
306 match M::POST_INVNTT_MODE {
307 PostInvNtt::Standard => {} // Already in standard domain
308 PostInvNtt::Montgomery => {
309 // Convert to Montgomery if requested
310 for c in coeffs.iter_mut() {
311 *c = to_montgomery::<M>(*c);
312 }
313 }
314 }
315 } else {
316 // Kyber Inverse NTT
317 let root_inv_std = pow_mod::<M>(M::ZETA, M::Q - 2); // FIXED: Removed unnecessary cast
318
319 let mut len = n >> 1;
320 while len >= 1 {
321 let exp = n / (len << 1);
322 let root_std = pow_mod::<M>(root_inv_std, exp as u32);
323 let root_mont = to_montgomery::<M>(root_std);
324
325 for start in (0..n).step_by(len << 1) {
326 let mut w_mont = M::MONT_R;
327
328 for j in 0..len {
329 let u = coeffs[start + j];
330 let v = coeffs[start + j + len];
331
332 coeffs[start + j] = add_mod_fast::<M>(u, v);
333 coeffs[start + j + len] =
334 montgomery_mul::<M>(sub_mod_fast::<M>(u, v), w_mont);
335
336 w_mont = montgomery_mul::<M>(w_mont, root_mont);
337 }
338 }
339 len >>= 1;
340 }
341
342 // Scale by N^(-1)
343 for c in coeffs.iter_mut() {
344 *c = montgomery_mul::<M>(*c, M::N_INV);
345 }
346
347 if M::POST_INVNTT_MODE == PostInvNtt::Standard {
348 for c in coeffs.iter_mut() {
349 *c = montgomery_reduce::<M>(*c as u64);
350 }
351 }
352 }
353
354 Ok(())
355 }
356}
357
358/// Extension methods for Polynomial
359impl<M: NttModulus> Polynomial<M> {
360 /// Convert polynomial to NTT domain
361 pub fn ntt_inplace(&mut self) -> Result<()> {
362 CooleyTukeyNtt::ntt(self)
363 }
364
365 /// Convert polynomial from NTT domain
366 pub fn from_ntt_inplace(&mut self) -> Result<()> {
367 CooleyTukeyNtt::inv_ntt(self)
368 }
369
370 /// Pointwise multiplication in NTT domain
371 ///
372 /// Both polynomials must already be in NTT domain.
373 /// For Dilithium: inputs/output in standard domain (post-NTT)
374 /// For Kyber: inputs/output in Montgomery domain
375 pub fn ntt_mul(&self, other: &Self) -> Self {
376 let mut result = Self::zero();
377 let n = M::N;
378 let is_dilithium = !M::ZETAS.is_empty(); // FIXED: Use is_empty()
379
380 if is_dilithium {
381 // Dilithium: coefficients are in standard domain after NTT
382 // Use standard multiplication
383 for i in 0..n {
384 result.coeffs[i] =
385 ((self.coeffs[i] as u64 * other.coeffs[i] as u64) % M::Q as u64) as u32;
386 }
387 } else {
388 // Kyber: coefficients are in Montgomery domain after NTT
389 // Use Montgomery multiplication to keep result in Montgomery domain
390 for i in 0..n {
391 result.coeffs[i] = montgomery_mul::<M>(self.coeffs[i], other.coeffs[i]);
392 }
393 }
394
395 result
396 }
397}
398
399#[cfg(test)]
400mod tests;