Skip to main content

alkahest_cas/poly/
interp.rs

1//! V2-3 — Sparse polynomial interpolation (Ben-Or/Tiwari, Zippel).
2//!
3//! Recovers a sparse multivariate polynomial over `F_p = ℤ/pℤ` from
4//! black-box evaluations using far fewer queries than dense interpolation.
5//!
6//! # Algorithms
7//!
8//! - **Univariate Ben-Or/Tiwari (Prony-style)** — [`sparse_interpolate_univariate`]:
9//!   given that `f ∈ F_p[x]` has at most `T` nonzero terms, recovers `f` from
10//!   exactly `2T` evaluations via Berlekamp–Massey +
11//!   `gcd(f, X^p − X)` + Cantor–Zassenhaus-style splitting over `F_p`
12//!   (tiny-degree fallback scans only).
13//!   + Vandermonde solve.  Cost: `2T` oracle calls.
14//!
15//! - **Multivariate Zippel** — [`sparse_interpolate`]: variable-by-variable
16//!   reduction.  At each variable level:
17//!     1. Evaluate `f(x₁, a₂, …, aₙ)` at random `aᵢ` and run Ben-Or/Tiwari
18//!        to find the `x₁`-exponent skeleton.
19//!     2. Lift all sibling coefficients simultaneously with one Vandermonde solve
20//!        per oracle call (`zippel_helper_multi`) when the stacked vector stays
21//!        small (`O(term_bound²)` budget); otherwise recurse per skeleton term like
22//!        classic Zippel.
23//!
24//! - **Dense fallback** — applied when `degree_bound ≤ term_bound` (dense
25//!   and sparse costs coincide).  Uses Lagrange interpolation at consecutive
26//!   integers.
27//!
28//! # Public API
29//!
30//! ```text
31//! sparse_interpolate_univariate(eval, term_bound, prime) → Vec<(coeff, exp)>
32//! sparse_interpolate(eval, vars, term_bound, degree_bound, prime, seed)
33//!     → MultiPolyFp
34//! ```
35
36use crate::errors::AlkahestError;
37use crate::kernel::ExprId;
38use crate::modular::{is_prime, MultiPolyFp};
39use std::collections::BTreeMap;
40
41// ---------------------------------------------------------------------------
42// Error type
43// ---------------------------------------------------------------------------
44
45/// Error returned by sparse interpolation functions.
46#[derive(Debug, Clone, PartialEq)]
47pub enum SparseInterpError {
48    /// The prime is ≤ 2 or composite.
49    InvalidPrime(u64),
50    /// The prime must be `> 2 * term_bound` for Ben-Or/Tiwari to work.
51    PrimeTooSmall { prime: u64, term_bound: usize },
52    /// Root-finding found fewer roots than expected (should not happen for
53    /// correct evaluations; indicates either colliding exponents or an
54    /// inconsistent evaluation oracle).
55    RootFindingFailed,
56    /// The Vandermonde / linear system is singular.  Retry with a different
57    /// seed or a larger prime.
58    SingularSystem,
59}
60
61impl std::fmt::Display for SparseInterpError {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            SparseInterpError::InvalidPrime(p) => {
65                write!(f, "invalid prime {p}: must be a prime ≥ 3")
66            }
67            SparseInterpError::PrimeTooSmall { prime, term_bound } => write!(
68                f,
69                "prime {prime} is too small for term_bound {term_bound}: need prime > 2·T = {}",
70                2 * term_bound
71            ),
72            SparseInterpError::RootFindingFailed => write!(
73                f,
74                "could not find the expected number of roots in F_p; \
75                 the prime may be too small or the oracle is inconsistent"
76            ),
77            SparseInterpError::SingularSystem => write!(
78                f,
79                "Vandermonde system is singular; try a different seed or a larger prime"
80            ),
81        }
82    }
83}
84
85impl std::error::Error for SparseInterpError {}
86
87impl AlkahestError for SparseInterpError {
88    fn code(&self) -> &'static str {
89        match self {
90            SparseInterpError::InvalidPrime(_) => "E-INTERP-001",
91            SparseInterpError::PrimeTooSmall { .. } => "E-INTERP-002",
92            SparseInterpError::RootFindingFailed => "E-INTERP-003",
93            SparseInterpError::SingularSystem => "E-INTERP-004",
94        }
95    }
96
97    fn remediation(&self) -> Option<&'static str> {
98        match self {
99            SparseInterpError::InvalidPrime(_) => {
100                Some("choose a prime p ≥ 3, e.g. 1009, 32749, 1000003")
101            }
102            SparseInterpError::PrimeTooSmall { .. } => {
103                Some("increase the prime so that p > 2 * term_bound")
104            }
105            SparseInterpError::RootFindingFailed => {
106                Some("choose a prime larger than the maximum degree in the polynomial")
107            }
108            SparseInterpError::SingularSystem => {
109                Some("retry with a different seed or use a larger prime")
110            }
111        }
112    }
113}
114
115// ---------------------------------------------------------------------------
116// Minimal PRNG (xorshift64) — no external crate needed
117// ---------------------------------------------------------------------------
118
119/// Simple xorshift64 PRNG for reproducible random evaluation points.
120pub struct Xorshift64 {
121    state: u64,
122}
123
124impl Xorshift64 {
125    pub fn new(seed: u64) -> Self {
126        // Ensure non-zero state.
127        let s = if seed == 0 { 0xdeadbeef_cafebabe } else { seed };
128        Xorshift64 { state: s }
129    }
130
131    pub fn step(&mut self) -> u64 {
132        let mut x = self.state;
133        x ^= x << 13;
134        x ^= x >> 7;
135        x ^= x << 17;
136        self.state = x;
137        x
138    }
139
140    /// Return a value in `[lo, hi)`.
141    pub fn next_range(&mut self, lo: u64, hi: u64) -> u64 {
142        debug_assert!(hi > lo);
143        lo + self.step() % (hi - lo)
144    }
145
146    /// Return a non-zero value in `[1, p)`.
147    pub fn nonzero(&mut self, p: u64) -> u64 {
148        loop {
149            let v = self.step() % p;
150            if v != 0 {
151                return v;
152            }
153        }
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Modular arithmetic helpers
159// ---------------------------------------------------------------------------
160
161#[inline]
162fn mul_mod(a: u64, b: u64, p: u64) -> u64 {
163    ((a as u128 * b as u128) % p as u128) as u64
164}
165
166#[inline]
167fn add_mod(a: u64, b: u64, p: u64) -> u64 {
168    let s = a + b;
169    if s >= p {
170        s - p
171    } else {
172        s
173    }
174}
175
176#[inline]
177fn sub_mod(a: u64, b: u64, p: u64) -> u64 {
178    if a >= b {
179        a - b
180    } else {
181        a + p - b
182    }
183}
184
185fn pow_mod(mut base: u64, mut exp: u64, p: u64) -> u64 {
186    let mut result = 1u64;
187    base %= p;
188    while exp > 0 {
189        if exp & 1 == 1 {
190            result = mul_mod(result, base, p);
191        }
192        base = mul_mod(base, base, p);
193        exp >>= 1;
194    }
195    result
196}
197
198/// Extended-GCD modular inverse.  Panics if `gcd(a, p) ≠ 1`.
199fn mod_inv(a: u64, p: u64) -> u64 {
200    debug_assert!(a != 0, "mod_inv: a must be non-zero");
201    let mut old_r = a as i128;
202    let mut r = p as i128;
203    let mut old_s: i128 = 1;
204    let mut s: i128 = 0;
205    while r != 0 {
206        let q = old_r / r;
207        let tmp = r;
208        r = old_r - q * r;
209        old_r = tmp;
210        let tmp = s;
211        s = old_s - q * s;
212        old_s = tmp;
213    }
214    ((old_s % p as i128 + p as i128) % p as i128) as u64
215}
216
217// ---------------------------------------------------------------------------
218// Polynomial evaluation over F_p
219// ---------------------------------------------------------------------------
220
221/// Evaluate `poly[0] + poly[1]*x + ... + poly[d]*x^d` at `x` modulo `p`.
222fn poly_eval(poly: &[u64], x: u64, p: u64) -> u64 {
223    let mut acc = 0u64;
224    let mut pw = 1u64;
225    for &c in poly {
226        acc = add_mod(acc, mul_mod(c, pw, p), p);
227        pw = mul_mod(pw, x, p);
228    }
229    acc
230}
231
232// ---------------------------------------------------------------------------
233// Primitive root of F_p
234// ---------------------------------------------------------------------------
235
236/// Find the smallest primitive root (generator) of `F_p*`.
237///
238/// A primitive root `g` satisfies `g^{(p-1)/q} ≢ 1 (mod p)` for every
239/// prime factor `q` of `p-1`.
240pub fn primitive_root(p: u64) -> u64 {
241    debug_assert!(is_prime(p), "primitive_root: p must be prime");
242    if p == 2 {
243        return 1;
244    }
245    if p == 3 {
246        return 2;
247    }
248    let factors = prime_factors(p - 1);
249    'outer: for g in 2..p {
250        for &q in &factors {
251            if pow_mod(g, (p - 1) / q, p) == 1 {
252                continue 'outer;
253            }
254        }
255        return g;
256    }
257    panic!("primitive_root: no root found for prime {p}");
258}
259
260/// Distinct prime factors of `n` (trial division).
261fn prime_factors(mut n: u64) -> Vec<u64> {
262    let mut factors = Vec::new();
263    let mut d = 2u64;
264    while d * d <= n {
265        if n % d == 0 {
266            factors.push(d);
267            while n % d == 0 {
268                n /= d;
269            }
270        }
271        d += 1;
272    }
273    if n > 1 {
274        factors.push(n);
275    }
276    factors
277}
278
279// ---------------------------------------------------------------------------
280// Berlekamp–Massey over F_p
281// ---------------------------------------------------------------------------
282
283/// Berlekamp–Massey algorithm over `F_p`.
284///
285/// Given a sequence `s[0], …, s[N-1]`, returns the minimal LFSR connection
286/// polynomial `Λ = [1, λ₁, …, λ_L]` (index = degree) such that
287///
288/// ```text
289/// s[n] + λ₁·s[n-1] + … + λ_L·s[n-L] = 0   for all n ≥ L.
290/// ```
291///
292/// The caller must supply `N ≥ 2L` for the result to be unique.
293fn berlekamp_massey(seq: &[u64], p: u64) -> Vec<u64> {
294    let n = seq.len();
295    let mut l = 0usize;
296    let mut c: Vec<u64> = vec![1];
297    let mut b: Vec<u64> = vec![1];
298    let mut b_disc: u64 = 1;
299    let mut x: usize = 1;
300
301    for n_idx in 0..n {
302        // Discrepancy d = s[n] + Σ_{i=1}^{L} c[i]·s[n-i]
303        let mut d = seq[n_idx];
304        let bound = l.min(c.len().saturating_sub(1));
305        for i in 1..=bound {
306            d = add_mod(d, mul_mod(c[i], seq[n_idx - i], p), p);
307        }
308
309        if d == 0 {
310            x += 1;
311            continue;
312        }
313
314        let t = c.clone();
315        let factor = mul_mod(d, mod_inv(b_disc, p), p);
316
317        // C ← C − factor·z^x·B
318        let needed = x + b.len();
319        if c.len() < needed {
320            c.resize(needed, 0);
321        }
322        for j in 0..b.len() {
323            let sub = mul_mod(factor, b[j], p);
324            c[x + j] = sub_mod(c[x + j], sub, p);
325        }
326
327        if 2 * l <= n_idx {
328            l = n_idx + 1 - l;
329            b = t;
330            b_disc = d;
331            x = 1;
332        } else {
333            x += 1;
334        }
335    }
336
337    c
338}
339
340// ---------------------------------------------------------------------------
341// Dense polynomials mod p (Cantor–Zassenhaus / probabilistic splitting)
342// ---------------------------------------------------------------------------
343
344fn poly_trim(mut a: Vec<u64>) -> Vec<u64> {
345    while a.len() > 1 && a.last() == Some(&0) {
346        a.pop();
347    }
348    a
349}
350
351#[inline]
352fn poly_deg(poly: &[u64]) -> i32 {
353    let t = poly_trim(poly.to_vec());
354    if t.is_empty() || (t.len() == 1 && t[0] == 0) {
355        return -1;
356    }
357    t.len() as i32 - 1
358}
359
360/// `a + b` in ascending order (may over-allocate briefly).
361fn poly_add(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
362    let n = a.len().max(b.len());
363    let mut out = vec![0u64; n];
364    for i in 0..n {
365        let x = if i < a.len() { a[i] } else { 0 };
366        let y = if i < b.len() { b[i] } else { 0 };
367        out[i] = add_mod(x, y, p);
368    }
369    poly_trim(out)
370}
371
372fn poly_sub_(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
373    let n = a.len().max(b.len());
374    let mut out = vec![0u64; n];
375    for i in 0..n {
376        let x = if i < a.len() { a[i] } else { 0 };
377        let y = if i < b.len() { b[i] } else { 0 };
378        out[i] = sub_mod(x, y, p);
379    }
380    poly_trim(out)
381}
382
383fn poly_mul(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
384    if a.is_empty() || b.is_empty() || (a.len() == 1 && a[0] == 0) || (b.len() == 1 && b[0] == 0) {
385        return vec![0];
386    }
387    let da = poly_deg(a);
388    let db = poly_deg(b);
389    if da < 0 || db < 0 {
390        return vec![0];
391    }
392    let mut out = vec![0u64; (da + db + 1) as usize];
393    for i in 0..=da as usize {
394        for j in 0..=db as usize {
395            out[i + j] = add_mod(out[i + j], mul_mod(a[i], b[j], p), p);
396        }
397    }
398    poly_trim(out)
399}
400
401/// Euclidean division over `F_p`; returns `(q, r)` with `a = q·b + r`, `deg r < deg b`.
402fn poly_divmod(dividend: &[u64], divisor: &[u64], p: u64) -> Option<(Vec<u64>, Vec<u64>)> {
403    let mut a = poly_trim(dividend.to_vec());
404    let b = poly_trim(divisor.to_vec());
405    if poly_deg(&b) < 0 {
406        return None;
407    }
408    let db = b.len() - 1;
409    let lb = *b.last().unwrap();
410    let inv_lb = mod_inv(lb, p);
411
412    let deg_a = poly_deg(&a);
413    if deg_a < db as i32 {
414        return Some((vec![0], a));
415    }
416
417    let q_len = (deg_a - db as i32 + 1) as usize;
418    let mut quot = vec![0u64; q_len];
419
420    while poly_deg(&a) >= db as i32 {
421        let da = poly_deg(&a) as usize;
422        let shift = da - db;
423        let scale = mul_mod(*a.last().unwrap(), inv_lb, p);
424        quot[shift] = add_mod(quot[shift], scale, p);
425        for j in 0..b.len() {
426            a[j + shift] = sub_mod(a[j + shift], mul_mod(scale, b[j], p), p);
427        }
428        a = poly_trim(a);
429    }
430
431    Some((poly_trim(quot), a))
432}
433
434fn polygcd(a_: &[u64], b_: &[u64], p: u64) -> Vec<u64> {
435    let mut a = poly_trim(a_.to_vec());
436    let mut b = poly_trim(b_.to_vec());
437    while poly_deg(&b) >= 0 {
438        let (_, r) = match poly_divmod(&a, &b, p) {
439            Some(x) => x,
440            None => break,
441        };
442        a = b;
443        b = r;
444    }
445    if poly_deg(&a) < 0 {
446        return vec![0];
447    }
448    poly_make_monic(&a, p)
449}
450
451fn poly_derivative(f: &[u64], p: u64) -> Vec<u64> {
452    let f = poly_trim(f.to_vec());
453    if f.len() <= 1 {
454        return vec![0];
455    }
456    let mut out = Vec::with_capacity(f.len() - 1);
457    for (k, &coeff) in f.iter().enumerate().skip(1) {
458        let d = mul_mod(coeff, k as u64, p);
459        out.push(d);
460    }
461    poly_trim(out)
462}
463
464fn poly_make_monic(f: &[u64], p: u64) -> Vec<u64> {
465    let f = poly_trim(f.to_vec());
466    if f.is_empty() {
467        return f;
468    }
469    let lc = *f.last().unwrap();
470    if lc == 0 {
471        return f;
472    }
473    let inv = mod_inv(lc, p);
474    f.iter().map(|&c| mul_mod(c, inv, p)).collect()
475}
476
477/// Remove repeated roots until `gcd(f, f′) = 1`.
478fn poly_squarefree(mut f: Vec<u64>, p: u64) -> Vec<u64> {
479    f = poly_make_monic(&f, p);
480    loop {
481        let dp = poly_derivative(&f, p);
482        let g = polygcd(&f, &dp, p);
483        let dg = poly_deg(&g);
484        if dg <= 0 {
485            break;
486        }
487        let (_, r) = poly_divmod(&f, &g, p).unwrap();
488        f = poly_make_monic(&r, p);
489    }
490    f
491}
492
493fn poly_mul_mod(a: &[u64], b: &[u64], modulo: &[u64], p: u64) -> Vec<u64> {
494    let prod = poly_mul(a, b, p);
495    poly_divmod(&prod, modulo, p)
496        .map(|(_, r)| r)
497        .unwrap_or(vec![0])
498}
499
500/// `base^exp (mod m)` in `F_p[X]`.
501fn poly_pow_mod(base: &[u64], mut exp: u64, m: &[u64], p: u64) -> Vec<u64> {
502    let m = poly_trim(m.to_vec());
503    if poly_deg(&m) < 0 {
504        return vec![0];
505    }
506    let mut acc = vec![1u64];
507    let mut b = poly_divmod(&poly_trim(base.to_vec()), &m, p)
508        .map(|(_, r)| r)
509        .unwrap_or(vec![0]);
510    while exp > 0 {
511        if exp & 1 != 0 {
512            acc = poly_mul_mod(&acc, &b, &m, p);
513        }
514        b = poly_mul_mod(&b, &b, &m, p);
515        exp >>= 1;
516    }
517    acc
518}
519
520/// Random dense polynomial of degree `< deg(f)` (for Cantor–Zassenhaus splitting).
521fn poly_random_below(max_deg: usize, p: u64, rng: &mut Xorshift64) -> Vec<u64> {
522    if max_deg == 0 {
523        return vec![0];
524    }
525    let mut c: Vec<u64> = (0..max_deg).map(|_| rng.next_range(0, p)).collect();
526    if c.iter().all(|&x| x == 0) {
527        c[rng.next_range(0, max_deg as u64) as usize] = rng.nonzero(p);
528    }
529    poly_trim(c)
530}
531
532/// Find all roots of `poly` in `F_p` using `gcd(f, X^p−X)` + probabilistic split.
533/// Assumes `p` is an odd prime and `deg(f) < p` (always true for Ben-Or/Tiwari Λ).
534fn find_roots(poly: &[u64], p: u64, rng: &mut Xorshift64) -> Result<Vec<u64>, SparseInterpError> {
535    let mut f = poly_trim(poly.to_vec());
536    if poly_deg(&f) < 0 {
537        return Ok(vec![]);
538    }
539    if p == 2 {
540        let mut r = Vec::new();
541        for v in 0..p {
542            if poly_eval(&f, v, p) == 0 {
543                r.push(v);
544            }
545        }
546        return Ok(r);
547    }
548    f = poly_squarefree(f, p);
549    if poly_deg(&f) < 0 {
550        return Ok(vec![]);
551    }
552    if poly_deg(&f) == 0 {
553        return Ok(vec![]);
554    }
555
556    // Split off the `F_p`-rational part: gcd(f, X^p − X).
557    let xp = poly_pow_mod(&[0, 1], p, &f, p);
558    let diff = poly_sub_(&xp, &[0, 1], p);
559    let mut h = polygcd(&f, &diff, p);
560    if poly_deg(&h) < 0 {
561        h = f;
562    }
563
564    let mut roots = Vec::new();
565    split_find_roots(&h, p, rng, &mut roots)?;
566    roots.sort_unstable();
567    roots.dedup();
568    Ok(roots)
569}
570
571fn split_find_roots(
572    f: &[u64],
573    p: u64,
574    rng: &mut Xorshift64,
575    roots: &mut Vec<u64>,
576) -> Result<(), SparseInterpError> {
577    let f = poly_make_monic(f, p);
578    let d = poly_deg(&f);
579    if d < 0 {
580        return Ok(());
581    }
582    if d == 0 {
583        return Ok(());
584    }
585    if d == 1 {
586        let a0 = sub_mod(0, f[0], p);
587        roots.push(a0);
588        return Ok(());
589    }
590
591    // Probabilistic split (Cantor–Zassenhaus / Rabin): for odd `p`, each nontrivial
592    // gcd( U^{(p−1)/2} ± 1 , f ) succeeds with probability ~1/2 per try.
593    const MAX_TRIES: usize = 256;
594    for _ in 0..MAX_TRIES {
595        let u = poly_random_below(d as usize, p, rng);
596        let exp = (p - 1) / 2;
597        let up = poly_pow_mod(&u, exp, &f, p);
598        for g in [poly_sub_(&up, &[1], p), poly_add(&up, &[1], p)] {
599            let d1 = polygcd(&f, &g, p);
600            let d1deg = poly_deg(&d1);
601            if d1deg > 0 && d1deg < d {
602                let (cofactor, rem) = poly_divmod(&f, &d1, p).unwrap();
603                // `d1` must be a genuine divisor; use the **quotient** cofactor.
604                if poly_deg(&rem) >= 0 {
605                    continue;
606                }
607                split_find_roots(&d1, p, rng, roots)?;
608                split_find_roots(&poly_make_monic(&cofactor, p), p, rng, roots)?;
609                return Ok(());
610            }
611        }
612    }
613    // Λ rarely has degree > ~20; brute force is negligible vs failing outright.
614    if (d as u128) * (p as u128) <= 2_500_000 {
615        for v in 0..p {
616            if poly_eval(&f, v, p) == 0 {
617                roots.push(v);
618            }
619        }
620        return Ok(());
621    }
622    Err(SparseInterpError::RootFindingFailed)
623}
624
625// ---------------------------------------------------------------------------
626// Baby-step giant-step discrete logarithm
627// ---------------------------------------------------------------------------
628
629/// Compute `e` such that `g^e ≡ target (mod p)`, or `None` if no such `e`
630/// exists in `{0, …, p-2}`.
631///
632/// Uses the Baby-step / Giant-step algorithm in `O(√p)` time and space.
633pub fn bsgs_dlog(g: u64, target: u64, p: u64) -> Option<u64> {
634    if target == 0 {
635        return None; // g is never 0 in F_p*
636    }
637    let order = p - 1; // order of F_p* (g is a generator)
638    let m = (order as f64).sqrt().ceil() as u64 + 1;
639
640    // Baby steps: table[g^j] = j  for j = 0 … m-1
641    let mut table = std::collections::HashMap::with_capacity(m as usize);
642    let mut gj = 1u64;
643    for j in 0..m {
644        table.insert(gj, j);
645        gj = mul_mod(gj, g, p);
646    }
647
648    // Giant steps: find i such that target · (g^{-m})^i is in table
649    let gm = pow_mod(g, m, p);
650    let gm_inv = mod_inv(gm, p);
651    let mut y = target;
652    for i in 0..m {
653        if let Some(&j) = table.get(&y) {
654            let e = i * m + j;
655            let e_mod = e % order;
656            // Verify
657            if pow_mod(g, e_mod, p) == target {
658                return Some(e_mod);
659            }
660        }
661        y = mul_mod(y, gm_inv, p);
662    }
663    None
664}
665
666// ---------------------------------------------------------------------------
667// Vandermonde solve (generalised)
668// ---------------------------------------------------------------------------
669
670/// Solve the generalised Vandermonde system:
671///
672/// ```text
673/// Σ_j  c[j] · pts[i]^{exps[j]}  =  vals[i]   for i = 0, …, t-1
674/// ```
675///
676/// Returns `Some(c)` if the system is non-singular, or `None` otherwise.
677fn vandermonde_solve(pts: &[u64], exps: &[u32], vals: &[u64], p: u64) -> Option<Vec<u64>> {
678    let t = pts.len();
679    debug_assert_eq!(exps.len(), t);
680    debug_assert_eq!(vals.len(), t);
681
682    // Build the t×t matrix A where A[i][j] = pts[i]^exps[j]
683    let mut mat: Vec<Vec<u64>> = (0..t)
684        .map(|i| (0..t).map(|j| pow_mod(pts[i], exps[j] as u64, p)).collect())
685        .collect();
686    let mut rhs: Vec<u64> = vals.to_vec();
687
688    gaussian_elim(&mut mat, &mut rhs, p)
689}
690
691/// Gaussian elimination with partial pivoting over `F_p`.
692/// Modifies `mat` and `rhs` in place; returns the solution or `None` if
693/// the system is singular.
694fn gaussian_elim(mat: &mut [Vec<u64>], rhs: &mut [u64], p: u64) -> Option<Vec<u64>> {
695    let n = mat.len();
696    for col in 0..n {
697        // Find pivot (first non-zero entry in column col, at or below row col)
698        let pivot_row = (col..n).find(|&r| mat[r][col] != 0)?;
699        mat.swap(col, pivot_row);
700        rhs.swap(col, pivot_row);
701
702        let inv = mod_inv(mat[col][col], p);
703        // Scale pivot row
704        for entry in &mut mat[col][col..] {
705            *entry = mul_mod(*entry, inv, p);
706        }
707        rhs[col] = mul_mod(rhs[col], inv, p);
708
709        // Eliminate column in all other rows
710        for row in 0..n {
711            if row == col {
712                continue;
713            }
714            let factor = mat[row][col];
715            if factor == 0 {
716                continue;
717            }
718            // Gather the pivot row values to avoid borrow conflict.
719            let pivot_row_vals: Vec<u64> = mat[col][col..].to_vec();
720            for (j, &pv) in pivot_row_vals.iter().enumerate() {
721                let sub = mul_mod(factor, pv, p);
722                mat[row][col + j] = sub_mod(mat[row][col + j], sub, p);
723            }
724            let sub = mul_mod(factor, rhs[col], p);
725            rhs[row] = sub_mod(rhs[row], sub, p);
726        }
727    }
728    Some(rhs.to_owned())
729}
730
731// ---------------------------------------------------------------------------
732// Univariate Ben-Or/Tiwari (internal)
733// ---------------------------------------------------------------------------
734
735/// Internal Ben-Or/Tiwari.  Evaluates at `g^0, …, g^{2T-1}` and runs the
736/// full Prony pipeline.  Returns `(coeff, exponent)` pairs, or an error.
737fn bt_univariate(
738    eval: &dyn Fn(u64) -> u64,
739    term_bound: usize,
740    prime: u64,
741    g: u64, // primitive root of F_p
742    rng: &mut Xorshift64,
743) -> Result<Vec<(u64, u32)>, SparseInterpError> {
744    if term_bound == 0 {
745        return Ok(vec![]);
746    }
747    let two_t = 2 * term_bound;
748
749    // --- Step 1: Evaluate at g^0, g^1, …, g^{2T-1} ---
750    let mut seq = Vec::with_capacity(two_t);
751    let mut gj = 1u64; // g^j
752    for _ in 0..two_t {
753        seq.push(eval(gj));
754        gj = mul_mod(gj, g, prime);
755    }
756
757    // --- Step 2: Berlekamp–Massey to find connection polynomial Λ ---
758    let lambda = berlekamp_massey(&seq, prime);
759    let ell = lambda.len() - 1; // LFSR length L ≤ T
760
761    if ell == 0 {
762        // Only the trivial polynomial: the sequence is identically zero.
763        return Ok(vec![]);
764    }
765
766    // --- Step 3: Find roots ρ of Λ in `F_p` (Cantor–Zassenhaus-style split) ---
767    let rho_roots = find_roots(&lambda, prime, rng)?;
768
769    if rho_roots.len() < ell {
770        return Err(SparseInterpError::RootFindingFailed);
771    }
772    // Use only the first `ell` roots (should be exactly ell distinct ones).
773    let rho: &[u64] = &rho_roots[..ell];
774
775    // --- Step 4: Map roots → frequencies → exponents ---
776    // Λ has roots ρ_j = g^{-e_j} (the inverses of the frequencies r_j).
777    // r_j = ρ_j^{-1} = g^{e_j}.
778    let mut exps: Vec<u32> = Vec::with_capacity(ell);
779    for &ro in rho {
780        if ro == 0 {
781            return Err(SparseInterpError::RootFindingFailed);
782        }
783        let r = mod_inv(ro, prime); // r = g^{e_j}
784        let e = bsgs_dlog(g, r, prime).ok_or(SparseInterpError::RootFindingFailed)?;
785        exps.push(e as u32);
786    }
787
788    // --- Step 5: Solve Vandermonde for coefficients ---
789    // The evaluation sequence satisfies: s[n] = Σ_j c_j · (g^{e_j})^n.
790    // As a matrix system with A[i][j] = pts[i]^{exps[j]}:
791    //   pts[i] = g^i  (i-th evaluation point)
792    //   exps[j] = e_j  (j-th monomial exponent)
793    //   vals[i] = s[i]  (i-th sequence value)
794    // This is the correct generalised-Vandermonde formulation.
795    let pts_for_vdm: Vec<u64> = (0..ell).map(|i| pow_mod(g, i as u64, prime)).collect();
796    let vals_for_vdm: Vec<u64> = seq[..ell].to_vec();
797    let coeffs = vandermonde_solve(&pts_for_vdm, &exps, &vals_for_vdm, prime)
798        .ok_or(SparseInterpError::SingularSystem)?;
799
800    Ok(coeffs
801        .into_iter()
802        .zip(exps)
803        .filter(|(c, _)| *c != 0)
804        .collect())
805}
806
807// ---------------------------------------------------------------------------
808// Dense univariate interpolation (fallback)
809// ---------------------------------------------------------------------------
810
811/// Dense Lagrange interpolation over `F_p`.
812///
813/// Given evaluations `f(1), f(2), …, f(D+1)` (at the first `D+1` non-zero
814/// field elements), returns the polynomial coefficients in ascending degree
815/// order.
816fn dense_interpolate(vals: &[u64], prime: u64) -> Vec<(u64, u32)> {
817    let n = vals.len();
818    // Evaluation points: 1, 2, …, n
819    let pts: Vec<u64> = (1..=n as u64).collect();
820    // Build Vandermonde system: pts[i]^j * c[j] = vals[i]
821    let mut mat: Vec<Vec<u64>> = (0..n)
822        .map(|i| (0..n).map(|j| pow_mod(pts[i], j as u64, prime)).collect())
823        .collect();
824    let mut rhs = vals.to_vec();
825    match gaussian_elim(&mut mat, &mut rhs, prime) {
826        Some(coeffs) => coeffs
827            .into_iter()
828            .enumerate()
829            .filter(|(_, c)| *c != 0)
830            .map(|(j, c)| (c, j as u32))
831            .collect(),
832        None => vec![], // degenerate; return empty
833    }
834}
835
836// ---------------------------------------------------------------------------
837// Multivariate Zippel (recursive)
838// ---------------------------------------------------------------------------
839
840/// One Vandermonde lift applied coherently across `dim` sibling components.
841fn lifted_eval_union(
842    x_pts: &[u64],
843    joint_exps: &[u32],
844    eval_multi: &dyn Fn(&[u64]) -> Vec<u64>,
845    prime: u64,
846    dim: usize,
847    m_count: usize,
848    x_suffix: &[u64],
849) -> Vec<u64> {
850    let mut new_vec = Vec::with_capacity(dim * m_count);
851    for j in 0..dim {
852        let f_vals: Vec<u64> = x_pts
853            .iter()
854            .map(|&xk| {
855                let mut args = vec![xk];
856                args.extend_from_slice(x_suffix);
857                eval_multi(&args).get(j).copied().unwrap_or(0)
858            })
859            .collect();
860        let coeffs = vandermonde_solve(x_pts, joint_exps, &f_vals, prime)
861            .unwrap_or_else(|| vec![0u64; m_count]);
862        debug_assert_eq!(coeffs.len(), m_count);
863        new_vec.extend(coeffs);
864    }
865    new_vec
866}
867
868/// Batched lifting: recover `dim` sibling coefficient polynomials simultaneously.
869/// Each map entry is `sparse exponents → coeff` in the remaining variables only.
870#[allow(clippy::too_many_arguments)] // recursion driver: shared oracle plus dimension bounds
871fn zippel_helper_multi(
872    eval_multi: &dyn Fn(&[u64]) -> Vec<u64>,
873    n_vars: usize,
874    dim: usize,
875    term_bound: usize,
876    degree_bound: u32,
877    prime: u64,
878    g: u64,
879    rng: &mut Xorshift64,
880) -> Result<Vec<BTreeMap<Vec<u32>, u64>>, SparseInterpError> {
881    if dim == 0 {
882        return Ok(vec![]);
883    }
884
885    if n_vars == 0 {
886        let v = eval_multi(&[]);
887        let mut out = Vec::with_capacity(dim);
888        for j in 0..dim {
889            let mut m = BTreeMap::new();
890            let c = *v.get(j).unwrap_or(&0);
891            if c != 0 {
892                m.insert(vec![], c);
893            }
894            out.push(m);
895        }
896        return Ok(out);
897    }
898
899    if n_vars == 1 {
900        let mut out = Vec::with_capacity(dim);
901        for j in 0..dim {
902            let terms = if degree_bound <= term_bound as u32 {
903                let d = degree_bound as usize + 1;
904                let vals: Vec<u64> = (1..=d as u64)
905                    .map(|x| eval_multi(&[x % prime]).get(j).copied().unwrap_or(0))
906                    .collect();
907                dense_interpolate(&vals, prime)
908            } else {
909                bt_univariate(
910                    &|t| eval_multi(&[t]).get(j).copied().unwrap_or(0),
911                    term_bound,
912                    prime,
913                    g,
914                    rng,
915                )?
916            };
917            let mut m = BTreeMap::new();
918            for (c, e) in terms {
919                if c != 0 {
920                    m.insert(vec![e], c);
921                }
922            }
923            out.push(m);
924        }
925        return Ok(out);
926    }
927
928    let a_rest: Vec<u64> = (0..n_vars - 1).map(|_| rng.nonzero(prime)).collect();
929
930    let mut per_comp_skeletons: Vec<Vec<(u64, u32)>> = Vec::with_capacity(dim);
931    for j in 0..dim {
932        let sk = {
933            let f1 = |t: u64| -> u64 {
934                let mut args = vec![t];
935                args.extend_from_slice(&a_rest);
936                eval_multi(&args).get(j).copied().unwrap_or(0)
937            };
938            if degree_bound <= term_bound as u32 {
939                let d = degree_bound as usize + 1;
940                let v: Vec<u64> = (1..=d as u64).map(|x| f1(x % prime)).collect();
941                dense_interpolate(&v, prime)
942            } else {
943                bt_univariate(&f1, term_bound, prime, g, rng)?
944            }
945        };
946        per_comp_skeletons.push(sk);
947    }
948
949    let mut joint_exps: Vec<u32> = Vec::new();
950    for sk in &per_comp_skeletons {
951        for &(_, e) in sk {
952            joint_exps.push(e);
953        }
954    }
955    joint_exps.sort_unstable();
956    joint_exps.dedup();
957    let m_count = joint_exps.len();
958
959    let empty_maps = || (0..dim).map(|_| BTreeMap::new()).collect::<Vec<_>>();
960
961    if m_count == 0 {
962        return Ok(empty_maps());
963    }
964
965    // Fully batched recursion uses vector dimension `dim · |joint|`; union can be
966    // large across many siblings (`dim ≈ term_bound`).  Above this budget fall back to
967    // the classic nested `zippel_helper` lifts — oracle depth improves over the legacy
968    // implementation (shared lift at outer peel) while keeping tests bounded.
969    // Allow large batched lifts for realistic `term_bound` (~20–50); tighter caps
970    // force the scalar fallback whose constant factor dominates on large `n_vars`.
971    let vec_budget = term_bound.saturating_mul(512).clamp(8192usize, 131072usize);
972    if dim.saturating_mul(m_count) > vec_budget {
973        let mut stacked: Vec<BTreeMap<Vec<u32>, u64>> = Vec::with_capacity(dim);
974        for (j, sk) in per_comp_skeletons.iter().enumerate().take(dim) {
975            if sk.is_empty() {
976                stacked.push(BTreeMap::new());
977                continue;
978            }
979            let exps_j: Vec<u32> = sk.iter().map(|(_, e)| *e).collect();
980            let tj = exps_j.len();
981            let mut pts: Vec<u64> = Vec::with_capacity(tj);
982            {
983                let mut used = std::collections::HashSet::new();
984                while pts.len() < tj {
985                    let v = rng.nonzero(prime);
986                    if used.insert(v) {
987                        pts.push(v);
988                    }
989                }
990            }
991            let mut comp_map = BTreeMap::new();
992            for k in 0..tj {
993                let e_cur = exps_j[k];
994                let sub_terms = zippel_helper(
995                    &|x_rest: &[u64]| -> u64 {
996                        let f_vals: Vec<u64> = pts
997                            .iter()
998                            .map(|&xk| {
999                                let mut args = vec![xk];
1000                                args.extend_from_slice(x_rest);
1001                                eval_multi(&args).get(j).copied().unwrap_or(0)
1002                            })
1003                            .collect();
1004                        vandermonde_solve(&pts, &exps_j, &f_vals, prime)
1005                            .map(|v| v[k])
1006                            .unwrap_or(0)
1007                    },
1008                    n_vars - 1,
1009                    term_bound,
1010                    degree_bound,
1011                    prime,
1012                    g,
1013                    rng,
1014                )?;
1015                for (mut sub_exp, coeff) in sub_terms {
1016                    if coeff != 0 {
1017                        let mut full = vec![e_cur];
1018                        full.append(&mut sub_exp);
1019                        comp_map.insert(full, coeff);
1020                    }
1021                }
1022            }
1023            stacked.push(comp_map);
1024        }
1025        return Ok(stacked);
1026    }
1027
1028    let mut x_pts: Vec<u64> = Vec::with_capacity(m_count);
1029    {
1030        let mut used = std::collections::HashSet::new();
1031        while x_pts.len() < m_count {
1032            let v = rng.nonzero(prime);
1033            if used.insert(v) {
1034                x_pts.push(v);
1035            }
1036        }
1037    }
1038
1039    let dim_next = dim * m_count;
1040    let sub = zippel_helper_multi(
1041        &|x_suffix: &[u64]| {
1042            lifted_eval_union(
1043                &x_pts,
1044                &joint_exps,
1045                eval_multi,
1046                prime,
1047                dim,
1048                m_count,
1049                x_suffix,
1050            )
1051        },
1052        n_vars - 1,
1053        dim_next,
1054        term_bound,
1055        degree_bound,
1056        prime,
1057        g,
1058        rng,
1059    )?;
1060
1061    let mut result: Vec<BTreeMap<Vec<u32>, u64>> = empty_maps();
1062    for (j, res_j) in result.iter_mut().enumerate().take(dim) {
1063        for (r, &e1) in joint_exps.iter().enumerate().take(m_count) {
1064            let slot = j * m_count + r;
1065            for (sub_exp, coeff) in &sub[slot] {
1066                if *coeff != 0 {
1067                    let mut full_exp = vec![e1];
1068                    full_exp.extend_from_slice(sub_exp);
1069                    res_j.insert(full_exp, *coeff);
1070                }
1071            }
1072        }
1073    }
1074
1075    Ok(result)
1076}
1077
1078/// Recursive Zippel helper.  Returns a map from exponent vectors to
1079/// coefficients in `F_p`.
1080fn zippel_helper(
1081    eval: &dyn Fn(&[u64]) -> u64,
1082    n_vars: usize,
1083    term_bound: usize,
1084    degree_bound: u32,
1085    prime: u64,
1086    g: u64,
1087    rng: &mut Xorshift64,
1088) -> Result<BTreeMap<Vec<u32>, u64>, SparseInterpError> {
1089    // --- Base case: constant polynomial ---
1090    if n_vars == 0 {
1091        let c = eval(&[]);
1092        let mut m = BTreeMap::new();
1093        if c != 0 {
1094            m.insert(vec![], c);
1095        }
1096        return Ok(m);
1097    }
1098
1099    // --- Base case: univariate ---
1100    if n_vars == 1 {
1101        // Use dense fallback if degree_bound is small (avoids BSGS overhead).
1102        let terms = if degree_bound <= term_bound as u32 {
1103            // Dense path: evaluate at degree_bound+1 points.
1104            let d = degree_bound as usize + 1;
1105            let v: Vec<u64> = (1..=d as u64).map(|x| eval(&[x % prime])).collect();
1106            dense_interpolate(&v, prime)
1107        } else {
1108            bt_univariate(&|t| eval(&[t]), term_bound, prime, g, rng)?
1109        };
1110        let mut m = BTreeMap::new();
1111        for (c, e) in terms {
1112            m.insert(vec![e], c);
1113        }
1114        return Ok(m);
1115    }
1116
1117    // --- Multivariate Zippel ---
1118
1119    // Step 1: Evaluate f(x₁, a₂, …, aₙ) for random aᵢ to get x₁-skeleton.
1120    let a_rest: Vec<u64> = (0..n_vars - 1).map(|_| rng.nonzero(prime)).collect();
1121
1122    let skeleton: Vec<(u64, u32)> = {
1123        let f1 = |t: u64| -> u64 {
1124            let mut args = vec![t];
1125            args.extend_from_slice(&a_rest);
1126            eval(&args)
1127        };
1128        if degree_bound <= term_bound as u32 {
1129            let d = degree_bound as usize + 1;
1130            let v: Vec<u64> = (1..=d as u64).map(|x| f1(x % prime)).collect();
1131            dense_interpolate(&v, prime)
1132        } else {
1133            bt_univariate(&f1, term_bound, prime, g, rng)?
1134        }
1135    };
1136
1137    if skeleton.is_empty() {
1138        return Ok(BTreeMap::new());
1139    }
1140
1141    let x1_exps: Vec<u32> = skeleton.iter().map(|(_, e)| *e).collect();
1142    let t = x1_exps.len();
1143
1144    // Step 2: Choose t distinct evaluation points for x₁.
1145    let mut x1_pts: Vec<u64> = Vec::with_capacity(t);
1146    {
1147        let mut used = std::collections::HashSet::new();
1148        while x1_pts.len() < t {
1149            let v = rng.nonzero(prime);
1150            if used.insert(v) {
1151                x1_pts.push(v);
1152            }
1153        }
1154    }
1155
1156    // Step 3: batched Vandermonde lift → single recursion (shared oracle).
1157    let eval_multi = |x_rest: &[u64]| -> Vec<u64> {
1158        let mut f_vals: Vec<u64> = Vec::with_capacity(t);
1159        for &xk in &x1_pts {
1160            let mut args = vec![xk];
1161            args.extend_from_slice(x_rest);
1162            f_vals.push(eval(&args));
1163        }
1164        vandermonde_solve(&x1_pts, &x1_exps, &f_vals, prime).unwrap_or_else(|| vec![0u64; t])
1165    };
1166
1167    let sub_maps = zippel_helper_multi(
1168        &eval_multi,
1169        n_vars - 1,
1170        t,
1171        term_bound,
1172        degree_bound,
1173        prime,
1174        g,
1175        rng,
1176    )?;
1177
1178    let mut result: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1179    for j in 0..t {
1180        let e1 = x1_exps[j];
1181        for (sub_exp, coeff) in &sub_maps[j] {
1182            if *coeff != 0 {
1183                let mut full_exp = vec![e1];
1184                full_exp.extend_from_slice(sub_exp);
1185                result.insert(full_exp, *coeff);
1186            }
1187        }
1188    }
1189
1190    Ok(result)
1191}
1192
1193// ---------------------------------------------------------------------------
1194// Public API
1195// ---------------------------------------------------------------------------
1196
1197/// Recover a sparse univariate polynomial `f ∈ F_p[x]` from black-box
1198/// evaluations using the Ben-Or/Tiwari (Prony-style) algorithm.
1199///
1200/// # Parameters
1201///
1202/// - `eval` — black-box oracle: `x ↦ f(x) mod p`.  Called `2·term_bound`
1203///   times (at consecutive powers of a primitive root of `F_p`).
1204/// - `term_bound` — `T`: upper bound on the number of nonzero terms in `f`.
1205/// - `prime` — field characteristic `p`.  Must satisfy `p > 2·T` and
1206///   `p > max_degree(f)` (so all exponents are representable as discrete
1207///   logarithms in `{0, …, p-2}`).
1208///
1209/// # Returns
1210///
1211/// A vector of `(coefficient, exponent)` pairs in arbitrary order.
1212///
1213/// # Errors
1214///
1215/// - [`SparseInterpError::InvalidPrime`] if `p` is not prime.
1216/// - [`SparseInterpError::PrimeTooSmall`] if `p ≤ 2·T`.
1217/// - [`SparseInterpError::RootFindingFailed`] if fewer roots were found than
1218///   expected (the prime may be smaller than `max_degree(f)`).
1219/// - [`SparseInterpError::SingularSystem`] if the Vandermonde system is
1220///   degenerate (extremely rare; retry with a different prime).
1221///
1222/// # Example
1223///
1224/// ```text
1225/// // Recover  x^100 + 3·x^17 + 5  from 6 evaluations (T=3).
1226/// let eval = |x: u64| { ... };
1227/// let terms = sparse_interpolate_univariate(&eval, 3, 1009)?;
1228/// // terms ≈ [(1, 100), (3, 17), (5, 0)]
1229/// ```
1230pub fn sparse_interpolate_univariate(
1231    eval: &dyn Fn(u64) -> u64,
1232    term_bound: usize,
1233    prime: u64,
1234) -> Result<Vec<(u64, u32)>, SparseInterpError> {
1235    if !is_prime(prime) {
1236        return Err(SparseInterpError::InvalidPrime(prime));
1237    }
1238    if prime <= 2 * term_bound as u64 {
1239        return Err(SparseInterpError::PrimeTooSmall { prime, term_bound });
1240    }
1241    let g = primitive_root(prime);
1242    let mut rng = Xorshift64::new(prime.wrapping_mul(0x5851_f42d_4c95_7f2d));
1243    bt_univariate(eval, term_bound, prime, g, &mut rng)
1244}
1245
1246/// Recover a sparse multivariate polynomial `f ∈ F_p[x₁, …, xₙ]` from
1247/// black-box evaluations using Zippel's variable-by-variable algorithm.
1248///
1249/// # Parameters
1250///
1251/// - `eval` — black-box oracle: `(x₁, …, xₙ) ↦ f(x₁, …, xₙ) mod p`.
1252///   Coordinates are given in the same order as `vars`.
1253/// - `vars` — symbolic variable identifiers (used to label the result).
1254/// - `term_bound` — `T`: upper bound on the number of nonzero terms.
1255/// - `degree_bound` — `D`: upper bound on the degree of each individual
1256///   variable.  Polynomials with lower `D` converge faster.  For the dense
1257///   fallback to kick in, set `D ≤ T`.
1258/// - `prime` — field characteristic `p`.  Must satisfy `p > 2·T` and
1259///   `p > D` (so exponents are representable as discrete logs).
1260/// - `seed` — seed for the internal PRNG.  Changing the seed helps recover
1261///   from occasional failures due to unlucky random evaluation points.
1262///
1263/// # Returns
1264///
1265/// A [`MultiPolyFp`] with the recovered polynomial.  Oracle complexity is
1266/// polynomial in the number of variables, `term_bound`, and `degree_bound` on
1267/// typical sparse inputs — unlike dense interpolation at `Ω((D+1)^n)`.
1268///
1269/// # Errors
1270///
1271/// See [`SparseInterpError`].
1272pub fn sparse_interpolate(
1273    eval: &dyn Fn(&[u64]) -> u64,
1274    vars: Vec<ExprId>,
1275    term_bound: usize,
1276    degree_bound: u32,
1277    prime: u64,
1278    seed: u64,
1279) -> Result<MultiPolyFp, SparseInterpError> {
1280    if !is_prime(prime) {
1281        return Err(SparseInterpError::InvalidPrime(prime));
1282    }
1283    if prime <= 2 * term_bound as u64 {
1284        return Err(SparseInterpError::PrimeTooSmall { prime, term_bound });
1285    }
1286
1287    let n_vars = vars.len();
1288    let g = primitive_root(prime);
1289    let mut rng = Xorshift64::new(seed);
1290
1291    let terms = zippel_helper(eval, n_vars, term_bound, degree_bound, prime, g, &mut rng)?;
1292
1293    let trimmed_terms: BTreeMap<Vec<u32>, u64> = terms
1294        .into_iter()
1295        .map(|(mut exp, c)| {
1296            // Trim trailing zeros in exponent vector.
1297            while exp.last() == Some(&0) {
1298                exp.pop();
1299            }
1300            (exp, c)
1301        })
1302        .filter(|(_, c)| *c != 0)
1303        .collect();
1304
1305    Ok(MultiPolyFp {
1306        vars,
1307        modulus: prime,
1308        terms: trimmed_terms,
1309    })
1310}
1311
1312// ---------------------------------------------------------------------------
1313// Sparse modular GCD — "substrate for faster modular algorithms"
1314// ---------------------------------------------------------------------------
1315
1316/// Error returned by [`gcd_sparse_modular`].
1317#[derive(Debug, Clone, PartialEq)]
1318pub enum SparseGcdError {
1319    /// The two polynomials have incompatible variable lists.
1320    IncompatiblePolynomials,
1321    /// Sparse interpolation failed during a modular GCD step.
1322    InterpFailed(SparseInterpError),
1323    /// CRT lifting failed.
1324    CrtFailed(crate::modular::ModularError),
1325}
1326
1327impl std::fmt::Display for SparseGcdError {
1328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1329        match self {
1330            SparseGcdError::IncompatiblePolynomials => {
1331                write!(f, "polynomials have incompatible variable lists")
1332            }
1333            SparseGcdError::InterpFailed(e) => write!(f, "interpolation step failed: {e}"),
1334            SparseGcdError::CrtFailed(e) => write!(f, "CRT lifting failed: {e}"),
1335        }
1336    }
1337}
1338
1339impl std::error::Error for SparseGcdError {
1340    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
1341        match self {
1342            SparseGcdError::InterpFailed(e) => Some(e),
1343            SparseGcdError::CrtFailed(e) => Some(e),
1344            _ => None,
1345        }
1346    }
1347}
1348
1349impl AlkahestError for SparseGcdError {
1350    fn code(&self) -> &'static str {
1351        match self {
1352            SparseGcdError::IncompatiblePolynomials => "E-INTERP-010",
1353            SparseGcdError::InterpFailed(_) => "E-INTERP-011",
1354            SparseGcdError::CrtFailed(_) => "E-INTERP-012",
1355        }
1356    }
1357
1358    fn remediation(&self) -> Option<&'static str> {
1359        match self {
1360            SparseGcdError::IncompatiblePolynomials => {
1361                Some("ensure both polynomials share the same variable list in the same order")
1362            }
1363            SparseGcdError::InterpFailed(_) => {
1364                Some("retry with a larger term_bound, degree_bound, or a different seed")
1365            }
1366            SparseGcdError::CrtFailed(_) => {
1367                Some("provide more primes or use a larger prime product threshold")
1368            }
1369        }
1370    }
1371}
1372
1373/// Evaluate all variables except `x₁` at `vals = [a₂, …, aₙ]`, returning a
1374/// dense coefficient vector `[c₀, c₁, …]` where `cₖ` is the coefficient of `x₁^k`.
1375fn specialize_except_first(fp: &MultiPolyFp, vals: &[u64]) -> Vec<u64> {
1376    let p = fp.modulus;
1377    let max_x1 = fp
1378        .terms
1379        .keys()
1380        .map(|e| e.first().copied().unwrap_or(0))
1381        .max()
1382        .unwrap_or(0) as usize;
1383    let mut result = vec![0u64; max_x1 + 1];
1384    for (exp, &coeff) in &fp.terms {
1385        let k = exp.first().copied().unwrap_or(0) as usize;
1386        let mut factor = coeff;
1387        for (i, &e) in exp.iter().skip(1).enumerate() {
1388            if e > 0 {
1389                let ai = *vals.get(i).unwrap_or(&0);
1390                factor = mul_mod(factor, pow_mod(ai, e as u64, p), p);
1391            }
1392        }
1393        result[k] = add_mod(result[k], factor, p);
1394    }
1395    poly_trim(result)
1396}
1397
1398/// Compute the monic GCD image over `Fₚ[x₁,…,xₙ]` via evaluation/interpolation.
1399/// Uses [`sparse_interpolate`] to recover each `x₁^k`-coefficient polynomial.
1400fn gcd_sparse_mod_p(
1401    f_p: &MultiPolyFp,
1402    g_p: &MultiPolyFp,
1403    sub_vars: Vec<ExprId>,
1404    term_bound: usize,
1405    degree_bound: u32,
1406    prime: u64,
1407    seed: u64,
1408) -> Result<MultiPolyFp, SparseInterpError> {
1409    let p = prime;
1410    let vars_full = f_p.vars.clone();
1411
1412    // Probe one random specialization to find the GCD degree in x₁.
1413    let n_sub = sub_vars.len();
1414    let mut rng = Xorshift64::new(seed ^ p.wrapping_mul(0x9e37_79b9_7f4a_7c15));
1415    let probe_vals: Vec<u64> = (0..n_sub).map(|_| rng.nonzero(p)).collect();
1416    let f1 = specialize_except_first(f_p, &probe_vals);
1417    let g1 = specialize_except_first(g_p, &probe_vals);
1418    let h1 = polygcd(&f1, &g1, p);
1419    let gcd_deg_x1 = poly_deg(&h1).max(0) as usize;
1420
1421    let mut h_terms: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1422
1423    if sub_vars.is_empty() {
1424        // Univariate: GCD is already determined from the probe.
1425        for (k, &c) in h1.iter().enumerate() {
1426            if c != 0 {
1427                let mut exp = vec![k as u32];
1428                while exp.last() == Some(&0) {
1429                    exp.pop();
1430                }
1431                h_terms.insert(exp, c);
1432            }
1433        }
1434    } else {
1435        // Multivariate: for each x₁-degree k, interpolate the coefficient polynomial.
1436        for k in 0..=gcd_deg_x1 {
1437            let oracle = |vals: &[u64]| -> u64 {
1438                let fa = specialize_except_first(f_p, vals);
1439                let ga = specialize_except_first(g_p, vals);
1440                let hk = polygcd(&fa, &ga, p);
1441                hk.get(k).copied().unwrap_or(0)
1442            };
1443            let ck = sparse_interpolate(
1444                &oracle,
1445                sub_vars.clone(),
1446                term_bound,
1447                degree_bound,
1448                p,
1449                seed.wrapping_add(k as u64 + 1),
1450            )?;
1451            for (sub_exp, &c) in &ck.terms {
1452                if c == 0 {
1453                    continue;
1454                }
1455                let mut full_exp = vec![k as u32];
1456                full_exp.extend_from_slice(sub_exp);
1457                while full_exp.last() == Some(&0) {
1458                    full_exp.pop();
1459                }
1460                h_terms.insert(full_exp, c);
1461            }
1462        }
1463    }
1464
1465    Ok(MultiPolyFp {
1466        vars: vars_full,
1467        modulus: prime,
1468        terms: h_terms,
1469    })
1470}
1471
1472/// Compute the primitive GCD of `f` and `g` in `ℤ[x₁,…,xₙ]` using sparse
1473/// interpolation and the Chinese Remainder Theorem.
1474///
1475/// # Algorithm (Zippel evaluation–interpolation GCD)
1476///
1477/// For each lucky prime `p` (skipping primes that collapse either polynomial's
1478/// integer content):
1479///  1. Reduce `f` and `g` modulo `p`.
1480///  2. For each degree `k` in `x₁`, build the oracle
1481///     `(a₂,…,aₙ) ↦ [x₁^k] gcd(f(x₁,a₂,…), g(x₁,a₂,…))`.
1482///  3. Call [`sparse_interpolate`] to recover the coefficient polynomial
1483///     `c_k(x₂,…,xₙ)`.
1484///  4. Assemble the modular GCD image `h mod p = Σ c_k · x₁^k`.
1485///
1486/// Once the product of chosen primes exceeds `2 · Mignotte(min(f, g))`, apply
1487/// CRT lifting ([`crate::modular::lift_crt`]) to recover the integer GCD, then
1488/// return the primitive part.
1489///
1490/// # Parameters
1491///
1492/// - `term_bound` — upper bound on the number of nonzero terms in the GCD.
1493///   Should be at most `min(terms(f), terms(g))` for efficiency.
1494/// - `degree_bound` — upper bound on the per-variable degree of the GCD (for
1495///   variables `x₂,…,xₙ`; degree in `x₁` is probed automatically).
1496/// - `seed` — PRNG seed for [`sparse_interpolate`]; change on failure.
1497///
1498/// # Errors
1499///
1500/// - [`SparseGcdError::IncompatiblePolynomials`] — different variable lists.
1501/// - [`SparseGcdError::InterpFailed`] — interpolation failure.
1502/// - [`SparseGcdError::CrtFailed`] — CRT reconstruction failure.
1503pub fn gcd_sparse_modular(
1504    f: &super::multipoly::MultiPoly,
1505    g: &super::multipoly::MultiPoly,
1506    term_bound: usize,
1507    degree_bound: u32,
1508    seed: u64,
1509) -> Result<super::multipoly::MultiPoly, SparseGcdError> {
1510    use crate::modular::{lift_crt, mignotte_bound, reduce_mod};
1511    use rug::Integer;
1512
1513    if f.vars != g.vars {
1514        return Err(SparseGcdError::IncompatiblePolynomials);
1515    }
1516    if f.is_zero() {
1517        return Ok(g.clone());
1518    }
1519    if g.is_zero() {
1520        return Ok(f.clone());
1521    }
1522
1523    let vars = f.vars.clone();
1524    let sub_vars = if vars.len() > 1 {
1525        vars[1..].to_vec()
1526    } else {
1527        vec![]
1528    };
1529
1530    let b_f = mignotte_bound(f);
1531    let b_g = mignotte_bound(g);
1532    let bound = b_f.min(b_g);
1533    let two_bound = bound.clone() << 1u32;
1534
1535    // Minimum prime: p > 2·T (for Ben-Or/Tiwari) and p > D (for discrete-log exponent map).
1536    let min_p = ((2 * term_bound + 2) as u64).max(degree_bound as u64 + 2);
1537
1538    // Content for divisibility avoidance.
1539    let content = f.integer_content() * g.integer_content();
1540
1541    let mut images: Vec<(MultiPolyFp, u64)> = Vec::new();
1542    let mut used: Vec<u64> = Vec::new();
1543    let mut m = Integer::from(1u64);
1544    let mut candidate = min_p.max(3);
1545
1546    while m <= two_bound {
1547        // Find next prime ≥ candidate that doesn't divide content.
1548        loop {
1549            if is_prime(candidate) && !used.contains(&candidate) {
1550                if content == 0 {
1551                    break;
1552                }
1553                let p_int = Integer::from(candidate);
1554                let r = content.clone() % p_int.clone();
1555                let r = if r < 0 { r + p_int } else { r };
1556                if r != 0 {
1557                    break;
1558                }
1559            }
1560            candidate += 1;
1561            if candidate > 1_000_003 {
1562                break;
1563            }
1564        }
1565        let p = candidate;
1566        candidate += 1;
1567
1568        let f_p = match reduce_mod(f, p) {
1569            Ok(x) if !x.is_zero() => x,
1570            _ => continue,
1571        };
1572        let g_p = match reduce_mod(g, p) {
1573            Ok(x) if !x.is_zero() => x,
1574            _ => continue,
1575        };
1576
1577        used.push(p);
1578
1579        let h_p = gcd_sparse_mod_p(
1580            &f_p,
1581            &g_p,
1582            sub_vars.clone(),
1583            term_bound,
1584            degree_bound,
1585            p,
1586            seed.wrapping_add(p),
1587        )
1588        .map_err(SparseGcdError::InterpFailed)?;
1589
1590        images.push((h_p, p));
1591        m *= Integer::from(p);
1592    }
1593
1594    let mut result = lift_crt(&images).map_err(SparseGcdError::CrtFailed)?;
1595
1596    // Normalise: positive leading coefficient, then take primitive part.
1597    if let Some((_, lc)) = result.terms.iter().next_back() {
1598        if lc.cmp0() == std::cmp::Ordering::Less {
1599            result = -result;
1600        }
1601    }
1602    Ok(result.primitive_part())
1603}
1604
1605// ---------------------------------------------------------------------------
1606// Unit tests
1607// ---------------------------------------------------------------------------
1608
1609#[cfg(test)]
1610mod tests {
1611    use super::*;
1612    use crate::kernel::{Domain, ExprPool};
1613
1614    // ---- helpers ------------------------------------------------------------
1615
1616    fn make_poly_eval(coeffs: &[(u64, Vec<u32>)], prime: u64) -> impl Fn(&[u64]) -> u64 + '_ {
1617        move |pt: &[u64]| -> u64 {
1618            let mut acc = 0u64;
1619            for (c, exp) in coeffs {
1620                let mut term = *c % prime;
1621                for (i, &e) in exp.iter().enumerate() {
1622                    let xi = if i < pt.len() { pt[i] } else { 0 };
1623                    term = mul_mod(term, pow_mod(xi, e as u64, prime), prime);
1624                }
1625                acc = add_mod(acc, term, prime);
1626            }
1627            acc
1628        }
1629    }
1630
1631    fn vars(n: usize) -> (ExprPool, Vec<ExprId>) {
1632        let pool = ExprPool::new();
1633        let vs: Vec<ExprId> = (0..n)
1634            .map(|i| pool.symbol(format!("x{i}"), Domain::Real))
1635            .collect();
1636        (pool, vs)
1637    }
1638
1639    // ---- primitive_root -----------------------------------------------------
1640
1641    #[test]
1642    fn prim_root_small_primes() {
1643        for p in [2u64, 3, 5, 7, 11, 13, 17, 19, 23] {
1644            let g = primitive_root(p);
1645            // Verify: g^{p-1} = 1 and g^{(p-1)/q} ≠ 1 for each prime q | p-1
1646            assert_eq!(pow_mod(g, p - 1, p), 1, "g^(p-1)=1 for p={p}");
1647            for q in prime_factors(p - 1) {
1648                assert_ne!(pow_mod(g, (p - 1) / q, p), 1, "g^((p-1)/{q}) ≠ 1 for p={p}");
1649            }
1650        }
1651    }
1652
1653    // ---- berlekamp_massey ---------------------------------------------------
1654
1655    #[test]
1656    fn bm_geometric_sequence() {
1657        // s[n] = 2^n mod 7.  LFSR: s[n] = 2·s[n-1], connection poly = 1 + 5z (since
1658        // 2·(1 + 5z) → 2·1 + 2·5·g = 0 means the root is 2^{-1} = 4 in F_7).
1659        let p = 7u64;
1660        let seq: Vec<u64> = (0..6).map(|n| pow_mod(2, n, p)).collect();
1661        let lambda = berlekamp_massey(&seq, p);
1662        assert_eq!(lambda.len() - 1, 1, "LFSR length should be 1");
1663        // Verify Λ(2^{-1}) = 0
1664        let inv2 = mod_inv(2, p);
1665        assert_eq!(poly_eval(&lambda, inv2, p), 0);
1666    }
1667
1668    #[test]
1669    fn bm_two_term_sequence() {
1670        // s[n] = 3·2^n + 5·3^n  mod 11
1671        let p = 11u64;
1672        let seq: Vec<u64> = (0..4)
1673            .map(|n| {
1674                add_mod(
1675                    mul_mod(3, pow_mod(2, n, p), p),
1676                    mul_mod(5, pow_mod(3, n, p), p),
1677                    p,
1678                )
1679            })
1680            .collect();
1681        let lambda = berlekamp_massey(&seq, p);
1682        assert_eq!(lambda.len() - 1, 2, "two-term sequence has LFSR length 2");
1683        // Roots of Λ should include inv(2) and inv(3)
1684        let mut rng = Xorshift64::new(0xbeef);
1685        let roots = find_roots(&lambda, p, &mut rng).unwrap();
1686        assert_eq!(roots.len(), 2);
1687        let expected: std::collections::HashSet<u64> =
1688            [mod_inv(2, p), mod_inv(3, p)].into_iter().collect();
1689        let got: std::collections::HashSet<u64> = roots.into_iter().collect();
1690        assert_eq!(got, expected);
1691    }
1692
1693    // ---- bsgs_dlog ----------------------------------------------------------
1694
1695    #[test]
1696    fn dlog_basic() {
1697        let p = 13u64;
1698        let g = primitive_root(p);
1699        for e in 0..p - 1 {
1700            let target = pow_mod(g, e, p);
1701            let found = bsgs_dlog(g, target, p).expect("dlog should succeed");
1702            assert_eq!(
1703                pow_mod(g, found, p),
1704                target,
1705                "g^{found} ≠ {target} for p={p}"
1706            );
1707        }
1708    }
1709
1710    // ---- sparse_interpolate_univariate --------------------------------------
1711
1712    #[test]
1713    fn uni_zero_polynomial() {
1714        let terms = sparse_interpolate_univariate(&|_| 0, 5, 101).unwrap();
1715        assert!(terms.is_empty());
1716    }
1717
1718    #[test]
1719    fn uni_constant() {
1720        // f(x) = 7.  One term (coeff=7, exp=0).
1721        let terms = sparse_interpolate_univariate(&|_| 7, 3, 101).unwrap();
1722        assert_eq!(terms.len(), 1);
1723        let (c, e) = terms[0];
1724        assert_eq!(c, 7);
1725        assert_eq!(e, 0);
1726    }
1727
1728    #[test]
1729    fn uni_single_monomial() {
1730        // f(x) = 3·x^5 mod 101
1731        let p = 101u64;
1732        let eval = |x: u64| mul_mod(3, pow_mod(x, 5, p), p);
1733        let terms = sparse_interpolate_univariate(&eval, 3, p).unwrap();
1734        assert_eq!(terms.len(), 1);
1735        let (c, e) = terms[0];
1736        assert_eq!(c, 3);
1737        assert_eq!(e, 5);
1738    }
1739
1740    #[test]
1741    fn uni_two_terms() {
1742        // f(x) = x^10 + 2·x^3 mod 101
1743        let p = 101u64;
1744        let eval = |x: u64| {
1745            let a = pow_mod(x, 10, p);
1746            let b = mul_mod(2, pow_mod(x, 3, p), p);
1747            add_mod(a, b, p)
1748        };
1749        let terms = sparse_interpolate_univariate(&eval, 3, p).unwrap();
1750        assert_eq!(terms.len(), 2);
1751        let mut sorted = terms.clone();
1752        sorted.sort_by_key(|&(_, e)| e);
1753        assert_eq!(sorted[0], (2, 3));
1754        assert_eq!(sorted[1], (1, 10));
1755    }
1756
1757    #[test]
1758    fn uni_roadmap_example() {
1759        // ROADMAP: recover x^100 + 3·x^17 + 5 from T=3 (6 evaluations).
1760        // Needs prime p > 100.  Use p = 997 (prime > 100 and > 2*3=6).
1761        let p = 997u64;
1762        let eval = |x: u64| {
1763            let a = pow_mod(x, 100, p);
1764            let b = mul_mod(3, pow_mod(x, 17, p), p);
1765            let c = 5u64;
1766            add_mod(add_mod(a, b, p), c, p)
1767        };
1768        let terms = sparse_interpolate_univariate(&eval, 4, p).unwrap();
1769        let mut sorted = terms.clone();
1770        sorted.sort_by_key(|&(_, e)| e);
1771        // Expect: [(5,0), (3,17), (1,100)]
1772        assert!(
1773            sorted.iter().any(|&(c, e)| c == 5 && e == 0),
1774            "missing constant 5"
1775        );
1776        assert!(
1777            sorted.iter().any(|&(c, e)| c == 3 && e == 17),
1778            "missing 3·x^17"
1779        );
1780        assert!(
1781            sorted.iter().any(|&(c, e)| c == 1 && e == 100),
1782            "missing x^100"
1783        );
1784    }
1785
1786    #[test]
1787    fn uni_error_invalid_prime() {
1788        let err = sparse_interpolate_univariate(&|_| 0, 3, 4);
1789        assert!(matches!(err, Err(SparseInterpError::InvalidPrime(4))));
1790    }
1791
1792    #[test]
1793    fn uni_error_prime_too_small() {
1794        // T=10 needs p > 20; use p=19.
1795        let err = sparse_interpolate_univariate(&|_| 0, 10, 19);
1796        assert!(matches!(
1797            err,
1798            Err(SparseInterpError::PrimeTooSmall {
1799                prime: 19,
1800                term_bound: 10
1801            })
1802        ));
1803    }
1804
1805    // ---- sparse_interpolate (multivariate) ----------------------------------
1806
1807    #[test]
1808    fn multi_constant() {
1809        let (_, vs) = vars(2);
1810        let result = sparse_interpolate(&|_| 42, vs, 3, 10, 101, 0).unwrap();
1811        assert_eq!(result.terms.len(), 1);
1812        assert_eq!(*result.terms.get(&vec![]).unwrap(), 42u64);
1813    }
1814
1815    #[test]
1816    fn multi_univariate_via_multi() {
1817        // x^2 + 3·x + 1 in one variable
1818        let p = 101u64;
1819        let (_, vs) = vars(1);
1820        let eval = |pt: &[u64]| {
1821            let x = pt[0];
1822            add_mod(add_mod(pow_mod(x, 2, p), mul_mod(3, x, p), p), 1, p)
1823        };
1824        let result = sparse_interpolate(&eval, vs, 5, 10, p, 0).unwrap();
1825        // Expect terms: exp=[0]→1, exp=[1]→3, exp=[2]→1
1826        assert_eq!(*result.terms.get(&vec![2]).unwrap(), 1u64, "x^2 coeff");
1827        assert_eq!(*result.terms.get(&vec![1]).unwrap(), 3u64, "x^1 coeff");
1828        assert_eq!(*result.terms.get(&vec![]).unwrap_or(&0), 1u64, "x^0 coeff");
1829    }
1830
1831    #[test]
1832    fn multi_bivariate_xy() {
1833        // f = x·y + 3 over F_101
1834        let p = 101u64;
1835        let (_, vs) = vars(2);
1836        let eval = |pt: &[u64]| add_mod(mul_mod(pt[0], pt[1], p), 3, p);
1837        let result = sparse_interpolate(&eval, vs, 4, 5, p, 1).unwrap();
1838        // Expect: {[1,1]→1, []→3} (or [0,0]→3)
1839        assert_eq!(
1840            *result.terms.get(&vec![1, 1]).unwrap_or(&0),
1841            1u64,
1842            "x*y coeff"
1843        );
1844        assert_eq!(*result.terms.get(&vec![]).unwrap_or(&0), 3u64, "constant");
1845    }
1846
1847    #[test]
1848    fn multi_bivariate_x_squared_y() {
1849        // f = x^2·y + 5·y + 2·x  over F_101
1850        let p = 101u64;
1851        let (_, vs) = vars(2);
1852        let eval = |pt: &[u64]| {
1853            let x = pt[0];
1854            let y = pt[1];
1855            let a = mul_mod(pow_mod(x, 2, p), y, p);
1856            let b = mul_mod(5, y, p);
1857            let c = mul_mod(2, x, p);
1858            add_mod(add_mod(a, b, p), c, p)
1859        };
1860        let result = sparse_interpolate(&eval, vs, 5, 6, p, 42).unwrap();
1861        assert_eq!(*result.terms.get(&vec![2, 1]).unwrap_or(&0), 1, "x^2*y");
1862        assert_eq!(*result.terms.get(&vec![0, 1]).unwrap_or(&0), 5, "5*y");
1863        assert_eq!(*result.terms.get(&vec![1]).unwrap_or(&0), 2, "2*x");
1864    }
1865
1866    #[test]
1867    fn multi_three_variables() {
1868        // f = x·y·z + x^2 + z  over F_1009
1869        let p = 1009u64;
1870        let (_, vs) = vars(3);
1871        let eval = |pt: &[u64]| {
1872            let x = pt[0];
1873            let y = pt[1];
1874            let z = pt[2];
1875            let xyz = mul_mod(mul_mod(x, y, p), z, p);
1876            let x2 = pow_mod(x, 2, p);
1877            add_mod(add_mod(xyz, x2, p), z, p)
1878        };
1879        let result = sparse_interpolate(&eval, vs, 5, 4, p, 7).unwrap();
1880        assert_eq!(*result.terms.get(&vec![1, 1, 1]).unwrap_or(&0), 1, "x*y*z");
1881        assert_eq!(*result.terms.get(&vec![2]).unwrap_or(&0), 1, "x^2");
1882        assert_eq!(*result.terms.get(&vec![0, 0, 1]).unwrap_or(&0), 1, "z");
1883    }
1884
1885    #[test]
1886    fn multi_roundtrip_via_multipoly() {
1887        // Build a MultiPoly, reduce mod p, then interpolate and verify agreement.
1888        use crate::poly::multipoly::MultiPoly;
1889        let p = 1009u64;
1890        let pool = ExprPool::new();
1891        let x = pool.symbol("x", Domain::Real);
1892        let y = pool.symbol("y", Domain::Real);
1893
1894        // f = x^3 + 2·x·y - y^2 + 4
1895        let x3 = pool.pow(x, pool.integer(3_i32));
1896        let xy = pool.mul(vec![pool.integer(2_i32), x, y]);
1897        let y2 = pool.mul(vec![pool.integer(-1_i32), pool.pow(y, pool.integer(2_i32))]);
1898        let expr = pool.add(vec![x3, xy, y2, pool.integer(4_i32)]);
1899
1900        let mp = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
1901        let fp_ref = crate::modular::reduce_mod(&mp, p).unwrap();
1902
1903        // Oracle evaluates the MultiPoly at a point over F_p.
1904        let vars_for_interp = vec![x, y];
1905        let eval = |pt: &[u64]| {
1906            let mut acc = 0u64;
1907            for (exp, coeff) in &mp.terms {
1908                let c_mod = {
1909                    let r = coeff.clone() % rug::Integer::from(p);
1910                    let r = if r < 0 { r + rug::Integer::from(p) } else { r };
1911                    r.to_u64().unwrap()
1912                };
1913                let mut term = c_mod;
1914                for (i, &e) in exp.iter().enumerate() {
1915                    let xi = if i < pt.len() { pt[i] } else { 0 };
1916                    term = mul_mod(term, pow_mod(xi, e as u64, p), p);
1917                }
1918                acc = add_mod(acc, term, p);
1919            }
1920            acc
1921        };
1922
1923        let recovered = sparse_interpolate(&eval, vars_for_interp, 6, 5, p, 0).unwrap();
1924
1925        // Compare term by term.
1926        for (exp, &coeff) in &recovered.terms {
1927            let ref_coeff = fp_ref.terms.get(exp).copied().unwrap_or(0);
1928            assert_eq!(coeff, ref_coeff, "mismatch at exp {:?}", exp);
1929        }
1930        // Check no terms were missed.
1931        for (exp, &ref_coeff) in &fp_ref.terms {
1932            let got = recovered.terms.get(exp).copied().unwrap_or(0);
1933            assert_eq!(got, ref_coeff, "missed term at exp {:?}", exp);
1934        }
1935    }
1936
1937    #[test]
1938    fn multi_diag_15term_three_var_smoke() {
1939        // Mirrors the benchmark diagonal structure (sparse_interp_multivar) at a CI-friendly size.
1940        let p = 32749u64;
1941        let n_vars = 3;
1942        let n_terms = n_vars;
1943        let mut terms = Vec::new();
1944        for i in 0..n_terms {
1945            let mut coeff = (((i + 1) as u64) * 7) % p;
1946            if coeff == 0 {
1947                coeff = 1;
1948            }
1949            let mut exp = vec![0u32; n_vars];
1950            exp[i % n_vars] = (i % 3) as u32 + 1;
1951            terms.push((coeff, exp));
1952        }
1953        let eval_fn = make_poly_eval(&terms, p);
1954        let (_, vs) = vars(n_vars);
1955        let mut expected: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1956        for (c, exp) in &terms {
1957            let mut e = exp.clone();
1958            while e.last() == Some(&0) {
1959                e.pop();
1960            }
1961            let nc = *c % p;
1962            expected
1963                .entry(e)
1964                .and_modify(|v| {
1965                    *v = add_mod(*v, nc, p);
1966                })
1967                .or_insert(nc);
1968        }
1969
1970        let mut successes = 0usize;
1971        for seed in [0_u64, 1, 2, 41] {
1972            let result = sparse_interpolate(&eval_fn, vs.clone(), n_terms + 5, 4, p, seed)
1973                .expect("smoke interpolate should succeed");
1974            let mut ok = result.terms.len() == expected.len();
1975            for (exp, &ec) in &expected {
1976                if result.terms.get(exp).copied().unwrap_or(0) != ec {
1977                    ok = false;
1978                }
1979            }
1980            if ok {
1981                successes += 1;
1982            }
1983        }
1984        assert!(successes >= 3, "expected ≥ 3 successes on diagonal smoke");
1985    }
1986
1987    #[test]
1988    #[ignore]
1989    fn multi_interp_diag_large_stress_slow() {
1990        // `cargo test -p alkahest-core poly::interp --release -- --ignored`
1991        //
1992        // 6-variable workload (benchmark-shaped diagonal polynomial).  Larger `size`
1993        // dimensions are exercised by benchmarks; CI keeps only a lightweight 3-var smoke.
1994        let p = 32749u64;
1995        let n_vars = 6;
1996        let n_terms = 15;
1997        let mut terms = Vec::new();
1998        for i in 0..n_terms {
1999            let mut coeff = (((i + 1) as u64) * 7) % p;
2000            if coeff == 0 {
2001                coeff = 1;
2002            }
2003            let mut exp = vec![0u32; n_vars];
2004            exp[i % n_vars] = (i % 3) as u32 + 1;
2005            terms.push((coeff, exp));
2006        }
2007        let eval_fn = make_poly_eval(&terms, p);
2008        let (_, vs) = vars(n_vars);
2009        let mut expected: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
2010        for (c, exp) in &terms {
2011            let mut e = exp.clone();
2012            while e.last() == Some(&0) {
2013                e.pop();
2014            }
2015            let nc = *c % p;
2016            expected
2017                .entry(e)
2018                .and_modify(|v| {
2019                    *v = add_mod(*v, nc, p);
2020                })
2021                .or_insert(nc);
2022        }
2023
2024        let result = sparse_interpolate(&eval_fn, vs.clone(), n_terms + 5, 4, p, 7)
2025            .expect("stress interpolate should succeed");
2026        assert_eq!(result.terms.len(), expected.len());
2027        for (exp, &ec) in &expected {
2028            assert_eq!(result.terms.get(exp).copied().unwrap_or(0), ec);
2029        }
2030    }
2031
2032    // ---- gcd_sparse_modular tests -------------------------------------------
2033
2034    fn vars_n(n: usize) -> (ExprPool, Vec<crate::kernel::ExprId>) {
2035        let pool = ExprPool::new();
2036        let vs: Vec<_> = (0..n)
2037            .map(|i| pool.symbol(format!("x{i}"), Domain::Real))
2038            .collect();
2039        (pool, vs)
2040    }
2041
2042    fn mp(
2043        expr: crate::kernel::ExprId,
2044        vids: Vec<crate::kernel::ExprId>,
2045        pool: &ExprPool,
2046    ) -> crate::poly::multipoly::MultiPoly {
2047        crate::poly::multipoly::MultiPoly::from_symbolic(expr, vids, pool)
2048            .expect("valid polynomial")
2049    }
2050
2051    #[test]
2052    fn gcd_sparse_univariate_linear_factor() {
2053        // gcd((x-1)(x+1), (x+1)(x-2)) = x+1
2054        let (pool, vs) = vars_n(1);
2055        let x = vs[0];
2056        let neg1 = pool.integer(-1i32);
2057        let neg2 = pool.integer(-2i32);
2058        let _one = pool.integer(1i32);
2059        // f = x^2 - 1 = (x-1)(x+1)
2060        let f = mp(
2061            pool.add(vec![pool.pow(x, pool.integer(2i32)), neg1]),
2062            vec![x],
2063            &pool,
2064        );
2065        // g = x^2 - x - 2 = (x+1)(x-2)
2066        let g = mp(
2067            pool.add(vec![
2068                pool.pow(x, pool.integer(2i32)),
2069                pool.mul(vec![neg1, x]),
2070                neg2,
2071            ]),
2072            vec![x],
2073            &pool,
2074        );
2075        let h = gcd_sparse_modular(&f, &g, 3, 3, 0).expect("gcd should succeed");
2076        // h = x + 1 (primitive, positive leading coeff)
2077        assert_eq!(h.terms.len(), 2, "GCD should have 2 terms: {h:?}");
2078        assert_eq!(
2079            h.terms.get(&vec![1u32]).cloned(),
2080            Some(rug::Integer::from(1)),
2081            "leading coeff of x should be 1"
2082        );
2083        let empty: Vec<u32> = vec![];
2084        assert_eq!(
2085            h.terms.get(&empty).cloned(),
2086            Some(rug::Integer::from(1)),
2087            "constant should be 1"
2088        );
2089    }
2090
2091    #[test]
2092    fn gcd_sparse_univariate_coprime() {
2093        // gcd(x, x+1) = 1
2094        let (pool, vs) = vars_n(1);
2095        let x = vs[0];
2096        let f = mp(x, vec![x], &pool);
2097        let g = mp(pool.add(vec![x, pool.integer(1i32)]), vec![x], &pool);
2098        let h = gcd_sparse_modular(&f, &g, 2, 2, 0).expect("gcd should succeed");
2099        // gcd(x, x+1) = 1 — a constant polynomial with one term {[]: 1}
2100        let empty: Vec<u32> = vec![];
2101        let constant = h.terms.get(&empty).cloned().unwrap_or_default();
2102        assert_eq!(
2103            constant,
2104            rug::Integer::from(1),
2105            "GCD of coprime polys should be 1, got {h:?}"
2106        );
2107    }
2108
2109    #[test]
2110    fn gcd_sparse_bivariate_common_factor() {
2111        // gcd((x+y)(x-y), (x+y)*(x+1)) = x+y
2112        let (pool, vs) = vars_n(2);
2113        let x = vs[0];
2114        let y = vs[1];
2115        let xpy = pool.add(vec![x, y]);
2116        let _xmy = pool.add(vec![x, pool.mul(vec![pool.integer(-1i32), y])]);
2117        let xp1 = pool.add(vec![x, pool.integer(1i32)]);
2118        // f = (x+y)(x-y) = x^2 - y^2
2119        let f = mp(
2120            pool.add(vec![
2121                pool.pow(x, pool.integer(2i32)),
2122                pool.mul(vec![pool.integer(-1i32), pool.pow(y, pool.integer(2i32))]),
2123            ]),
2124            vec![x, y],
2125            &pool,
2126        );
2127        // g = (x+y)(x+1) = x^2 + x + xy + y
2128        let g = mp(pool.mul(vec![xpy, xp1]), vec![x, y], &pool);
2129        let h = gcd_sparse_modular(&f, &g, 3, 2, 0).expect("gcd should succeed");
2130        // h = x + y  (primitive)
2131        assert_eq!(h.terms.len(), 2, "GCD = x+y should have 2 terms, got {h:?}");
2132        // Leading monomial in BTreeMap order is (1, 1) for xy or (1,) for x
2133        // Actually for x+y: x has exp [1], y has exp [0,1]
2134        let coeff_x = h.terms.get(&vec![1u32]).cloned();
2135        let coeff_y = h.terms.get(&vec![0u32, 1u32]).cloned();
2136        assert_eq!(
2137            coeff_x,
2138            Some(rug::Integer::from(1)),
2139            "coeff of x should be 1"
2140        );
2141        assert_eq!(
2142            coeff_y,
2143            Some(rug::Integer::from(1)),
2144            "coeff of y should be 1"
2145        );
2146    }
2147}