Skip to main content

alkahest_cas/modular/
mod.rs

1//! V2-1 — Modular / CRT framework as a first-class primitive.
2//!
3//! Provides three core operations over sparse multivariate polynomials:
4//!
5//! - [`reduce_mod`] — reduce `f ∈ ℤ[x₁,…,xₙ]` to `F_p = ℤ/pℤ`
6//! - [`lift_crt`] — reconstruct `f` from modular images via Chinese Remainder Theorem
7//! - [`rational_reconstruction`] — recover `a/b` from `n ≡ b⁻¹·a (mod M)`
8//!
9//! Plus utilities used by higher-level algorithms (GCDs, factorization, Gröbner):
10//!
11//! - [`mignotte_bound`] — Cauchy–Mignotte coefficient bound
12//! - [`select_lucky_prime`] — choose a prime that doesn't collapse the leading coefficient
13
14use crate::errors::AlkahestError;
15use crate::kernel::ExprId;
16use crate::poly::MultiPoly;
17use rug::Integer;
18use std::collections::BTreeMap;
19
20// ---------------------------------------------------------------------------
21// MultiPolyFp — sparse multivariate polynomial over F_p = ℤ/pℤ
22// ---------------------------------------------------------------------------
23
24/// Sparse multivariate polynomial over the prime field `F_p = ℤ/pℤ`.
25///
26/// Coefficients are stored as `u64` in `[0, p)`.  The prime modulus is stored
27/// alongside the polynomial so that callers can check consistency before
28/// combining images with [`lift_crt`].
29#[derive(Clone, PartialEq, Eq, Debug)]
30pub struct MultiPolyFp {
31    /// Variable identifiers — same ordering as the originating [`MultiPoly`].
32    pub vars: Vec<ExprId>,
33    /// The prime modulus `p`.
34    pub modulus: u64,
35    /// Exponent vector → coefficient in `[0, p)`.  Zero terms are never stored.
36    pub terms: BTreeMap<Vec<u32>, u64>,
37}
38
39impl MultiPolyFp {
40    pub fn zero(vars: Vec<ExprId>, modulus: u64) -> Self {
41        MultiPolyFp {
42            vars,
43            modulus,
44            terms: BTreeMap::new(),
45        }
46    }
47
48    pub fn is_zero(&self) -> bool {
49        self.terms.is_empty()
50    }
51
52    pub fn total_degree(&self) -> u32 {
53        self.terms
54            .keys()
55            .map(|e| e.iter().sum::<u32>())
56            .max()
57            .unwrap_or(0)
58    }
59
60    pub fn compatible_with(&self, other: &Self) -> bool {
61        self.vars == other.vars && self.modulus == other.modulus
62    }
63}
64
65impl std::fmt::Display for MultiPolyFp {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        if self.is_zero() {
68            return write!(f, "0 (mod {})", self.modulus);
69        }
70        let mut first = true;
71        for (exp, coeff) in &self.terms {
72            if !first {
73                write!(f, " + ")?;
74            }
75            first = false;
76            write!(f, "{coeff}")?;
77            for (i, &e) in exp.iter().enumerate() {
78                if e == 0 {
79                    continue;
80                }
81                if e == 1 {
82                    write!(f, "*x{i}")?;
83                } else {
84                    write!(f, "*x{i}^{e}")?;
85                }
86            }
87        }
88        write!(f, " (mod {})", self.modulus)
89    }
90}
91
92// ---------------------------------------------------------------------------
93// ModularValue — a tagged element of ℤ/pℤ for derivation traces
94// ---------------------------------------------------------------------------
95
96/// A single element of `ℤ/pℤ`, tagged with its modulus.
97///
98/// Used as a tracer value to tag which modular image produced a given
99/// coefficient during GCD or resultant computation.
100#[derive(Clone, Debug, PartialEq, Eq)]
101pub struct ModularValue {
102    /// The residue, in `[0, modulus)`.
103    pub value: u64,
104    /// The prime modulus.
105    pub modulus: u64,
106}
107
108impl ModularValue {
109    pub fn new(value: u64, modulus: u64) -> Self {
110        debug_assert!(
111            value < modulus,
112            "ModularValue: value must be in [0, modulus)"
113        );
114        ModularValue { value, modulus }
115    }
116
117    pub fn zero(modulus: u64) -> Self {
118        ModularValue { value: 0, modulus }
119    }
120
121    pub fn one(modulus: u64) -> Self {
122        ModularValue {
123            value: if modulus > 1 { 1 } else { 0 },
124            modulus,
125        }
126    }
127
128    pub fn add(&self, other: &Self) -> Self {
129        debug_assert_eq!(
130            self.modulus, other.modulus,
131            "ModularValue: mismatched moduli"
132        );
133        let v = ((self.value as u128 + other.value as u128) % self.modulus as u128) as u64;
134        ModularValue::new(v, self.modulus)
135    }
136
137    pub fn sub(&self, other: &Self) -> Self {
138        debug_assert_eq!(
139            self.modulus, other.modulus,
140            "ModularValue: mismatched moduli"
141        );
142        let v = (self.value + self.modulus - other.value % self.modulus) % self.modulus;
143        ModularValue::new(v, self.modulus)
144    }
145
146    pub fn mul(&self, other: &Self) -> Self {
147        debug_assert_eq!(
148            self.modulus, other.modulus,
149            "ModularValue: mismatched moduli"
150        );
151        let v = ((self.value as u128 * other.value as u128) % self.modulus as u128) as u64;
152        ModularValue::new(v, self.modulus)
153    }
154
155    pub fn neg(&self) -> Self {
156        if self.value == 0 {
157            self.clone()
158        } else {
159            ModularValue::new(self.modulus - self.value, self.modulus)
160        }
161    }
162
163    /// Multiplicative inverse. Returns `None` if `self.value == 0`.
164    pub fn inverse(&self) -> Option<Self> {
165        if self.value == 0 {
166            return None;
167        }
168        Some(ModularValue::new(
169            mod_inverse_u64(self.value, self.modulus),
170            self.modulus,
171        ))
172    }
173}
174
175// ---------------------------------------------------------------------------
176// ModularError
177// ---------------------------------------------------------------------------
178
179/// Error type for modular arithmetic operations.
180#[derive(Debug, Clone, PartialEq)]
181pub enum ModularError {
182    /// The given modulus is not a prime ≥ 2.
183    InvalidModulus(u64),
184    /// The input polynomials have incompatible variable lists or moduli.
185    IncompatiblePolynomials,
186    /// CRT lifting requires at least one modular image.
187    EmptyImageList,
188    /// Rational reconstruction failed: no `a/b` with small norm exists.
189    ReconstructionFailed,
190}
191
192impl std::fmt::Display for ModularError {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            ModularError::InvalidModulus(p) => {
196                write!(f, "invalid modulus {p}: must be prime ≥ 2")
197            }
198            ModularError::IncompatiblePolynomials => {
199                write!(f, "polynomials have incompatible variable lists or moduli")
200            }
201            ModularError::EmptyImageList => {
202                write!(f, "CRT lifting requires at least one modular image")
203            }
204            ModularError::ReconstructionFailed => write!(
205                f,
206                "rational reconstruction failed: no a/b ≤ ⌊√(M/2)⌋ with a/b ≡ n (mod M)"
207            ),
208        }
209    }
210}
211
212impl std::error::Error for ModularError {}
213
214impl AlkahestError for ModularError {
215    fn code(&self) -> &'static str {
216        match self {
217            ModularError::InvalidModulus(_) => "E-MOD-001",
218            ModularError::IncompatiblePolynomials => "E-MOD-002",
219            ModularError::EmptyImageList => "E-MOD-003",
220            ModularError::ReconstructionFailed => "E-MOD-004",
221        }
222    }
223
224    fn remediation(&self) -> Option<&'static str> {
225        match self {
226            ModularError::InvalidModulus(_) => {
227                Some("use a prime modulus p ≥ 2, e.g. 101, 1009, 32749")
228            }
229            ModularError::IncompatiblePolynomials => {
230                Some("ensure all images share the same variable ordering and modulus")
231            }
232            ModularError::EmptyImageList => Some("provide at least one (MultiPolyFp, prime) pair"),
233            ModularError::ReconstructionFailed => {
234                Some("provide more modular images so the prime product M exceeds 2 * max_coeff²")
235            }
236        }
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Public API
242// ---------------------------------------------------------------------------
243
244/// Reduce a polynomial over ℤ to a polynomial over `F_p = ℤ/pℤ`.
245///
246/// Each coefficient `c` is mapped to the representative in `[0, p)`.
247/// Terms whose reduced coefficient is zero are dropped.
248///
249/// # Errors
250///
251/// Returns [`ModularError::InvalidModulus`] if `p` is not a prime ≥ 2.
252pub fn reduce_mod(poly: &MultiPoly, p: u64) -> Result<MultiPolyFp, ModularError> {
253    if !is_prime(p) {
254        return Err(ModularError::InvalidModulus(p));
255    }
256
257    let mut terms = BTreeMap::new();
258    for (exp, coeff) in &poly.terms {
259        let c_mod = rug_mod_u64(coeff, p);
260        if c_mod != 0 {
261            terms.insert(exp.clone(), c_mod);
262        }
263    }
264
265    Ok(MultiPolyFp {
266        vars: poly.vars.clone(),
267        modulus: p,
268        terms,
269    })
270}
271
272/// Reconstruct a polynomial over ℤ from modular images via the Chinese Remainder Theorem.
273///
274/// Given images `[(f mod p₁, p₁), …, (f mod pₖ, pₖ)]` with distinct primes `pᵢ`,
275/// returns the unique polynomial `f` with coefficients centered in `(-M/2, M/2]`
276/// where `M = p₁ · … · pₖ`.
277///
278/// All images must share the same variable list.  Terms absent from an image are
279/// treated as zero.
280///
281/// # Errors
282///
283/// - [`ModularError::EmptyImageList`] — no images provided.
284/// - [`ModularError::IncompatiblePolynomials`] — images have different variable lists.
285pub fn lift_crt(images: &[(MultiPolyFp, u64)]) -> Result<MultiPoly, ModularError> {
286    if images.is_empty() {
287        return Err(ModularError::EmptyImageList);
288    }
289
290    let vars = images[0].0.vars.clone();
291    for (img, _) in images {
292        if img.vars != vars {
293            return Err(ModularError::IncompatiblePolynomials);
294        }
295    }
296
297    // Collect every exponent vector that appears in any image.
298    let mut all_exps: std::collections::BTreeSet<Vec<u32>> = std::collections::BTreeSet::new();
299    for (img, _) in images {
300        for exp in img.terms.keys() {
301            all_exps.insert(exp.clone());
302        }
303    }
304
305    let mut terms: BTreeMap<Vec<u32>, Integer> = BTreeMap::new();
306
307    for exp in &all_exps {
308        let residues: Vec<(u64, u64)> = images
309            .iter()
310            .map(|(img, p)| (img.terms.get(exp).copied().unwrap_or(0), *p))
311            .collect();
312
313        let (combined, m) = crt_combine(&residues);
314        let centered = center_mod(&combined, &m);
315
316        if centered != 0 {
317            terms.insert(exp.clone(), centered);
318        }
319    }
320
321    Ok(MultiPoly { vars, terms })
322}
323
324/// Rational number reconstruction from a modular representative.
325///
326/// Given `n ∈ [0, M)` and modulus `M > 1`, finds the unique rational `a/b`
327/// (with `b > 0`, `gcd(|a|, b) = 1`) such that:
328///
329/// - `b · n ≡ a (mod M)`
330/// - `|a| ≤ T` and `b ≤ T`, where `T = ⌊√(M/2)⌋`
331///
332/// Returns `None` if no such rational exists (the prime product `M` is too
333/// small to uniquely determine the value).
334pub fn rational_reconstruction(n: &Integer, m: &Integer) -> Option<(Integer, Integer)> {
335    if *m <= 1 {
336        return None;
337    }
338
339    // Map n to [0, m)
340    let n_mod = {
341        let r = n.clone() % m.clone();
342        if r < 0 {
343            r + m
344        } else {
345            r
346        }
347    };
348
349    if n_mod == 0 {
350        return Some((Integer::from(0), Integer::from(1)));
351    }
352
353    // T = ⌊√(M/2)⌋
354    let half_m = m.clone() >> 1u32;
355    let t = half_m.sqrt();
356
357    // Extended Euclidean: r₋₁ = m, r₀ = n; s₋₁ = 0, s₀ = 1
358    let mut r_prev = m.clone();
359    let mut r_curr = n_mod;
360    let mut s_prev = Integer::from(0);
361    let mut s_curr = Integer::from(1);
362
363    while r_curr > t {
364        if r_curr == 0 {
365            return None;
366        }
367        let q = r_prev.clone() / r_curr.clone();
368        let r_next = r_prev.clone() - q.clone() * r_curr.clone();
369        let s_next = s_prev.clone() - q * s_curr.clone();
370        r_prev = r_curr;
371        r_curr = r_next;
372        s_prev = s_curr;
373        s_curr = s_next;
374    }
375
376    if r_curr == 0 {
377        return None;
378    }
379
380    let b_abs = s_curr.clone().abs();
381    if b_abs == 0 || b_abs > t {
382        return None;
383    }
384    if r_curr.clone().abs() > t {
385        return None;
386    }
387
388    // Normalise so the denominator is positive.
389    let (a, b) = if s_curr < 0 {
390        (-r_curr, -s_curr)
391    } else {
392        (r_curr, s_curr)
393    };
394
395    Some((a, b))
396}
397
398/// Compute a Cauchy–Mignotte coefficient bound for `poly`.
399///
400/// Returns `B = ‖f‖₁ · 2^d` where `‖f‖₁ = Σ|aᵢ|` is the L¹ norm and
401/// `d = total_degree(f)`.  For CRT reconstruction to succeed, the product of
402/// primes must exceed `2B`.
403pub fn mignotte_bound(poly: &MultiPoly) -> Integer {
404    if poly.is_zero() {
405        return Integer::from(1);
406    }
407
408    let l1: Integer = poly
409        .terms
410        .values()
411        .map(|c| Integer::from(c.abs_ref()))
412        .fold(Integer::from(0), |acc, x| acc + x);
413
414    let d = poly.total_degree();
415    let scale = Integer::from(1) << d;
416    l1 * scale
417}
418
419/// Select the smallest prime not in `used` that does not divide `avoid_divisor`.
420///
421/// Pass the integer content of the polynomial as `avoid_divisor` to skip primes
422/// that would cause leading-coefficient collapse (unlucky primes).  Pass
423/// `&Integer::from(0)` to apply no divisibility constraint.
424///
425/// # Panics
426///
427/// Panics if no suitable prime can be found below 1 000 000 (should never
428/// happen in practice).
429pub fn select_lucky_prime(avoid_divisor: &Integer, used: &[u64]) -> u64 {
430    let mut candidate = 2u64;
431    loop {
432        if is_prime(candidate) && !used.contains(&candidate) {
433            let lucky = if *avoid_divisor == 0 {
434                true
435            } else {
436                let p_int = Integer::from(candidate);
437                let rem = avoid_divisor.clone() % p_int.clone();
438                let rem = if rem < 0 { rem + p_int } else { rem };
439                rem != 0
440            };
441            if lucky {
442                return candidate;
443            }
444        }
445        candidate += 1;
446        if candidate > 1_000_000 {
447            panic!("select_lucky_prime: no suitable prime found below 1_000_000");
448        }
449    }
450}
451
452// ---------------------------------------------------------------------------
453// Internal helpers
454// ---------------------------------------------------------------------------
455
456/// Iterative CRT combination.
457///
458/// Returns `(a, M)` where `a ∈ [0, M)` is the CRT representative and
459/// `M = p₁ · … · pₖ`.
460fn crt_combine(pairs: &[(u64, u64)]) -> (Integer, Integer) {
461    if pairs.is_empty() {
462        return (Integer::from(0), Integer::from(1));
463    }
464
465    let (a0, p0) = pairs[0];
466    let mut a = Integer::from(a0); // invariant: a ∈ [0, M) throughout
467    let mut m = Integer::from(p0);
468
469    for &(ai, pi) in &pairs[1..] {
470        // a_new ≡ a (mod m) and a_new ≡ ai (mod pi)
471        // a_new = a + m * t, where t ≡ (ai − a) · m⁻¹ (mod pi)
472        let a_mod_pi = rug_mod_u64(&a, pi);
473        let diff = ((ai as u128 + pi as u128 - a_mod_pi as u128) % pi as u128) as u64;
474        let m_mod_pi = rug_mod_u64(&m, pi);
475        let m_inv = mod_inverse_u64(m_mod_pi, pi);
476        let t = ((diff as u128 * m_inv as u128) % pi as u128) as u64;
477        // a_new = a + m*t; since t < pi, a_new < m*pi = new_m  ✓
478        a += m.clone() * t;
479        m *= Integer::from(pi);
480    }
481
482    (a, m)
483}
484
485/// Center `a ∈ [0, M)` in the symmetric range `(-M/2, M/2]`.
486fn center_mod(a: &Integer, m: &Integer) -> Integer {
487    let half = m.clone() >> 1u32; // ⌊M/2⌋
488    if *a > half {
489        a.clone() - m
490    } else {
491        a.clone()
492    }
493}
494
495/// Reduce a `rug::Integer` to a `u64` representative in `[0, p)`.
496fn rug_mod_u64(a: &Integer, p: u64) -> u64 {
497    let p_big = Integer::from(p);
498    let r = a.clone() % p_big.clone();
499    let r = if r < 0 { r + p_big } else { r };
500    r.to_u64().expect("modular result fits in u64")
501}
502
503/// Extended-GCD modular inverse for `u64`.
504///
505/// Precondition: `gcd(a, m) = 1`.
506fn mod_inverse_u64(a: u64, m: u64) -> u64 {
507    if m == 1 {
508        return 0;
509    }
510    let mut old_r = a as i128;
511    let mut r = m as i128;
512    let mut old_s: i128 = 1;
513    let mut s: i128 = 0;
514
515    while r != 0 {
516        let q = old_r / r;
517        let tmp_r = r;
518        r = old_r - q * r;
519        old_r = tmp_r;
520        let tmp_s = s;
521        s = old_s - q * s;
522        old_s = tmp_s;
523    }
524
525    ((old_s % m as i128 + m as i128) % m as i128) as u64
526}
527
528/// Deterministic Miller–Rabin primality test.
529///
530/// Uses witnesses `{2, 3, 5, 7}` for `n < 3_215_031_751` and
531/// `{2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}` for larger values.
532/// Both sets are sufficient to decide primality for all 64-bit integers.
533pub fn is_prime(n: u64) -> bool {
534    match n {
535        0 | 1 => return false,
536        2 | 3 | 5 | 7 => return true,
537        _ if n % 2 == 0 || n % 3 == 0 || n % 5 == 0 => return false,
538        _ => {}
539    }
540
541    let mut d = n - 1;
542    let mut r = 0u32;
543    while d % 2 == 0 {
544        d >>= 1;
545        r += 1;
546    }
547
548    let witnesses: &[u64] = if n < 3_215_031_751 {
549        &[2, 3, 5, 7]
550    } else {
551        &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]
552    };
553
554    'outer: for &a in witnesses {
555        if a >= n {
556            continue;
557        }
558        let mut x = pow_mod(a, d, n);
559        if x == 1 || x == n - 1 {
560            continue;
561        }
562        for _ in 0..r - 1 {
563            x = mul_mod(x, x, n);
564            if x == n - 1 {
565                continue 'outer;
566            }
567        }
568        return false;
569    }
570    true
571}
572
573fn pow_mod(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
574    let mut result = 1u64;
575    base %= modulus;
576    while exp > 0 {
577        if exp & 1 == 1 {
578            result = mul_mod(result, base, modulus);
579        }
580        base = mul_mod(base, base, modulus);
581        exp >>= 1;
582    }
583    result
584}
585
586#[inline]
587fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
588    ((a as u128 * b as u128) % m as u128) as u64
589}
590
591// ---------------------------------------------------------------------------
592// Unit tests
593// ---------------------------------------------------------------------------
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use crate::kernel::{Domain, ExprPool};
599
600    fn pool_xy() -> (ExprPool, ExprId, ExprId) {
601        let p = ExprPool::new();
602        let x = p.symbol("x", Domain::Real);
603        let y = p.symbol("y", Domain::Real);
604        (p, x, y)
605    }
606
607    // --- is_prime ---
608
609    #[test]
610    fn prime_small() {
611        for &(n, exp) in &[
612            (0u64, false),
613            (1, false),
614            (2, true),
615            (3, true),
616            (4, false),
617            (5, true),
618            (9, false),
619            (97, true),
620            (100, false),
621            (101, true),
622        ] {
623            assert_eq!(is_prime(n), exp, "is_prime({n})");
624        }
625    }
626
627    #[test]
628    fn prime_large() {
629        assert!(is_prime(999_983));
630        assert!(!is_prime(1_000_000));
631        assert!(is_prime(1_000_003));
632        // Large Mersenne prime M31
633        assert!(is_prime(2_147_483_647));
634    }
635
636    // --- mod_inverse_u64 ---
637
638    #[test]
639    fn mod_inverse_basic() {
640        assert_eq!(mod_inverse_u64(3, 7), 5); // 3·5 = 15 ≡ 1 (mod 7)
641        assert_eq!(mod_inverse_u64(2, 101), 51); // 2·51 = 102 ≡ 1 (mod 101)
642        assert_eq!(mod_inverse_u64(1, 7), 1);
643    }
644
645    // --- reduce_mod ---
646
647    #[test]
648    fn reduce_mod_basic() {
649        let (pool, x, y) = pool_xy();
650        // 6x + 4 → mod 5 → x + 4
651        let expr = pool.add(vec![
652            pool.mul(vec![pool.integer(6_i32), x]),
653            pool.integer(4_i32),
654        ]);
655        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
656        let fp = reduce_mod(&poly, 5).unwrap();
657        assert_eq!(fp.modulus, 5);
658        assert_eq!(*fp.terms.get(&vec![1]).unwrap(), 1u64); // 6 mod 5 = 1
659        assert_eq!(*fp.terms.get(&vec![]).unwrap(), 4u64); // 4 mod 5 = 4
660    }
661
662    #[test]
663    fn reduce_mod_negative_coeff() {
664        let (pool, x, y) = pool_xy();
665        // -3x → mod 7 → 4x
666        let expr = pool.mul(vec![pool.integer(-3_i32), x]);
667        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
668        let fp = reduce_mod(&poly, 7).unwrap();
669        assert_eq!(*fp.terms.get(&vec![1]).unwrap(), 4u64); // -3 mod 7 = 4
670    }
671
672    #[test]
673    fn reduce_mod_vanishing_term() {
674        let (pool, x, y) = pool_xy();
675        // 5x + 7 → mod 5 → 2 (x term vanishes)
676        let expr = pool.add(vec![
677            pool.mul(vec![pool.integer(5_i32), x]),
678            pool.integer(7_i32),
679        ]);
680        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
681        let fp = reduce_mod(&poly, 5).unwrap();
682        assert!(!fp.terms.contains_key(&vec![1]));
683        assert_eq!(*fp.terms.get(&vec![]).unwrap(), 2u64);
684    }
685
686    #[test]
687    fn reduce_mod_invalid() {
688        let (pool, x, y) = pool_xy();
689        let poly = MultiPoly::from_symbolic(x, vec![x, y], &pool).unwrap();
690        for bad in [0, 1, 4, 6, 9] {
691            assert!(
692                matches!(reduce_mod(&poly, bad), Err(ModularError::InvalidModulus(_))),
693                "expected InvalidModulus for {bad}"
694            );
695        }
696    }
697
698    // --- crt_combine ---
699
700    #[test]
701    fn crt_combine_single() {
702        let (a, m) = crt_combine(&[(3, 5)]);
703        assert_eq!(a, Integer::from(3));
704        assert_eq!(m, Integer::from(5));
705    }
706
707    #[test]
708    fn crt_combine_two() {
709        // x ≡ 2 (mod 3), x ≡ 3 (mod 5) → x ≡ 8 (mod 15)
710        let (a, m) = crt_combine(&[(2, 3), (3, 5)]);
711        assert_eq!(m, Integer::from(15));
712        assert_eq!(a, Integer::from(8));
713        assert_eq!(8u64 % 3, 2);
714        assert_eq!(8u64 % 5, 3);
715    }
716
717    #[test]
718    fn crt_combine_three() {
719        // x ≡ 1 (mod 2), x ≡ 2 (mod 3), x ≡ 3 (mod 5) → x ≡ 23 (mod 30)
720        let (a, m) = crt_combine(&[(1, 2), (2, 3), (3, 5)]);
721        assert_eq!(m, Integer::from(30));
722        assert_eq!(a, Integer::from(23));
723        assert_eq!(23u64 % 2, 1);
724        assert_eq!(23u64 % 3, 2);
725        assert_eq!(23u64 % 5, 3);
726    }
727
728    // --- lift_crt ---
729
730    #[test]
731    fn lift_crt_roundtrip_positive() {
732        let (pool, x, y) = pool_xy();
733        // f = 3x² + 2x + 1
734        let x2 = pool.pow(x, pool.integer(2_i32));
735        let expr = pool.add(vec![
736            pool.mul(vec![pool.integer(3_i32), x2]),
737            pool.mul(vec![pool.integer(2_i32), x]),
738            pool.integer(1_i32),
739        ]);
740        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
741
742        let p1 = 101u64;
743        let p2 = 103u64;
744        let fp1 = reduce_mod(&poly, p1).unwrap();
745        let fp2 = reduce_mod(&poly, p2).unwrap();
746        let lifted = lift_crt(&[(fp1, p1), (fp2, p2)]).unwrap();
747        assert_eq!(lifted, poly);
748    }
749
750    #[test]
751    fn lift_crt_negative_coeff() {
752        let (pool, x, y) = pool_xy();
753        // f = x - 50; coefficients in (-50, 50] → need M > 100
754        let expr = pool.add(vec![x, pool.integer(-50_i32)]);
755        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
756
757        let p1 = 101u64;
758        let p2 = 103u64; // M = 101 * 103 = 10403 > 100
759        let lifted = lift_crt(&[
760            (reduce_mod(&poly, p1).unwrap(), p1),
761            (reduce_mod(&poly, p2).unwrap(), p2),
762        ])
763        .unwrap();
764        assert_eq!(lifted, poly);
765    }
766
767    #[test]
768    fn lift_crt_bivariate() {
769        let (pool, x, y) = pool_xy();
770        // f = x*y + 3
771        let expr = pool.add(vec![pool.mul(vec![x, y]), pool.integer(3_i32)]);
772        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
773
774        let p = 7u64;
775        let q = 11u64;
776        let lifted = lift_crt(&[
777            (reduce_mod(&poly, p).unwrap(), p),
778            (reduce_mod(&poly, q).unwrap(), q),
779        ])
780        .unwrap();
781        assert_eq!(lifted, poly);
782    }
783
784    #[test]
785    fn lift_crt_empty_error() {
786        assert!(matches!(lift_crt(&[]), Err(ModularError::EmptyImageList)));
787    }
788
789    // --- rational_reconstruction ---
790
791    #[test]
792    fn rat_recon_one_half() {
793        // 1/2 mod 101: 2⁻¹ ≡ 51 (mod 101), so n=51
794        let result = rational_reconstruction(&Integer::from(51), &Integer::from(101));
795        assert!(result.is_some());
796        let (a, b) = result.unwrap();
797        assert_eq!(a, Integer::from(1));
798        assert_eq!(b, Integer::from(2));
799    }
800
801    #[test]
802    fn rat_recon_negative_numerator() {
803        // -1/2 mod 101: -1 * 51 = -51 ≡ 50 (mod 101)
804        let result = rational_reconstruction(&Integer::from(50), &Integer::from(101));
805        assert!(result.is_some());
806        let (a, b) = result.unwrap();
807        assert_eq!(a, Integer::from(-1));
808        assert_eq!(b, Integer::from(2));
809    }
810
811    #[test]
812    fn rat_recon_zero() {
813        let result = rational_reconstruction(&Integer::from(0), &Integer::from(101));
814        assert!(result.is_some());
815        let (a, b) = result.unwrap();
816        assert_eq!(a, Integer::from(0));
817        assert_eq!(b, Integer::from(1));
818    }
819
820    #[test]
821    fn rat_recon_integer() {
822        // n = 5, m = 101: T = 7, 5 ≤ 7, so this is just the integer 5
823        let result = rational_reconstruction(&Integer::from(5), &Integer::from(101));
824        assert!(result.is_some());
825        let (a, b) = result.unwrap();
826        assert_eq!(b, Integer::from(1));
827        assert_eq!(a, Integer::from(5));
828    }
829
830    #[test]
831    fn rat_recon_m_too_small() {
832        // n=2, M=7: T=⌊√3⌋=1; integer 2 can't be reconstructed since |2| > T=1
833        // and no other a/b with |a|≤1 and b≤1 satisfies a/b ≡ 2 (mod 7).
834        let result = rational_reconstruction(&Integer::from(2), &Integer::from(7));
835        assert!(result.is_none());
836    }
837
838    // --- mignotte_bound ---
839
840    #[test]
841    fn mignotte_constant() {
842        let (pool, x, y) = pool_xy();
843        let poly = MultiPoly::from_symbolic(pool.integer(5_i32), vec![x, y], &pool).unwrap();
844        // L1=5, d=0 → B=5
845        assert_eq!(mignotte_bound(&poly), Integer::from(5));
846    }
847
848    #[test]
849    fn mignotte_linear() {
850        let (pool, x, y) = pool_xy();
851        // 3x + 2: L1=5, d=1 → B=10
852        let expr = pool.add(vec![
853            pool.mul(vec![pool.integer(3_i32), x]),
854            pool.integer(2_i32),
855        ]);
856        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
857        assert_eq!(mignotte_bound(&poly), Integer::from(10));
858    }
859
860    #[test]
861    fn mignotte_zero_poly() {
862        let (_, x, y) = pool_xy();
863        let z = MultiPoly::zero(vec![x, y]);
864        assert_eq!(mignotte_bound(&z), Integer::from(1));
865    }
866
867    // --- select_lucky_prime ---
868
869    #[test]
870    fn lucky_prime_no_constraint() {
871        let p = select_lucky_prime(&Integer::from(0), &[]);
872        assert!(is_prime(p));
873        assert_eq!(p, 2);
874    }
875
876    #[test]
877    fn lucky_prime_avoids_divisors() {
878        // avoid_divisor=6=2×3; lucky prime must not divide 6
879        let p = select_lucky_prime(&Integer::from(6), &[]);
880        assert!(is_prime(p));
881        assert_ne!(6 % p, 0);
882        assert_eq!(p, 5); // first prime not dividing 6
883    }
884
885    #[test]
886    fn lucky_prime_skips_used() {
887        let p = select_lucky_prime(&Integer::from(0), &[2, 3, 5]);
888        assert_eq!(p, 7);
889    }
890
891    #[test]
892    fn lucky_prime_combined() {
893        // avoid_divisor=30=2×3×5; skip 2, 3, 5, 7 as used
894        let p = select_lucky_prime(&Integer::from(30), &[7]);
895        assert!(is_prime(p));
896        assert_ne!(30 % p, 0);
897        assert_ne!(p, 7);
898    }
899
900    // --- ModularValue ---
901
902    #[test]
903    fn modular_value_add() {
904        let a = ModularValue::new(3, 7);
905        let b = ModularValue::new(5, 7);
906        assert_eq!(a.add(&b), ModularValue::new(1, 7)); // (3+5) mod 7 = 1
907    }
908
909    #[test]
910    fn modular_value_sub() {
911        let a = ModularValue::new(3, 7);
912        let b = ModularValue::new(5, 7);
913        assert_eq!(a.sub(&b), ModularValue::new(5, 7)); // (3-5) mod 7 = -2 ≡ 5
914    }
915
916    #[test]
917    fn modular_value_mul() {
918        let a = ModularValue::new(3, 7);
919        let b = ModularValue::new(5, 7);
920        assert_eq!(a.mul(&b), ModularValue::new(1, 7)); // 15 mod 7 = 1
921    }
922
923    #[test]
924    fn modular_value_neg() {
925        assert_eq!(ModularValue::new(3, 7).neg(), ModularValue::new(4, 7));
926        assert_eq!(ModularValue::new(0, 7).neg(), ModularValue::new(0, 7));
927    }
928
929    #[test]
930    fn modular_value_inverse() {
931        // 3⁻¹ ≡ 5 (mod 7): 3·5 = 15 ≡ 1 (mod 7)
932        assert_eq!(
933            ModularValue::new(3, 7).inverse().unwrap(),
934            ModularValue::new(5, 7)
935        );
936        assert!(ModularValue::new(0, 7).inverse().is_none());
937    }
938
939    // --- error codes ---
940
941    #[test]
942    fn error_codes() {
943        assert_eq!(ModularError::InvalidModulus(4).code(), "E-MOD-001");
944        assert_eq!(ModularError::IncompatiblePolynomials.code(), "E-MOD-002");
945        assert_eq!(ModularError::EmptyImageList.code(), "E-MOD-003");
946        assert_eq!(ModularError::ReconstructionFailed.code(), "E-MOD-004");
947    }
948}