Skip to main content

alkahest_cas/poly/
interp.rs

1//! V2-3 — Sparse polynomial interpolation (Ben-Or/Tiwari, Zippel).
2//!
3//! Recovers a sparse multivariate polynomial over `F_p = ℤ/pℤ` from
4//! black-box evaluations using far fewer queries than dense interpolation.
5//!
6//! # Algorithms
7//!
8//! - **Univariate Ben-Or/Tiwari (Prony-style)** — [`sparse_interpolate_univariate`]:
9//!   given that `f ∈ F_p[x]` has at most `T` nonzero terms, recovers `f` from
10//!   exactly `2T` evaluations via Berlekamp–Massey +
11//!   `gcd(f, X^p − X)` + Cantor–Zassenhaus-style splitting over `F_p`
12//!   (tiny-degree fallback scans only).
13//!   + Vandermonde solve.  Cost: `2T` oracle calls.
14//!
15//! - **Multivariate Zippel** — [`sparse_interpolate`]: variable-by-variable
16//!   reduction.  At each variable level:
17//!     1. Evaluate `f(x₁, a₂, …, aₙ)` at random `aᵢ` and run Ben-Or/Tiwari
18//!        to find the `x₁`-exponent skeleton.
19//!     2. Lift all sibling coefficients simultaneously with one Vandermonde solve
20//!        per oracle call (`zippel_helper_multi`) when the stacked vector stays
21//!        small (`O(term_bound²)` budget); otherwise recurse per skeleton term like
22//!        classic Zippel.
23//!
24//! - **Dense fallback** — applied when `degree_bound ≤ term_bound` (dense
25//!   and sparse costs coincide).  Uses Lagrange interpolation at consecutive
26//!   integers.
27//!
28//! # Public API
29//!
30//! ```text
31//! sparse_interpolate_univariate(eval, term_bound, prime) → Vec<(coeff, exp)>
32//! sparse_interpolate(eval, vars, term_bound, degree_bound, prime, seed)
33//!     → MultiPolyFp
34//! ```
35
36use crate::errors::AlkahestError;
37use crate::kernel::ExprId;
38use crate::modular::{is_prime, MultiPolyFp};
39use std::collections::BTreeMap;
40
41// ---------------------------------------------------------------------------
42// Error type
43// ---------------------------------------------------------------------------
44
45/// Error returned by sparse interpolation functions.
46#[derive(Debug, Clone, PartialEq)]
47pub enum SparseInterpError {
48    /// The prime is ≤ 2 or composite.
49    InvalidPrime(u64),
50    /// The prime must be `> 2 * term_bound` for Ben-Or/Tiwari to work.
51    PrimeTooSmall { prime: u64, term_bound: usize },
52    /// Root-finding found fewer roots than expected (should not happen for
53    /// correct evaluations; indicates either colliding exponents or an
54    /// inconsistent evaluation oracle).
55    RootFindingFailed,
56    /// The Vandermonde / linear system is singular.  Retry with a different
57    /// seed or a larger prime.
58    SingularSystem,
59}
60
61impl std::fmt::Display for SparseInterpError {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            SparseInterpError::InvalidPrime(p) => {
65                write!(f, "invalid prime {p}: must be a prime ≥ 3")
66            }
67            SparseInterpError::PrimeTooSmall { prime, term_bound } => write!(
68                f,
69                "prime {prime} is too small for term_bound {term_bound}: need prime > 2·T = {}",
70                2 * term_bound
71            ),
72            SparseInterpError::RootFindingFailed => write!(
73                f,
74                "could not find the expected number of roots in F_p; \
75                 the prime may be too small or the oracle is inconsistent"
76            ),
77            SparseInterpError::SingularSystem => write!(
78                f,
79                "Vandermonde system is singular; try a different seed or a larger prime"
80            ),
81        }
82    }
83}
84
85impl std::error::Error for SparseInterpError {}
86
87impl AlkahestError for SparseInterpError {
88    fn code(&self) -> &'static str {
89        match self {
90            SparseInterpError::InvalidPrime(_) => "E-INTERP-001",
91            SparseInterpError::PrimeTooSmall { .. } => "E-INTERP-002",
92            SparseInterpError::RootFindingFailed => "E-INTERP-003",
93            SparseInterpError::SingularSystem => "E-INTERP-004",
94        }
95    }
96
97    fn remediation(&self) -> Option<&'static str> {
98        match self {
99            SparseInterpError::InvalidPrime(_) => {
100                Some("choose a prime p ≥ 3, e.g. 1009, 32749, 1000003")
101            }
102            SparseInterpError::PrimeTooSmall { .. } => {
103                Some("increase the prime so that p > 2 * term_bound")
104            }
105            SparseInterpError::RootFindingFailed => {
106                Some("choose a prime larger than the maximum degree in the polynomial")
107            }
108            SparseInterpError::SingularSystem => {
109                Some("retry with a different seed or use a larger prime")
110            }
111        }
112    }
113}
114
115// ---------------------------------------------------------------------------
116// Minimal PRNG (xorshift64) — no external crate needed
117// ---------------------------------------------------------------------------
118
119/// Simple xorshift64 PRNG for reproducible random evaluation points.
120pub struct Xorshift64 {
121    state: u64,
122}
123
124impl Xorshift64 {
125    pub fn new(seed: u64) -> Self {
126        // Ensure non-zero state.
127        let s = if seed == 0 { 0xdeadbeef_cafebabe } else { seed };
128        Xorshift64 { state: s }
129    }
130
131    pub fn step(&mut self) -> u64 {
132        let mut x = self.state;
133        x ^= x << 13;
134        x ^= x >> 7;
135        x ^= x << 17;
136        self.state = x;
137        x
138    }
139
140    /// Return a value in `[lo, hi)`.
141    pub fn next_range(&mut self, lo: u64, hi: u64) -> u64 {
142        debug_assert!(hi > lo);
143        lo + self.step() % (hi - lo)
144    }
145
146    /// Return a non-zero value in `[1, p)`.
147    pub fn nonzero(&mut self, p: u64) -> u64 {
148        loop {
149            let v = self.step() % p;
150            if v != 0 {
151                return v;
152            }
153        }
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Modular arithmetic helpers
159// ---------------------------------------------------------------------------
160
161#[inline]
162fn mul_mod(a: u64, b: u64, p: u64) -> u64 {
163    ((a as u128 * b as u128) % p as u128) as u64
164}
165
166#[inline]
167fn add_mod(a: u64, b: u64, p: u64) -> u64 {
168    let s = a + b;
169    if s >= p {
170        s - p
171    } else {
172        s
173    }
174}
175
176#[inline]
177fn sub_mod(a: u64, b: u64, p: u64) -> u64 {
178    if a >= b {
179        a - b
180    } else {
181        a + p - b
182    }
183}
184
185fn pow_mod(mut base: u64, mut exp: u64, p: u64) -> u64 {
186    let mut result = 1u64;
187    base %= p;
188    while exp > 0 {
189        if exp & 1 == 1 {
190            result = mul_mod(result, base, p);
191        }
192        base = mul_mod(base, base, p);
193        exp >>= 1;
194    }
195    result
196}
197
198/// Extended-GCD modular inverse.  Panics if `gcd(a, p) ≠ 1`.
199fn mod_inv(a: u64, p: u64) -> u64 {
200    debug_assert!(a != 0, "mod_inv: a must be non-zero");
201    let mut old_r = a as i128;
202    let mut r = p as i128;
203    let mut old_s: i128 = 1;
204    let mut s: i128 = 0;
205    while r != 0 {
206        let q = old_r / r;
207        let tmp = r;
208        r = old_r - q * r;
209        old_r = tmp;
210        let tmp = s;
211        s = old_s - q * s;
212        old_s = tmp;
213    }
214    ((old_s % p as i128 + p as i128) % p as i128) as u64
215}
216
217// ---------------------------------------------------------------------------
218// Polynomial evaluation over F_p
219// ---------------------------------------------------------------------------
220
221/// Evaluate `poly[0] + poly[1]*x + ... + poly[d]*x^d` at `x` modulo `p`.
222fn poly_eval(poly: &[u64], x: u64, p: u64) -> u64 {
223    let mut acc = 0u64;
224    let mut pw = 1u64;
225    for &c in poly {
226        acc = add_mod(acc, mul_mod(c, pw, p), p);
227        pw = mul_mod(pw, x, p);
228    }
229    acc
230}
231
232// ---------------------------------------------------------------------------
233// Primitive root of F_p
234// ---------------------------------------------------------------------------
235
236/// Find the smallest primitive root (generator) of `F_p*`.
237///
238/// A primitive root `g` satisfies `g^{(p-1)/q} ≢ 1 (mod p)` for every
239/// prime factor `q` of `p-1`.
240pub fn primitive_root(p: u64) -> u64 {
241    debug_assert!(is_prime(p), "primitive_root: p must be prime");
242    if p == 2 {
243        return 1;
244    }
245    if p == 3 {
246        return 2;
247    }
248    let factors = prime_factors(p - 1);
249    'outer: for g in 2..p {
250        for &q in &factors {
251            if pow_mod(g, (p - 1) / q, p) == 1 {
252                continue 'outer;
253            }
254        }
255        return g;
256    }
257    panic!("primitive_root: no root found for prime {p}");
258}
259
260/// Distinct prime factors of `n` (trial division).
261fn prime_factors(mut n: u64) -> Vec<u64> {
262    let mut factors = Vec::new();
263    let mut d = 2u64;
264    while d * d <= n {
265        if n % d == 0 {
266            factors.push(d);
267            while n % d == 0 {
268                n /= d;
269            }
270        }
271        d += 1;
272    }
273    if n > 1 {
274        factors.push(n);
275    }
276    factors
277}
278
279// ---------------------------------------------------------------------------
280// Berlekamp–Massey over F_p
281// ---------------------------------------------------------------------------
282
283/// Berlekamp–Massey algorithm over `F_p`.
284///
285/// Given a sequence `s[0], …, s[N-1]`, returns the minimal LFSR connection
286/// polynomial `Λ = [1, λ₁, …, λ_L]` (index = degree) such that
287///
288/// ```text
289/// s[n] + λ₁·s[n-1] + … + λ_L·s[n-L] = 0   for all n ≥ L.
290/// ```
291///
292/// The caller must supply `N ≥ 2L` for the result to be unique.
293fn berlekamp_massey(seq: &[u64], p: u64) -> Vec<u64> {
294    let n = seq.len();
295    let mut l = 0usize;
296    let mut c: Vec<u64> = vec![1];
297    let mut b: Vec<u64> = vec![1];
298    let mut b_disc: u64 = 1;
299    let mut x: usize = 1;
300
301    for n_idx in 0..n {
302        // Discrepancy d = s[n] + Σ_{i=1}^{L} c[i]·s[n-i]
303        let mut d = seq[n_idx];
304        let bound = l.min(c.len().saturating_sub(1));
305        for i in 1..=bound {
306            d = add_mod(d, mul_mod(c[i], seq[n_idx - i], p), p);
307        }
308
309        if d == 0 {
310            x += 1;
311            continue;
312        }
313
314        let t = c.clone();
315        let factor = mul_mod(d, mod_inv(b_disc, p), p);
316
317        // C ← C − factor·z^x·B
318        let needed = x + b.len();
319        if c.len() < needed {
320            c.resize(needed, 0);
321        }
322        for j in 0..b.len() {
323            let sub = mul_mod(factor, b[j], p);
324            c[x + j] = sub_mod(c[x + j], sub, p);
325        }
326
327        if 2 * l <= n_idx {
328            l = n_idx + 1 - l;
329            b = t;
330            b_disc = d;
331            x = 1;
332        } else {
333            x += 1;
334        }
335    }
336
337    c
338}
339
340// ---------------------------------------------------------------------------
341// Dense polynomials mod p (Cantor–Zassenhaus / probabilistic splitting)
342// ---------------------------------------------------------------------------
343
344fn poly_trim(mut a: Vec<u64>) -> Vec<u64> {
345    while a.len() > 1 && a.last() == Some(&0) {
346        a.pop();
347    }
348    a
349}
350
351#[inline]
352fn poly_deg(poly: &[u64]) -> i32 {
353    let t = poly_trim(poly.to_vec());
354    if t.is_empty() || (t.len() == 1 && t[0] == 0) {
355        return -1;
356    }
357    t.len() as i32 - 1
358}
359
360/// `a + b` in ascending order (may over-allocate briefly).
361fn poly_add(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
362    let n = a.len().max(b.len());
363    let mut out = vec![0u64; n];
364    for i in 0..n {
365        let x = if i < a.len() { a[i] } else { 0 };
366        let y = if i < b.len() { b[i] } else { 0 };
367        out[i] = add_mod(x, y, p);
368    }
369    poly_trim(out)
370}
371
372fn poly_sub_(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
373    let n = a.len().max(b.len());
374    let mut out = vec![0u64; n];
375    for i in 0..n {
376        let x = if i < a.len() { a[i] } else { 0 };
377        let y = if i < b.len() { b[i] } else { 0 };
378        out[i] = sub_mod(x, y, p);
379    }
380    poly_trim(out)
381}
382
383fn poly_mul(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
384    if a.is_empty() || b.is_empty() || (a.len() == 1 && a[0] == 0) || (b.len() == 1 && b[0] == 0) {
385        return vec![0];
386    }
387    let da = poly_deg(a);
388    let db = poly_deg(b);
389    if da < 0 || db < 0 {
390        return vec![0];
391    }
392    let mut out = vec![0u64; (da + db + 1) as usize];
393    for i in 0..=da as usize {
394        for j in 0..=db as usize {
395            out[i + j] = add_mod(out[i + j], mul_mod(a[i], b[j], p), p);
396        }
397    }
398    poly_trim(out)
399}
400
401/// Euclidean division over `F_p`; returns `(q, r)` with `a = q·b + r`, `deg r < deg b`.
402fn poly_divmod(dividend: &[u64], divisor: &[u64], p: u64) -> Option<(Vec<u64>, Vec<u64>)> {
403    let mut a = poly_trim(dividend.to_vec());
404    let b = poly_trim(divisor.to_vec());
405    if poly_deg(&b) < 0 {
406        return None;
407    }
408    let db = b.len() - 1;
409    let lb = *b.last().unwrap();
410    let inv_lb = mod_inv(lb, p);
411
412    let deg_a = poly_deg(&a);
413    if deg_a < db as i32 {
414        return Some((vec![0], a));
415    }
416
417    let q_len = (deg_a - db as i32 + 1) as usize;
418    let mut quot = vec![0u64; q_len];
419
420    while poly_deg(&a) >= db as i32 {
421        let da = poly_deg(&a) as usize;
422        let shift = da - db;
423        let scale = mul_mod(*a.last().unwrap(), inv_lb, p);
424        quot[shift] = add_mod(quot[shift], scale, p);
425        for j in 0..b.len() {
426            a[j + shift] = sub_mod(a[j + shift], mul_mod(scale, b[j], p), p);
427        }
428        a = poly_trim(a);
429    }
430
431    Some((poly_trim(quot), a))
432}
433
434fn polygcd(a_: &[u64], b_: &[u64], p: u64) -> Vec<u64> {
435    let mut a = poly_trim(a_.to_vec());
436    let mut b = poly_trim(b_.to_vec());
437    while poly_deg(&b) >= 0 {
438        let (_, r) = match poly_divmod(&a, &b, p) {
439            Some(x) => x,
440            None => break,
441        };
442        a = b;
443        b = r;
444    }
445    if poly_deg(&a) < 0 {
446        return vec![0];
447    }
448    poly_make_monic(&a, p)
449}
450
451fn poly_derivative(f: &[u64], p: u64) -> Vec<u64> {
452    let f = poly_trim(f.to_vec());
453    if f.len() <= 1 {
454        return vec![0];
455    }
456    let mut out = Vec::with_capacity(f.len() - 1);
457    for (k, &coeff) in f.iter().enumerate().skip(1) {
458        let d = mul_mod(coeff, k as u64, p);
459        out.push(d);
460    }
461    poly_trim(out)
462}
463
464fn poly_make_monic(f: &[u64], p: u64) -> Vec<u64> {
465    let f = poly_trim(f.to_vec());
466    if f.is_empty() {
467        return f;
468    }
469    let lc = *f.last().unwrap();
470    if lc == 0 {
471        return f;
472    }
473    let inv = mod_inv(lc, p);
474    f.iter().map(|&c| mul_mod(c, inv, p)).collect()
475}
476
477/// Remove repeated roots until `gcd(f, f′) = 1`.
478fn poly_squarefree(mut f: Vec<u64>, p: u64) -> Vec<u64> {
479    f = poly_make_monic(&f, p);
480    loop {
481        let dp = poly_derivative(&f, p);
482        let g = polygcd(&f, &dp, p);
483        let dg = poly_deg(&g);
484        if dg <= 0 {
485            break;
486        }
487        let (_, r) = poly_divmod(&f, &g, p).unwrap();
488        f = poly_make_monic(&r, p);
489    }
490    f
491}
492
493fn poly_mul_mod(a: &[u64], b: &[u64], modulo: &[u64], p: u64) -> Vec<u64> {
494    let prod = poly_mul(a, b, p);
495    poly_divmod(&prod, modulo, p)
496        .map(|(_, r)| r)
497        .unwrap_or(vec![0])
498}
499
500/// `base^exp (mod m)` in `F_p[X]`.
501fn poly_pow_mod(base: &[u64], mut exp: u64, m: &[u64], p: u64) -> Vec<u64> {
502    let m = poly_trim(m.to_vec());
503    if poly_deg(&m) < 0 {
504        return vec![0];
505    }
506    let mut acc = vec![1u64];
507    let mut b = poly_divmod(&poly_trim(base.to_vec()), &m, p)
508        .map(|(_, r)| r)
509        .unwrap_or(vec![0]);
510    while exp > 0 {
511        if exp & 1 != 0 {
512            acc = poly_mul_mod(&acc, &b, &m, p);
513        }
514        b = poly_mul_mod(&b, &b, &m, p);
515        exp >>= 1;
516    }
517    acc
518}
519
520/// Random dense polynomial of degree `< deg(f)` (for Cantor–Zassenhaus splitting).
521fn poly_random_below(max_deg: usize, p: u64, rng: &mut Xorshift64) -> Vec<u64> {
522    if max_deg == 0 {
523        return vec![0];
524    }
525    let mut c: Vec<u64> = (0..max_deg).map(|_| rng.next_range(0, p)).collect();
526    if c.iter().all(|&x| x == 0) {
527        c[rng.next_range(0, max_deg as u64) as usize] = rng.nonzero(p);
528    }
529    poly_trim(c)
530}
531
532/// Find all roots of `poly` in `F_p` using `gcd(f, X^p−X)` + probabilistic split.
533/// Assumes `p` is an odd prime and `deg(f) < p` (always true for Ben-Or/Tiwari Λ).
534fn find_roots(poly: &[u64], p: u64, rng: &mut Xorshift64) -> Result<Vec<u64>, SparseInterpError> {
535    let mut f = poly_trim(poly.to_vec());
536    if poly_deg(&f) < 0 {
537        return Ok(vec![]);
538    }
539    if p == 2 {
540        let mut r = Vec::new();
541        for v in 0..p {
542            if poly_eval(&f, v, p) == 0 {
543                r.push(v);
544            }
545        }
546        return Ok(r);
547    }
548    f = poly_squarefree(f, p);
549    if poly_deg(&f) < 0 {
550        return Ok(vec![]);
551    }
552    if poly_deg(&f) == 0 {
553        return Ok(vec![]);
554    }
555
556    // Split off the `F_p`-rational part: gcd(f, X^p − X).
557    let xp = poly_pow_mod(&[0, 1], p, &f, p);
558    let diff = poly_sub_(&xp, &[0, 1], p);
559    let mut h = polygcd(&f, &diff, p);
560    if poly_deg(&h) < 0 {
561        h = f;
562    }
563
564    let mut roots = Vec::new();
565    split_find_roots(&h, p, rng, &mut roots)?;
566    roots.sort_unstable();
567    roots.dedup();
568    Ok(roots)
569}
570
571fn split_find_roots(
572    f: &[u64],
573    p: u64,
574    rng: &mut Xorshift64,
575    roots: &mut Vec<u64>,
576) -> Result<(), SparseInterpError> {
577    let f = poly_make_monic(f, p);
578    let d = poly_deg(&f);
579    if d < 0 {
580        return Ok(());
581    }
582    if d == 0 {
583        return Ok(());
584    }
585    if d == 1 {
586        let a0 = sub_mod(0, f[0], p);
587        roots.push(a0);
588        return Ok(());
589    }
590
591    // Probabilistic split (Cantor–Zassenhaus / Rabin): for odd `p`, each nontrivial
592    // gcd( U^{(p−1)/2} ± 1 , f ) succeeds with probability ~1/2 per try.
593    const MAX_TRIES: usize = 256;
594    for _ in 0..MAX_TRIES {
595        let u = poly_random_below(d as usize, p, rng);
596        let exp = (p - 1) / 2;
597        let up = poly_pow_mod(&u, exp, &f, p);
598        for g in [poly_sub_(&up, &[1], p), poly_add(&up, &[1], p)] {
599            let d1 = polygcd(&f, &g, p);
600            let d1deg = poly_deg(&d1);
601            if d1deg > 0 && d1deg < d {
602                let (cofactor, rem) = poly_divmod(&f, &d1, p).unwrap();
603                // `d1` must be a genuine divisor; use the **quotient** cofactor.
604                if poly_deg(&rem) >= 0 {
605                    continue;
606                }
607                split_find_roots(&d1, p, rng, roots)?;
608                split_find_roots(&poly_make_monic(&cofactor, p), p, rng, roots)?;
609                return Ok(());
610            }
611        }
612    }
613    // Λ rarely has degree > ~20; brute force is negligible vs failing outright.
614    if (d as u128) * (p as u128) <= 2_500_000 {
615        for v in 0..p {
616            if poly_eval(&f, v, p) == 0 {
617                roots.push(v);
618            }
619        }
620        return Ok(());
621    }
622    Err(SparseInterpError::RootFindingFailed)
623}
624
625// ---------------------------------------------------------------------------
626// Baby-step giant-step discrete logarithm
627// ---------------------------------------------------------------------------
628
629/// Compute `e` such that `g^e ≡ target (mod p)`, or `None` if no such `e`
630/// exists in `{0, …, p-2}`.
631///
632/// Uses the Baby-step / Giant-step algorithm in `O(√p)` time and space.
633pub fn bsgs_dlog(g: u64, target: u64, p: u64) -> Option<u64> {
634    if target == 0 {
635        return None; // g is never 0 in F_p*
636    }
637    let order = p - 1; // order of F_p* (g is a generator)
638    let m = (order as f64).sqrt().ceil() as u64 + 1;
639
640    // Baby steps: table[g^j] = j  for j = 0 … m-1
641    let mut table = std::collections::HashMap::with_capacity(m as usize);
642    let mut gj = 1u64;
643    for j in 0..m {
644        table.insert(gj, j);
645        gj = mul_mod(gj, g, p);
646    }
647
648    // Giant steps: find i such that target · (g^{-m})^i is in table
649    let gm = pow_mod(g, m, p);
650    let gm_inv = mod_inv(gm, p);
651    let mut y = target;
652    for i in 0..m {
653        if let Some(&j) = table.get(&y) {
654            let e = i * m + j;
655            let e_mod = e % order;
656            // Verify
657            if pow_mod(g, e_mod, p) == target {
658                return Some(e_mod);
659            }
660        }
661        y = mul_mod(y, gm_inv, p);
662    }
663    None
664}
665
666// ---------------------------------------------------------------------------
667// Vandermonde solve (generalised)
668// ---------------------------------------------------------------------------
669
670/// Solve the generalised Vandermonde system:
671///
672/// ```text
673/// Σ_j  c[j] · pts[i]^{exps[j]}  =  vals[i]   for i = 0, …, t-1
674/// ```
675///
676/// Returns `Some(c)` if the system is non-singular, or `None` otherwise.
677fn vandermonde_solve(pts: &[u64], exps: &[u32], vals: &[u64], p: u64) -> Option<Vec<u64>> {
678    let t = pts.len();
679    debug_assert_eq!(exps.len(), t);
680    debug_assert_eq!(vals.len(), t);
681
682    // Build the t×t matrix A where A[i][j] = pts[i]^exps[j]
683    let mut mat: Vec<Vec<u64>> = (0..t)
684        .map(|i| (0..t).map(|j| pow_mod(pts[i], exps[j] as u64, p)).collect())
685        .collect();
686    let mut rhs: Vec<u64> = vals.to_vec();
687
688    gaussian_elim(&mut mat, &mut rhs, p)
689}
690
691/// Gaussian elimination with partial pivoting over `F_p`.
692/// Modifies `mat` and `rhs` in place; returns the solution or `None` if
693/// the system is singular.
694fn gaussian_elim(mat: &mut [Vec<u64>], rhs: &mut [u64], p: u64) -> Option<Vec<u64>> {
695    let n = mat.len();
696    for col in 0..n {
697        // Find pivot (first non-zero entry in column col, at or below row col)
698        let pivot_row = (col..n).find(|&r| mat[r][col] != 0)?;
699        mat.swap(col, pivot_row);
700        rhs.swap(col, pivot_row);
701
702        let inv = mod_inv(mat[col][col], p);
703        // Scale pivot row
704        for entry in &mut mat[col][col..] {
705            *entry = mul_mod(*entry, inv, p);
706        }
707        rhs[col] = mul_mod(rhs[col], inv, p);
708
709        // Eliminate column in all other rows
710        for row in 0..n {
711            if row == col {
712                continue;
713            }
714            let factor = mat[row][col];
715            if factor == 0 {
716                continue;
717            }
718            // Gather the pivot row values to avoid borrow conflict.
719            let pivot_row_vals: Vec<u64> = mat[col][col..].to_vec();
720            for (j, &pv) in pivot_row_vals.iter().enumerate() {
721                let sub = mul_mod(factor, pv, p);
722                mat[row][col + j] = sub_mod(mat[row][col + j], sub, p);
723            }
724            let sub = mul_mod(factor, rhs[col], p);
725            rhs[row] = sub_mod(rhs[row], sub, p);
726        }
727    }
728    Some(rhs.to_owned())
729}
730
731// ---------------------------------------------------------------------------
732// Univariate Ben-Or/Tiwari (internal)
733// ---------------------------------------------------------------------------
734
735/// Internal Ben-Or/Tiwari.  Evaluates at `g^0, …, g^{2T-1}` and runs the
736/// full Prony pipeline.  Returns `(coeff, exponent)` pairs, or an error.
737fn bt_univariate(
738    eval: &dyn Fn(u64) -> u64,
739    term_bound: usize,
740    prime: u64,
741    g: u64, // primitive root of F_p
742    rng: &mut Xorshift64,
743) -> Result<Vec<(u64, u32)>, SparseInterpError> {
744    if term_bound == 0 {
745        return Ok(vec![]);
746    }
747    let two_t = 2 * term_bound;
748
749    // --- Step 1: Evaluate at g^0, g^1, …, g^{2T-1} ---
750    let mut seq = Vec::with_capacity(two_t);
751    let mut gj = 1u64; // g^j
752    for _ in 0..two_t {
753        seq.push(eval(gj));
754        gj = mul_mod(gj, g, prime);
755    }
756
757    // --- Step 2: Berlekamp–Massey to find connection polynomial Λ ---
758    let lambda = berlekamp_massey(&seq, prime);
759    let ell = lambda.len() - 1; // LFSR length L ≤ T
760
761    if ell == 0 {
762        // Only the trivial polynomial: the sequence is identically zero.
763        return Ok(vec![]);
764    }
765
766    // --- Step 3: Find roots ρ of Λ in `F_p` (Cantor–Zassenhaus-style split) ---
767    let rho_roots = find_roots(&lambda, prime, rng)?;
768
769    if rho_roots.len() < ell {
770        return Err(SparseInterpError::RootFindingFailed);
771    }
772    // Use only the first `ell` roots (should be exactly ell distinct ones).
773    let rho: &[u64] = &rho_roots[..ell];
774
775    // --- Step 4: Map roots → frequencies → exponents ---
776    // Λ has roots ρ_j = g^{-e_j} (the inverses of the frequencies r_j).
777    // r_j = ρ_j^{-1} = g^{e_j}.
778    let mut exps: Vec<u32> = Vec::with_capacity(ell);
779    for &ro in rho {
780        if ro == 0 {
781            return Err(SparseInterpError::RootFindingFailed);
782        }
783        let r = mod_inv(ro, prime); // r = g^{e_j}
784        let e = bsgs_dlog(g, r, prime).ok_or(SparseInterpError::RootFindingFailed)?;
785        exps.push(e as u32);
786    }
787
788    // --- Step 5: Solve Vandermonde for coefficients ---
789    // The evaluation sequence satisfies: s[n] = Σ_j c_j · (g^{e_j})^n.
790    // As a matrix system with A[i][j] = pts[i]^{exps[j]}:
791    //   pts[i] = g^i  (i-th evaluation point)
792    //   exps[j] = e_j  (j-th monomial exponent)
793    //   vals[i] = s[i]  (i-th sequence value)
794    // This is the correct generalised-Vandermonde formulation.
795    let pts_for_vdm: Vec<u64> = (0..ell).map(|i| pow_mod(g, i as u64, prime)).collect();
796    let vals_for_vdm: Vec<u64> = seq[..ell].to_vec();
797    let coeffs = vandermonde_solve(&pts_for_vdm, &exps, &vals_for_vdm, prime)
798        .ok_or(SparseInterpError::SingularSystem)?;
799
800    Ok(coeffs
801        .into_iter()
802        .zip(exps)
803        .filter(|(c, _)| *c != 0)
804        .collect())
805}
806
807// ---------------------------------------------------------------------------
808// Dense univariate interpolation (fallback)
809// ---------------------------------------------------------------------------
810
811/// Dense Lagrange interpolation over `F_p`.
812///
813/// Given evaluations `f(1), f(2), …, f(D+1)` (at the first `D+1` non-zero
814/// field elements), returns the polynomial coefficients in ascending degree
815/// order.
816fn dense_interpolate(vals: &[u64], prime: u64) -> Vec<(u64, u32)> {
817    let n = vals.len();
818    // Evaluation points: 1, 2, …, n
819    let pts: Vec<u64> = (1..=n as u64).collect();
820    // Build Vandermonde system: pts[i]^j * c[j] = vals[i]
821    let mut mat: Vec<Vec<u64>> = (0..n)
822        .map(|i| (0..n).map(|j| pow_mod(pts[i], j as u64, prime)).collect())
823        .collect();
824    let mut rhs = vals.to_vec();
825    match gaussian_elim(&mut mat, &mut rhs, prime) {
826        Some(coeffs) => coeffs
827            .into_iter()
828            .enumerate()
829            .filter(|(_, c)| *c != 0)
830            .map(|(j, c)| (c, j as u32))
831            .collect(),
832        None => vec![], // degenerate; return empty
833    }
834}
835
836// ---------------------------------------------------------------------------
837// Multivariate Zippel (recursive)
838// ---------------------------------------------------------------------------
839
840/// One Vandermonde lift applied coherently across `dim` sibling components.
841fn lifted_eval_union(
842    x_pts: &[u64],
843    joint_exps: &[u32],
844    eval_multi: &dyn Fn(&[u64]) -> Vec<u64>,
845    prime: u64,
846    dim: usize,
847    m_count: usize,
848    x_suffix: &[u64],
849) -> Vec<u64> {
850    let mut new_vec = Vec::with_capacity(dim * m_count);
851    for j in 0..dim {
852        let f_vals: Vec<u64> = x_pts
853            .iter()
854            .map(|&xk| {
855                let mut args = vec![xk];
856                args.extend_from_slice(x_suffix);
857                eval_multi(&args).get(j).copied().unwrap_or(0)
858            })
859            .collect();
860        let coeffs = vandermonde_solve(x_pts, joint_exps, &f_vals, prime)
861            .unwrap_or_else(|| vec![0u64; m_count]);
862        debug_assert_eq!(coeffs.len(), m_count);
863        new_vec.extend(coeffs);
864    }
865    new_vec
866}
867
868/// Batched lifting: recover `dim` sibling coefficient polynomials simultaneously.
869/// Each map entry is `sparse exponents → coeff` in the remaining variables only.
870#[allow(clippy::too_many_arguments)] // recursion driver: shared oracle plus dimension bounds
871fn zippel_helper_multi(
872    eval_multi: &dyn Fn(&[u64]) -> Vec<u64>,
873    n_vars: usize,
874    dim: usize,
875    term_bound: usize,
876    degree_bound: u32,
877    prime: u64,
878    g: u64,
879    rng: &mut Xorshift64,
880) -> Result<Vec<BTreeMap<Vec<u32>, u64>>, SparseInterpError> {
881    if dim == 0 {
882        return Ok(vec![]);
883    }
884
885    if n_vars == 0 {
886        let v = eval_multi(&[]);
887        let mut out = Vec::with_capacity(dim);
888        for j in 0..dim {
889            let mut m = BTreeMap::new();
890            let c = *v.get(j).unwrap_or(&0);
891            if c != 0 {
892                m.insert(vec![], c);
893            }
894            out.push(m);
895        }
896        return Ok(out);
897    }
898
899    if n_vars == 1 {
900        let mut out = Vec::with_capacity(dim);
901        for j in 0..dim {
902            let terms = if degree_bound <= term_bound as u32 {
903                let d = degree_bound as usize + 1;
904                let vals: Vec<u64> = (1..=d as u64)
905                    .map(|x| eval_multi(&[x % prime]).get(j).copied().unwrap_or(0))
906                    .collect();
907                dense_interpolate(&vals, prime)
908            } else {
909                bt_univariate(
910                    &|t| eval_multi(&[t]).get(j).copied().unwrap_or(0),
911                    term_bound,
912                    prime,
913                    g,
914                    rng,
915                )?
916            };
917            let mut m = BTreeMap::new();
918            for (c, e) in terms {
919                if c != 0 {
920                    m.insert(vec![e], c);
921                }
922            }
923            out.push(m);
924        }
925        return Ok(out);
926    }
927
928    let a_rest: Vec<u64> = (0..n_vars - 1).map(|_| rng.nonzero(prime)).collect();
929
930    let mut per_comp_skeletons: Vec<Vec<(u64, u32)>> = Vec::with_capacity(dim);
931    for j in 0..dim {
932        let sk = {
933            let f1 = |t: u64| -> u64 {
934                let mut args = vec![t];
935                args.extend_from_slice(&a_rest);
936                eval_multi(&args).get(j).copied().unwrap_or(0)
937            };
938            if degree_bound <= term_bound as u32 {
939                let d = degree_bound as usize + 1;
940                let v: Vec<u64> = (1..=d as u64).map(|x| f1(x % prime)).collect();
941                dense_interpolate(&v, prime)
942            } else {
943                bt_univariate(&f1, term_bound, prime, g, rng)?
944            }
945        };
946        per_comp_skeletons.push(sk);
947    }
948
949    let mut joint_exps: Vec<u32> = Vec::new();
950    for sk in &per_comp_skeletons {
951        for &(_, e) in sk {
952            joint_exps.push(e);
953        }
954    }
955    joint_exps.sort_unstable();
956    joint_exps.dedup();
957    let m_count = joint_exps.len();
958
959    let empty_maps = || (0..dim).map(|_| BTreeMap::new()).collect::<Vec<_>>();
960
961    if m_count == 0 {
962        return Ok(empty_maps());
963    }
964
965    // Fully batched recursion uses vector dimension `dim · |joint|`; union can be
966    // large across many siblings (`dim ≈ term_bound`).  Above this budget fall back to
967    // the classic nested `zippel_helper` lifts — oracle depth improves over the legacy
968    // implementation (shared lift at outer peel) while keeping tests bounded.
969    // Allow large batched lifts for realistic `term_bound` (~20–50); tighter caps
970    // force the scalar fallback whose constant factor dominates on large `n_vars`.
971    let vec_budget = term_bound.saturating_mul(512).clamp(8192usize, 131072usize);
972    if dim.saturating_mul(m_count) > vec_budget {
973        let mut stacked: Vec<BTreeMap<Vec<u32>, u64>> = Vec::with_capacity(dim);
974        for (j, sk) in per_comp_skeletons.iter().enumerate().take(dim) {
975            if sk.is_empty() {
976                stacked.push(BTreeMap::new());
977                continue;
978            }
979            let exps_j: Vec<u32> = sk.iter().map(|(_, e)| *e).collect();
980            let tj = exps_j.len();
981            let mut pts: Vec<u64> = Vec::with_capacity(tj);
982            {
983                let mut used = std::collections::HashSet::new();
984                while pts.len() < tj {
985                    let v = rng.nonzero(prime);
986                    if used.insert(v) {
987                        pts.push(v);
988                    }
989                }
990            }
991            let mut comp_map = BTreeMap::new();
992            for k in 0..tj {
993                let e_cur = exps_j[k];
994                let sub_terms = zippel_helper(
995                    &|x_rest: &[u64]| -> u64 {
996                        let f_vals: Vec<u64> = pts
997                            .iter()
998                            .map(|&xk| {
999                                let mut args = vec![xk];
1000                                args.extend_from_slice(x_rest);
1001                                eval_multi(&args).get(j).copied().unwrap_or(0)
1002                            })
1003                            .collect();
1004                        vandermonde_solve(&pts, &exps_j, &f_vals, prime)
1005                            .map(|v| v[k])
1006                            .unwrap_or(0)
1007                    },
1008                    n_vars - 1,
1009                    term_bound,
1010                    degree_bound,
1011                    prime,
1012                    g,
1013                    rng,
1014                )?;
1015                for (mut sub_exp, coeff) in sub_terms {
1016                    if coeff != 0 {
1017                        let mut full = vec![e_cur];
1018                        full.append(&mut sub_exp);
1019                        comp_map.insert(full, coeff);
1020                    }
1021                }
1022            }
1023            stacked.push(comp_map);
1024        }
1025        return Ok(stacked);
1026    }
1027
1028    let mut x_pts: Vec<u64> = Vec::with_capacity(m_count);
1029    {
1030        let mut used = std::collections::HashSet::new();
1031        while x_pts.len() < m_count {
1032            let v = rng.nonzero(prime);
1033            if used.insert(v) {
1034                x_pts.push(v);
1035            }
1036        }
1037    }
1038
1039    let dim_next = dim * m_count;
1040    let sub = zippel_helper_multi(
1041        &|x_suffix: &[u64]| {
1042            lifted_eval_union(
1043                &x_pts,
1044                &joint_exps,
1045                eval_multi,
1046                prime,
1047                dim,
1048                m_count,
1049                x_suffix,
1050            )
1051        },
1052        n_vars - 1,
1053        dim_next,
1054        term_bound,
1055        degree_bound,
1056        prime,
1057        g,
1058        rng,
1059    )?;
1060
1061    let mut result: Vec<BTreeMap<Vec<u32>, u64>> = empty_maps();
1062    for (j, res_j) in result.iter_mut().enumerate().take(dim) {
1063        for (r, &e1) in joint_exps.iter().enumerate().take(m_count) {
1064            let slot = j * m_count + r;
1065            for (sub_exp, coeff) in &sub[slot] {
1066                if *coeff != 0 {
1067                    let mut full_exp = vec![e1];
1068                    full_exp.extend_from_slice(sub_exp);
1069                    res_j.insert(full_exp, *coeff);
1070                }
1071            }
1072        }
1073    }
1074
1075    Ok(result)
1076}
1077
1078/// Recursive Zippel helper.  Returns a map from exponent vectors to
1079/// coefficients in `F_p`.
1080fn zippel_helper(
1081    eval: &dyn Fn(&[u64]) -> u64,
1082    n_vars: usize,
1083    term_bound: usize,
1084    degree_bound: u32,
1085    prime: u64,
1086    g: u64,
1087    rng: &mut Xorshift64,
1088) -> Result<BTreeMap<Vec<u32>, u64>, SparseInterpError> {
1089    // --- Base case: constant polynomial ---
1090    if n_vars == 0 {
1091        let c = eval(&[]);
1092        let mut m = BTreeMap::new();
1093        if c != 0 {
1094            m.insert(vec![], c);
1095        }
1096        return Ok(m);
1097    }
1098
1099    // --- Base case: univariate ---
1100    if n_vars == 1 {
1101        // Use dense fallback if degree_bound is small (avoids BSGS overhead).
1102        let terms = if degree_bound <= term_bound as u32 {
1103            // Dense path: evaluate at degree_bound+1 points.
1104            let d = degree_bound as usize + 1;
1105            let v: Vec<u64> = (1..=d as u64).map(|x| eval(&[x % prime])).collect();
1106            dense_interpolate(&v, prime)
1107        } else {
1108            bt_univariate(&|t| eval(&[t]), term_bound, prime, g, rng)?
1109        };
1110        let mut m = BTreeMap::new();
1111        for (c, e) in terms {
1112            m.insert(vec![e], c);
1113        }
1114        return Ok(m);
1115    }
1116
1117    // --- Multivariate Zippel ---
1118
1119    // Step 1: Evaluate f(x₁, a₂, …, aₙ) for random aᵢ to get x₁-skeleton.
1120    let a_rest: Vec<u64> = (0..n_vars - 1).map(|_| rng.nonzero(prime)).collect();
1121
1122    let skeleton: Vec<(u64, u32)> = {
1123        let f1 = |t: u64| -> u64 {
1124            let mut args = vec![t];
1125            args.extend_from_slice(&a_rest);
1126            eval(&args)
1127        };
1128        if degree_bound <= term_bound as u32 {
1129            let d = degree_bound as usize + 1;
1130            let v: Vec<u64> = (1..=d as u64).map(|x| f1(x % prime)).collect();
1131            dense_interpolate(&v, prime)
1132        } else {
1133            bt_univariate(&f1, term_bound, prime, g, rng)?
1134        }
1135    };
1136
1137    if skeleton.is_empty() {
1138        return Ok(BTreeMap::new());
1139    }
1140
1141    let x1_exps: Vec<u32> = skeleton.iter().map(|(_, e)| *e).collect();
1142    let t = x1_exps.len();
1143
1144    // Step 2: Choose t distinct evaluation points for x₁.
1145    let mut x1_pts: Vec<u64> = Vec::with_capacity(t);
1146    {
1147        let mut used = std::collections::HashSet::new();
1148        while x1_pts.len() < t {
1149            let v = rng.nonzero(prime);
1150            if used.insert(v) {
1151                x1_pts.push(v);
1152            }
1153        }
1154    }
1155
1156    // Step 3: batched Vandermonde lift → single recursion (shared oracle).
1157    let eval_multi = |x_rest: &[u64]| -> Vec<u64> {
1158        let mut f_vals: Vec<u64> = Vec::with_capacity(t);
1159        for &xk in &x1_pts {
1160            let mut args = vec![xk];
1161            args.extend_from_slice(x_rest);
1162            f_vals.push(eval(&args));
1163        }
1164        vandermonde_solve(&x1_pts, &x1_exps, &f_vals, prime).unwrap_or_else(|| vec![0u64; t])
1165    };
1166
1167    let sub_maps = zippel_helper_multi(
1168        &eval_multi,
1169        n_vars - 1,
1170        t,
1171        term_bound,
1172        degree_bound,
1173        prime,
1174        g,
1175        rng,
1176    )?;
1177
1178    let mut result: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1179    for j in 0..t {
1180        let e1 = x1_exps[j];
1181        for (sub_exp, coeff) in &sub_maps[j] {
1182            if *coeff != 0 {
1183                let mut full_exp = vec![e1];
1184                full_exp.extend_from_slice(sub_exp);
1185                result.insert(full_exp, *coeff);
1186            }
1187        }
1188    }
1189
1190    Ok(result)
1191}
1192
1193// ---------------------------------------------------------------------------
1194// Public API
1195// ---------------------------------------------------------------------------
1196
1197/// Recover a sparse univariate polynomial `f ∈ F_p[x]` from black-box
1198/// evaluations using the Ben-Or/Tiwari (Prony-style) algorithm.
1199///
1200/// # Parameters
1201///
1202/// - `eval` — black-box oracle: `x ↦ f(x) mod p`.  Called `2·term_bound`
1203///   times (at consecutive powers of a primitive root of `F_p`).
1204/// - `term_bound` — `T`: upper bound on the number of nonzero terms in `f`.
1205/// - `prime` — field characteristic `p`.  Must satisfy `p > 2·T` and
1206///   `p > max_degree(f)` (so all exponents are representable as discrete
1207///   logarithms in `{0, …, p-2}`).
1208///
1209/// # Returns
1210///
1211/// A vector of `(coefficient, exponent)` pairs in arbitrary order.
1212///
1213/// # Errors
1214///
1215/// - [`SparseInterpError::InvalidPrime`] if `p` is not prime.
1216/// - [`SparseInterpError::PrimeTooSmall`] if `p ≤ 2·T`.
1217/// - [`SparseInterpError::RootFindingFailed`] if fewer roots were found than
1218///   expected (the prime may be smaller than `max_degree(f)`).
1219/// - [`SparseInterpError::SingularSystem`] if the Vandermonde system is
1220///   degenerate (extremely rare; retry with a different prime).
1221///
1222/// # Example
1223///
1224/// ```text
1225/// // Recover  x^100 + 3·x^17 + 5  from 6 evaluations (T=3).
1226/// let eval = |x: u64| { ... };
1227/// let terms = sparse_interpolate_univariate(&eval, 3, 1009)?;
1228/// // terms ≈ [(1, 100), (3, 17), (5, 0)]
1229/// ```
1230pub fn sparse_interpolate_univariate(
1231    eval: &dyn Fn(u64) -> u64,
1232    term_bound: usize,
1233    prime: u64,
1234) -> Result<Vec<(u64, u32)>, SparseInterpError> {
1235    if !is_prime(prime) {
1236        return Err(SparseInterpError::InvalidPrime(prime));
1237    }
1238    if prime <= 2 * term_bound as u64 {
1239        return Err(SparseInterpError::PrimeTooSmall { prime, term_bound });
1240    }
1241    let g = primitive_root(prime);
1242    let mut rng = Xorshift64::new(prime.wrapping_mul(0x5851_f42d_4c95_7f2d));
1243    bt_univariate(eval, term_bound, prime, g, &mut rng)
1244}
1245
1246/// Recover a sparse multivariate polynomial `f ∈ F_p[x₁, …, xₙ]` from
1247/// black-box evaluations using Zippel's variable-by-variable algorithm.
1248///
1249/// # Parameters
1250///
1251/// - `eval` — black-box oracle: `(x₁, …, xₙ) ↦ f(x₁, …, xₙ) mod p`.
1252///   Coordinates are given in the same order as `vars`.
1253/// - `vars` — symbolic variable identifiers (used to label the result).
1254/// - `term_bound` — `T`: upper bound on the number of nonzero terms.
1255/// - `degree_bound` — `D`: upper bound on the degree of each individual
1256///   variable.  Polynomials with lower `D` converge faster.  For the dense
1257///   fallback to kick in, set `D ≤ T`.
1258/// - `prime` — field characteristic `p`.  Must satisfy `p > 2·T` and
1259///   `p > D` (so exponents are representable as discrete logs).
1260/// - `seed` — seed for the internal PRNG.  Changing the seed helps recover
1261///   from occasional failures due to unlucky random evaluation points.
1262///
1263/// # Returns
1264///
1265/// A [`MultiPolyFp`] with the recovered polynomial.  Oracle complexity is
1266/// polynomial in the number of variables, `term_bound`, and `degree_bound` on
1267/// typical sparse inputs — unlike dense interpolation at `Ω((D+1)^n)`.
1268///
1269/// # Errors
1270///
1271/// See [`SparseInterpError`].
1272pub fn sparse_interpolate(
1273    eval: &dyn Fn(&[u64]) -> u64,
1274    vars: Vec<ExprId>,
1275    term_bound: usize,
1276    degree_bound: u32,
1277    prime: u64,
1278    seed: u64,
1279) -> Result<MultiPolyFp, SparseInterpError> {
1280    if !is_prime(prime) {
1281        return Err(SparseInterpError::InvalidPrime(prime));
1282    }
1283    if prime <= 2 * term_bound as u64 {
1284        return Err(SparseInterpError::PrimeTooSmall { prime, term_bound });
1285    }
1286
1287    let n_vars = vars.len();
1288    let g = primitive_root(prime);
1289    let mut rng = Xorshift64::new(seed);
1290
1291    let terms = zippel_helper(eval, n_vars, term_bound, degree_bound, prime, g, &mut rng)?;
1292
1293    let trimmed_terms: BTreeMap<Vec<u32>, u64> = terms
1294        .into_iter()
1295        .map(|(mut exp, c)| {
1296            // Trim trailing zeros in exponent vector.
1297            while exp.last() == Some(&0) {
1298                exp.pop();
1299            }
1300            (exp, c)
1301        })
1302        .filter(|(_, c)| *c != 0)
1303        .collect();
1304
1305    Ok(MultiPolyFp {
1306        vars,
1307        modulus: prime,
1308        terms: trimmed_terms,
1309    })
1310}
1311
1312// ---------------------------------------------------------------------------
1313// Unit tests
1314// ---------------------------------------------------------------------------
1315
1316#[cfg(test)]
1317mod tests {
1318    use super::*;
1319    use crate::kernel::{Domain, ExprPool};
1320
1321    // ---- helpers ------------------------------------------------------------
1322
1323    fn make_poly_eval(coeffs: &[(u64, Vec<u32>)], prime: u64) -> impl Fn(&[u64]) -> u64 + '_ {
1324        move |pt: &[u64]| -> u64 {
1325            let mut acc = 0u64;
1326            for (c, exp) in coeffs {
1327                let mut term = *c % prime;
1328                for (i, &e) in exp.iter().enumerate() {
1329                    let xi = if i < pt.len() { pt[i] } else { 0 };
1330                    term = mul_mod(term, pow_mod(xi, e as u64, prime), prime);
1331                }
1332                acc = add_mod(acc, term, prime);
1333            }
1334            acc
1335        }
1336    }
1337
1338    fn vars(n: usize) -> (ExprPool, Vec<ExprId>) {
1339        let pool = ExprPool::new();
1340        let vs: Vec<ExprId> = (0..n)
1341            .map(|i| pool.symbol(format!("x{i}"), Domain::Real))
1342            .collect();
1343        (pool, vs)
1344    }
1345
1346    // ---- primitive_root -----------------------------------------------------
1347
1348    #[test]
1349    fn prim_root_small_primes() {
1350        for p in [2u64, 3, 5, 7, 11, 13, 17, 19, 23] {
1351            let g = primitive_root(p);
1352            // Verify: g^{p-1} = 1 and g^{(p-1)/q} ≠ 1 for each prime q | p-1
1353            assert_eq!(pow_mod(g, p - 1, p), 1, "g^(p-1)=1 for p={p}");
1354            for q in prime_factors(p - 1) {
1355                assert_ne!(pow_mod(g, (p - 1) / q, p), 1, "g^((p-1)/{q}) ≠ 1 for p={p}");
1356            }
1357        }
1358    }
1359
1360    // ---- berlekamp_massey ---------------------------------------------------
1361
1362    #[test]
1363    fn bm_geometric_sequence() {
1364        // s[n] = 2^n mod 7.  LFSR: s[n] = 2·s[n-1], connection poly = 1 + 5z (since
1365        // 2·(1 + 5z) → 2·1 + 2·5·g = 0 means the root is 2^{-1} = 4 in F_7).
1366        let p = 7u64;
1367        let seq: Vec<u64> = (0..6).map(|n| pow_mod(2, n, p)).collect();
1368        let lambda = berlekamp_massey(&seq, p);
1369        assert_eq!(lambda.len() - 1, 1, "LFSR length should be 1");
1370        // Verify Λ(2^{-1}) = 0
1371        let inv2 = mod_inv(2, p);
1372        assert_eq!(poly_eval(&lambda, inv2, p), 0);
1373    }
1374
1375    #[test]
1376    fn bm_two_term_sequence() {
1377        // s[n] = 3·2^n + 5·3^n  mod 11
1378        let p = 11u64;
1379        let seq: Vec<u64> = (0..4)
1380            .map(|n| {
1381                add_mod(
1382                    mul_mod(3, pow_mod(2, n, p), p),
1383                    mul_mod(5, pow_mod(3, n, p), p),
1384                    p,
1385                )
1386            })
1387            .collect();
1388        let lambda = berlekamp_massey(&seq, p);
1389        assert_eq!(lambda.len() - 1, 2, "two-term sequence has LFSR length 2");
1390        // Roots of Λ should include inv(2) and inv(3)
1391        let mut rng = Xorshift64::new(0xbeef);
1392        let roots = find_roots(&lambda, p, &mut rng).unwrap();
1393        assert_eq!(roots.len(), 2);
1394        let expected: std::collections::HashSet<u64> =
1395            [mod_inv(2, p), mod_inv(3, p)].into_iter().collect();
1396        let got: std::collections::HashSet<u64> = roots.into_iter().collect();
1397        assert_eq!(got, expected);
1398    }
1399
1400    // ---- bsgs_dlog ----------------------------------------------------------
1401
1402    #[test]
1403    fn dlog_basic() {
1404        let p = 13u64;
1405        let g = primitive_root(p);
1406        for e in 0..p - 1 {
1407            let target = pow_mod(g, e, p);
1408            let found = bsgs_dlog(g, target, p).expect("dlog should succeed");
1409            assert_eq!(
1410                pow_mod(g, found, p),
1411                target,
1412                "g^{found} ≠ {target} for p={p}"
1413            );
1414        }
1415    }
1416
1417    // ---- sparse_interpolate_univariate --------------------------------------
1418
1419    #[test]
1420    fn uni_zero_polynomial() {
1421        let terms = sparse_interpolate_univariate(&|_| 0, 5, 101).unwrap();
1422        assert!(terms.is_empty());
1423    }
1424
1425    #[test]
1426    fn uni_constant() {
1427        // f(x) = 7.  One term (coeff=7, exp=0).
1428        let terms = sparse_interpolate_univariate(&|_| 7, 3, 101).unwrap();
1429        assert_eq!(terms.len(), 1);
1430        let (c, e) = terms[0];
1431        assert_eq!(c, 7);
1432        assert_eq!(e, 0);
1433    }
1434
1435    #[test]
1436    fn uni_single_monomial() {
1437        // f(x) = 3·x^5 mod 101
1438        let p = 101u64;
1439        let eval = |x: u64| mul_mod(3, pow_mod(x, 5, p), p);
1440        let terms = sparse_interpolate_univariate(&eval, 3, p).unwrap();
1441        assert_eq!(terms.len(), 1);
1442        let (c, e) = terms[0];
1443        assert_eq!(c, 3);
1444        assert_eq!(e, 5);
1445    }
1446
1447    #[test]
1448    fn uni_two_terms() {
1449        // f(x) = x^10 + 2·x^3 mod 101
1450        let p = 101u64;
1451        let eval = |x: u64| {
1452            let a = pow_mod(x, 10, p);
1453            let b = mul_mod(2, pow_mod(x, 3, p), p);
1454            add_mod(a, b, p)
1455        };
1456        let terms = sparse_interpolate_univariate(&eval, 3, p).unwrap();
1457        assert_eq!(terms.len(), 2);
1458        let mut sorted = terms.clone();
1459        sorted.sort_by_key(|&(_, e)| e);
1460        assert_eq!(sorted[0], (2, 3));
1461        assert_eq!(sorted[1], (1, 10));
1462    }
1463
1464    #[test]
1465    fn uni_roadmap_example() {
1466        // ROADMAP: recover x^100 + 3·x^17 + 5 from T=3 (6 evaluations).
1467        // Needs prime p > 100.  Use p = 997 (prime > 100 and > 2*3=6).
1468        let p = 997u64;
1469        let eval = |x: u64| {
1470            let a = pow_mod(x, 100, p);
1471            let b = mul_mod(3, pow_mod(x, 17, p), p);
1472            let c = 5u64;
1473            add_mod(add_mod(a, b, p), c, p)
1474        };
1475        let terms = sparse_interpolate_univariate(&eval, 4, p).unwrap();
1476        let mut sorted = terms.clone();
1477        sorted.sort_by_key(|&(_, e)| e);
1478        // Expect: [(5,0), (3,17), (1,100)]
1479        assert!(
1480            sorted.iter().any(|&(c, e)| c == 5 && e == 0),
1481            "missing constant 5"
1482        );
1483        assert!(
1484            sorted.iter().any(|&(c, e)| c == 3 && e == 17),
1485            "missing 3·x^17"
1486        );
1487        assert!(
1488            sorted.iter().any(|&(c, e)| c == 1 && e == 100),
1489            "missing x^100"
1490        );
1491    }
1492
1493    #[test]
1494    fn uni_error_invalid_prime() {
1495        let err = sparse_interpolate_univariate(&|_| 0, 3, 4);
1496        assert!(matches!(err, Err(SparseInterpError::InvalidPrime(4))));
1497    }
1498
1499    #[test]
1500    fn uni_error_prime_too_small() {
1501        // T=10 needs p > 20; use p=19.
1502        let err = sparse_interpolate_univariate(&|_| 0, 10, 19);
1503        assert!(matches!(
1504            err,
1505            Err(SparseInterpError::PrimeTooSmall {
1506                prime: 19,
1507                term_bound: 10
1508            })
1509        ));
1510    }
1511
1512    // ---- sparse_interpolate (multivariate) ----------------------------------
1513
1514    #[test]
1515    fn multi_constant() {
1516        let (_, vs) = vars(2);
1517        let result = sparse_interpolate(&|_| 42, vs, 3, 10, 101, 0).unwrap();
1518        assert_eq!(result.terms.len(), 1);
1519        assert_eq!(*result.terms.get(&vec![]).unwrap(), 42u64);
1520    }
1521
1522    #[test]
1523    fn multi_univariate_via_multi() {
1524        // x^2 + 3·x + 1 in one variable
1525        let p = 101u64;
1526        let (_, vs) = vars(1);
1527        let eval = |pt: &[u64]| {
1528            let x = pt[0];
1529            add_mod(add_mod(pow_mod(x, 2, p), mul_mod(3, x, p), p), 1, p)
1530        };
1531        let result = sparse_interpolate(&eval, vs, 5, 10, p, 0).unwrap();
1532        // Expect terms: exp=[0]→1, exp=[1]→3, exp=[2]→1
1533        assert_eq!(*result.terms.get(&vec![2]).unwrap(), 1u64, "x^2 coeff");
1534        assert_eq!(*result.terms.get(&vec![1]).unwrap(), 3u64, "x^1 coeff");
1535        assert_eq!(*result.terms.get(&vec![]).unwrap_or(&0), 1u64, "x^0 coeff");
1536    }
1537
1538    #[test]
1539    fn multi_bivariate_xy() {
1540        // f = x·y + 3 over F_101
1541        let p = 101u64;
1542        let (_, vs) = vars(2);
1543        let eval = |pt: &[u64]| add_mod(mul_mod(pt[0], pt[1], p), 3, p);
1544        let result = sparse_interpolate(&eval, vs, 4, 5, p, 1).unwrap();
1545        // Expect: {[1,1]→1, []→3} (or [0,0]→3)
1546        assert_eq!(
1547            *result.terms.get(&vec![1, 1]).unwrap_or(&0),
1548            1u64,
1549            "x*y coeff"
1550        );
1551        assert_eq!(*result.terms.get(&vec![]).unwrap_or(&0), 3u64, "constant");
1552    }
1553
1554    #[test]
1555    fn multi_bivariate_x_squared_y() {
1556        // f = x^2·y + 5·y + 2·x  over F_101
1557        let p = 101u64;
1558        let (_, vs) = vars(2);
1559        let eval = |pt: &[u64]| {
1560            let x = pt[0];
1561            let y = pt[1];
1562            let a = mul_mod(pow_mod(x, 2, p), y, p);
1563            let b = mul_mod(5, y, p);
1564            let c = mul_mod(2, x, p);
1565            add_mod(add_mod(a, b, p), c, p)
1566        };
1567        let result = sparse_interpolate(&eval, vs, 5, 6, p, 42).unwrap();
1568        assert_eq!(*result.terms.get(&vec![2, 1]).unwrap_or(&0), 1, "x^2*y");
1569        assert_eq!(*result.terms.get(&vec![0, 1]).unwrap_or(&0), 5, "5*y");
1570        assert_eq!(*result.terms.get(&vec![1]).unwrap_or(&0), 2, "2*x");
1571    }
1572
1573    #[test]
1574    fn multi_three_variables() {
1575        // f = x·y·z + x^2 + z  over F_1009
1576        let p = 1009u64;
1577        let (_, vs) = vars(3);
1578        let eval = |pt: &[u64]| {
1579            let x = pt[0];
1580            let y = pt[1];
1581            let z = pt[2];
1582            let xyz = mul_mod(mul_mod(x, y, p), z, p);
1583            let x2 = pow_mod(x, 2, p);
1584            add_mod(add_mod(xyz, x2, p), z, p)
1585        };
1586        let result = sparse_interpolate(&eval, vs, 5, 4, p, 7).unwrap();
1587        assert_eq!(*result.terms.get(&vec![1, 1, 1]).unwrap_or(&0), 1, "x*y*z");
1588        assert_eq!(*result.terms.get(&vec![2]).unwrap_or(&0), 1, "x^2");
1589        assert_eq!(*result.terms.get(&vec![0, 0, 1]).unwrap_or(&0), 1, "z");
1590    }
1591
1592    #[test]
1593    fn multi_roundtrip_via_multipoly() {
1594        // Build a MultiPoly, reduce mod p, then interpolate and verify agreement.
1595        use crate::poly::multipoly::MultiPoly;
1596        let p = 1009u64;
1597        let pool = ExprPool::new();
1598        let x = pool.symbol("x", Domain::Real);
1599        let y = pool.symbol("y", Domain::Real);
1600
1601        // f = x^3 + 2·x·y - y^2 + 4
1602        let x3 = pool.pow(x, pool.integer(3_i32));
1603        let xy = pool.mul(vec![pool.integer(2_i32), x, y]);
1604        let y2 = pool.mul(vec![pool.integer(-1_i32), pool.pow(y, pool.integer(2_i32))]);
1605        let expr = pool.add(vec![x3, xy, y2, pool.integer(4_i32)]);
1606
1607        let mp = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
1608        let fp_ref = crate::modular::reduce_mod(&mp, p).unwrap();
1609
1610        // Oracle evaluates the MultiPoly at a point over F_p.
1611        let vars_for_interp = vec![x, y];
1612        let eval = |pt: &[u64]| {
1613            let mut acc = 0u64;
1614            for (exp, coeff) in &mp.terms {
1615                let c_mod = {
1616                    let r = coeff.clone() % rug::Integer::from(p);
1617                    let r = if r < 0 { r + rug::Integer::from(p) } else { r };
1618                    r.to_u64().unwrap()
1619                };
1620                let mut term = c_mod;
1621                for (i, &e) in exp.iter().enumerate() {
1622                    let xi = if i < pt.len() { pt[i] } else { 0 };
1623                    term = mul_mod(term, pow_mod(xi, e as u64, p), p);
1624                }
1625                acc = add_mod(acc, term, p);
1626            }
1627            acc
1628        };
1629
1630        let recovered = sparse_interpolate(&eval, vars_for_interp, 6, 5, p, 0).unwrap();
1631
1632        // Compare term by term.
1633        for (exp, &coeff) in &recovered.terms {
1634            let ref_coeff = fp_ref.terms.get(exp).copied().unwrap_or(0);
1635            assert_eq!(coeff, ref_coeff, "mismatch at exp {:?}", exp);
1636        }
1637        // Check no terms were missed.
1638        for (exp, &ref_coeff) in &fp_ref.terms {
1639            let got = recovered.terms.get(exp).copied().unwrap_or(0);
1640            assert_eq!(got, ref_coeff, "missed term at exp {:?}", exp);
1641        }
1642    }
1643
1644    #[test]
1645    fn multi_diag_15term_three_var_smoke() {
1646        // Mirrors the benchmark diagonal structure (sparse_interp_multivar) at a CI-friendly size.
1647        let p = 32749u64;
1648        let n_vars = 3;
1649        let n_terms = n_vars;
1650        let mut terms = Vec::new();
1651        for i in 0..n_terms {
1652            let mut coeff = (((i + 1) as u64) * 7) % p;
1653            if coeff == 0 {
1654                coeff = 1;
1655            }
1656            let mut exp = vec![0u32; n_vars];
1657            exp[i % n_vars] = (i % 3) as u32 + 1;
1658            terms.push((coeff, exp));
1659        }
1660        let eval_fn = make_poly_eval(&terms, p);
1661        let (_, vs) = vars(n_vars);
1662        let mut expected: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1663        for (c, exp) in &terms {
1664            let mut e = exp.clone();
1665            while e.last() == Some(&0) {
1666                e.pop();
1667            }
1668            let nc = *c % p;
1669            expected
1670                .entry(e)
1671                .and_modify(|v| {
1672                    *v = add_mod(*v, nc, p);
1673                })
1674                .or_insert(nc);
1675        }
1676
1677        let mut successes = 0usize;
1678        for seed in [0_u64, 1, 2, 41] {
1679            let result = sparse_interpolate(&eval_fn, vs.clone(), n_terms + 5, 4, p, seed)
1680                .expect("smoke interpolate should succeed");
1681            let mut ok = result.terms.len() == expected.len();
1682            for (exp, &ec) in &expected {
1683                if result.terms.get(exp).copied().unwrap_or(0) != ec {
1684                    ok = false;
1685                }
1686            }
1687            if ok {
1688                successes += 1;
1689            }
1690        }
1691        assert!(successes >= 3, "expected ≥ 3 successes on diagonal smoke");
1692    }
1693
1694    #[test]
1695    #[ignore]
1696    fn multi_interp_diag_large_stress_slow() {
1697        // `cargo test -p alkahest-core poly::interp --release -- --ignored`
1698        //
1699        // 6-variable workload (benchmark-shaped diagonal polynomial).  Larger `size`
1700        // dimensions are exercised by benchmarks; CI keeps only a lightweight 3-var smoke.
1701        let p = 32749u64;
1702        let n_vars = 6;
1703        let n_terms = 15;
1704        let mut terms = Vec::new();
1705        for i in 0..n_terms {
1706            let mut coeff = (((i + 1) as u64) * 7) % p;
1707            if coeff == 0 {
1708                coeff = 1;
1709            }
1710            let mut exp = vec![0u32; n_vars];
1711            exp[i % n_vars] = (i % 3) as u32 + 1;
1712            terms.push((coeff, exp));
1713        }
1714        let eval_fn = make_poly_eval(&terms, p);
1715        let (_, vs) = vars(n_vars);
1716        let mut expected: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1717        for (c, exp) in &terms {
1718            let mut e = exp.clone();
1719            while e.last() == Some(&0) {
1720                e.pop();
1721            }
1722            let nc = *c % p;
1723            expected
1724                .entry(e)
1725                .and_modify(|v| {
1726                    *v = add_mod(*v, nc, p);
1727                })
1728                .or_insert(nc);
1729        }
1730
1731        let result = sparse_interpolate(&eval_fn, vs.clone(), n_terms + 5, 4, p, 7)
1732            .expect("stress interpolate should succeed");
1733        assert_eq!(result.terms.len(), expected.len());
1734        for (exp, &ec) in &expected {
1735            assert_eq!(result.terms.get(exp).copied().unwrap_or(0), ec);
1736        }
1737    }
1738}