1use core::fmt;
32
33#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct Polynomial {
39 coeffs: Vec<u64>,
41 modulus: u64,
43 max_degree: usize,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum PolynomialError {
50 InvalidModulus(u64),
52 DegreeOverflow { requested: usize, max_degree: usize },
54 IncompatiblePolynomials,
56 ProductDegreeOverflow { degree: usize, max_degree: usize },
58 NttLengthMustBePowerOfTwo(usize),
60 NttLengthUnsupported { length: usize, modulus: u64 },
62 InvalidPrimitiveRoot {
64 primitive_root: u64,
65 length: usize,
66 modulus: u64,
67 },
68 DivisionByZeroPolynomial,
70 NonInvertibleCoefficient { coefficient: u64, modulus: u64 },
72}
73
74impl fmt::Display for PolynomialError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::InvalidModulus(m) => write!(f, "invalid modulus {m}, expected q >= 2"),
78 Self::DegreeOverflow {
79 requested,
80 max_degree,
81 } => write!(
82 f,
83 "requested degree {requested} exceeds max_degree {max_degree}"
84 ),
85 Self::IncompatiblePolynomials => {
86 write!(f, "polynomials are incompatible (different n or q)")
87 }
88 Self::ProductDegreeOverflow { degree, max_degree } => {
89 write!(f, "product degree {degree} exceeds max_degree {max_degree}")
90 }
91 Self::NttLengthMustBePowerOfTwo(length) => {
92 write!(f, "NTT length {length} must be a non-zero power of two")
93 }
94 Self::NttLengthUnsupported { length, modulus } => write!(
95 f,
96 "NTT length {length} is not supported by modulus {modulus} (length must divide q-1)"
97 ),
98 Self::InvalidPrimitiveRoot {
99 primitive_root,
100 length,
101 modulus,
102 } => write!(
103 f,
104 "primitive root {primitive_root} is invalid for NTT length {length} under modulus {modulus}"
105 ),
106 Self::DivisionByZeroPolynomial => write!(f, "division by zero polynomial"),
107 Self::NonInvertibleCoefficient {
108 coefficient,
109 modulus,
110 } => write!(
111 f,
112 "coefficient {coefficient} has no multiplicative inverse modulo {modulus}"
113 ),
114 }
115 }
116}
117
118impl std::error::Error for PolynomialError {}
119
120impl Polynomial {
121 pub fn new(max_degree: usize, modulus: u64, coeffs: &[u64]) -> Result<Self, PolynomialError> {
131 if modulus < 2 {
132 return Err(PolynomialError::InvalidModulus(modulus));
133 }
134 if coeffs.len() > max_degree + 1 {
135 return Err(PolynomialError::DegreeOverflow {
136 requested: coeffs.len().saturating_sub(1),
137 max_degree,
138 });
139 }
140
141 let mut normalized = vec![0_u64; max_degree + 1];
142 for (i, value) in coeffs.iter().copied().enumerate() {
143 normalized[i] = value % modulus;
144 }
145
146 Ok(Self {
147 coeffs: normalized,
148 modulus,
149 max_degree,
150 })
151 }
152
153 pub fn zero(max_degree: usize, modulus: u64) -> Result<Self, PolynomialError> {
155 Self::new(max_degree, modulus, &[])
156 }
157
158 pub fn one(max_degree: usize, modulus: u64) -> Result<Self, PolynomialError> {
160 Self::new(max_degree, modulus, &[1])
161 }
162
163 pub fn max_degree(&self) -> usize {
165 self.max_degree
166 }
167
168 pub fn modulus(&self) -> u64 {
170 self.modulus
171 }
172
173 pub fn coefficients(&self) -> &[u64] {
175 &self.coeffs
176 }
177
178 pub fn trimmed_coefficients(&self) -> Vec<u64> {
182 match self.degree() {
183 Some(d) => self.coeffs[..=d].to_vec(),
184 None => vec![0],
185 }
186 }
187
188 pub fn coeff(&self, degree: usize) -> Option<u64> {
190 self.coeffs.get(degree).copied()
191 }
192
193 pub fn degree(&self) -> Option<usize> {
195 self.coeffs.iter().rposition(|&c| c != 0)
196 }
197
198 pub fn is_zero(&self) -> bool {
200 self.degree().is_none()
201 }
202
203 pub fn add(&self, rhs: &Self) -> Result<Self, PolynomialError> {
208 self.ensure_compatible(rhs)?;
209
210 let mut out = vec![0_u64; self.max_degree + 1];
211 for (i, slot) in out.iter_mut().enumerate() {
212 *slot = mod_add(self.coeffs[i], rhs.coeffs[i], self.modulus);
213 }
214
215 Ok(Self {
216 coeffs: out,
217 modulus: self.modulus,
218 max_degree: self.max_degree,
219 })
220 }
221
222 pub fn sub(&self, rhs: &Self) -> Result<Self, PolynomialError> {
227 self.ensure_compatible(rhs)?;
228
229 let mut out = vec![0_u64; self.max_degree + 1];
230 for (i, slot) in out.iter_mut().enumerate() {
231 *slot = mod_sub(self.coeffs[i], rhs.coeffs[i], self.modulus);
232 }
233
234 Ok(Self {
235 coeffs: out,
236 modulus: self.modulus,
237 max_degree: self.max_degree,
238 })
239 }
240
241 pub fn neg(&self) -> Self {
243 let mut out = vec![0_u64; self.max_degree + 1];
244 for (i, slot) in out.iter_mut().enumerate() {
245 let c = self.coeffs[i];
246 *slot = if c == 0 { 0 } else { self.modulus - c };
247 }
248
249 Self {
250 coeffs: out,
251 modulus: self.modulus,
252 max_degree: self.max_degree,
253 }
254 }
255
256 pub fn scalar_mul(&self, scalar: u64) -> Self {
258 let mut out = vec![0_u64; self.max_degree + 1];
259 let reduced_scalar = scalar % self.modulus;
260
261 for (i, slot) in out.iter_mut().enumerate() {
262 *slot = mod_mul(self.coeffs[i], reduced_scalar, self.modulus);
263 }
264
265 Self {
266 coeffs: out,
267 modulus: self.modulus,
268 max_degree: self.max_degree,
269 }
270 }
271
272 pub fn mul(&self, rhs: &Self) -> Result<Self, PolynomialError> {
279 self.ensure_compatible(rhs)?;
280
281 if let (Some(d1), Some(d2)) = (self.degree(), rhs.degree()) {
282 let d = d1 + d2;
283 if d > self.max_degree {
284 return Err(PolynomialError::ProductDegreeOverflow {
285 degree: d,
286 max_degree: self.max_degree,
287 });
288 }
289 }
290
291 let mut out = vec![0_u64; self.max_degree + 1];
292 for (i, &a) in self.coeffs.iter().enumerate() {
293 if a == 0 {
294 continue;
295 }
296 for (j, &b) in rhs.coeffs.iter().enumerate() {
297 if b == 0 {
298 continue;
299 }
300 let idx = i + j;
301 if idx > self.max_degree {
302 continue;
303 }
304 let prod = mod_mul(a, b, self.modulus);
306 out[idx] = mod_add(out[idx], prod, self.modulus);
307 }
308 }
309
310 Ok(Self {
311 coeffs: out,
312 modulus: self.modulus,
313 max_degree: self.max_degree,
314 })
315 }
316
317 pub fn mul_ntt(&self, rhs: &Self, primitive_root: u64) -> Result<Self, PolynomialError> {
328 self.ensure_compatible(rhs)?;
329
330 let (Some(lhs_degree), Some(rhs_degree)) = (self.degree(), rhs.degree()) else {
331 return Self::zero(self.max_degree, self.modulus);
333 };
334
335 let degree = lhs_degree + rhs_degree;
336 if degree > self.max_degree {
337 return Err(PolynomialError::ProductDegreeOverflow {
338 degree,
339 max_degree: self.max_degree,
340 });
341 }
342
343 let lhs = &self.coeffs[..=lhs_degree];
344 let rhs = &rhs.coeffs[..=rhs_degree];
345 let product = convolution_ntt(lhs, rhs, self.modulus, primitive_root)?;
346
347 let mut out = vec![0_u64; self.max_degree + 1];
348 for (i, coeff) in product.into_iter().enumerate() {
349 out[i] = coeff;
350 }
351
352 Ok(Self {
353 coeffs: out,
354 modulus: self.modulus,
355 max_degree: self.max_degree,
356 })
357 }
358
359 pub fn mul_truncated(&self, rhs: &Self) -> Result<Self, PolynomialError> {
367 self.ensure_compatible(rhs)?;
368
369 let mut out = vec![0_u64; self.max_degree + 1];
370 for (i, &a) in self.coeffs.iter().enumerate() {
371 if a == 0 {
372 continue;
373 }
374 for (j, &b) in rhs.coeffs.iter().enumerate() {
375 if b == 0 {
376 continue;
377 }
378 let idx = i + j;
379 if idx > self.max_degree {
380 break;
382 }
383 let prod = mod_mul(a, b, self.modulus);
384 out[idx] = mod_add(out[idx], prod, self.modulus);
385 }
386 }
387
388 Ok(Self {
389 coeffs: out,
390 modulus: self.modulus,
391 max_degree: self.max_degree,
392 })
393 }
394
395 pub fn rem_mod_poly(&self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
404 self.ensure_compatible(modulus_poly)?;
405
406 let mut reduced = self.coeffs.clone();
407 reduce_coefficients_mod_poly(&mut reduced, &modulus_poly.coeffs, self.modulus)?;
408
409 Ok(Self {
410 coeffs: reduced,
411 modulus: self.modulus,
412 max_degree: self.max_degree,
413 })
414 }
415
416 pub fn add_mod_poly(&self, rhs: &Self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
418 self.add(rhs)?.rem_mod_poly(modulus_poly)
419 }
420
421 pub fn sub_mod_poly(&self, rhs: &Self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
423 self.sub(rhs)?.rem_mod_poly(modulus_poly)
424 }
425
426 pub fn mul_mod_poly(&self, rhs: &Self, modulus_poly: &Self) -> Result<Self, PolynomialError> {
434 self.ensure_compatible(rhs)?;
435 self.ensure_compatible(modulus_poly)?;
436
437 let (Some(lhs_degree), Some(rhs_degree)) = (self.degree(), rhs.degree()) else {
438 return Self::zero(self.max_degree, self.modulus);
439 };
440
441 let mut product = vec![0_u64; lhs_degree + rhs_degree + 1];
442 for i in 0..=lhs_degree {
443 let a = self.coeffs[i];
444 if a == 0 {
445 continue;
446 }
447 for j in 0..=rhs_degree {
448 let b = rhs.coeffs[j];
449 if b == 0 {
450 continue;
451 }
452 let idx = i + j;
453 let term = mod_mul(a, b, self.modulus);
455 product[idx] = mod_add(product[idx], term, self.modulus);
456 }
457 }
458
459 reduce_coefficients_mod_poly(&mut product, &modulus_poly.coeffs, self.modulus)?;
460
461 let mut out = vec![0_u64; self.max_degree + 1];
462 for (i, coeff) in product.into_iter().enumerate().take(self.max_degree + 1) {
463 out[i] = coeff;
464 }
465
466 Ok(Self {
467 coeffs: out,
468 modulus: self.modulus,
469 max_degree: self.max_degree,
470 })
471 }
472
473 pub fn evaluate(&self, x: u64) -> u64 {
475 let x_mod = x % self.modulus;
476 let mut acc = 0_u64;
477
478 for &coeff in self.coeffs.iter().rev() {
479 acc = mod_mul(acc, x_mod, self.modulus);
480 acc = mod_add(acc, coeff, self.modulus);
481 }
482
483 acc
484 }
485
486 pub fn derivative(&self) -> Self {
488 let mut out = vec![0_u64; self.max_degree + 1];
489 for (deg, &coeff) in self.coeffs.iter().enumerate().skip(1) {
490 let factor = deg as u64 % self.modulus;
491 out[deg - 1] = mod_mul(factor, coeff, self.modulus);
492 }
493
494 Self {
495 coeffs: out,
496 modulus: self.modulus,
497 max_degree: self.max_degree,
498 }
499 }
500
501 pub fn div_rem(&self, divisor: &Self) -> Result<(Self, Self), PolynomialError> {
510 self.ensure_compatible(divisor)?;
511 if divisor.is_zero() {
512 return Err(PolynomialError::DivisionByZeroPolynomial);
513 }
514 if self.is_zero() {
515 return Ok((
516 Self::zero(self.max_degree, self.modulus)?,
517 Self::zero(self.max_degree, self.modulus)?,
518 ));
519 }
520
521 let divisor_degree = divisor.degree().expect("checked non-zero divisor");
522 let lead = divisor.coeffs[divisor_degree];
523 let lead_inv =
524 mod_inverse(lead, self.modulus).ok_or(PolynomialError::NonInvertibleCoefficient {
525 coefficient: lead,
526 modulus: self.modulus,
527 })?;
528
529 let mut remainder = self.clone();
530 let mut quotient = Self::zero(self.max_degree, self.modulus)?;
531
532 while let Some(rem_deg) = remainder.degree() {
533 if rem_deg < divisor_degree {
534 break;
535 }
536
537 let diff = rem_deg - divisor_degree;
539 let rem_lead = remainder.coeffs[rem_deg];
540 let factor = mod_mul(rem_lead, lead_inv, self.modulus);
541 quotient.coeffs[diff] = mod_add(quotient.coeffs[diff], factor, self.modulus);
542
543 for i in 0..=divisor_degree {
545 let idx = i + diff;
546 let scaled = mod_mul(factor, divisor.coeffs[i], self.modulus);
547 remainder.coeffs[idx] = mod_sub(remainder.coeffs[idx], scaled, self.modulus);
548 }
549 }
550
551 Ok((quotient, remainder))
552 }
553
554 fn ensure_compatible(&self, rhs: &Self) -> Result<(), PolynomialError> {
556 if self.modulus != rhs.modulus || self.max_degree != rhs.max_degree {
557 return Err(PolynomialError::IncompatiblePolynomials);
558 }
559 Ok(())
560 }
561}
562
563fn mod_add(a: u64, b: u64, modulus: u64) -> u64 {
565 ((a as u128 + b as u128) % modulus as u128) as u64
566}
567
568fn mod_sub(a: u64, b: u64, modulus: u64) -> u64 {
570 ((a as u128 + modulus as u128 - b as u128) % modulus as u128) as u64
571}
572
573fn mod_mul(a: u64, b: u64, modulus: u64) -> u64 {
575 ((a as u128 * b as u128) % modulus as u128) as u64
576}
577
578fn mod_pow(mut base: u64, mut exponent: u64, modulus: u64) -> u64 {
580 let mut acc = 1_u64 % modulus;
581 base %= modulus;
582
583 while exponent > 0 {
584 if exponent & 1 == 1 {
585 acc = mod_mul(acc, base, modulus);
586 }
587 base = mod_mul(base, base, modulus);
588 exponent >>= 1;
589 }
590
591 acc
592}
593
594fn bit_reverse_permute(values: &mut [u64]) {
598 let n = values.len();
599 let mut j = 0_usize;
600
601 for i in 1..n {
602 let mut bit = n >> 1;
603 while j & bit != 0 {
604 j ^= bit;
605 bit >>= 1;
606 }
607 j ^= bit;
608
609 if i < j {
610 values.swap(i, j);
611 }
612 }
613}
614
615fn ntt_in_place(values: &mut [u64], root: u64, modulus: u64) -> Result<(), PolynomialError> {
619 let n = values.len();
620 if n == 0 || !n.is_power_of_two() {
621 return Err(PolynomialError::NttLengthMustBePowerOfTwo(n));
622 }
623
624 bit_reverse_permute(values);
626
627 let mut len = 2_usize;
628 while len <= n {
629 let wlen = mod_pow(root, (n / len) as u64, modulus);
631 for start in (0..n).step_by(len) {
632 let mut w = 1_u64;
633 for i in 0..(len / 2) {
634 let u = values[start + i];
637 let v = mod_mul(values[start + i + len / 2], w, modulus);
638 values[start + i] = mod_add(u, v, modulus);
639 values[start + i + len / 2] = mod_sub(u, v, modulus);
640 w = mod_mul(w, wlen, modulus);
641 }
642 }
643 len <<= 1;
644 }
645
646 Ok(())
647}
648
649fn convolution_ntt(
653 lhs: &[u64],
654 rhs: &[u64],
655 modulus: u64,
656 primitive_root: u64,
657) -> Result<Vec<u64>, PolynomialError> {
658 if lhs.is_empty() || rhs.is_empty() {
659 return Ok(vec![0]);
660 }
661
662 let out_len = lhs.len() + rhs.len() - 1;
664 let ntt_len = out_len.next_power_of_two();
665 if !ntt_len.is_power_of_two() {
666 return Err(PolynomialError::NttLengthMustBePowerOfTwo(ntt_len));
667 }
668
669 let ntt_len_u64 = ntt_len as u64;
670 if (modulus - 1) % ntt_len_u64 != 0 {
671 return Err(PolynomialError::NttLengthUnsupported {
672 length: ntt_len,
673 modulus,
674 });
675 }
676
677 let root = mod_pow(primitive_root, (modulus - 1) / ntt_len_u64, modulus);
679 let is_valid_root = mod_pow(root, ntt_len_u64, modulus) == 1
680 && (ntt_len == 1 || mod_pow(root, (ntt_len / 2) as u64, modulus) != 1);
681 if !is_valid_root {
682 return Err(PolynomialError::InvalidPrimitiveRoot {
683 primitive_root,
684 length: ntt_len,
685 modulus,
686 });
687 }
688
689 let root_inv = mod_inverse(root, modulus).ok_or(PolynomialError::NonInvertibleCoefficient {
691 coefficient: root,
692 modulus,
693 })?;
694 let n_inv = mod_inverse(ntt_len_u64 % modulus, modulus).ok_or(
695 PolynomialError::NonInvertibleCoefficient {
696 coefficient: ntt_len_u64 % modulus,
697 modulus,
698 },
699 )?;
700
701 let mut fa = vec![0_u64; ntt_len];
703 let mut fb = vec![0_u64; ntt_len];
704 for (i, coeff) in lhs.iter().copied().enumerate() {
705 fa[i] = coeff % modulus;
706 }
707 for (i, coeff) in rhs.iter().copied().enumerate() {
708 fb[i] = coeff % modulus;
709 }
710
711 ntt_in_place(&mut fa, root, modulus)?;
713 ntt_in_place(&mut fb, root, modulus)?;
714
715 for (a, b) in fa.iter_mut().zip(fb.iter()) {
717 *a = mod_mul(*a, *b, modulus);
718 }
719
720 ntt_in_place(&mut fa, root_inv, modulus)?;
722 for coeff in &mut fa {
723 *coeff = mod_mul(*coeff, n_inv, modulus);
724 }
725 fa.truncate(out_len);
727
728 Ok(fa)
729}
730
731fn mod_inverse(a: u64, modulus: u64) -> Option<u64> {
735 let mut t = 0_i128;
736 let mut new_t = 1_i128;
737 let mut r = modulus as i128;
738 let mut new_r = (a % modulus) as i128;
739
740 while new_r != 0 {
741 let quotient = r / new_r;
742 (t, new_t) = (new_t, t - quotient * new_t);
743 (r, new_r) = (new_r, r - quotient * new_r);
744 }
745
746 if r != 1 {
747 return None;
748 }
749
750 if t < 0 {
751 t += modulus as i128;
752 }
753 Some(t as u64)
754}
755
756fn degree_of(coeffs: &[u64]) -> Option<usize> {
758 coeffs.iter().rposition(|&c| c != 0)
759}
760
761fn reduce_coefficients_mod_poly(
765 coeffs: &mut [u64],
766 modulus_poly: &[u64],
767 modulus: u64,
768) -> Result<(), PolynomialError> {
769 let Some(modulus_degree) = degree_of(modulus_poly) else {
770 return Err(PolynomialError::DivisionByZeroPolynomial);
771 };
772
773 let leading = modulus_poly[modulus_degree];
774 let leading_inverse =
775 mod_inverse(leading, modulus).ok_or(PolynomialError::NonInvertibleCoefficient {
776 coefficient: leading,
777 modulus,
778 })?;
779
780 while let Some(current_degree) = degree_of(coeffs) {
781 if current_degree < modulus_degree {
782 break;
783 }
784
785 let shift = current_degree - modulus_degree;
787 let factor = mod_mul(coeffs[current_degree], leading_inverse, modulus);
788 if factor == 0 {
789 continue;
790 }
791
792 for (i, &modulus_coeff) in modulus_poly.iter().take(modulus_degree + 1).enumerate() {
793 let idx = shift + i;
794 if idx >= coeffs.len() || modulus_coeff == 0 {
795 continue;
796 }
797 let scaled = mod_mul(factor, modulus_coeff, modulus);
798 coeffs[idx] = mod_sub(coeffs[idx], scaled, modulus);
799 }
800 }
801
802 Ok(())
803}
804
805#[cfg(test)]
806mod tests {
807 use super::{Polynomial, PolynomialError, mod_inverse, mod_mul, mod_pow, ntt_in_place};
808
809 fn p(n: usize, q: u64, coeffs: &[u64]) -> Polynomial {
810 Polynomial::new(n, q, coeffs).expect("polynomial should build")
811 }
812
813 fn lcg_next(state: &mut u64) -> u64 {
814 *state = state
815 .wrapping_mul(6_364_136_223_846_793_005)
816 .wrapping_add(1_442_695_040_888_963_407);
817 *state
818 }
819
820 fn enumerate_polynomials(
821 max_degree: usize,
822 modulus: u64,
823 values: &[u64],
824 used_len: usize,
825 ) -> Vec<Polynomial> {
826 assert!(
827 used_len <= max_degree + 1,
828 "used_len exceeds polynomial capacity"
829 );
830 assert!(!values.is_empty(), "values cannot be empty");
831
832 let total = values.len().pow(used_len as u32);
833 let mut out = Vec::with_capacity(total);
834 for mut state in 0..total {
835 let mut coeffs = vec![0_u64; used_len];
836 for slot in &mut coeffs {
837 let digit = state % values.len();
838 *slot = values[digit];
839 state /= values.len();
840 }
841 out.push(p(max_degree, modulus, &coeffs));
842 }
843 out
844 }
845
846 #[test]
847 fn constructor_normalizes_and_pads_coefficients() {
848 let poly = p(5, 7, &[10, 15, 6]);
849 assert_eq!(poly.coefficients(), &[3, 1, 6, 0, 0, 0]);
850 }
851
852 #[test]
853 fn constructor_rejects_invalid_modulus() {
854 let err = Polynomial::new(3, 1, &[1, 2]).expect_err("expected error");
855 assert_eq!(err, PolynomialError::InvalidModulus(1));
856 }
857
858 #[test]
859 fn constructor_rejects_degree_overflow() {
860 let err = Polynomial::new(2, 11, &[1, 2, 3, 4]).expect_err("expected error");
861 assert_eq!(
862 err,
863 PolynomialError::DegreeOverflow {
864 requested: 3,
865 max_degree: 2
866 }
867 );
868 }
869
870 #[test]
871 fn degree_and_zero_behaviour() {
872 let zero = p(4, 13, &[]);
873 assert!(zero.is_zero());
874 assert_eq!(zero.degree(), None);
875 assert_eq!(zero.trimmed_coefficients(), vec![0]);
876
877 let poly = p(4, 13, &[0, 2, 0, 9]);
878 assert!(!poly.is_zero());
879 assert_eq!(poly.degree(), Some(3));
880 assert_eq!(poly.trimmed_coefficients(), vec![0, 2, 0, 9]);
881 }
882
883 #[test]
884 fn addition_and_subtraction_work_modulo_q() {
885 let a = p(5, 17, &[16, 5, 0, 1]);
886 let b = p(5, 17, &[3, 13, 6, 10]);
887
888 let sum = a.add(&b).expect("add should work");
889 assert_eq!(sum.trimmed_coefficients(), vec![2, 1, 6, 11]);
890
891 let diff = a.sub(&b).expect("sub should work");
892 assert_eq!(diff.trimmed_coefficients(), vec![13, 9, 11, 8]);
893 }
894
895 #[test]
896 fn unary_negation_and_scalar_multiplication() {
897 let a = p(4, 19, &[0, 4, 18, 3]);
898 let neg = a.neg();
899 assert_eq!(neg.trimmed_coefficients(), vec![0, 15, 1, 16]);
900
901 let scaled = a.scalar_mul(7);
902 assert_eq!(scaled.trimmed_coefficients(), vec![0, 9, 12, 2]);
903 }
904
905 #[test]
906 fn incompatible_polynomials_fail_for_binary_ops() {
907 let a = p(3, 11, &[1, 2]);
908 let b = p(4, 11, &[1, 2]);
909 let err = a.add(&b).expect_err("expected incompatibility");
910 assert_eq!(err, PolynomialError::IncompatiblePolynomials);
911 }
912
913 #[test]
914 fn multiplication_checked_and_truncated() {
915 let a = p(5, 23, &[1, 2, 3]);
916 let b = p(5, 23, &[4, 5]);
917
918 let prod = a.mul(&b).expect("mul should work");
919 assert_eq!(prod.trimmed_coefficients(), vec![4, 13, 22, 15]);
920
921 let c = p(3, 29, &[1, 2, 3, 4]);
922 let d = p(3, 29, &[1, 1, 1, 1]);
923 let err = c.mul(&d).expect_err("degree should overflow");
924 assert_eq!(
925 err,
926 PolynomialError::ProductDegreeOverflow {
927 degree: 6,
928 max_degree: 3
929 }
930 );
931
932 let truncated = c.mul_truncated(&d).expect("mul_truncated should work");
933 assert_eq!(truncated.coefficients(), &[1, 3, 6, 10]);
934 }
935
936 #[test]
937 fn schoolbook_convolution_matches_expected_coefficients() {
938 let a = p(10, 998_244_353, &[1, 2, 3, 4]);
939 let b = p(10, 998_244_353, &[5, 6, 7]);
940 let product = a.mul(&b).expect("schoolbook multiplication should work");
941
942 assert_eq!(product.trimmed_coefficients(), vec![5, 16, 34, 52, 45, 28]);
944 }
945
946 #[test]
947 fn ntt_round_trip_recovers_original_vector() {
948 let modulus = 998_244_353_u64;
949 let primitive_root = 3_u64;
950 let n = 8_usize;
951
952 let omega = mod_pow(primitive_root, (modulus - 1) / n as u64, modulus);
953 let omega_inv = mod_inverse(omega, modulus).expect("omega must be invertible");
954 let n_inv = mod_inverse(n as u64, modulus).expect("length must be invertible");
955
956 let mut values = vec![7, 11, 19, 23, 31, 2, 5, 13];
957 let original = values.clone();
958
959 ntt_in_place(&mut values, omega, modulus).expect("forward NTT should work");
960 ntt_in_place(&mut values, omega_inv, modulus).expect("inverse NTT should work");
961 for value in &mut values {
962 *value = mod_mul(*value, n_inv, modulus);
963 }
964
965 assert_eq!(values, original);
966 }
967
968 #[test]
969 fn ntt_multiplication_matches_schoolbook_convolution() {
970 let modulus = 998_244_353_u64;
971 let primitive_root = 3_u64;
972
973 let a = p(31, modulus, &[4, 1, 9, 16, 25, 36, 49, 64, 81, 100]);
974 let b = p(31, modulus, &[3, 14, 15, 92, 65, 35, 89, 79]);
975
976 let schoolbook = a.mul(&b).expect("schoolbook multiplication should work");
977 let ntt = a
978 .mul_ntt(&b, primitive_root)
979 .expect("NTT multiplication should work");
980
981 assert_eq!(ntt.coefficients(), schoolbook.coefficients());
982 }
983
984 #[test]
985 fn ntt_matches_schoolbook_on_many_deterministic_cases() {
986 let modulus = 998_244_353_u64;
987 let primitive_root = 3_u64;
988 let mut seed = 0xC0FFEE_u64;
989
990 for _ in 0..200 {
991 let len_a = (lcg_next(&mut seed) % 48 + 1) as usize;
992 let len_b = (lcg_next(&mut seed) % 48 + 1) as usize;
993 let max_degree = len_a + len_b;
994
995 let mut coeffs_a = vec![0_u64; len_a];
996 let mut coeffs_b = vec![0_u64; len_b];
997 for coeff in &mut coeffs_a {
998 *coeff = lcg_next(&mut seed) % modulus;
999 }
1000 for coeff in &mut coeffs_b {
1001 *coeff = lcg_next(&mut seed) % modulus;
1002 }
1003
1004 let a = p(max_degree, modulus, &coeffs_a);
1005 let b = p(max_degree, modulus, &coeffs_b);
1006 let schoolbook = a.mul(&b).expect("schoolbook multiplication should work");
1007 let ntt = a
1008 .mul_ntt(&b, primitive_root)
1009 .expect("NTT multiplication should work");
1010
1011 assert_eq!(ntt.coefficients(), schoolbook.coefficients());
1012 }
1013 }
1014
1015 #[test]
1016 fn ntt_errors_are_reported_for_bad_parameters() {
1017 let coeffs = vec![1_u64; 11];
1019 let a = p(20, 17, &coeffs);
1020 let b = p(20, 17, &coeffs);
1021 let err = a
1022 .mul_ntt(&b, 3)
1023 .expect_err("unsupported NTT length should fail");
1024 assert_eq!(
1025 err,
1026 PolynomialError::NttLengthUnsupported {
1027 length: 32,
1028 modulus: 17
1029 }
1030 );
1031
1032 let c = p(16, 998_244_353, &[1, 2, 3, 4]);
1033 let d = p(16, 998_244_353, &[5, 6, 7, 8]);
1034 let err = c
1035 .mul_ntt(&d, 1)
1036 .expect_err("invalid primitive root should fail");
1037 assert_eq!(
1038 err,
1039 PolynomialError::InvalidPrimitiveRoot {
1040 primitive_root: 1,
1041 length: 8,
1042 modulus: 998_244_353
1043 }
1044 );
1045 }
1046
1047 #[test]
1048 fn remainder_mod_polynomial_reduces_degree() {
1049 let m = p(6, 17, &[1, 0, 0, 0, 1]);
1051 let a = p(6, 17, &[2, 0, 0, 0, 1, 3]); let reduced = a.rem_mod_poly(&m).expect("reduction should work");
1053
1054 assert_eq!(reduced.trimmed_coefficients(), vec![1, 14]);
1056 }
1057
1058 #[test]
1059 fn quotient_ring_add_sub_and_mul() {
1060 let m = p(4, 17, &[1, 0, 0, 0, 1]);
1062 let a = p(4, 17, &[0, 0, 0, 16]); let b = p(4, 17, &[0, 0, 0, 2]); let added = a.add_mod_poly(&b, &m).expect("add mod poly should work");
1066 assert_eq!(added.trimmed_coefficients(), vec![0, 0, 0, 1]);
1067
1068 let subbed = a.sub_mod_poly(&b, &m).expect("sub mod poly should work");
1069 assert_eq!(subbed.trimmed_coefficients(), vec![0, 0, 0, 14]);
1070
1071 let c = p(4, 17, &[1, 0, 0, 1]); let d = p(4, 17, &[1, 0, 0, 1]); let err = c.mul(&d).expect_err("plain mul overflows max_degree");
1074 assert_eq!(
1075 err,
1076 PolynomialError::ProductDegreeOverflow {
1077 degree: 6,
1078 max_degree: 4
1079 }
1080 );
1081
1082 let ring_product = c.mul_mod_poly(&d, &m).expect("mul mod poly should work");
1083 assert_eq!(ring_product.trimmed_coefficients(), vec![1, 0, 16, 2]);
1085 }
1086
1087 #[test]
1088 fn polynomial_modulus_errors_are_reported() {
1089 let a = p(4, 11, &[1, 2, 3]);
1090 let b = p(4, 11, &[3, 4]);
1091 let zero_poly = p(4, 11, &[]);
1092 let err = a
1093 .mul_mod_poly(&b, &zero_poly)
1094 .expect_err("zero modulus polynomial should fail");
1095 assert_eq!(err, PolynomialError::DivisionByZeroPolynomial);
1096
1097 let x = p(4, 8, &[0, 1]);
1099 let non_invertible_modulus = p(4, 8, &[1, 2]); let err = x
1101 .rem_mod_poly(&non_invertible_modulus)
1102 .expect_err("non-invertible modulus lead should fail");
1103 assert_eq!(
1104 err,
1105 PolynomialError::NonInvertibleCoefficient {
1106 coefficient: 2,
1107 modulus: 8
1108 }
1109 );
1110
1111 let c = p(5, 11, &[1, 2]);
1112 let err = a
1113 .add_mod_poly(&b, &c)
1114 .expect_err("incompatible polynomial settings should fail");
1115 assert_eq!(err, PolynomialError::IncompatiblePolynomials);
1116 }
1117
1118 #[test]
1119 fn evaluate_uses_horner_rule_modulo_q() {
1120 let poly = p(4, 31, &[7, 0, 3, 4]); let value = poly.evaluate(10);
1122 assert_eq!(value, 29);
1124 }
1125
1126 #[test]
1127 fn derivative_is_computed_modulo_q() {
1128 let poly = p(6, 11, &[3, 5, 7, 9, 2]); let deriv = poly.derivative(); assert_eq!(deriv.trimmed_coefficients(), vec![5, 3, 5, 8]);
1131 }
1132
1133 #[test]
1134 fn long_division_exact_case() {
1135 let dividend = p(5, 7, &[5, 6, 2, 1]);
1137 let divisor = p(5, 7, &[1, 1]);
1138 let (quotient, remainder) = dividend.div_rem(&divisor).expect("division should work");
1139
1140 assert_eq!(quotient.trimmed_coefficients(), vec![5, 1, 1]);
1141 assert!(remainder.is_zero());
1142 }
1143
1144 #[test]
1145 fn long_division_with_remainder() {
1146 let dividend = p(4, 5, &[3, 0, 1]);
1148 let divisor = p(4, 5, &[2, 1]);
1149 let (quotient, remainder) = dividend.div_rem(&divisor).expect("division should work");
1150
1151 assert_eq!(quotient.trimmed_coefficients(), vec![3, 1]);
1152 assert_eq!(remainder.trimmed_coefficients(), vec![2]);
1153
1154 let reconstructed = divisor
1156 .mul("ient)
1157 .expect("reconstruction product should fit")
1158 .add(&remainder)
1159 .expect("reconstruction sum should fit");
1160 assert_eq!(reconstructed.coefficients(), dividend.coefficients());
1161 }
1162
1163 #[test]
1164 fn division_errors_are_reported() {
1165 let a = p(4, 9, &[1, 2, 3]);
1166 let zero = p(4, 9, &[]);
1167 let err = a.div_rem(&zero).expect_err("division by zero should fail");
1168 assert_eq!(err, PolynomialError::DivisionByZeroPolynomial);
1169
1170 let dividend = p(4, 8, &[1, 0, 1]);
1172 let divisor = p(4, 8, &[0, 2]);
1173 let err = dividend
1174 .div_rem(&divisor)
1175 .expect_err("division should fail when inverse does not exist");
1176 assert_eq!(
1177 err,
1178 PolynomialError::NonInvertibleCoefficient {
1179 coefficient: 2,
1180 modulus: 8
1181 }
1182 );
1183 }
1184
1185 #[test]
1186 fn mod_q_algebraic_laws_hold_on_small_exhaustive_domain() {
1187 let modulus = 5_u64;
1188 let max_degree = 2_usize;
1189 let all = enumerate_polynomials(max_degree, modulus, &[0, 1, 2, 3, 4], 3);
1190 let reps = enumerate_polynomials(max_degree, modulus, &[0, 1, 4], 3);
1191 let zero = p(max_degree, modulus, &[]);
1192 let one = p(max_degree, modulus, &[1]);
1193
1194 for a in &all {
1195 assert_eq!(a.add(&zero).expect("a + 0 should work"), *a);
1196 assert_eq!(zero.add(a).expect("0 + a should work"), *a);
1197 assert_eq!(a.sub(&zero).expect("a - 0 should work"), *a);
1198 assert!(
1199 a.add(&a.neg()).expect("a + (-a) should work").is_zero(),
1200 "additive inverse should cancel"
1201 );
1202 assert_eq!(a.scalar_mul(modulus + 2), a.scalar_mul(2));
1203 assert_eq!(a.scalar_mul(0), zero);
1204 assert_eq!(
1205 a.mul_truncated(&one).expect("a * 1 should work"),
1206 *a,
1207 "multiplicative identity should hold"
1208 );
1209 }
1210
1211 for a in &all {
1212 for b in &all {
1213 assert_eq!(
1214 a.add(b).expect("a+b should work"),
1215 b.add(a).expect("b+a should work"),
1216 "addition should commute"
1217 );
1218 assert_eq!(
1219 a.mul_truncated(b).expect("a*b should work"),
1220 b.mul_truncated(a).expect("b*a should work"),
1221 "truncated multiplication should commute"
1222 );
1223 assert_eq!(
1224 a.sub(b).expect("a-b should work"),
1225 a.add(&b.neg()).expect("a+(-b) should work"),
1226 "subtraction should match addition with negation"
1227 );
1228 }
1229 }
1230
1231 for a in &reps {
1232 for b in &reps {
1233 for c in &reps {
1234 assert_eq!(
1235 a.add(&b.add(c).expect("b+c should work"))
1236 .expect("a+(b+c) should work"),
1237 a.add(b)
1238 .expect("a+b should work")
1239 .add(c)
1240 .expect("(a+b)+c should work"),
1241 "addition should associate"
1242 );
1243
1244 let lhs = a
1245 .mul_truncated(&b.add(c).expect("b+c should work"))
1246 .expect("a*(b+c) should work");
1247 let rhs = a
1248 .mul_truncated(b)
1249 .expect("a*b should work")
1250 .add(&a.mul_truncated(c).expect("a*c should work"))
1251 .expect("ab+ac should work");
1252 assert_eq!(lhs, rhs, "multiplication should distribute over addition");
1253 }
1254 }
1255 }
1256 }
1257
1258 #[test]
1259 fn evaluate_matches_naive_formula_on_exhaustive_small_domain() {
1260 let modulus = 11_u64;
1261 let max_degree = 3_usize;
1262 let polys = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 4);
1263
1264 for poly in &polys {
1265 for x in 0..modulus {
1266 let mut acc = 0_u64;
1267 let mut x_pow = 1_u64;
1268 for &coeff in poly.coefficients() {
1269 acc = (acc + (coeff * x_pow) % modulus) % modulus;
1270 x_pow = (x_pow * x) % modulus;
1271 }
1272 assert_eq!(poly.evaluate(x), acc);
1273 }
1274 }
1275 }
1276
1277 #[test]
1278 fn ntt_matches_schoolbook_exhaustive_small_domains() {
1279 let small_q_polys = enumerate_polynomials(2, 5, &[0, 1, 2, 3, 4], 2);
1281 for a in &small_q_polys {
1282 for b in &small_q_polys {
1283 let schoolbook = a.mul(b).expect("schoolbook should work");
1284 let ntt = a.mul_ntt(b, 2).expect("NTT should work");
1285 assert_eq!(ntt.coefficients(), schoolbook.coefficients());
1286 }
1287 }
1288
1289 let larger_polys = enumerate_polynomials(4, 17, &[0, 1, 2, 3], 3);
1291 for a in &larger_polys {
1292 for b in &larger_polys {
1293 let schoolbook = a.mul(b).expect("schoolbook should work");
1294 let ntt = a.mul_ntt(b, 3).expect("NTT should work");
1295 assert_eq!(ntt.coefficients(), schoolbook.coefficients());
1296 }
1297 }
1298 }
1299
1300 #[test]
1301 fn mod_polynomial_reduction_is_bounded_and_idempotent() {
1302 let modulus = 17_u64;
1303 let max_degree = 4_usize;
1304 let modulus_poly = p(max_degree, modulus, &[1, 0, 0, 0, 1]); let all = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 5);
1306 let sample: Vec<_> = all.iter().take(64).cloned().collect();
1307
1308 for poly in &all {
1309 let reduced = poly
1310 .rem_mod_poly(&modulus_poly)
1311 .expect("reduction should succeed");
1312 if let Some(degree) = reduced.degree() {
1313 assert!(
1314 degree < 4,
1315 "canonical representative degree must be < deg(modulus)"
1316 );
1317 }
1318 assert_eq!(
1319 reduced,
1320 reduced
1321 .rem_mod_poly(&modulus_poly)
1322 .expect("reduction idempotence should hold")
1323 );
1324 }
1325
1326 for a in &sample {
1327 for b in &sample {
1328 let lhs_add = a
1329 .add_mod_poly(b, &modulus_poly)
1330 .expect("add mod poly should work");
1331 let rhs_add = a
1332 .rem_mod_poly(&modulus_poly)
1333 .expect("reduction should work")
1334 .add_mod_poly(
1335 &b.rem_mod_poly(&modulus_poly)
1336 .expect("reduction should work"),
1337 &modulus_poly,
1338 )
1339 .expect("add mod poly should work");
1340 assert_eq!(lhs_add, rhs_add);
1341
1342 let lhs_mul = a
1343 .mul_mod_poly(b, &modulus_poly)
1344 .expect("mul mod poly should work");
1345 let rhs_mul = a
1346 .rem_mod_poly(&modulus_poly)
1347 .expect("reduction should work")
1348 .mul_mod_poly(
1349 &b.rem_mod_poly(&modulus_poly)
1350 .expect("reduction should work"),
1351 &modulus_poly,
1352 )
1353 .expect("mul mod poly should work");
1354 assert_eq!(lhs_mul, rhs_mul);
1355 }
1356 }
1357 }
1358
1359 #[test]
1360 fn long_division_invariants_hold_on_exhaustive_small_domain() {
1361 let modulus = 5_u64;
1362 let max_degree = 4_usize;
1363 let dividends = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 5);
1364 let divisors: Vec<_> = enumerate_polynomials(max_degree, modulus, &[0, 1, 2], 3)
1365 .into_iter()
1366 .filter(|poly| !poly.is_zero())
1367 .collect();
1368
1369 for dividend in ÷nds {
1370 for divisor in &divisors {
1371 let (quotient, remainder) =
1372 dividend.div_rem(divisor).expect("division should succeed");
1373
1374 let reconstructed = divisor
1375 .mul("ient)
1376 .expect("product should fit in max degree")
1377 .add(&remainder)
1378 .expect("sum should fit in max degree");
1379 assert_eq!(reconstructed.coefficients(), dividend.coefficients());
1380
1381 if let Some(rem_degree) = remainder.degree() {
1382 let divisor_degree = divisor.degree().expect("divisors are non-zero");
1383 assert!(
1384 rem_degree < divisor_degree,
1385 "remainder degree must be strictly less than divisor degree"
1386 );
1387 }
1388 }
1389 }
1390 }
1391}