dcrypt_algorithms/poly/ntt/
mod.rs

1//! Number Theoretic Transform Implementation
2//!
3//! Generic NTT/iNTT for polynomials over finite fields, with full
4//! FIPS-204 compliance for Dilithium and support for Kyber variants.
5//!
6//! ## Dilithium (FIPS-204)
7//! - Forward NTT: Algorithm 41 (DIF with standard domain I/O)
8//! - Inverse NTT: Algorithm 42 (GS with standard domain I/O)
9//! - Twiddle factors: Precomputed in Montgomery form (ζ·R mod q)
10//! - Butterfly differences: Kept in [0, 2Q) range as per spec
11//! - Pointwise multiplication: Standard domain multiplication
12//!
13//! ## Kyber
14//! - Cooley-Tukey NTT with on-the-fly twiddle computation
15//! - Full Montgomery domain processing
16//! - Pointwise multiplication: Montgomery domain multiplication
17
18#![cfg_attr(not(feature = "std"), no_std)]
19
20use super::params::{Modulus, NttModulus, PostInvNtt};
21use super::polynomial::Polynomial;
22use crate::error::{Error, Result};
23
24/// Modular exponentiation in standard domain
25#[inline(always)]
26fn pow_mod<M: Modulus>(mut base: u32, mut exp: u32) -> u32 {
27    let mut acc: u32 = 1;
28    while exp != 0 {
29        if (exp & 1) == 1 {
30            acc = ((acc as u64 * base as u64) % M::Q as u64) as u32;
31        }
32        base = ((base as u64 * base as u64) % M::Q as u64) as u32;
33        exp >>= 1;
34    }
35    acc
36}
37
38/// Forward Number Theoretic Transform
39pub trait NttOperator<M: NttModulus> {
40    /// Performs forward NTT on polynomial in-place
41    ///
42    /// # Dilithium (FIPS-204)
43    /// - Implements Algorithm 41 (DIF)
44    /// - Input: coefficients in standard domain
45    /// - Output: coefficients in standard domain
46    ///
47    /// # Kyber
48    /// - Implements Cooley-Tukey NTT
49    /// - Converts to Montgomery domain internally
50    fn ntt(poly: &mut Polynomial<M>) -> Result<()>;
51}
52
53/// Inverse Number Theoretic Transform
54pub trait InverseNttOperator<M: NttModulus> {
55    /// Performs inverse NTT on polynomial in-place
56    ///
57    /// # Dilithium (FIPS-204)
58    /// - Implements Algorithm 42 (GS)
59    /// - Input: coefficients in standard domain
60    /// - Output: standard or Montgomery domain based on POST_INVNTT_MODE
61    ///
62    /// # Kyber
63    /// - Implements Cooley-Tukey inverse NTT
64    /// - Scales by N^(-1) and converts back to standard domain
65    fn inv_ntt(poly: &mut Polynomial<M>) -> Result<()>;
66}
67
68/// Cooley-Tukey NTT implementation
69pub struct CooleyTukeyNtt;
70
71/// Montgomery reduction: computes a * R^-1 mod Q
72///
73/// For a ∈ [0, Q·R), returns a·R^(-1) mod Q in [0, Q)
74#[inline(always)]
75pub fn montgomery_reduce<M: NttModulus>(a: u64) -> u32 {
76    let q = M::Q as u64;
77    let neg_qinv = M::NEG_QINV as u64;
78
79    // Compute m = (a * NEG_QINV) mod 2^32
80    let m = ((a as u32) as u64).wrapping_mul(neg_qinv) & 0xFFFFFFFF;
81    // Compute t = (a + m * q) >> 32
82    let t = a.wrapping_add(m.wrapping_mul(q)) >> 32;
83
84    // Conditional reduction
85    let result = t as u32;
86    let mask = ((result >= M::Q) as u32).wrapping_neg();
87    result.wrapping_sub(M::Q & mask)
88}
89
90/// Reduce any u32 to [0, Q)
91/// Handles both normal range and wrapped values from underflow
92#[inline]
93fn reduce_to_q<M: Modulus>(x: u32) -> u32 {
94    // Fast path for common case (x < 4Q)
95    let mut y = x;
96    y -= M::Q & ((y >= M::Q) as u32).wrapping_neg();
97    y -= M::Q & ((y >= M::Q) as u32).wrapping_neg();
98
99    if y < M::Q {
100        return y;
101    }
102
103    // Barrett reduction for large/wrapped values
104    let (mu, k) = if M::BARRETT_MU != 0 {
105        (M::BARRETT_MU, M::BARRETT_K)
106    } else {
107        // Dynamic computation for moduli without precomputed constants
108        let log_q = 64 - (M::Q as u64).leading_zeros(); // FIXED: Removed unnecessary cast
109        let k = log_q + 32;
110        let mu = (1u128 << k) / M::Q as u128; // FIXED: Removed unnecessary cast
111        (mu, k)
112    };
113
114    let x_wide = y as u128;
115    let q = ((x_wide * mu) >> k) as u32;
116    let mut r = y.wrapping_sub(q.wrapping_mul(M::Q));
117
118    r = r.wrapping_sub(M::Q & ((r >= M::Q) as u32).wrapping_neg());
119    r
120}
121
122/// Montgomery multiplication: a * b * R^-1 mod Q
123/// Accepts extended range inputs (e.g., [0, 9Q)) to preserve sign encoding
124#[inline(always)]
125fn montgomery_mul<M: NttModulus>(a: u32, b: u32) -> u32 {
126    montgomery_reduce::<M>((a as u64) * (b as u64))
127}
128
129/// Modular addition with full reduction
130#[inline(always)]
131fn add_mod<M: Modulus>(a: u32, b: u32) -> u32 {
132    ((a as u64 + b as u64) % M::Q as u64) as u32
133}
134
135/// Fast modular addition for inputs < Q
136#[inline(always)]
137fn add_mod_fast<M: Modulus>(a: u32, b: u32) -> u32 {
138    let s = a + b;
139    let mask = ((s >= M::Q) as u32).wrapping_neg();
140    s - (M::Q & mask)
141}
142
143/// Fast modular subtraction for inputs < Q
144#[inline(always)]
145fn sub_mod_fast<M: Modulus>(a: u32, b: u32) -> u32 {
146    let t = a.wrapping_add(M::Q).wrapping_sub(b);
147    let mask = ((t >= M::Q) as u32).wrapping_neg();
148    t - (M::Q & mask)
149}
150
151/// Modular subtraction returning [0, 2Q)
152/// Used in FIPS-204 butterflies to preserve sign information
153#[inline(always)]
154fn sub_mod_upto_2q<M: Modulus>(a: u32, b: u32) -> u32 {
155    a.wrapping_add(M::Q).wrapping_sub(b)
156}
157
158/// Convert standard domain to Montgomery domain
159#[inline(always)]
160fn to_montgomery<M: NttModulus>(val: u32) -> u32 {
161    ((val as u64 * M::MONT_R as u64) % M::Q as u64) as u32
162}
163
164impl<M: NttModulus> NttOperator<M> for CooleyTukeyNtt {
165    fn ntt(poly: &mut Polynomial<M>) -> Result<()> {
166        let n = M::N;
167        if n & (n - 1) != 0 {
168            return Err(Error::Parameter {
169                name: "NTT".into(),
170                reason: "Polynomial degree must be a power of 2".into(),
171            });
172        }
173
174        let coeffs = poly.as_mut_coeffs_slice();
175        let is_dilithium = !M::ZETAS.is_empty(); // FIXED: Use is_empty()
176
177        if is_dilithium {
178            // FIPS-204 Algorithm 41: Forward NTT
179            // Decimation-in-Frequency (DIF) with row-major twiddle traversal
180            // Input: standard domain, Output: standard domain
181            let mut k = 0;
182            let mut len = n / 2; // Start at 128 for N=256
183
184            while len >= 1 {
185                // Row-major (block-first) iteration matches twiddle table order
186                for start in (0..n).step_by(2 * len) {
187                    let zeta = M::ZETAS[k]; // ζ·R mod q (Montgomery form)
188                    k += 1;
189
190                    for j in start..start + len {
191                        let a = coeffs[j];
192                        let b = coeffs[j + len];
193
194                        // FIPS-204 DIF butterfly:
195                        // t = ζ * b (Montgomery mul with ζ·R gives standard domain)
196                        let t = montgomery_mul::<M>(b, zeta);
197                        // a' = a + t mod q
198                        coeffs[j] = add_mod::<M>(a, t);
199                        // b' = a - t + Q (kept in [0, 2Q) per Algorithm 41)
200                        coeffs[j + len] = sub_mod_upto_2q::<M>(a, t);
201                    }
202                }
203
204                len >>= 1;
205            }
206
207            // Reduce all coefficients to [0, Q) for Dilithium compatibility
208            for c in coeffs.iter_mut() {
209                *c = reduce_to_q::<M>(*c);
210            }
211        } else {
212            // Kyber NTT
213            for c in coeffs.iter_mut() {
214                *c = to_montgomery::<M>(*c);
215            }
216
217            let mut len = 1_usize;
218            while len < n {
219                let exp = n / (len << 1);
220                let root_std = pow_mod::<M>(M::ZETA, exp as u32);
221                let root_mont = to_montgomery::<M>(root_std);
222
223                for start in (0..n).step_by(len << 1) {
224                    let mut w_mont = M::MONT_R;
225
226                    for j in 0..len {
227                        let u = coeffs[start + j];
228                        let v = montgomery_mul::<M>(coeffs[start + j + len], w_mont);
229
230                        coeffs[start + j] = add_mod_fast::<M>(u, v);
231                        coeffs[start + j + len] = sub_mod_fast::<M>(u, v);
232
233                        w_mont = montgomery_mul::<M>(w_mont, root_mont);
234                    }
235                }
236                len <<= 1;
237            }
238            // Kyber: Do NOT reduce here - coefficients must stay in Montgomery form!
239        }
240
241        Ok(())
242    }
243}
244
245impl<M: NttModulus> InverseNttOperator<M> for CooleyTukeyNtt {
246    fn inv_ntt(poly: &mut Polynomial<M>) -> Result<()> {
247        let n = M::N;
248        if n & (n - 1) != 0 {
249            return Err(Error::Parameter {
250                name: "Inverse NTT".into(),
251                reason: "Polynomial degree must be a power of 2".into(),
252            });
253        }
254
255        let coeffs = poly.as_mut_coeffs_slice();
256        let is_dilithium = !M::ZETAS.is_empty(); // FIXED: Use is_empty()
257
258        if is_dilithium {
259            // FIPS-204 Algorithm 42: Inverse NTT
260            // Gentleman-Sande (GS) with row-major traversal
261
262            // Pre-condition: ensure coefficients < Q for GS butterflies
263            for c in coeffs.iter_mut() {
264                *c = reduce_to_q::<M>(*c);
265            }
266
267            let mut k = M::ZETAS.len(); // Start after last entry
268            let mut len = 1;
269
270            while len < n {
271                // Row-major iteration matching forward NTT structure
272                for start in (0..n).step_by(2 * len) {
273                    k -= 1; // Traverse ZETAS in reverse
274
275                    // Use negated forward twiddle for inverse
276                    let zeta_fwd = M::ZETAS[k];
277                    let zeta = if zeta_fwd == 0 { 0 } else { M::Q - zeta_fwd };
278
279                    for j in start..start + len {
280                        let t = coeffs[j];
281                        let u = coeffs[j + len];
282
283                        // FIPS-204 GS butterfly:
284                        // Line 13: w_j ← w_j + w_{j+len}
285                        coeffs[j] = add_mod::<M>(t, u);
286                        // Line 14: w_{j+len} ← ζ^(-1) * (w_j - w_{j+len})
287                        let diff = sub_mod_upto_2q::<M>(t, u);
288                        coeffs[j + len] = montgomery_mul::<M>(diff, zeta);
289                    }
290                }
291
292                len <<= 1;
293            }
294
295            // Final reduction before N^(-1) scaling
296            for c in coeffs.iter_mut() {
297                *c = reduce_to_q::<M>(*c);
298            }
299
300            // Scale by N^(-1) in standard domain
301            let n_inv_std = pow_mod::<M>(M::N as u32, M::Q - 2);
302            for c in coeffs.iter_mut() {
303                *c = ((*c as u64 * n_inv_std as u64) % M::Q as u64) as u32;
304            }
305
306            match M::POST_INVNTT_MODE {
307                PostInvNtt::Standard => {} // Already in standard domain
308                PostInvNtt::Montgomery => {
309                    // Convert to Montgomery if requested
310                    for c in coeffs.iter_mut() {
311                        *c = to_montgomery::<M>(*c);
312                    }
313                }
314            }
315        } else {
316            // Kyber Inverse NTT
317            let root_inv_std = pow_mod::<M>(M::ZETA, M::Q - 2); // FIXED: Removed unnecessary cast
318
319            let mut len = n >> 1;
320            while len >= 1 {
321                let exp = n / (len << 1);
322                let root_std = pow_mod::<M>(root_inv_std, exp as u32);
323                let root_mont = to_montgomery::<M>(root_std);
324
325                for start in (0..n).step_by(len << 1) {
326                    let mut w_mont = M::MONT_R;
327
328                    for j in 0..len {
329                        let u = coeffs[start + j];
330                        let v = coeffs[start + j + len];
331
332                        coeffs[start + j] = add_mod_fast::<M>(u, v);
333                        coeffs[start + j + len] =
334                            montgomery_mul::<M>(sub_mod_fast::<M>(u, v), w_mont);
335
336                        w_mont = montgomery_mul::<M>(w_mont, root_mont);
337                    }
338                }
339                len >>= 1;
340            }
341
342            // Scale by N^(-1)
343            for c in coeffs.iter_mut() {
344                *c = montgomery_mul::<M>(*c, M::N_INV);
345            }
346
347            if M::POST_INVNTT_MODE == PostInvNtt::Standard {
348                for c in coeffs.iter_mut() {
349                    *c = montgomery_reduce::<M>(*c as u64);
350                }
351            }
352        }
353
354        Ok(())
355    }
356}
357
358/// Extension methods for Polynomial
359impl<M: NttModulus> Polynomial<M> {
360    /// Convert polynomial to NTT domain
361    pub fn ntt_inplace(&mut self) -> Result<()> {
362        CooleyTukeyNtt::ntt(self)
363    }
364
365    /// Convert polynomial from NTT domain
366    pub fn from_ntt_inplace(&mut self) -> Result<()> {
367        CooleyTukeyNtt::inv_ntt(self)
368    }
369
370    /// Pointwise multiplication in NTT domain
371    ///
372    /// Both polynomials must already be in NTT domain.
373    /// For Dilithium: inputs/output in standard domain (post-NTT)
374    /// For Kyber: inputs/output in Montgomery domain
375    pub fn ntt_mul(&self, other: &Self) -> Self {
376        let mut result = Self::zero();
377        let n = M::N;
378        let is_dilithium = !M::ZETAS.is_empty(); // FIXED: Use is_empty()
379
380        if is_dilithium {
381            // Dilithium: coefficients are in standard domain after NTT
382            // Use standard multiplication
383            for i in 0..n {
384                result.coeffs[i] =
385                    ((self.coeffs[i] as u64 * other.coeffs[i] as u64) % M::Q as u64) as u32;
386            }
387        } else {
388            // Kyber: coefficients are in Montgomery domain after NTT
389            // Use Montgomery multiplication to keep result in Montgomery domain
390            for i in 0..n {
391                result.coeffs[i] = montgomery_mul::<M>(self.coeffs[i], other.coeffs[i]);
392            }
393        }
394
395        result
396    }
397}
398
399#[cfg(test)]
400mod tests;