Skip to main content

nc_polynomial/
polynomial.rs

1// MIT License
2//
3// Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22//! Polynomial arithmetic over `Z_q[x]` with an explicit maximum degree bound.
23//!
24//! The library stores coefficients in dense form up to `max_degree`, always reduced modulo `q`.
25//! It supports:
26//! - basic arithmetic (`+`, `-`, negation, scalar multiply),
27//! - polynomial multiplication (schoolbook and NTT),
28//! - quotient-ring operations modulo another polynomial,
29//! - evaluation, derivative, and long division with remainder.
30
31use core::fmt;
32
33/// Dense polynomial representation with fixed metadata (`max_degree`, `modulus`).
34///
35/// Internally, `coeffs` always has length `max_degree + 1`. Coefficients beyond the true degree
36/// are stored as zeros.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct Polynomial {
39    /// Dense coefficient vector where `coeffs[i]` is the coefficient of `x^i`.
40    coeffs: Vec<u64>,
41    /// Coefficient modulus `q` for arithmetic in `Z_q`.
42    modulus: u64,
43    /// Maximum supported degree `n`.
44    max_degree: usize,
45}
46
47/// Errors returned by polynomial operations.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum PolynomialError {
50    /// The modulus must be at least `2`.
51    InvalidModulus(u64),
52    /// Input coefficients requested a degree that exceeds `max_degree`.
53    DegreeOverflow { requested: usize, max_degree: usize },
54    /// Polynomials are incompatible (different degree bounds or moduli).
55    IncompatiblePolynomials,
56    /// Multiplication produced a degree higher than `max_degree`.
57    ProductDegreeOverflow { degree: usize, max_degree: usize },
58    /// NTT requires a non-zero power-of-two transform length.
59    NttLengthMustBePowerOfTwo(usize),
60    /// NTT length is not supported by the chosen modulus.
61    NttLengthUnsupported { length: usize, modulus: u64 },
62    /// Provided primitive root cannot generate a valid root of unity for the transform length.
63    InvalidPrimitiveRoot {
64        primitive_root: u64,
65        length: usize,
66        modulus: u64,
67    },
68    /// Division by the zero polynomial is undefined.
69    DivisionByZeroPolynomial,
70    /// A required multiplicative inverse does not exist under the modulus.
71    NonInvertibleCoefficient { coefficient: u64, modulus: u64 },
72}
73
74impl fmt::Display for PolynomialError {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::InvalidModulus(m) => write!(f, "invalid modulus {m}, expected q >= 2"),
78            Self::DegreeOverflow {
79                requested,
80                max_degree,
81            } => write!(
82                f,
83                "requested degree {requested} exceeds max_degree {max_degree}"
84            ),
85            Self::IncompatiblePolynomials => {
86                write!(f, "polynomials are incompatible (different n or q)")
87            }
88            Self::ProductDegreeOverflow { degree, max_degree } => {
89                write!(f, "product degree {degree} exceeds max_degree {max_degree}")
90            }
91            Self::NttLengthMustBePowerOfTwo(length) => {
92                write!(f, "NTT length {length} must be a non-zero power of two")
93            }
94            Self::NttLengthUnsupported { length, modulus } => write!(
95                f,
96                "NTT length {length} is not supported by modulus {modulus} (length must divide q-1)"
97            ),
98            Self::InvalidPrimitiveRoot {
99                primitive_root,
100                length,
101                modulus,
102            } => write!(
103                f,
104                "primitive root {primitive_root} is invalid for NTT length {length} under modulus {modulus}"
105            ),
106            Self::DivisionByZeroPolynomial => write!(f, "division by zero polynomial"),
107            Self::NonInvertibleCoefficient {
108                coefficient,
109                modulus,
110            } => write!(
111                f,
112                "coefficient {coefficient} has no multiplicative inverse modulo {modulus}"
113            ),
114        }
115    }
116}
117
118impl std::error::Error for PolynomialError {}
119
120impl Polynomial {
121    /// Creates a polynomial with degree bound `max_degree` over modulus `modulus`.
122    ///
123    /// Coefficients are normalized modulo `modulus` and then zero-padded to length
124    /// `max_degree + 1`.
125    ///
126    /// # Errors
127    /// Returns:
128    /// - [`PolynomialError::InvalidModulus`] when `modulus < 2`,
129    /// - [`PolynomialError::DegreeOverflow`] when `coeffs.len() > max_degree + 1`.
130    pub fn new(max_degree: usize, modulus: u64, coeffs: &[u64]) -> Result<Self, PolynomialError> {
131        if modulus < 2 {
132            return Err(PolynomialError::InvalidModulus(modulus));
133        }
134        if coeffs.len() > max_degree + 1 {
135            return Err(PolynomialError::DegreeOverflow {
136                requested: coeffs.len().saturating_sub(1),
137                max_degree,
138            });
139        }
140
141        let mut normalized = vec![0_u64; max_degree + 1];
142        for (i, value) in coeffs.iter().copied().enumerate() {
143            normalized[i] = value % modulus;
144        }
145
146        Ok(Self {
147            coeffs: normalized,
148            modulus,
149            max_degree,
150        })
151    }
152
153    /// Returns the additive identity polynomial `0`.
154    pub fn zero(max_degree: usize, modulus: u64) -> Result<Self, PolynomialError> {
155        Self::new(max_degree, modulus, &[])
156    }
157
158    /// Returns the multiplicative identity polynomial `1`.
159    pub fn one(max_degree: usize, modulus: u64) -> Result<Self, PolynomialError> {
160        Self::new(max_degree, modulus, &[1])
161    }
162
163    /// Returns the configured maximum degree `n`.
164    pub fn max_degree(&self) -> usize {
165        self.max_degree
166    }
167
168    /// Returns the coefficient modulus `q`.
169    pub fn modulus(&self) -> u64 {
170        self.modulus
171    }
172
173    /// Returns the internal dense coefficient slice of length `max_degree + 1`.
174    pub fn coefficients(&self) -> &[u64] {
175        &self.coeffs
176    }
177
178    /// Returns coefficients trimmed to the actual degree.
179    ///
180    /// Zero polynomial is returned as `[0]`.
181    pub fn trimmed_coefficients(&self) -> Vec<u64> {
182        match self.degree() {
183            Some(d) => self.coeffs[..=d].to_vec(),
184            None => vec![0],
185        }
186    }
187
188    /// Returns coefficient of `x^degree`, or `None` when out of bounds.
189    pub fn coeff(&self, degree: usize) -> Option<u64> {
190        self.coeffs.get(degree).copied()
191    }
192
193    /// Returns the true polynomial degree, or `None` for zero polynomial.
194    pub fn degree(&self) -> Option<usize> {
195        self.coeffs.iter().rposition(|&c| c != 0)
196    }
197
198    /// Returns `true` when all coefficients are zero.
199    pub fn is_zero(&self) -> bool {
200        self.degree().is_none()
201    }
202
203    /// Adds two polynomials coefficient-wise in `Z_q[x]`.
204    ///
205    /// # Errors
206    /// Returns [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ.
207    pub fn add(&self, rhs: &Self) -> Result<Self, PolynomialError> {
208        self.ensure_compatible(rhs)?;
209
210        let mut out = vec![0_u64; self.max_degree + 1];
211        for (i, slot) in out.iter_mut().enumerate() {
212            *slot = mod_add(self.coeffs[i], rhs.coeffs[i], self.modulus);
213        }
214
215        Ok(Self {
216            coeffs: out,
217            modulus: self.modulus,
218            max_degree: self.max_degree,
219        })
220    }
221
222    /// Subtracts `rhs` from `self` coefficient-wise in `Z_q[x]`.
223    ///
224    /// # Errors
225    /// Returns [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ.
226    pub fn sub(&self, rhs: &Self) -> Result<Self, PolynomialError> {
227        self.ensure_compatible(rhs)?;
228
229        let mut out = vec![0_u64; self.max_degree + 1];
230        for (i, slot) in out.iter_mut().enumerate() {
231            *slot = mod_sub(self.coeffs[i], rhs.coeffs[i], self.modulus);
232        }
233
234        Ok(Self {
235            coeffs: out,
236            modulus: self.modulus,
237            max_degree: self.max_degree,
238        })
239    }
240
241    /// Returns additive inverse `-self` in `Z_q[x]`.
242    pub fn neg(&self) -> Self {
243        let mut out = vec![0_u64; self.max_degree + 1];
244        for (i, slot) in out.iter_mut().enumerate() {
245            let c = self.coeffs[i];
246            *slot = if c == 0 { 0 } else { self.modulus - c };
247        }
248
249        Self {
250            coeffs: out,
251            modulus: self.modulus,
252            max_degree: self.max_degree,
253        }
254    }
255
256    /// Multiplies all coefficients by `scalar` in `Z_q`.
257    pub fn scalar_mul(&self, scalar: u64) -> Self {
258        let mut out = vec![0_u64; self.max_degree + 1];
259        let reduced_scalar = scalar % self.modulus;
260
261        for (i, slot) in out.iter_mut().enumerate() {
262            *slot = mod_mul(self.coeffs[i], reduced_scalar, self.modulus);
263        }
264
265        Self {
266            coeffs: out,
267            modulus: self.modulus,
268            max_degree: self.max_degree,
269        }
270    }
271
272    /// Schoolbook polynomial multiplication with strict degree bound enforcement.
273    ///
274    /// # Errors
275    /// Returns:
276    /// - [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ,
277    /// - [`PolynomialError::ProductDegreeOverflow`] if exact product degree exceeds `max_degree`.
278    pub fn mul(&self, rhs: &Self) -> Result<Self, PolynomialError> {
279        self.ensure_compatible(rhs)?;
280
281        if let (Some(d1), Some(d2)) = (self.degree(), rhs.degree()) {
282            let d = d1 + d2;
283            if d > self.max_degree {
284                return Err(PolynomialError::ProductDegreeOverflow {
285                    degree: d,
286                    max_degree: self.max_degree,
287                });
288            }
289        }
290
291        let mut out = vec![0_u64; self.max_degree + 1];
292        for (i, &a) in self.coeffs.iter().enumerate() {
293            if a == 0 {
294                continue;
295            }
296            for (j, &b) in rhs.coeffs.iter().enumerate() {
297                if b == 0 {
298                    continue;
299                }
300                let idx = i + j;
301                if idx > self.max_degree {
302                    continue;
303                }
304                // Convolution accumulation: out[idx] += a_i * b_j (mod q).
305                let prod = mod_mul(a, b, self.modulus);
306                out[idx] = mod_add(out[idx], prod, self.modulus);
307            }
308        }
309
310        Ok(Self {
311            coeffs: out,
312            modulus: self.modulus,
313            max_degree: self.max_degree,
314        })
315    }
316
317    /// NTT-based polynomial multiplication.
318    ///
319    /// This computes exact convolution via NTT and then places the result in the dense output
320    /// buffer (length `max_degree + 1`).
321    ///
322    /// # Errors
323    /// Returns:
324    /// - [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ,
325    /// - [`PolynomialError::ProductDegreeOverflow`] if exact product degree exceeds `max_degree`,
326    /// - NTT-specific errors when modulus/root parameters are not valid.
327    pub fn mul_ntt(&self, rhs: &Self, primitive_root: u64) -> Result<Self, PolynomialError> {
328        self.ensure_compatible(rhs)?;
329
330        let (Some(lhs_degree), Some(rhs_degree)) = (self.degree(), rhs.degree()) else {
331            // Zero times anything is zero.
332            return Self::zero(self.max_degree, self.modulus);
333        };
334
335        let degree = lhs_degree + rhs_degree;
336        if degree > self.max_degree {
337            return Err(PolynomialError::ProductDegreeOverflow {
338                degree,
339                max_degree: self.max_degree,
340            });
341        }
342
343        let lhs = &self.coeffs[..=lhs_degree];
344        let rhs = &rhs.coeffs[..=rhs_degree];
345        let product = convolution_ntt(lhs, rhs, self.modulus, primitive_root)?;
346
347        let mut out = vec![0_u64; self.max_degree + 1];
348        for (i, coeff) in product.into_iter().enumerate() {
349            out[i] = coeff;
350        }
351
352        Ok(Self {
353            coeffs: out,
354            modulus: self.modulus,
355            max_degree: self.max_degree,
356        })
357    }
358
359    /// Schoolbook multiplication truncated to degree `max_degree`.
360    ///
361    /// Unlike [`Self::mul`], this method never returns degree-overflow errors; terms above
362    /// `max_degree` are discarded.
363    ///
364    /// # Errors
365    /// Returns [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ.
366    pub fn mul_truncated(&self, rhs: &Self) -> Result<Self, PolynomialError> {
367        self.ensure_compatible(rhs)?;
368
369        let mut out = vec![0_u64; self.max_degree + 1];
370        for (i, &a) in self.coeffs.iter().enumerate() {
371            if a == 0 {
372                continue;
373            }
374            for (j, &b) in rhs.coeffs.iter().enumerate() {
375                if b == 0 {
376                    continue;
377                }
378                let idx = i + j;
379                if idx > self.max_degree {
380                    // Remaining j values only increase idx, so we can stop this inner loop.
381                    break;
382                }
383                let prod = mod_mul(a, b, self.modulus);
384                out[idx] = mod_add(out[idx], prod, self.modulus);
385            }
386        }
387
388        Ok(Self {
389            coeffs: out,
390            modulus: self.modulus,
391            max_degree: self.max_degree,
392        })
393    }
394
395    /// Reduces `self` modulo `modulus_poly`, returning `self mod modulus_poly` in `Z_q[x]`.
396    ///
397    /// # Errors
398    /// Returns:
399    /// - [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ,
400    /// - [`PolynomialError::DivisionByZeroPolynomial`] if `modulus_poly` is zero,
401    /// - [`PolynomialError::NonInvertibleCoefficient`] if leading coefficient of `modulus_poly`
402    ///   is not invertible in `Z_q`.
403    pub fn rem_mod_poly(&self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
404        self.ensure_compatible(modulus_poly)?;
405
406        let mut reduced = self.coeffs.clone();
407        reduce_coefficients_mod_poly(&mut reduced, &modulus_poly.coeffs, self.modulus)?;
408
409        Ok(Self {
410            coeffs: reduced,
411            modulus: self.modulus,
412            max_degree: self.max_degree,
413        })
414    }
415
416    /// Adds two polynomials and reduces modulo `modulus_poly`.
417    pub fn add_mod_poly(&self, rhs: &Self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
418        self.add(rhs)?.rem_mod_poly(modulus_poly)
419    }
420
421    /// Subtracts two polynomials and reduces modulo `modulus_poly`.
422    pub fn sub_mod_poly(&self, rhs: &Self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
423        self.sub(rhs)?.rem_mod_poly(modulus_poly)
424    }
425
426    /// Multiplies two polynomials and reduces modulo `modulus_poly`.
427    ///
428    /// This computes the exact product first (in a temporary buffer sized to exact degree),
429    /// then applies polynomial reduction.
430    ///
431    /// # Errors
432    /// Returns compatibility and reduction-related errors as described by [`Self::rem_mod_poly`].
433    pub fn mul_mod_poly(&self, rhs: &Self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
434        self.ensure_compatible(rhs)?;
435        self.ensure_compatible(modulus_poly)?;
436
437        let (Some(lhs_degree), Some(rhs_degree)) = (self.degree(), rhs.degree()) else {
438            return Self::zero(self.max_degree, self.modulus);
439        };
440
441        let mut product = vec![0_u64; lhs_degree + rhs_degree + 1];
442        for i in 0..=lhs_degree {
443            let a = self.coeffs[i];
444            if a == 0 {
445                continue;
446            }
447            for j in 0..=rhs_degree {
448                let b = rhs.coeffs[j];
449                if b == 0 {
450                    continue;
451                }
452                let idx = i + j;
453                // Exact product accumulation before quotient-ring reduction.
454                let term = mod_mul(a, b, self.modulus);
455                product[idx] = mod_add(product[idx], term, self.modulus);
456            }
457        }
458
459        reduce_coefficients_mod_poly(&mut product, &modulus_poly.coeffs, self.modulus)?;
460
461        let mut out = vec![0_u64; self.max_degree + 1];
462        for (i, coeff) in product.into_iter().enumerate().take(self.max_degree + 1) {
463            out[i] = coeff;
464        }
465
466        Ok(Self {
467            coeffs: out,
468            modulus: self.modulus,
469            max_degree: self.max_degree,
470        })
471    }
472
473    /// Evaluates polynomial at `x` using Horner's method in `Z_q`.
474    pub fn evaluate(&self, x: u64) -> u64 {
475        let x_mod = x % self.modulus;
476        let mut acc = 0_u64;
477
478        for &coeff in self.coeffs.iter().rev() {
479            acc = mod_mul(acc, x_mod, self.modulus);
480            acc = mod_add(acc, coeff, self.modulus);
481        }
482
483        acc
484    }
485
486    /// Returns formal derivative `d/dx(self)` in `Z_q[x]`.
487    pub fn derivative(&self) -> Self {
488        let mut out = vec![0_u64; self.max_degree + 1];
489        for (deg, &coeff) in self.coeffs.iter().enumerate().skip(1) {
490            let factor = deg as u64 % self.modulus;
491            out[deg - 1] = mod_mul(factor, coeff, self.modulus);
492        }
493
494        Self {
495            coeffs: out,
496            modulus: self.modulus,
497            max_degree: self.max_degree,
498        }
499    }
500
501    /// Polynomial long division in `Z_q[x]`: returns `(quotient, remainder)`.
502    ///
503    /// # Errors
504    /// Returns:
505    /// - [`PolynomialError::IncompatiblePolynomials`] when `n` or `q` differ,
506    /// - [`PolynomialError::DivisionByZeroPolynomial`] if divisor is zero,
507    /// - [`PolynomialError::NonInvertibleCoefficient`] if divisor leading coefficient has no
508    ///   inverse in `Z_q`.
509    pub fn div_rem(&self, divisor: &Self) -> Result<(Self, Self), PolynomialError> {
510        self.ensure_compatible(divisor)?;
511        if divisor.is_zero() {
512            return Err(PolynomialError::DivisionByZeroPolynomial);
513        }
514        if self.is_zero() {
515            return Ok((
516                Self::zero(self.max_degree, self.modulus)?,
517                Self::zero(self.max_degree, self.modulus)?,
518            ));
519        }
520
521        let divisor_degree = divisor.degree().expect("checked non-zero divisor");
522        let lead = divisor.coeffs[divisor_degree];
523        let lead_inv =
524            mod_inverse(lead, self.modulus).ok_or(PolynomialError::NonInvertibleCoefficient {
525                coefficient: lead,
526                modulus: self.modulus,
527            })?;
528
529        let mut remainder = self.clone();
530        let mut quotient = Self::zero(self.max_degree, self.modulus)?;
531
532        while let Some(rem_deg) = remainder.degree() {
533            if rem_deg < divisor_degree {
534                break;
535            }
536
537            // Cancel the current highest remainder term.
538            let diff = rem_deg - divisor_degree;
539            let rem_lead = remainder.coeffs[rem_deg];
540            let factor = mod_mul(rem_lead, lead_inv, self.modulus);
541            quotient.coeffs[diff] = mod_add(quotient.coeffs[diff], factor, self.modulus);
542
543            // remainder -= factor * x^diff * divisor
544            for i in 0..=divisor_degree {
545                let idx = i + diff;
546                let scaled = mod_mul(factor, divisor.coeffs[i], self.modulus);
547                remainder.coeffs[idx] = mod_sub(remainder.coeffs[idx], scaled, self.modulus);
548            }
549        }
550
551        Ok((quotient, remainder))
552    }
553
554    /// Ensures both polynomials share identical `(max_degree, modulus)` metadata.
555    fn ensure_compatible(&self, rhs: &Self) -> Result<(), PolynomialError> {
556        if self.modulus != rhs.modulus || self.max_degree != rhs.max_degree {
557            return Err(PolynomialError::IncompatiblePolynomials);
558        }
559        Ok(())
560    }
561}
562
563/// Computes `(a + b) mod modulus` using widened arithmetic to avoid overflow.
564fn mod_add(a: u64, b: u64, modulus: u64) -> u64 {
565    ((a as u128 + b as u128) % modulus as u128) as u64
566}
567
568/// Computes `(a - b) mod modulus` using widened arithmetic to avoid underflow.
569fn mod_sub(a: u64, b: u64, modulus: u64) -> u64 {
570    ((a as u128 + modulus as u128 - b as u128) % modulus as u128) as u64
571}
572
573/// Computes `(a * b) mod modulus` using widened arithmetic to avoid overflow.
574fn mod_mul(a: u64, b: u64, modulus: u64) -> u64 {
575    ((a as u128 * b as u128) % modulus as u128) as u64
576}
577
578/// Computes `base^exponent mod modulus` via binary exponentiation.
579fn mod_pow(mut base: u64, mut exponent: u64, modulus: u64) -> u64 {
580    let mut acc = 1_u64 % modulus;
581    base %= modulus;
582
583    while exponent > 0 {
584        if exponent & 1 == 1 {
585            acc = mod_mul(acc, base, modulus);
586        }
587        base = mod_mul(base, base, modulus);
588        exponent >>= 1;
589    }
590
591    acc
592}
593
594/// Reorders values in-place using bit-reversed indices.
595///
596/// This is the standard preprocessing step for iterative Cooley-Tukey NTT.
597fn bit_reverse_permute(values: &mut [u64]) {
598    let n = values.len();
599    let mut j = 0_usize;
600
601    for i in 1..n {
602        let mut bit = n >> 1;
603        while j & bit != 0 {
604            j ^= bit;
605            bit >>= 1;
606        }
607        j ^= bit;
608
609        if i < j {
610            values.swap(i, j);
611        }
612    }
613}
614
615/// In-place iterative radix-2 NTT.
616///
617/// `root` must be an `n`-th root of unity in `Z_modulus`, where `n = values.len()`.
618fn ntt_in_place(values: &mut [u64], root: u64, modulus: u64) -> Result<(), PolynomialError> {
619    let n = values.len();
620    if n == 0 || !n.is_power_of_two() {
621        return Err(PolynomialError::NttLengthMustBePowerOfTwo(n));
622    }
623
624    // The butterfly stages assume bit-reversed layout.
625    bit_reverse_permute(values);
626
627    let mut len = 2_usize;
628    while len <= n {
629        // Twiddle factor increment for this stage.
630        let wlen = mod_pow(root, (n / len) as u64, modulus);
631        for start in (0..n).step_by(len) {
632            let mut w = 1_u64;
633            for i in 0..(len / 2) {
634                // Cooley-Tukey butterfly:
635                // (u, v) -> (u + w*v, u - w*v)
636                let u = values[start + i];
637                let v = mod_mul(values[start + i + len / 2], w, modulus);
638                values[start + i] = mod_add(u, v, modulus);
639                values[start + i + len / 2] = mod_sub(u, v, modulus);
640                w = mod_mul(w, wlen, modulus);
641            }
642        }
643        len <<= 1;
644    }
645
646    Ok(())
647}
648
649/// Convolution via NTT for coefficient vectors over `Z_modulus`.
650///
651/// `primitive_root` should be a primitive root modulo `modulus`.
652fn convolution_ntt(
653    lhs: &[u64],
654    rhs: &[u64],
655    modulus: u64,
656    primitive_root: u64,
657) -> Result<Vec<u64>, PolynomialError> {
658    if lhs.is_empty() || rhs.is_empty() {
659        return Ok(vec![0]);
660    }
661
662    // Convolution size and transform size.
663    let out_len = lhs.len() + rhs.len() - 1;
664    let ntt_len = out_len.next_power_of_two();
665    if !ntt_len.is_power_of_two() {
666        return Err(PolynomialError::NttLengthMustBePowerOfTwo(ntt_len));
667    }
668
669    let ntt_len_u64 = ntt_len as u64;
670    if (modulus - 1) % ntt_len_u64 != 0 {
671        return Err(PolynomialError::NttLengthUnsupported {
672            length: ntt_len,
673            modulus,
674        });
675    }
676
677    // Derive primitive n-th root of unity from primitive root of the field.
678    let root = mod_pow(primitive_root, (modulus - 1) / ntt_len_u64, modulus);
679    let is_valid_root = mod_pow(root, ntt_len_u64, modulus) == 1
680        && (ntt_len == 1 || mod_pow(root, (ntt_len / 2) as u64, modulus) != 1);
681    if !is_valid_root {
682        return Err(PolynomialError::InvalidPrimitiveRoot {
683            primitive_root,
684            length: ntt_len,
685            modulus,
686        });
687    }
688
689    // Inverse transform constants.
690    let root_inv = mod_inverse(root, modulus).ok_or(PolynomialError::NonInvertibleCoefficient {
691        coefficient: root,
692        modulus,
693    })?;
694    let n_inv = mod_inverse(ntt_len_u64 % modulus, modulus).ok_or(
695        PolynomialError::NonInvertibleCoefficient {
696            coefficient: ntt_len_u64 % modulus,
697            modulus,
698        },
699    )?;
700
701    // Zero-pad inputs to the transform size.
702    let mut fa = vec![0_u64; ntt_len];
703    let mut fb = vec![0_u64; ntt_len];
704    for (i, coeff) in lhs.iter().copied().enumerate() {
705        fa[i] = coeff % modulus;
706    }
707    for (i, coeff) in rhs.iter().copied().enumerate() {
708        fb[i] = coeff % modulus;
709    }
710
711    // Forward transforms.
712    ntt_in_place(&mut fa, root, modulus)?;
713    ntt_in_place(&mut fb, root, modulus)?;
714
715    // Point-wise multiplication in evaluation domain.
716    for (a, b) in fa.iter_mut().zip(fb.iter()) {
717        *a = mod_mul(*a, *b, modulus);
718    }
719
720    // Inverse transform and normalization by 1/n.
721    ntt_in_place(&mut fa, root_inv, modulus)?;
722    for coeff in &mut fa {
723        *coeff = mod_mul(*coeff, n_inv, modulus);
724    }
725    // Trim back to exact convolution length.
726    fa.truncate(out_len);
727
728    Ok(fa)
729}
730
731/// Computes multiplicative inverse of `a mod modulus` using extended Euclid.
732///
733/// Returns `None` if `gcd(a, modulus) != 1`.
734fn mod_inverse(a: u64, modulus: u64) -> Option<u64> {
735    let mut t = 0_i128;
736    let mut new_t = 1_i128;
737    let mut r = modulus as i128;
738    let mut new_r = (a % modulus) as i128;
739
740    while new_r != 0 {
741        let quotient = r / new_r;
742        (t, new_t) = (new_t, t - quotient * new_t);
743        (r, new_r) = (new_r, r - quotient * new_r);
744    }
745
746    if r != 1 {
747        return None;
748    }
749
750    if t < 0 {
751        t += modulus as i128;
752    }
753    Some(t as u64)
754}
755
756/// Returns the highest index whose coefficient is non-zero.
757fn degree_of(coeffs: &[u64]) -> Option<usize> {
758    coeffs.iter().rposition(|&c| c != 0)
759}
760
761/// Reduces a mutable coefficient buffer modulo `modulus_poly` in-place.
762///
763/// This performs polynomial long-division style elimination on the highest degree terms.
764fn reduce_coefficients_mod_poly(
765    coeffs: &mut [u64],
766    modulus_poly: &[u64],
767    modulus: u64,
768) -> Result<(), PolynomialError> {
769    let Some(modulus_degree) = degree_of(modulus_poly) else {
770        return Err(PolynomialError::DivisionByZeroPolynomial);
771    };
772
773    let leading = modulus_poly[modulus_degree];
774    let leading_inverse =
775        mod_inverse(leading, modulus).ok_or(PolynomialError::NonInvertibleCoefficient {
776            coefficient: leading,
777            modulus,
778        })?;
779
780    while let Some(current_degree) = degree_of(coeffs) {
781        if current_degree < modulus_degree {
782            break;
783        }
784
785        // Eliminate the current leading term using shifted/scaled modulus polynomial.
786        let shift = current_degree - modulus_degree;
787        let factor = mod_mul(coeffs[current_degree], leading_inverse, modulus);
788        if factor == 0 {
789            continue;
790        }
791
792        for (i, &modulus_coeff) in modulus_poly.iter().take(modulus_degree + 1).enumerate() {
793            let idx = shift + i;
794            if idx >= coeffs.len() || modulus_coeff == 0 {
795                continue;
796            }
797            let scaled = mod_mul(factor, modulus_coeff, modulus);
798            coeffs[idx] = mod_sub(coeffs[idx], scaled, modulus);
799        }
800    }
801
802    Ok(())
803}
804
805#[cfg(test)]
806mod tests {
807    use super::{Polynomial, PolynomialError, mod_inverse, mod_mul, mod_pow, ntt_in_place};
808
809    fn p(n: usize, q: u64, coeffs: &[u64]) -> Polynomial {
810        Polynomial::new(n, q, coeffs).expect("polynomial should build")
811    }
812
813    fn lcg_next(state: &mut u64) -> u64 {
814        *state = state
815            .wrapping_mul(6_364_136_223_846_793_005)
816            .wrapping_add(1_442_695_040_888_963_407);
817        *state
818    }
819
820    fn enumerate_polynomials(
821        max_degree: usize,
822        modulus: u64,
823        values: &[u64],
824        used_len: usize,
825    ) -> Vec<Polynomial> {
826        assert!(
827            used_len <= max_degree + 1,
828            "used_len exceeds polynomial capacity"
829        );
830        assert!(!values.is_empty(), "values cannot be empty");
831
832        let total = values.len().pow(used_len as u32);
833        let mut out = Vec::with_capacity(total);
834        for mut state in 0..total {
835            let mut coeffs = vec![0_u64; used_len];
836            for slot in &mut coeffs {
837                let digit = state % values.len();
838                *slot = values[digit];
839                state /= values.len();
840            }
841            out.push(p(max_degree, modulus, &coeffs));
842        }
843        out
844    }
845
846    #[test]
847    fn constructor_normalizes_and_pads_coefficients() {
848        let poly = p(5, 7, &[10, 15, 6]);
849        assert_eq!(poly.coefficients(), &[3, 1, 6, 0, 0, 0]);
850    }
851
852    #[test]
853    fn constructor_rejects_invalid_modulus() {
854        let err = Polynomial::new(3, 1, &[1, 2]).expect_err("expected error");
855        assert_eq!(err, PolynomialError::InvalidModulus(1));
856    }
857
858    #[test]
859    fn constructor_rejects_degree_overflow() {
860        let err = Polynomial::new(2, 11, &[1, 2, 3, 4]).expect_err("expected error");
861        assert_eq!(
862            err,
863            PolynomialError::DegreeOverflow {
864                requested: 3,
865                max_degree: 2
866            }
867        );
868    }
869
870    #[test]
871    fn degree_and_zero_behaviour() {
872        let zero = p(4, 13, &[]);
873        assert!(zero.is_zero());
874        assert_eq!(zero.degree(), None);
875        assert_eq!(zero.trimmed_coefficients(), vec![0]);
876
877        let poly = p(4, 13, &[0, 2, 0, 9]);
878        assert!(!poly.is_zero());
879        assert_eq!(poly.degree(), Some(3));
880        assert_eq!(poly.trimmed_coefficients(), vec![0, 2, 0, 9]);
881    }
882
883    #[test]
884    fn addition_and_subtraction_work_modulo_q() {
885        let a = p(5, 17, &[16, 5, 0, 1]);
886        let b = p(5, 17, &[3, 13, 6, 10]);
887
888        let sum = a.add(&b).expect("add should work");
889        assert_eq!(sum.trimmed_coefficients(), vec![2, 1, 6, 11]);
890
891        let diff = a.sub(&b).expect("sub should work");
892        assert_eq!(diff.trimmed_coefficients(), vec![13, 9, 11, 8]);
893    }
894
895    #[test]
896    fn unary_negation_and_scalar_multiplication() {
897        let a = p(4, 19, &[0, 4, 18, 3]);
898        let neg = a.neg();
899        assert_eq!(neg.trimmed_coefficients(), vec![0, 15, 1, 16]);
900
901        let scaled = a.scalar_mul(7);
902        assert_eq!(scaled.trimmed_coefficients(), vec![0, 9, 12, 2]);
903    }
904
905    #[test]
906    fn incompatible_polynomials_fail_for_binary_ops() {
907        let a = p(3, 11, &[1, 2]);
908        let b = p(4, 11, &[1, 2]);
909        let err = a.add(&b).expect_err("expected incompatibility");
910        assert_eq!(err, PolynomialError::IncompatiblePolynomials);
911    }
912
913    #[test]
914    fn multiplication_checked_and_truncated() {
915        let a = p(5, 23, &[1, 2, 3]);
916        let b = p(5, 23, &[4, 5]);
917
918        let prod = a.mul(&b).expect("mul should work");
919        assert_eq!(prod.trimmed_coefficients(), vec![4, 13, 22, 15]);
920
921        let c = p(3, 29, &[1, 2, 3, 4]);
922        let d = p(3, 29, &[1, 1, 1, 1]);
923        let err = c.mul(&d).expect_err("degree should overflow");
924        assert_eq!(
925            err,
926            PolynomialError::ProductDegreeOverflow {
927                degree: 6,
928                max_degree: 3
929            }
930        );
931
932        let truncated = c.mul_truncated(&d).expect("mul_truncated should work");
933        assert_eq!(truncated.coefficients(), &[1, 3, 6, 10]);
934    }
935
936    #[test]
937    fn schoolbook_convolution_matches_expected_coefficients() {
938        let a = p(10, 998_244_353, &[1, 2, 3, 4]);
939        let b = p(10, 998_244_353, &[5, 6, 7]);
940        let product = a.mul(&b).expect("schoolbook multiplication should work");
941
942        // [1,2,3,4] * [5,6,7] = [5,16,34,52,45,28]
943        assert_eq!(product.trimmed_coefficients(), vec![5, 16, 34, 52, 45, 28]);
944    }
945
946    #[test]
947    fn ntt_round_trip_recovers_original_vector() {
948        let modulus = 998_244_353_u64;
949        let primitive_root = 3_u64;
950        let n = 8_usize;
951
952        let omega = mod_pow(primitive_root, (modulus - 1) / n as u64, modulus);
953        let omega_inv = mod_inverse(omega, modulus).expect("omega must be invertible");
954        let n_inv = mod_inverse(n as u64, modulus).expect("length must be invertible");
955
956        let mut values = vec![7, 11, 19, 23, 31, 2, 5, 13];
957        let original = values.clone();
958
959        ntt_in_place(&mut values, omega, modulus).expect("forward NTT should work");
960        ntt_in_place(&mut values, omega_inv, modulus).expect("inverse NTT should work");
961        for value in &mut values {
962            *value = mod_mul(*value, n_inv, modulus);
963        }
964
965        assert_eq!(values, original);
966    }
967
968    #[test]
969    fn ntt_multiplication_matches_schoolbook_convolution() {
970        let modulus = 998_244_353_u64;
971        let primitive_root = 3_u64;
972
973        let a = p(31, modulus, &[4, 1, 9, 16, 25, 36, 49, 64, 81, 100]);
974        let b = p(31, modulus, &[3, 14, 15, 92, 65, 35, 89, 79]);
975
976        let schoolbook = a.mul(&b).expect("schoolbook multiplication should work");
977        let ntt = a
978            .mul_ntt(&b, primitive_root)
979            .expect("NTT multiplication should work");
980
981        assert_eq!(ntt.coefficients(), schoolbook.coefficients());
982    }
983
984    #[test]
985    fn ntt_matches_schoolbook_on_many_deterministic_cases() {
986        let modulus = 998_244_353_u64;
987        let primitive_root = 3_u64;
988        let mut seed = 0xC0FFEE_u64;
989
990        for _ in 0..200 {
991            let len_a = (lcg_next(&mut seed) % 48 + 1) as usize;
992            let len_b = (lcg_next(&mut seed) % 48 + 1) as usize;
993            let max_degree = len_a + len_b;
994
995            let mut coeffs_a = vec![0_u64; len_a];
996            let mut coeffs_b = vec![0_u64; len_b];
997            for coeff in &mut coeffs_a {
998                *coeff = lcg_next(&mut seed) % modulus;
999            }
1000            for coeff in &mut coeffs_b {
1001                *coeff = lcg_next(&mut seed) % modulus;
1002            }
1003
1004            let a = p(max_degree, modulus, &coeffs_a);
1005            let b = p(max_degree, modulus, &coeffs_b);
1006            let schoolbook = a.mul(&b).expect("schoolbook multiplication should work");
1007            let ntt = a
1008                .mul_ntt(&b, primitive_root)
1009                .expect("NTT multiplication should work");
1010
1011            assert_eq!(ntt.coefficients(), schoolbook.coefficients());
1012        }
1013    }
1014
1015    #[test]
1016    fn ntt_errors_are_reported_for_bad_parameters() {
1017        // out_len = 21, next power of two is 32, but 32 does not divide q-1 = 16.
1018        let coeffs = vec![1_u64; 11];
1019        let a = p(20, 17, &coeffs);
1020        let b = p(20, 17, &coeffs);
1021        let err = a
1022            .mul_ntt(&b, 3)
1023            .expect_err("unsupported NTT length should fail");
1024        assert_eq!(
1025            err,
1026            PolynomialError::NttLengthUnsupported {
1027                length: 32,
1028                modulus: 17
1029            }
1030        );
1031
1032        let c = p(16, 998_244_353, &[1, 2, 3, 4]);
1033        let d = p(16, 998_244_353, &[5, 6, 7, 8]);
1034        let err = c
1035            .mul_ntt(&d, 1)
1036            .expect_err("invalid primitive root should fail");
1037        assert_eq!(
1038            err,
1039            PolynomialError::InvalidPrimitiveRoot {
1040                primitive_root: 1,
1041                length: 8,
1042                modulus: 998_244_353
1043            }
1044        );
1045    }
1046
1047    #[test]
1048    fn remainder_mod_polynomial_reduces_degree() {
1049        // Modulus m(x) = x^4 + 1 over mod 17.
1050        let m = p(6, 17, &[1, 0, 0, 0, 1]);
1051        let a = p(6, 17, &[2, 0, 0, 0, 1, 3]); // 2 + x^4 + 3x^5
1052        let reduced = a.rem_mod_poly(&m).expect("reduction should work");
1053
1054        // x^4 = -1 and x^5 = -x, so result is 1 - 3x = 1 + 14x mod 17.
1055        assert_eq!(reduced.trimmed_coefficients(), vec![1, 14]);
1056    }
1057
1058    #[test]
1059    fn quotient_ring_add_sub_and_mul() {
1060        // Work in Z_17[x] / (x^4 + 1).
1061        let m = p(4, 17, &[1, 0, 0, 0, 1]);
1062        let a = p(4, 17, &[0, 0, 0, 16]); // -x^3
1063        let b = p(4, 17, &[0, 0, 0, 2]); // 2x^3
1064
1065        let added = a.add_mod_poly(&b, &m).expect("add mod poly should work");
1066        assert_eq!(added.trimmed_coefficients(), vec![0, 0, 0, 1]);
1067
1068        let subbed = a.sub_mod_poly(&b, &m).expect("sub mod poly should work");
1069        assert_eq!(subbed.trimmed_coefficients(), vec![0, 0, 0, 14]);
1070
1071        let c = p(4, 17, &[1, 0, 0, 1]); // 1 + x^3
1072        let d = p(4, 17, &[1, 0, 0, 1]); // 1 + x^3
1073        let err = c.mul(&d).expect_err("plain mul overflows max_degree");
1074        assert_eq!(
1075            err,
1076            PolynomialError::ProductDegreeOverflow {
1077                degree: 6,
1078                max_degree: 4
1079            }
1080        );
1081
1082        let ring_product = c.mul_mod_poly(&d, &m).expect("mul mod poly should work");
1083        // (1 + x^3)^2 = 1 + 2x^3 + x^6, x^6 = -x^2 over x^4 + 1.
1084        assert_eq!(ring_product.trimmed_coefficients(), vec![1, 0, 16, 2]);
1085    }
1086
1087    #[test]
1088    fn polynomial_modulus_errors_are_reported() {
1089        let a = p(4, 11, &[1, 2, 3]);
1090        let b = p(4, 11, &[3, 4]);
1091        let zero_poly = p(4, 11, &[]);
1092        let err = a
1093            .mul_mod_poly(&b, &zero_poly)
1094            .expect_err("zero modulus polynomial should fail");
1095        assert_eq!(err, PolynomialError::DivisionByZeroPolynomial);
1096
1097        // Leading coefficient 2 has no inverse mod 8.
1098        let x = p(4, 8, &[0, 1]);
1099        let non_invertible_modulus = p(4, 8, &[1, 2]); // 1 + 2x
1100        let err = x
1101            .rem_mod_poly(&non_invertible_modulus)
1102            .expect_err("non-invertible modulus lead should fail");
1103        assert_eq!(
1104            err,
1105            PolynomialError::NonInvertibleCoefficient {
1106                coefficient: 2,
1107                modulus: 8
1108            }
1109        );
1110
1111        let c = p(5, 11, &[1, 2]);
1112        let err = a
1113            .add_mod_poly(&b, &c)
1114            .expect_err("incompatible polynomial settings should fail");
1115        assert_eq!(err, PolynomialError::IncompatiblePolynomials);
1116    }
1117
1118    #[test]
1119    fn evaluate_uses_horner_rule_modulo_q() {
1120        let poly = p(4, 31, &[7, 0, 3, 4]); // 7 + 3x^2 + 4x^3
1121        let value = poly.evaluate(10);
1122        // 7 + 3*100 + 4*1000 = 4307, 4307 mod 31 = 29
1123        assert_eq!(value, 29);
1124    }
1125
1126    #[test]
1127    fn derivative_is_computed_modulo_q() {
1128        let poly = p(6, 11, &[3, 5, 7, 9, 2]); // 3 + 5x + 7x^2 + 9x^3 + 2x^4
1129        let deriv = poly.derivative(); // 5 + 14x + 27x^2 + 8x^3 mod 11
1130        assert_eq!(deriv.trimmed_coefficients(), vec![5, 3, 5, 8]);
1131    }
1132
1133    #[test]
1134    fn long_division_exact_case() {
1135        // (x^3 + 2x^2 + 6x + 5) / (x + 1) over mod 7 = x^2 + x + 5, remainder 0
1136        let dividend = p(5, 7, &[5, 6, 2, 1]);
1137        let divisor = p(5, 7, &[1, 1]);
1138        let (quotient, remainder) = dividend.div_rem(&divisor).expect("division should work");
1139
1140        assert_eq!(quotient.trimmed_coefficients(), vec![5, 1, 1]);
1141        assert!(remainder.is_zero());
1142    }
1143
1144    #[test]
1145    fn long_division_with_remainder() {
1146        // (3 + 0x + x^2) / (2 + x) over mod 5
1147        let dividend = p(4, 5, &[3, 0, 1]);
1148        let divisor = p(4, 5, &[2, 1]);
1149        let (quotient, remainder) = dividend.div_rem(&divisor).expect("division should work");
1150
1151        assert_eq!(quotient.trimmed_coefficients(), vec![3, 1]);
1152        assert_eq!(remainder.trimmed_coefficients(), vec![2]);
1153
1154        // Validate dividend = divisor * quotient + remainder
1155        let reconstructed = divisor
1156            .mul(&quotient)
1157            .expect("reconstruction product should fit")
1158            .add(&remainder)
1159            .expect("reconstruction sum should fit");
1160        assert_eq!(reconstructed.coefficients(), dividend.coefficients());
1161    }
1162
1163    #[test]
1164    fn division_errors_are_reported() {
1165        let a = p(4, 9, &[1, 2, 3]);
1166        let zero = p(4, 9, &[]);
1167        let err = a.div_rem(&zero).expect_err("division by zero should fail");
1168        assert_eq!(err, PolynomialError::DivisionByZeroPolynomial);
1169
1170        // Leading coefficient is 2, not invertible mod 8.
1171        let dividend = p(4, 8, &[1, 0, 1]);
1172        let divisor = p(4, 8, &[0, 2]);
1173        let err = dividend
1174            .div_rem(&divisor)
1175            .expect_err("division should fail when inverse does not exist");
1176        assert_eq!(
1177            err,
1178            PolynomialError::NonInvertibleCoefficient {
1179                coefficient: 2,
1180                modulus: 8
1181            }
1182        );
1183    }
1184
1185    #[test]
1186    fn mod_q_algebraic_laws_hold_on_small_exhaustive_domain() {
1187        let modulus = 5_u64;
1188        let max_degree = 2_usize;
1189        let all = enumerate_polynomials(max_degree, modulus, &[0, 1, 2, 3, 4], 3);
1190        let reps = enumerate_polynomials(max_degree, modulus, &[0, 1, 4], 3);
1191        let zero = p(max_degree, modulus, &[]);
1192        let one = p(max_degree, modulus, &[1]);
1193
1194        for a in &all {
1195            assert_eq!(a.add(&zero).expect("a + 0 should work"), *a);
1196            assert_eq!(zero.add(a).expect("0 + a should work"), *a);
1197            assert_eq!(a.sub(&zero).expect("a - 0 should work"), *a);
1198            assert!(
1199                a.add(&a.neg()).expect("a + (-a) should work").is_zero(),
1200                "additive inverse should cancel"
1201            );
1202            assert_eq!(a.scalar_mul(modulus + 2), a.scalar_mul(2));
1203            assert_eq!(a.scalar_mul(0), zero);
1204            assert_eq!(
1205                a.mul_truncated(&one).expect("a * 1 should work"),
1206                *a,
1207                "multiplicative identity should hold"
1208            );
1209        }
1210
1211        for a in &all {
1212            for b in &all {
1213                assert_eq!(
1214                    a.add(b).expect("a+b should work"),
1215                    b.add(a).expect("b+a should work"),
1216                    "addition should commute"
1217                );
1218                assert_eq!(
1219                    a.mul_truncated(b).expect("a*b should work"),
1220                    b.mul_truncated(a).expect("b*a should work"),
1221                    "truncated multiplication should commute"
1222                );
1223                assert_eq!(
1224                    a.sub(b).expect("a-b should work"),
1225                    a.add(&b.neg()).expect("a+(-b) should work"),
1226                    "subtraction should match addition with negation"
1227                );
1228            }
1229        }
1230
1231        for a in &reps {
1232            for b in &reps {
1233                for c in &reps {
1234                    assert_eq!(
1235                        a.add(&b.add(c).expect("b+c should work"))
1236                            .expect("a+(b+c) should work"),
1237                        a.add(b)
1238                            .expect("a+b should work")
1239                            .add(c)
1240                            .expect("(a+b)+c should work"),
1241                        "addition should associate"
1242                    );
1243
1244                    let lhs = a
1245                        .mul_truncated(&b.add(c).expect("b+c should work"))
1246                        .expect("a*(b+c) should work");
1247                    let rhs = a
1248                        .mul_truncated(b)
1249                        .expect("a*b should work")
1250                        .add(&a.mul_truncated(c).expect("a*c should work"))
1251                        .expect("ab+ac should work");
1252                    assert_eq!(lhs, rhs, "multiplication should distribute over addition");
1253                }
1254            }
1255        }
1256    }
1257
1258    #[test]
1259    fn evaluate_matches_naive_formula_on_exhaustive_small_domain() {
1260        let modulus = 11_u64;
1261        let max_degree = 3_usize;
1262        let polys = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 4);
1263
1264        for poly in &polys {
1265            for x in 0..modulus {
1266                let mut acc = 0_u64;
1267                let mut x_pow = 1_u64;
1268                for &coeff in poly.coefficients() {
1269                    acc = (acc + (coeff * x_pow) % modulus) % modulus;
1270                    x_pow = (x_pow * x) % modulus;
1271                }
1272                assert_eq!(poly.evaluate(x), acc);
1273            }
1274        }
1275    }
1276
1277    #[test]
1278    fn ntt_matches_schoolbook_exhaustive_small_domains() {
1279        // Degree <= 1 under q=5 gives out_len <= 3 -> ntt_len=4, supported since 4 | (5-1).
1280        let small_q_polys = enumerate_polynomials(2, 5, &[0, 1, 2, 3, 4], 2);
1281        for a in &small_q_polys {
1282            for b in &small_q_polys {
1283                let schoolbook = a.mul(b).expect("schoolbook should work");
1284                let ntt = a.mul_ntt(b, 2).expect("NTT should work");
1285                assert_eq!(ntt.coefficients(), schoolbook.coefficients());
1286            }
1287        }
1288
1289        // Degree <= 2 under q=17 gives out_len <= 5 -> ntt_len=8, supported since 8 | (17-1).
1290        let larger_polys = enumerate_polynomials(4, 17, &[0, 1, 2, 3], 3);
1291        for a in &larger_polys {
1292            for b in &larger_polys {
1293                let schoolbook = a.mul(b).expect("schoolbook should work");
1294                let ntt = a.mul_ntt(b, 3).expect("NTT should work");
1295                assert_eq!(ntt.coefficients(), schoolbook.coefficients());
1296            }
1297        }
1298    }
1299
1300    #[test]
1301    fn mod_polynomial_reduction_is_bounded_and_idempotent() {
1302        let modulus = 17_u64;
1303        let max_degree = 4_usize;
1304        let modulus_poly = p(max_degree, modulus, &[1, 0, 0, 0, 1]); // x^4 + 1
1305        let all = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 5);
1306        let sample: Vec<_> = all.iter().take(64).cloned().collect();
1307
1308        for poly in &all {
1309            let reduced = poly
1310                .rem_mod_poly(&modulus_poly)
1311                .expect("reduction should succeed");
1312            if let Some(degree) = reduced.degree() {
1313                assert!(
1314                    degree < 4,
1315                    "canonical representative degree must be < deg(modulus)"
1316                );
1317            }
1318            assert_eq!(
1319                reduced,
1320                reduced
1321                    .rem_mod_poly(&modulus_poly)
1322                    .expect("reduction idempotence should hold")
1323            );
1324        }
1325
1326        for a in &sample {
1327            for b in &sample {
1328                let lhs_add = a
1329                    .add_mod_poly(b, &modulus_poly)
1330                    .expect("add mod poly should work");
1331                let rhs_add = a
1332                    .rem_mod_poly(&modulus_poly)
1333                    .expect("reduction should work")
1334                    .add_mod_poly(
1335                        &b.rem_mod_poly(&modulus_poly)
1336                            .expect("reduction should work"),
1337                        &modulus_poly,
1338                    )
1339                    .expect("add mod poly should work");
1340                assert_eq!(lhs_add, rhs_add);
1341
1342                let lhs_mul = a
1343                    .mul_mod_poly(b, &modulus_poly)
1344                    .expect("mul mod poly should work");
1345                let rhs_mul = a
1346                    .rem_mod_poly(&modulus_poly)
1347                    .expect("reduction should work")
1348                    .mul_mod_poly(
1349                        &b.rem_mod_poly(&modulus_poly)
1350                            .expect("reduction should work"),
1351                        &modulus_poly,
1352                    )
1353                    .expect("mul mod poly should work");
1354                assert_eq!(lhs_mul, rhs_mul);
1355            }
1356        }
1357    }
1358
1359    #[test]
1360    fn long_division_invariants_hold_on_exhaustive_small_domain() {
1361        let modulus = 5_u64;
1362        let max_degree = 4_usize;
1363        let dividends = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 5);
1364        let divisors: Vec<_> = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 3)
1365            .into_iter()
1366            .filter(|poly| !poly.is_zero())
1367            .collect();
1368
1369        for dividend in &dividends {
1370            for divisor in &divisors {
1371                let (quotient, remainder) =
1372                    dividend.div_rem(divisor).expect("division should succeed");
1373
1374                let reconstructed = divisor
1375                    .mul(&quotient)
1376                    .expect("product should fit in max degree")
1377                    .add(&remainder)
1378                    .expect("sum should fit in max degree");
1379                assert_eq!(reconstructed.coefficients(), dividend.coefficients());
1380
1381                if let Some(rem_degree) = remainder.degree() {
1382                    let divisor_degree = divisor.degree().expect("divisors are non-zero");
1383                    assert!(
1384                        rem_degree < divisor_degree,
1385                        "remainder degree must be strictly less than divisor degree"
1386                    );
1387                }
1388            }
1389        }
1390    }
1391}