1use crate::errors::AlkahestError;
37use crate::kernel::ExprId;
38use crate::modular::{is_prime, MultiPolyFp};
39use std::collections::BTreeMap;
40
41#[derive(Debug, Clone, PartialEq)]
47pub enum SparseInterpError {
48 InvalidPrime(u64),
50 PrimeTooSmall { prime: u64, term_bound: usize },
52 RootFindingFailed,
56 SingularSystem,
59}
60
61impl std::fmt::Display for SparseInterpError {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 SparseInterpError::InvalidPrime(p) => {
65 write!(f, "invalid prime {p}: must be a prime ≥ 3")
66 }
67 SparseInterpError::PrimeTooSmall { prime, term_bound } => write!(
68 f,
69 "prime {prime} is too small for term_bound {term_bound}: need prime > 2·T = {}",
70 2 * term_bound
71 ),
72 SparseInterpError::RootFindingFailed => write!(
73 f,
74 "could not find the expected number of roots in F_p; \
75 the prime may be too small or the oracle is inconsistent"
76 ),
77 SparseInterpError::SingularSystem => write!(
78 f,
79 "Vandermonde system is singular; try a different seed or a larger prime"
80 ),
81 }
82 }
83}
84
85impl std::error::Error for SparseInterpError {}
86
87impl AlkahestError for SparseInterpError {
88 fn code(&self) -> &'static str {
89 match self {
90 SparseInterpError::InvalidPrime(_) => "E-INTERP-001",
91 SparseInterpError::PrimeTooSmall { .. } => "E-INTERP-002",
92 SparseInterpError::RootFindingFailed => "E-INTERP-003",
93 SparseInterpError::SingularSystem => "E-INTERP-004",
94 }
95 }
96
97 fn remediation(&self) -> Option<&'static str> {
98 match self {
99 SparseInterpError::InvalidPrime(_) => {
100 Some("choose a prime p ≥ 3, e.g. 1009, 32749, 1000003")
101 }
102 SparseInterpError::PrimeTooSmall { .. } => {
103 Some("increase the prime so that p > 2 * term_bound")
104 }
105 SparseInterpError::RootFindingFailed => {
106 Some("choose a prime larger than the maximum degree in the polynomial")
107 }
108 SparseInterpError::SingularSystem => {
109 Some("retry with a different seed or use a larger prime")
110 }
111 }
112 }
113}
114
115pub struct Xorshift64 {
121 state: u64,
122}
123
124impl Xorshift64 {
125 pub fn new(seed: u64) -> Self {
126 let s = if seed == 0 { 0xdeadbeef_cafebabe } else { seed };
128 Xorshift64 { state: s }
129 }
130
131 pub fn step(&mut self) -> u64 {
132 let mut x = self.state;
133 x ^= x << 13;
134 x ^= x >> 7;
135 x ^= x << 17;
136 self.state = x;
137 x
138 }
139
140 pub fn next_range(&mut self, lo: u64, hi: u64) -> u64 {
142 debug_assert!(hi > lo);
143 lo + self.step() % (hi - lo)
144 }
145
146 pub fn nonzero(&mut self, p: u64) -> u64 {
148 loop {
149 let v = self.step() % p;
150 if v != 0 {
151 return v;
152 }
153 }
154 }
155}
156
157#[inline]
162fn mul_mod(a: u64, b: u64, p: u64) -> u64 {
163 ((a as u128 * b as u128) % p as u128) as u64
164}
165
166#[inline]
167fn add_mod(a: u64, b: u64, p: u64) -> u64 {
168 let s = a + b;
169 if s >= p {
170 s - p
171 } else {
172 s
173 }
174}
175
176#[inline]
177fn sub_mod(a: u64, b: u64, p: u64) -> u64 {
178 if a >= b {
179 a - b
180 } else {
181 a + p - b
182 }
183}
184
185fn pow_mod(mut base: u64, mut exp: u64, p: u64) -> u64 {
186 let mut result = 1u64;
187 base %= p;
188 while exp > 0 {
189 if exp & 1 == 1 {
190 result = mul_mod(result, base, p);
191 }
192 base = mul_mod(base, base, p);
193 exp >>= 1;
194 }
195 result
196}
197
198fn mod_inv(a: u64, p: u64) -> u64 {
200 debug_assert!(a != 0, "mod_inv: a must be non-zero");
201 let mut old_r = a as i128;
202 let mut r = p as i128;
203 let mut old_s: i128 = 1;
204 let mut s: i128 = 0;
205 while r != 0 {
206 let q = old_r / r;
207 let tmp = r;
208 r = old_r - q * r;
209 old_r = tmp;
210 let tmp = s;
211 s = old_s - q * s;
212 old_s = tmp;
213 }
214 ((old_s % p as i128 + p as i128) % p as i128) as u64
215}
216
217fn poly_eval(poly: &[u64], x: u64, p: u64) -> u64 {
223 let mut acc = 0u64;
224 let mut pw = 1u64;
225 for &c in poly {
226 acc = add_mod(acc, mul_mod(c, pw, p), p);
227 pw = mul_mod(pw, x, p);
228 }
229 acc
230}
231
232pub fn primitive_root(p: u64) -> u64 {
241 debug_assert!(is_prime(p), "primitive_root: p must be prime");
242 if p == 2 {
243 return 1;
244 }
245 if p == 3 {
246 return 2;
247 }
248 let factors = prime_factors(p - 1);
249 'outer: for g in 2..p {
250 for &q in &factors {
251 if pow_mod(g, (p - 1) / q, p) == 1 {
252 continue 'outer;
253 }
254 }
255 return g;
256 }
257 panic!("primitive_root: no root found for prime {p}");
258}
259
260fn prime_factors(mut n: u64) -> Vec<u64> {
262 let mut factors = Vec::new();
263 let mut d = 2u64;
264 while d * d <= n {
265 if n % d == 0 {
266 factors.push(d);
267 while n % d == 0 {
268 n /= d;
269 }
270 }
271 d += 1;
272 }
273 if n > 1 {
274 factors.push(n);
275 }
276 factors
277}
278
279fn berlekamp_massey(seq: &[u64], p: u64) -> Vec<u64> {
294 let n = seq.len();
295 let mut l = 0usize;
296 let mut c: Vec<u64> = vec![1];
297 let mut b: Vec<u64> = vec![1];
298 let mut b_disc: u64 = 1;
299 let mut x: usize = 1;
300
301 for n_idx in 0..n {
302 let mut d = seq[n_idx];
304 let bound = l.min(c.len().saturating_sub(1));
305 for i in 1..=bound {
306 d = add_mod(d, mul_mod(c[i], seq[n_idx - i], p), p);
307 }
308
309 if d == 0 {
310 x += 1;
311 continue;
312 }
313
314 let t = c.clone();
315 let factor = mul_mod(d, mod_inv(b_disc, p), p);
316
317 let needed = x + b.len();
319 if c.len() < needed {
320 c.resize(needed, 0);
321 }
322 for j in 0..b.len() {
323 let sub = mul_mod(factor, b[j], p);
324 c[x + j] = sub_mod(c[x + j], sub, p);
325 }
326
327 if 2 * l <= n_idx {
328 l = n_idx + 1 - l;
329 b = t;
330 b_disc = d;
331 x = 1;
332 } else {
333 x += 1;
334 }
335 }
336
337 c
338}
339
340fn poly_trim(mut a: Vec<u64>) -> Vec<u64> {
345 while a.len() > 1 && a.last() == Some(&0) {
346 a.pop();
347 }
348 a
349}
350
351#[inline]
352fn poly_deg(poly: &[u64]) -> i32 {
353 let t = poly_trim(poly.to_vec());
354 if t.is_empty() || (t.len() == 1 && t[0] == 0) {
355 return -1;
356 }
357 t.len() as i32 - 1
358}
359
360fn poly_add(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
362 let n = a.len().max(b.len());
363 let mut out = vec![0u64; n];
364 for i in 0..n {
365 let x = if i < a.len() { a[i] } else { 0 };
366 let y = if i < b.len() { b[i] } else { 0 };
367 out[i] = add_mod(x, y, p);
368 }
369 poly_trim(out)
370}
371
372fn poly_sub_(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
373 let n = a.len().max(b.len());
374 let mut out = vec![0u64; n];
375 for i in 0..n {
376 let x = if i < a.len() { a[i] } else { 0 };
377 let y = if i < b.len() { b[i] } else { 0 };
378 out[i] = sub_mod(x, y, p);
379 }
380 poly_trim(out)
381}
382
383fn poly_mul(a: &[u64], b: &[u64], p: u64) -> Vec<u64> {
384 if a.is_empty() || b.is_empty() || (a.len() == 1 && a[0] == 0) || (b.len() == 1 && b[0] == 0) {
385 return vec![0];
386 }
387 let da = poly_deg(a);
388 let db = poly_deg(b);
389 if da < 0 || db < 0 {
390 return vec![0];
391 }
392 let mut out = vec![0u64; (da + db + 1) as usize];
393 for i in 0..=da as usize {
394 for j in 0..=db as usize {
395 out[i + j] = add_mod(out[i + j], mul_mod(a[i], b[j], p), p);
396 }
397 }
398 poly_trim(out)
399}
400
401fn poly_divmod(dividend: &[u64], divisor: &[u64], p: u64) -> Option<(Vec<u64>, Vec<u64>)> {
403 let mut a = poly_trim(dividend.to_vec());
404 let b = poly_trim(divisor.to_vec());
405 if poly_deg(&b) < 0 {
406 return None;
407 }
408 let db = b.len() - 1;
409 let lb = *b.last().unwrap();
410 let inv_lb = mod_inv(lb, p);
411
412 let deg_a = poly_deg(&a);
413 if deg_a < db as i32 {
414 return Some((vec![0], a));
415 }
416
417 let q_len = (deg_a - db as i32 + 1) as usize;
418 let mut quot = vec![0u64; q_len];
419
420 while poly_deg(&a) >= db as i32 {
421 let da = poly_deg(&a) as usize;
422 let shift = da - db;
423 let scale = mul_mod(*a.last().unwrap(), inv_lb, p);
424 quot[shift] = add_mod(quot[shift], scale, p);
425 for j in 0..b.len() {
426 a[j + shift] = sub_mod(a[j + shift], mul_mod(scale, b[j], p), p);
427 }
428 a = poly_trim(a);
429 }
430
431 Some((poly_trim(quot), a))
432}
433
434fn polygcd(a_: &[u64], b_: &[u64], p: u64) -> Vec<u64> {
435 let mut a = poly_trim(a_.to_vec());
436 let mut b = poly_trim(b_.to_vec());
437 while poly_deg(&b) >= 0 {
438 let (_, r) = match poly_divmod(&a, &b, p) {
439 Some(x) => x,
440 None => break,
441 };
442 a = b;
443 b = r;
444 }
445 if poly_deg(&a) < 0 {
446 return vec![0];
447 }
448 poly_make_monic(&a, p)
449}
450
451fn poly_derivative(f: &[u64], p: u64) -> Vec<u64> {
452 let f = poly_trim(f.to_vec());
453 if f.len() <= 1 {
454 return vec![0];
455 }
456 let mut out = Vec::with_capacity(f.len() - 1);
457 for (k, &coeff) in f.iter().enumerate().skip(1) {
458 let d = mul_mod(coeff, k as u64, p);
459 out.push(d);
460 }
461 poly_trim(out)
462}
463
464fn poly_make_monic(f: &[u64], p: u64) -> Vec<u64> {
465 let f = poly_trim(f.to_vec());
466 if f.is_empty() {
467 return f;
468 }
469 let lc = *f.last().unwrap();
470 if lc == 0 {
471 return f;
472 }
473 let inv = mod_inv(lc, p);
474 f.iter().map(|&c| mul_mod(c, inv, p)).collect()
475}
476
477fn poly_squarefree(mut f: Vec<u64>, p: u64) -> Vec<u64> {
479 f = poly_make_monic(&f, p);
480 loop {
481 let dp = poly_derivative(&f, p);
482 let g = polygcd(&f, &dp, p);
483 let dg = poly_deg(&g);
484 if dg <= 0 {
485 break;
486 }
487 let (_, r) = poly_divmod(&f, &g, p).unwrap();
488 f = poly_make_monic(&r, p);
489 }
490 f
491}
492
493fn poly_mul_mod(a: &[u64], b: &[u64], modulo: &[u64], p: u64) -> Vec<u64> {
494 let prod = poly_mul(a, b, p);
495 poly_divmod(&prod, modulo, p)
496 .map(|(_, r)| r)
497 .unwrap_or(vec![0])
498}
499
500fn poly_pow_mod(base: &[u64], mut exp: u64, m: &[u64], p: u64) -> Vec<u64> {
502 let m = poly_trim(m.to_vec());
503 if poly_deg(&m) < 0 {
504 return vec![0];
505 }
506 let mut acc = vec![1u64];
507 let mut b = poly_divmod(&poly_trim(base.to_vec()), &m, p)
508 .map(|(_, r)| r)
509 .unwrap_or(vec![0]);
510 while exp > 0 {
511 if exp & 1 != 0 {
512 acc = poly_mul_mod(&acc, &b, &m, p);
513 }
514 b = poly_mul_mod(&b, &b, &m, p);
515 exp >>= 1;
516 }
517 acc
518}
519
520fn poly_random_below(max_deg: usize, p: u64, rng: &mut Xorshift64) -> Vec<u64> {
522 if max_deg == 0 {
523 return vec![0];
524 }
525 let mut c: Vec<u64> = (0..max_deg).map(|_| rng.next_range(0, p)).collect();
526 if c.iter().all(|&x| x == 0) {
527 c[rng.next_range(0, max_deg as u64) as usize] = rng.nonzero(p);
528 }
529 poly_trim(c)
530}
531
532fn find_roots(poly: &[u64], p: u64, rng: &mut Xorshift64) -> Result<Vec<u64>, SparseInterpError> {
535 let mut f = poly_trim(poly.to_vec());
536 if poly_deg(&f) < 0 {
537 return Ok(vec![]);
538 }
539 if p == 2 {
540 let mut r = Vec::new();
541 for v in 0..p {
542 if poly_eval(&f, v, p) == 0 {
543 r.push(v);
544 }
545 }
546 return Ok(r);
547 }
548 f = poly_squarefree(f, p);
549 if poly_deg(&f) < 0 {
550 return Ok(vec![]);
551 }
552 if poly_deg(&f) == 0 {
553 return Ok(vec![]);
554 }
555
556 let xp = poly_pow_mod(&[0, 1], p, &f, p);
558 let diff = poly_sub_(&xp, &[0, 1], p);
559 let mut h = polygcd(&f, &diff, p);
560 if poly_deg(&h) < 0 {
561 h = f;
562 }
563
564 let mut roots = Vec::new();
565 split_find_roots(&h, p, rng, &mut roots)?;
566 roots.sort_unstable();
567 roots.dedup();
568 Ok(roots)
569}
570
571fn split_find_roots(
572 f: &[u64],
573 p: u64,
574 rng: &mut Xorshift64,
575 roots: &mut Vec<u64>,
576) -> Result<(), SparseInterpError> {
577 let f = poly_make_monic(f, p);
578 let d = poly_deg(&f);
579 if d < 0 {
580 return Ok(());
581 }
582 if d == 0 {
583 return Ok(());
584 }
585 if d == 1 {
586 let a0 = sub_mod(0, f[0], p);
587 roots.push(a0);
588 return Ok(());
589 }
590
591 const MAX_TRIES: usize = 256;
594 for _ in 0..MAX_TRIES {
595 let u = poly_random_below(d as usize, p, rng);
596 let exp = (p - 1) / 2;
597 let up = poly_pow_mod(&u, exp, &f, p);
598 for g in [poly_sub_(&up, &[1], p), poly_add(&up, &[1], p)] {
599 let d1 = polygcd(&f, &g, p);
600 let d1deg = poly_deg(&d1);
601 if d1deg > 0 && d1deg < d {
602 let (cofactor, rem) = poly_divmod(&f, &d1, p).unwrap();
603 if poly_deg(&rem) >= 0 {
605 continue;
606 }
607 split_find_roots(&d1, p, rng, roots)?;
608 split_find_roots(&poly_make_monic(&cofactor, p), p, rng, roots)?;
609 return Ok(());
610 }
611 }
612 }
613 if (d as u128) * (p as u128) <= 2_500_000 {
615 for v in 0..p {
616 if poly_eval(&f, v, p) == 0 {
617 roots.push(v);
618 }
619 }
620 return Ok(());
621 }
622 Err(SparseInterpError::RootFindingFailed)
623}
624
625pub fn bsgs_dlog(g: u64, target: u64, p: u64) -> Option<u64> {
634 if target == 0 {
635 return None; }
637 let order = p - 1; let m = (order as f64).sqrt().ceil() as u64 + 1;
639
640 let mut table = std::collections::HashMap::with_capacity(m as usize);
642 let mut gj = 1u64;
643 for j in 0..m {
644 table.insert(gj, j);
645 gj = mul_mod(gj, g, p);
646 }
647
648 let gm = pow_mod(g, m, p);
650 let gm_inv = mod_inv(gm, p);
651 let mut y = target;
652 for i in 0..m {
653 if let Some(&j) = table.get(&y) {
654 let e = i * m + j;
655 let e_mod = e % order;
656 if pow_mod(g, e_mod, p) == target {
658 return Some(e_mod);
659 }
660 }
661 y = mul_mod(y, gm_inv, p);
662 }
663 None
664}
665
666fn vandermonde_solve(pts: &[u64], exps: &[u32], vals: &[u64], p: u64) -> Option<Vec<u64>> {
678 let t = pts.len();
679 debug_assert_eq!(exps.len(), t);
680 debug_assert_eq!(vals.len(), t);
681
682 let mut mat: Vec<Vec<u64>> = (0..t)
684 .map(|i| (0..t).map(|j| pow_mod(pts[i], exps[j] as u64, p)).collect())
685 .collect();
686 let mut rhs: Vec<u64> = vals.to_vec();
687
688 gaussian_elim(&mut mat, &mut rhs, p)
689}
690
691fn gaussian_elim(mat: &mut [Vec<u64>], rhs: &mut [u64], p: u64) -> Option<Vec<u64>> {
695 let n = mat.len();
696 for col in 0..n {
697 let pivot_row = (col..n).find(|&r| mat[r][col] != 0)?;
699 mat.swap(col, pivot_row);
700 rhs.swap(col, pivot_row);
701
702 let inv = mod_inv(mat[col][col], p);
703 for entry in &mut mat[col][col..] {
705 *entry = mul_mod(*entry, inv, p);
706 }
707 rhs[col] = mul_mod(rhs[col], inv, p);
708
709 for row in 0..n {
711 if row == col {
712 continue;
713 }
714 let factor = mat[row][col];
715 if factor == 0 {
716 continue;
717 }
718 let pivot_row_vals: Vec<u64> = mat[col][col..].to_vec();
720 for (j, &pv) in pivot_row_vals.iter().enumerate() {
721 let sub = mul_mod(factor, pv, p);
722 mat[row][col + j] = sub_mod(mat[row][col + j], sub, p);
723 }
724 let sub = mul_mod(factor, rhs[col], p);
725 rhs[row] = sub_mod(rhs[row], sub, p);
726 }
727 }
728 Some(rhs.to_owned())
729}
730
731fn bt_univariate(
738 eval: &dyn Fn(u64) -> u64,
739 term_bound: usize,
740 prime: u64,
741 g: u64, rng: &mut Xorshift64,
743) -> Result<Vec<(u64, u32)>, SparseInterpError> {
744 if term_bound == 0 {
745 return Ok(vec![]);
746 }
747 let two_t = 2 * term_bound;
748
749 let mut seq = Vec::with_capacity(two_t);
751 let mut gj = 1u64; for _ in 0..two_t {
753 seq.push(eval(gj));
754 gj = mul_mod(gj, g, prime);
755 }
756
757 let lambda = berlekamp_massey(&seq, prime);
759 let ell = lambda.len() - 1; if ell == 0 {
762 return Ok(vec![]);
764 }
765
766 let rho_roots = find_roots(&lambda, prime, rng)?;
768
769 if rho_roots.len() < ell {
770 return Err(SparseInterpError::RootFindingFailed);
771 }
772 let rho: &[u64] = &rho_roots[..ell];
774
775 let mut exps: Vec<u32> = Vec::with_capacity(ell);
779 for &ro in rho {
780 if ro == 0 {
781 return Err(SparseInterpError::RootFindingFailed);
782 }
783 let r = mod_inv(ro, prime); let e = bsgs_dlog(g, r, prime).ok_or(SparseInterpError::RootFindingFailed)?;
785 exps.push(e as u32);
786 }
787
788 let pts_for_vdm: Vec<u64> = (0..ell).map(|i| pow_mod(g, i as u64, prime)).collect();
796 let vals_for_vdm: Vec<u64> = seq[..ell].to_vec();
797 let coeffs = vandermonde_solve(&pts_for_vdm, &exps, &vals_for_vdm, prime)
798 .ok_or(SparseInterpError::SingularSystem)?;
799
800 Ok(coeffs
801 .into_iter()
802 .zip(exps)
803 .filter(|(c, _)| *c != 0)
804 .collect())
805}
806
807fn dense_interpolate(vals: &[u64], prime: u64) -> Vec<(u64, u32)> {
817 let n = vals.len();
818 let pts: Vec<u64> = (1..=n as u64).collect();
820 let mut mat: Vec<Vec<u64>> = (0..n)
822 .map(|i| (0..n).map(|j| pow_mod(pts[i], j as u64, prime)).collect())
823 .collect();
824 let mut rhs = vals.to_vec();
825 match gaussian_elim(&mut mat, &mut rhs, prime) {
826 Some(coeffs) => coeffs
827 .into_iter()
828 .enumerate()
829 .filter(|(_, c)| *c != 0)
830 .map(|(j, c)| (c, j as u32))
831 .collect(),
832 None => vec![], }
834}
835
836fn lifted_eval_union(
842 x_pts: &[u64],
843 joint_exps: &[u32],
844 eval_multi: &dyn Fn(&[u64]) -> Vec<u64>,
845 prime: u64,
846 dim: usize,
847 m_count: usize,
848 x_suffix: &[u64],
849) -> Vec<u64> {
850 let mut new_vec = Vec::with_capacity(dim * m_count);
851 for j in 0..dim {
852 let f_vals: Vec<u64> = x_pts
853 .iter()
854 .map(|&xk| {
855 let mut args = vec![xk];
856 args.extend_from_slice(x_suffix);
857 eval_multi(&args).get(j).copied().unwrap_or(0)
858 })
859 .collect();
860 let coeffs = vandermonde_solve(x_pts, joint_exps, &f_vals, prime)
861 .unwrap_or_else(|| vec![0u64; m_count]);
862 debug_assert_eq!(coeffs.len(), m_count);
863 new_vec.extend(coeffs);
864 }
865 new_vec
866}
867
868#[allow(clippy::too_many_arguments)] fn zippel_helper_multi(
872 eval_multi: &dyn Fn(&[u64]) -> Vec<u64>,
873 n_vars: usize,
874 dim: usize,
875 term_bound: usize,
876 degree_bound: u32,
877 prime: u64,
878 g: u64,
879 rng: &mut Xorshift64,
880) -> Result<Vec<BTreeMap<Vec<u32>, u64>>, SparseInterpError> {
881 if dim == 0 {
882 return Ok(vec![]);
883 }
884
885 if n_vars == 0 {
886 let v = eval_multi(&[]);
887 let mut out = Vec::with_capacity(dim);
888 for j in 0..dim {
889 let mut m = BTreeMap::new();
890 let c = *v.get(j).unwrap_or(&0);
891 if c != 0 {
892 m.insert(vec![], c);
893 }
894 out.push(m);
895 }
896 return Ok(out);
897 }
898
899 if n_vars == 1 {
900 let mut out = Vec::with_capacity(dim);
901 for j in 0..dim {
902 let terms = if degree_bound <= term_bound as u32 {
903 let d = degree_bound as usize + 1;
904 let vals: Vec<u64> = (1..=d as u64)
905 .map(|x| eval_multi(&[x % prime]).get(j).copied().unwrap_or(0))
906 .collect();
907 dense_interpolate(&vals, prime)
908 } else {
909 bt_univariate(
910 &|t| eval_multi(&[t]).get(j).copied().unwrap_or(0),
911 term_bound,
912 prime,
913 g,
914 rng,
915 )?
916 };
917 let mut m = BTreeMap::new();
918 for (c, e) in terms {
919 if c != 0 {
920 m.insert(vec![e], c);
921 }
922 }
923 out.push(m);
924 }
925 return Ok(out);
926 }
927
928 let a_rest: Vec<u64> = (0..n_vars - 1).map(|_| rng.nonzero(prime)).collect();
929
930 let mut per_comp_skeletons: Vec<Vec<(u64, u32)>> = Vec::with_capacity(dim);
931 for j in 0..dim {
932 let sk = {
933 let f1 = |t: u64| -> u64 {
934 let mut args = vec![t];
935 args.extend_from_slice(&a_rest);
936 eval_multi(&args).get(j).copied().unwrap_or(0)
937 };
938 if degree_bound <= term_bound as u32 {
939 let d = degree_bound as usize + 1;
940 let v: Vec<u64> = (1..=d as u64).map(|x| f1(x % prime)).collect();
941 dense_interpolate(&v, prime)
942 } else {
943 bt_univariate(&f1, term_bound, prime, g, rng)?
944 }
945 };
946 per_comp_skeletons.push(sk);
947 }
948
949 let mut joint_exps: Vec<u32> = Vec::new();
950 for sk in &per_comp_skeletons {
951 for &(_, e) in sk {
952 joint_exps.push(e);
953 }
954 }
955 joint_exps.sort_unstable();
956 joint_exps.dedup();
957 let m_count = joint_exps.len();
958
959 let empty_maps = || (0..dim).map(|_| BTreeMap::new()).collect::<Vec<_>>();
960
961 if m_count == 0 {
962 return Ok(empty_maps());
963 }
964
965 let vec_budget = term_bound.saturating_mul(512).clamp(8192usize, 131072usize);
972 if dim.saturating_mul(m_count) > vec_budget {
973 let mut stacked: Vec<BTreeMap<Vec<u32>, u64>> = Vec::with_capacity(dim);
974 for (j, sk) in per_comp_skeletons.iter().enumerate().take(dim) {
975 if sk.is_empty() {
976 stacked.push(BTreeMap::new());
977 continue;
978 }
979 let exps_j: Vec<u32> = sk.iter().map(|(_, e)| *e).collect();
980 let tj = exps_j.len();
981 let mut pts: Vec<u64> = Vec::with_capacity(tj);
982 {
983 let mut used = std::collections::HashSet::new();
984 while pts.len() < tj {
985 let v = rng.nonzero(prime);
986 if used.insert(v) {
987 pts.push(v);
988 }
989 }
990 }
991 let mut comp_map = BTreeMap::new();
992 for k in 0..tj {
993 let e_cur = exps_j[k];
994 let sub_terms = zippel_helper(
995 &|x_rest: &[u64]| -> u64 {
996 let f_vals: Vec<u64> = pts
997 .iter()
998 .map(|&xk| {
999 let mut args = vec![xk];
1000 args.extend_from_slice(x_rest);
1001 eval_multi(&args).get(j).copied().unwrap_or(0)
1002 })
1003 .collect();
1004 vandermonde_solve(&pts, &exps_j, &f_vals, prime)
1005 .map(|v| v[k])
1006 .unwrap_or(0)
1007 },
1008 n_vars - 1,
1009 term_bound,
1010 degree_bound,
1011 prime,
1012 g,
1013 rng,
1014 )?;
1015 for (mut sub_exp, coeff) in sub_terms {
1016 if coeff != 0 {
1017 let mut full = vec![e_cur];
1018 full.append(&mut sub_exp);
1019 comp_map.insert(full, coeff);
1020 }
1021 }
1022 }
1023 stacked.push(comp_map);
1024 }
1025 return Ok(stacked);
1026 }
1027
1028 let mut x_pts: Vec<u64> = Vec::with_capacity(m_count);
1029 {
1030 let mut used = std::collections::HashSet::new();
1031 while x_pts.len() < m_count {
1032 let v = rng.nonzero(prime);
1033 if used.insert(v) {
1034 x_pts.push(v);
1035 }
1036 }
1037 }
1038
1039 let dim_next = dim * m_count;
1040 let sub = zippel_helper_multi(
1041 &|x_suffix: &[u64]| {
1042 lifted_eval_union(
1043 &x_pts,
1044 &joint_exps,
1045 eval_multi,
1046 prime,
1047 dim,
1048 m_count,
1049 x_suffix,
1050 )
1051 },
1052 n_vars - 1,
1053 dim_next,
1054 term_bound,
1055 degree_bound,
1056 prime,
1057 g,
1058 rng,
1059 )?;
1060
1061 let mut result: Vec<BTreeMap<Vec<u32>, u64>> = empty_maps();
1062 for (j, res_j) in result.iter_mut().enumerate().take(dim) {
1063 for (r, &e1) in joint_exps.iter().enumerate().take(m_count) {
1064 let slot = j * m_count + r;
1065 for (sub_exp, coeff) in &sub[slot] {
1066 if *coeff != 0 {
1067 let mut full_exp = vec![e1];
1068 full_exp.extend_from_slice(sub_exp);
1069 res_j.insert(full_exp, *coeff);
1070 }
1071 }
1072 }
1073 }
1074
1075 Ok(result)
1076}
1077
1078fn zippel_helper(
1081 eval: &dyn Fn(&[u64]) -> u64,
1082 n_vars: usize,
1083 term_bound: usize,
1084 degree_bound: u32,
1085 prime: u64,
1086 g: u64,
1087 rng: &mut Xorshift64,
1088) -> Result<BTreeMap<Vec<u32>, u64>, SparseInterpError> {
1089 if n_vars == 0 {
1091 let c = eval(&[]);
1092 let mut m = BTreeMap::new();
1093 if c != 0 {
1094 m.insert(vec![], c);
1095 }
1096 return Ok(m);
1097 }
1098
1099 if n_vars == 1 {
1101 let terms = if degree_bound <= term_bound as u32 {
1103 let d = degree_bound as usize + 1;
1105 let v: Vec<u64> = (1..=d as u64).map(|x| eval(&[x % prime])).collect();
1106 dense_interpolate(&v, prime)
1107 } else {
1108 bt_univariate(&|t| eval(&[t]), term_bound, prime, g, rng)?
1109 };
1110 let mut m = BTreeMap::new();
1111 for (c, e) in terms {
1112 m.insert(vec![e], c);
1113 }
1114 return Ok(m);
1115 }
1116
1117 let a_rest: Vec<u64> = (0..n_vars - 1).map(|_| rng.nonzero(prime)).collect();
1121
1122 let skeleton: Vec<(u64, u32)> = {
1123 let f1 = |t: u64| -> u64 {
1124 let mut args = vec![t];
1125 args.extend_from_slice(&a_rest);
1126 eval(&args)
1127 };
1128 if degree_bound <= term_bound as u32 {
1129 let d = degree_bound as usize + 1;
1130 let v: Vec<u64> = (1..=d as u64).map(|x| f1(x % prime)).collect();
1131 dense_interpolate(&v, prime)
1132 } else {
1133 bt_univariate(&f1, term_bound, prime, g, rng)?
1134 }
1135 };
1136
1137 if skeleton.is_empty() {
1138 return Ok(BTreeMap::new());
1139 }
1140
1141 let x1_exps: Vec<u32> = skeleton.iter().map(|(_, e)| *e).collect();
1142 let t = x1_exps.len();
1143
1144 let mut x1_pts: Vec<u64> = Vec::with_capacity(t);
1146 {
1147 let mut used = std::collections::HashSet::new();
1148 while x1_pts.len() < t {
1149 let v = rng.nonzero(prime);
1150 if used.insert(v) {
1151 x1_pts.push(v);
1152 }
1153 }
1154 }
1155
1156 let eval_multi = |x_rest: &[u64]| -> Vec<u64> {
1158 let mut f_vals: Vec<u64> = Vec::with_capacity(t);
1159 for &xk in &x1_pts {
1160 let mut args = vec![xk];
1161 args.extend_from_slice(x_rest);
1162 f_vals.push(eval(&args));
1163 }
1164 vandermonde_solve(&x1_pts, &x1_exps, &f_vals, prime).unwrap_or_else(|| vec![0u64; t])
1165 };
1166
1167 let sub_maps = zippel_helper_multi(
1168 &eval_multi,
1169 n_vars - 1,
1170 t,
1171 term_bound,
1172 degree_bound,
1173 prime,
1174 g,
1175 rng,
1176 )?;
1177
1178 let mut result: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1179 for j in 0..t {
1180 let e1 = x1_exps[j];
1181 for (sub_exp, coeff) in &sub_maps[j] {
1182 if *coeff != 0 {
1183 let mut full_exp = vec![e1];
1184 full_exp.extend_from_slice(sub_exp);
1185 result.insert(full_exp, *coeff);
1186 }
1187 }
1188 }
1189
1190 Ok(result)
1191}
1192
1193pub fn sparse_interpolate_univariate(
1231 eval: &dyn Fn(u64) -> u64,
1232 term_bound: usize,
1233 prime: u64,
1234) -> Result<Vec<(u64, u32)>, SparseInterpError> {
1235 if !is_prime(prime) {
1236 return Err(SparseInterpError::InvalidPrime(prime));
1237 }
1238 if prime <= 2 * term_bound as u64 {
1239 return Err(SparseInterpError::PrimeTooSmall { prime, term_bound });
1240 }
1241 let g = primitive_root(prime);
1242 let mut rng = Xorshift64::new(prime.wrapping_mul(0x5851_f42d_4c95_7f2d));
1243 bt_univariate(eval, term_bound, prime, g, &mut rng)
1244}
1245
1246pub fn sparse_interpolate(
1273 eval: &dyn Fn(&[u64]) -> u64,
1274 vars: Vec<ExprId>,
1275 term_bound: usize,
1276 degree_bound: u32,
1277 prime: u64,
1278 seed: u64,
1279) -> Result<MultiPolyFp, SparseInterpError> {
1280 if !is_prime(prime) {
1281 return Err(SparseInterpError::InvalidPrime(prime));
1282 }
1283 if prime <= 2 * term_bound as u64 {
1284 return Err(SparseInterpError::PrimeTooSmall { prime, term_bound });
1285 }
1286
1287 let n_vars = vars.len();
1288 let g = primitive_root(prime);
1289 let mut rng = Xorshift64::new(seed);
1290
1291 let terms = zippel_helper(eval, n_vars, term_bound, degree_bound, prime, g, &mut rng)?;
1292
1293 let trimmed_terms: BTreeMap<Vec<u32>, u64> = terms
1294 .into_iter()
1295 .map(|(mut exp, c)| {
1296 while exp.last() == Some(&0) {
1298 exp.pop();
1299 }
1300 (exp, c)
1301 })
1302 .filter(|(_, c)| *c != 0)
1303 .collect();
1304
1305 Ok(MultiPolyFp {
1306 vars,
1307 modulus: prime,
1308 terms: trimmed_terms,
1309 })
1310}
1311
1312#[derive(Debug, Clone, PartialEq)]
1318pub enum SparseGcdError {
1319 IncompatiblePolynomials,
1321 InterpFailed(SparseInterpError),
1323 CrtFailed(crate::modular::ModularError),
1325}
1326
1327impl std::fmt::Display for SparseGcdError {
1328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1329 match self {
1330 SparseGcdError::IncompatiblePolynomials => {
1331 write!(f, "polynomials have incompatible variable lists")
1332 }
1333 SparseGcdError::InterpFailed(e) => write!(f, "interpolation step failed: {e}"),
1334 SparseGcdError::CrtFailed(e) => write!(f, "CRT lifting failed: {e}"),
1335 }
1336 }
1337}
1338
1339impl std::error::Error for SparseGcdError {
1340 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
1341 match self {
1342 SparseGcdError::InterpFailed(e) => Some(e),
1343 SparseGcdError::CrtFailed(e) => Some(e),
1344 _ => None,
1345 }
1346 }
1347}
1348
1349impl AlkahestError for SparseGcdError {
1350 fn code(&self) -> &'static str {
1351 match self {
1352 SparseGcdError::IncompatiblePolynomials => "E-INTERP-010",
1353 SparseGcdError::InterpFailed(_) => "E-INTERP-011",
1354 SparseGcdError::CrtFailed(_) => "E-INTERP-012",
1355 }
1356 }
1357
1358 fn remediation(&self) -> Option<&'static str> {
1359 match self {
1360 SparseGcdError::IncompatiblePolynomials => {
1361 Some("ensure both polynomials share the same variable list in the same order")
1362 }
1363 SparseGcdError::InterpFailed(_) => {
1364 Some("retry with a larger term_bound, degree_bound, or a different seed")
1365 }
1366 SparseGcdError::CrtFailed(_) => {
1367 Some("provide more primes or use a larger prime product threshold")
1368 }
1369 }
1370 }
1371}
1372
1373fn specialize_except_first(fp: &MultiPolyFp, vals: &[u64]) -> Vec<u64> {
1376 let p = fp.modulus;
1377 let max_x1 = fp
1378 .terms
1379 .keys()
1380 .map(|e| e.first().copied().unwrap_or(0))
1381 .max()
1382 .unwrap_or(0) as usize;
1383 let mut result = vec![0u64; max_x1 + 1];
1384 for (exp, &coeff) in &fp.terms {
1385 let k = exp.first().copied().unwrap_or(0) as usize;
1386 let mut factor = coeff;
1387 for (i, &e) in exp.iter().skip(1).enumerate() {
1388 if e > 0 {
1389 let ai = *vals.get(i).unwrap_or(&0);
1390 factor = mul_mod(factor, pow_mod(ai, e as u64, p), p);
1391 }
1392 }
1393 result[k] = add_mod(result[k], factor, p);
1394 }
1395 poly_trim(result)
1396}
1397
1398fn gcd_sparse_mod_p(
1401 f_p: &MultiPolyFp,
1402 g_p: &MultiPolyFp,
1403 sub_vars: Vec<ExprId>,
1404 term_bound: usize,
1405 degree_bound: u32,
1406 prime: u64,
1407 seed: u64,
1408) -> Result<MultiPolyFp, SparseInterpError> {
1409 let p = prime;
1410 let vars_full = f_p.vars.clone();
1411
1412 let n_sub = sub_vars.len();
1414 let mut rng = Xorshift64::new(seed ^ p.wrapping_mul(0x9e37_79b9_7f4a_7c15));
1415 let probe_vals: Vec<u64> = (0..n_sub).map(|_| rng.nonzero(p)).collect();
1416 let f1 = specialize_except_first(f_p, &probe_vals);
1417 let g1 = specialize_except_first(g_p, &probe_vals);
1418 let h1 = polygcd(&f1, &g1, p);
1419 let gcd_deg_x1 = poly_deg(&h1).max(0) as usize;
1420
1421 let mut h_terms: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1422
1423 if sub_vars.is_empty() {
1424 for (k, &c) in h1.iter().enumerate() {
1426 if c != 0 {
1427 let mut exp = vec![k as u32];
1428 while exp.last() == Some(&0) {
1429 exp.pop();
1430 }
1431 h_terms.insert(exp, c);
1432 }
1433 }
1434 } else {
1435 for k in 0..=gcd_deg_x1 {
1437 let oracle = |vals: &[u64]| -> u64 {
1438 let fa = specialize_except_first(f_p, vals);
1439 let ga = specialize_except_first(g_p, vals);
1440 let hk = polygcd(&fa, &ga, p);
1441 hk.get(k).copied().unwrap_or(0)
1442 };
1443 let ck = sparse_interpolate(
1444 &oracle,
1445 sub_vars.clone(),
1446 term_bound,
1447 degree_bound,
1448 p,
1449 seed.wrapping_add(k as u64 + 1),
1450 )?;
1451 for (sub_exp, &c) in &ck.terms {
1452 if c == 0 {
1453 continue;
1454 }
1455 let mut full_exp = vec![k as u32];
1456 full_exp.extend_from_slice(sub_exp);
1457 while full_exp.last() == Some(&0) {
1458 full_exp.pop();
1459 }
1460 h_terms.insert(full_exp, c);
1461 }
1462 }
1463 }
1464
1465 Ok(MultiPolyFp {
1466 vars: vars_full,
1467 modulus: prime,
1468 terms: h_terms,
1469 })
1470}
1471
1472pub fn gcd_sparse_modular(
1504 f: &super::multipoly::MultiPoly,
1505 g: &super::multipoly::MultiPoly,
1506 term_bound: usize,
1507 degree_bound: u32,
1508 seed: u64,
1509) -> Result<super::multipoly::MultiPoly, SparseGcdError> {
1510 use crate::modular::{lift_crt, mignotte_bound, reduce_mod};
1511 use rug::Integer;
1512
1513 if f.vars != g.vars {
1514 return Err(SparseGcdError::IncompatiblePolynomials);
1515 }
1516 if f.is_zero() {
1517 return Ok(g.clone());
1518 }
1519 if g.is_zero() {
1520 return Ok(f.clone());
1521 }
1522
1523 let vars = f.vars.clone();
1524 let sub_vars = if vars.len() > 1 {
1525 vars[1..].to_vec()
1526 } else {
1527 vec![]
1528 };
1529
1530 let b_f = mignotte_bound(f);
1531 let b_g = mignotte_bound(g);
1532 let bound = b_f.min(b_g);
1533 let two_bound = bound.clone() << 1u32;
1534
1535 let min_p = ((2 * term_bound + 2) as u64).max(degree_bound as u64 + 2);
1537
1538 let content = f.integer_content() * g.integer_content();
1540
1541 let mut images: Vec<(MultiPolyFp, u64)> = Vec::new();
1542 let mut used: Vec<u64> = Vec::new();
1543 let mut m = Integer::from(1u64);
1544 let mut candidate = min_p.max(3);
1545
1546 while m <= two_bound {
1547 loop {
1549 if is_prime(candidate) && !used.contains(&candidate) {
1550 if content == 0 {
1551 break;
1552 }
1553 let p_int = Integer::from(candidate);
1554 let r = content.clone() % p_int.clone();
1555 let r = if r < 0 { r + p_int } else { r };
1556 if r != 0 {
1557 break;
1558 }
1559 }
1560 candidate += 1;
1561 if candidate > 1_000_003 {
1562 break;
1563 }
1564 }
1565 let p = candidate;
1566 candidate += 1;
1567
1568 let f_p = match reduce_mod(f, p) {
1569 Ok(x) if !x.is_zero() => x,
1570 _ => continue,
1571 };
1572 let g_p = match reduce_mod(g, p) {
1573 Ok(x) if !x.is_zero() => x,
1574 _ => continue,
1575 };
1576
1577 used.push(p);
1578
1579 let h_p = gcd_sparse_mod_p(
1580 &f_p,
1581 &g_p,
1582 sub_vars.clone(),
1583 term_bound,
1584 degree_bound,
1585 p,
1586 seed.wrapping_add(p),
1587 )
1588 .map_err(SparseGcdError::InterpFailed)?;
1589
1590 images.push((h_p, p));
1591 m *= Integer::from(p);
1592 }
1593
1594 let mut result = lift_crt(&images).map_err(SparseGcdError::CrtFailed)?;
1595
1596 if let Some((_, lc)) = result.terms.iter().next_back() {
1598 if lc.cmp0() == std::cmp::Ordering::Less {
1599 result = -result;
1600 }
1601 }
1602 Ok(result.primitive_part())
1603}
1604
1605#[cfg(test)]
1610mod tests {
1611 use super::*;
1612 use crate::kernel::{Domain, ExprPool};
1613
1614 fn make_poly_eval(coeffs: &[(u64, Vec<u32>)], prime: u64) -> impl Fn(&[u64]) -> u64 + '_ {
1617 move |pt: &[u64]| -> u64 {
1618 let mut acc = 0u64;
1619 for (c, exp) in coeffs {
1620 let mut term = *c % prime;
1621 for (i, &e) in exp.iter().enumerate() {
1622 let xi = if i < pt.len() { pt[i] } else { 0 };
1623 term = mul_mod(term, pow_mod(xi, e as u64, prime), prime);
1624 }
1625 acc = add_mod(acc, term, prime);
1626 }
1627 acc
1628 }
1629 }
1630
1631 fn vars(n: usize) -> (ExprPool, Vec<ExprId>) {
1632 let pool = ExprPool::new();
1633 let vs: Vec<ExprId> = (0..n)
1634 .map(|i| pool.symbol(format!("x{i}"), Domain::Real))
1635 .collect();
1636 (pool, vs)
1637 }
1638
1639 #[test]
1642 fn prim_root_small_primes() {
1643 for p in [2u64, 3, 5, 7, 11, 13, 17, 19, 23] {
1644 let g = primitive_root(p);
1645 assert_eq!(pow_mod(g, p - 1, p), 1, "g^(p-1)=1 for p={p}");
1647 for q in prime_factors(p - 1) {
1648 assert_ne!(pow_mod(g, (p - 1) / q, p), 1, "g^((p-1)/{q}) ≠ 1 for p={p}");
1649 }
1650 }
1651 }
1652
1653 #[test]
1656 fn bm_geometric_sequence() {
1657 let p = 7u64;
1660 let seq: Vec<u64> = (0..6).map(|n| pow_mod(2, n, p)).collect();
1661 let lambda = berlekamp_massey(&seq, p);
1662 assert_eq!(lambda.len() - 1, 1, "LFSR length should be 1");
1663 let inv2 = mod_inv(2, p);
1665 assert_eq!(poly_eval(&lambda, inv2, p), 0);
1666 }
1667
1668 #[test]
1669 fn bm_two_term_sequence() {
1670 let p = 11u64;
1672 let seq: Vec<u64> = (0..4)
1673 .map(|n| {
1674 add_mod(
1675 mul_mod(3, pow_mod(2, n, p), p),
1676 mul_mod(5, pow_mod(3, n, p), p),
1677 p,
1678 )
1679 })
1680 .collect();
1681 let lambda = berlekamp_massey(&seq, p);
1682 assert_eq!(lambda.len() - 1, 2, "two-term sequence has LFSR length 2");
1683 let mut rng = Xorshift64::new(0xbeef);
1685 let roots = find_roots(&lambda, p, &mut rng).unwrap();
1686 assert_eq!(roots.len(), 2);
1687 let expected: std::collections::HashSet<u64> =
1688 [mod_inv(2, p), mod_inv(3, p)].into_iter().collect();
1689 let got: std::collections::HashSet<u64> = roots.into_iter().collect();
1690 assert_eq!(got, expected);
1691 }
1692
1693 #[test]
1696 fn dlog_basic() {
1697 let p = 13u64;
1698 let g = primitive_root(p);
1699 for e in 0..p - 1 {
1700 let target = pow_mod(g, e, p);
1701 let found = bsgs_dlog(g, target, p).expect("dlog should succeed");
1702 assert_eq!(
1703 pow_mod(g, found, p),
1704 target,
1705 "g^{found} ≠ {target} for p={p}"
1706 );
1707 }
1708 }
1709
1710 #[test]
1713 fn uni_zero_polynomial() {
1714 let terms = sparse_interpolate_univariate(&|_| 0, 5, 101).unwrap();
1715 assert!(terms.is_empty());
1716 }
1717
1718 #[test]
1719 fn uni_constant() {
1720 let terms = sparse_interpolate_univariate(&|_| 7, 3, 101).unwrap();
1722 assert_eq!(terms.len(), 1);
1723 let (c, e) = terms[0];
1724 assert_eq!(c, 7);
1725 assert_eq!(e, 0);
1726 }
1727
1728 #[test]
1729 fn uni_single_monomial() {
1730 let p = 101u64;
1732 let eval = |x: u64| mul_mod(3, pow_mod(x, 5, p), p);
1733 let terms = sparse_interpolate_univariate(&eval, 3, p).unwrap();
1734 assert_eq!(terms.len(), 1);
1735 let (c, e) = terms[0];
1736 assert_eq!(c, 3);
1737 assert_eq!(e, 5);
1738 }
1739
1740 #[test]
1741 fn uni_two_terms() {
1742 let p = 101u64;
1744 let eval = |x: u64| {
1745 let a = pow_mod(x, 10, p);
1746 let b = mul_mod(2, pow_mod(x, 3, p), p);
1747 add_mod(a, b, p)
1748 };
1749 let terms = sparse_interpolate_univariate(&eval, 3, p).unwrap();
1750 assert_eq!(terms.len(), 2);
1751 let mut sorted = terms.clone();
1752 sorted.sort_by_key(|&(_, e)| e);
1753 assert_eq!(sorted[0], (2, 3));
1754 assert_eq!(sorted[1], (1, 10));
1755 }
1756
1757 #[test]
1758 fn uni_roadmap_example() {
1759 let p = 997u64;
1762 let eval = |x: u64| {
1763 let a = pow_mod(x, 100, p);
1764 let b = mul_mod(3, pow_mod(x, 17, p), p);
1765 let c = 5u64;
1766 add_mod(add_mod(a, b, p), c, p)
1767 };
1768 let terms = sparse_interpolate_univariate(&eval, 4, p).unwrap();
1769 let mut sorted = terms.clone();
1770 sorted.sort_by_key(|&(_, e)| e);
1771 assert!(
1773 sorted.iter().any(|&(c, e)| c == 5 && e == 0),
1774 "missing constant 5"
1775 );
1776 assert!(
1777 sorted.iter().any(|&(c, e)| c == 3 && e == 17),
1778 "missing 3·x^17"
1779 );
1780 assert!(
1781 sorted.iter().any(|&(c, e)| c == 1 && e == 100),
1782 "missing x^100"
1783 );
1784 }
1785
1786 #[test]
1787 fn uni_error_invalid_prime() {
1788 let err = sparse_interpolate_univariate(&|_| 0, 3, 4);
1789 assert!(matches!(err, Err(SparseInterpError::InvalidPrime(4))));
1790 }
1791
1792 #[test]
1793 fn uni_error_prime_too_small() {
1794 let err = sparse_interpolate_univariate(&|_| 0, 10, 19);
1796 assert!(matches!(
1797 err,
1798 Err(SparseInterpError::PrimeTooSmall {
1799 prime: 19,
1800 term_bound: 10
1801 })
1802 ));
1803 }
1804
1805 #[test]
1808 fn multi_constant() {
1809 let (_, vs) = vars(2);
1810 let result = sparse_interpolate(&|_| 42, vs, 3, 10, 101, 0).unwrap();
1811 assert_eq!(result.terms.len(), 1);
1812 assert_eq!(*result.terms.get(&vec![]).unwrap(), 42u64);
1813 }
1814
1815 #[test]
1816 fn multi_univariate_via_multi() {
1817 let p = 101u64;
1819 let (_, vs) = vars(1);
1820 let eval = |pt: &[u64]| {
1821 let x = pt[0];
1822 add_mod(add_mod(pow_mod(x, 2, p), mul_mod(3, x, p), p), 1, p)
1823 };
1824 let result = sparse_interpolate(&eval, vs, 5, 10, p, 0).unwrap();
1825 assert_eq!(*result.terms.get(&vec![2]).unwrap(), 1u64, "x^2 coeff");
1827 assert_eq!(*result.terms.get(&vec![1]).unwrap(), 3u64, "x^1 coeff");
1828 assert_eq!(*result.terms.get(&vec![]).unwrap_or(&0), 1u64, "x^0 coeff");
1829 }
1830
1831 #[test]
1832 fn multi_bivariate_xy() {
1833 let p = 101u64;
1835 let (_, vs) = vars(2);
1836 let eval = |pt: &[u64]| add_mod(mul_mod(pt[0], pt[1], p), 3, p);
1837 let result = sparse_interpolate(&eval, vs, 4, 5, p, 1).unwrap();
1838 assert_eq!(
1840 *result.terms.get(&vec![1, 1]).unwrap_or(&0),
1841 1u64,
1842 "x*y coeff"
1843 );
1844 assert_eq!(*result.terms.get(&vec![]).unwrap_or(&0), 3u64, "constant");
1845 }
1846
1847 #[test]
1848 fn multi_bivariate_x_squared_y() {
1849 let p = 101u64;
1851 let (_, vs) = vars(2);
1852 let eval = |pt: &[u64]| {
1853 let x = pt[0];
1854 let y = pt[1];
1855 let a = mul_mod(pow_mod(x, 2, p), y, p);
1856 let b = mul_mod(5, y, p);
1857 let c = mul_mod(2, x, p);
1858 add_mod(add_mod(a, b, p), c, p)
1859 };
1860 let result = sparse_interpolate(&eval, vs, 5, 6, p, 42).unwrap();
1861 assert_eq!(*result.terms.get(&vec![2, 1]).unwrap_or(&0), 1, "x^2*y");
1862 assert_eq!(*result.terms.get(&vec![0, 1]).unwrap_or(&0), 5, "5*y");
1863 assert_eq!(*result.terms.get(&vec![1]).unwrap_or(&0), 2, "2*x");
1864 }
1865
1866 #[test]
1867 fn multi_three_variables() {
1868 let p = 1009u64;
1870 let (_, vs) = vars(3);
1871 let eval = |pt: &[u64]| {
1872 let x = pt[0];
1873 let y = pt[1];
1874 let z = pt[2];
1875 let xyz = mul_mod(mul_mod(x, y, p), z, p);
1876 let x2 = pow_mod(x, 2, p);
1877 add_mod(add_mod(xyz, x2, p), z, p)
1878 };
1879 let result = sparse_interpolate(&eval, vs, 5, 4, p, 7).unwrap();
1880 assert_eq!(*result.terms.get(&vec![1, 1, 1]).unwrap_or(&0), 1, "x*y*z");
1881 assert_eq!(*result.terms.get(&vec![2]).unwrap_or(&0), 1, "x^2");
1882 assert_eq!(*result.terms.get(&vec![0, 0, 1]).unwrap_or(&0), 1, "z");
1883 }
1884
1885 #[test]
1886 fn multi_roundtrip_via_multipoly() {
1887 use crate::poly::multipoly::MultiPoly;
1889 let p = 1009u64;
1890 let pool = ExprPool::new();
1891 let x = pool.symbol("x", Domain::Real);
1892 let y = pool.symbol("y", Domain::Real);
1893
1894 let x3 = pool.pow(x, pool.integer(3_i32));
1896 let xy = pool.mul(vec![pool.integer(2_i32), x, y]);
1897 let y2 = pool.mul(vec![pool.integer(-1_i32), pool.pow(y, pool.integer(2_i32))]);
1898 let expr = pool.add(vec![x3, xy, y2, pool.integer(4_i32)]);
1899
1900 let mp = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
1901 let fp_ref = crate::modular::reduce_mod(&mp, p).unwrap();
1902
1903 let vars_for_interp = vec![x, y];
1905 let eval = |pt: &[u64]| {
1906 let mut acc = 0u64;
1907 for (exp, coeff) in &mp.terms {
1908 let c_mod = {
1909 let r = coeff.clone() % rug::Integer::from(p);
1910 let r = if r < 0 { r + rug::Integer::from(p) } else { r };
1911 r.to_u64().unwrap()
1912 };
1913 let mut term = c_mod;
1914 for (i, &e) in exp.iter().enumerate() {
1915 let xi = if i < pt.len() { pt[i] } else { 0 };
1916 term = mul_mod(term, pow_mod(xi, e as u64, p), p);
1917 }
1918 acc = add_mod(acc, term, p);
1919 }
1920 acc
1921 };
1922
1923 let recovered = sparse_interpolate(&eval, vars_for_interp, 6, 5, p, 0).unwrap();
1924
1925 for (exp, &coeff) in &recovered.terms {
1927 let ref_coeff = fp_ref.terms.get(exp).copied().unwrap_or(0);
1928 assert_eq!(coeff, ref_coeff, "mismatch at exp {:?}", exp);
1929 }
1930 for (exp, &ref_coeff) in &fp_ref.terms {
1932 let got = recovered.terms.get(exp).copied().unwrap_or(0);
1933 assert_eq!(got, ref_coeff, "missed term at exp {:?}", exp);
1934 }
1935 }
1936
1937 #[test]
1938 fn multi_diag_15term_three_var_smoke() {
1939 let p = 32749u64;
1941 let n_vars = 3;
1942 let n_terms = n_vars;
1943 let mut terms = Vec::new();
1944 for i in 0..n_terms {
1945 let mut coeff = (((i + 1) as u64) * 7) % p;
1946 if coeff == 0 {
1947 coeff = 1;
1948 }
1949 let mut exp = vec![0u32; n_vars];
1950 exp[i % n_vars] = (i % 3) as u32 + 1;
1951 terms.push((coeff, exp));
1952 }
1953 let eval_fn = make_poly_eval(&terms, p);
1954 let (_, vs) = vars(n_vars);
1955 let mut expected: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
1956 for (c, exp) in &terms {
1957 let mut e = exp.clone();
1958 while e.last() == Some(&0) {
1959 e.pop();
1960 }
1961 let nc = *c % p;
1962 expected
1963 .entry(e)
1964 .and_modify(|v| {
1965 *v = add_mod(*v, nc, p);
1966 })
1967 .or_insert(nc);
1968 }
1969
1970 let mut successes = 0usize;
1971 for seed in [0_u64, 1, 2, 41] {
1972 let result = sparse_interpolate(&eval_fn, vs.clone(), n_terms + 5, 4, p, seed)
1973 .expect("smoke interpolate should succeed");
1974 let mut ok = result.terms.len() == expected.len();
1975 for (exp, &ec) in &expected {
1976 if result.terms.get(exp).copied().unwrap_or(0) != ec {
1977 ok = false;
1978 }
1979 }
1980 if ok {
1981 successes += 1;
1982 }
1983 }
1984 assert!(successes >= 3, "expected ≥ 3 successes on diagonal smoke");
1985 }
1986
1987 #[test]
1988 #[ignore]
1989 fn multi_interp_diag_large_stress_slow() {
1990 let p = 32749u64;
1995 let n_vars = 6;
1996 let n_terms = 15;
1997 let mut terms = Vec::new();
1998 for i in 0..n_terms {
1999 let mut coeff = (((i + 1) as u64) * 7) % p;
2000 if coeff == 0 {
2001 coeff = 1;
2002 }
2003 let mut exp = vec![0u32; n_vars];
2004 exp[i % n_vars] = (i % 3) as u32 + 1;
2005 terms.push((coeff, exp));
2006 }
2007 let eval_fn = make_poly_eval(&terms, p);
2008 let (_, vs) = vars(n_vars);
2009 let mut expected: BTreeMap<Vec<u32>, u64> = BTreeMap::new();
2010 for (c, exp) in &terms {
2011 let mut e = exp.clone();
2012 while e.last() == Some(&0) {
2013 e.pop();
2014 }
2015 let nc = *c % p;
2016 expected
2017 .entry(e)
2018 .and_modify(|v| {
2019 *v = add_mod(*v, nc, p);
2020 })
2021 .or_insert(nc);
2022 }
2023
2024 let result = sparse_interpolate(&eval_fn, vs.clone(), n_terms + 5, 4, p, 7)
2025 .expect("stress interpolate should succeed");
2026 assert_eq!(result.terms.len(), expected.len());
2027 for (exp, &ec) in &expected {
2028 assert_eq!(result.terms.get(exp).copied().unwrap_or(0), ec);
2029 }
2030 }
2031
2032 fn vars_n(n: usize) -> (ExprPool, Vec<crate::kernel::ExprId>) {
2035 let pool = ExprPool::new();
2036 let vs: Vec<_> = (0..n)
2037 .map(|i| pool.symbol(format!("x{i}"), Domain::Real))
2038 .collect();
2039 (pool, vs)
2040 }
2041
2042 fn mp(
2043 expr: crate::kernel::ExprId,
2044 vids: Vec<crate::kernel::ExprId>,
2045 pool: &ExprPool,
2046 ) -> crate::poly::multipoly::MultiPoly {
2047 crate::poly::multipoly::MultiPoly::from_symbolic(expr, vids, pool)
2048 .expect("valid polynomial")
2049 }
2050
2051 #[test]
2052 fn gcd_sparse_univariate_linear_factor() {
2053 let (pool, vs) = vars_n(1);
2055 let x = vs[0];
2056 let neg1 = pool.integer(-1i32);
2057 let neg2 = pool.integer(-2i32);
2058 let _one = pool.integer(1i32);
2059 let f = mp(
2061 pool.add(vec![pool.pow(x, pool.integer(2i32)), neg1]),
2062 vec![x],
2063 &pool,
2064 );
2065 let g = mp(
2067 pool.add(vec![
2068 pool.pow(x, pool.integer(2i32)),
2069 pool.mul(vec![neg1, x]),
2070 neg2,
2071 ]),
2072 vec![x],
2073 &pool,
2074 );
2075 let h = gcd_sparse_modular(&f, &g, 3, 3, 0).expect("gcd should succeed");
2076 assert_eq!(h.terms.len(), 2, "GCD should have 2 terms: {h:?}");
2078 assert_eq!(
2079 h.terms.get(&vec![1u32]).cloned(),
2080 Some(rug::Integer::from(1)),
2081 "leading coeff of x should be 1"
2082 );
2083 let empty: Vec<u32> = vec![];
2084 assert_eq!(
2085 h.terms.get(&empty).cloned(),
2086 Some(rug::Integer::from(1)),
2087 "constant should be 1"
2088 );
2089 }
2090
2091 #[test]
2092 fn gcd_sparse_univariate_coprime() {
2093 let (pool, vs) = vars_n(1);
2095 let x = vs[0];
2096 let f = mp(x, vec![x], &pool);
2097 let g = mp(pool.add(vec![x, pool.integer(1i32)]), vec![x], &pool);
2098 let h = gcd_sparse_modular(&f, &g, 2, 2, 0).expect("gcd should succeed");
2099 let empty: Vec<u32> = vec![];
2101 let constant = h.terms.get(&empty).cloned().unwrap_or_default();
2102 assert_eq!(
2103 constant,
2104 rug::Integer::from(1),
2105 "GCD of coprime polys should be 1, got {h:?}"
2106 );
2107 }
2108
2109 #[test]
2110 fn gcd_sparse_bivariate_common_factor() {
2111 let (pool, vs) = vars_n(2);
2113 let x = vs[0];
2114 let y = vs[1];
2115 let xpy = pool.add(vec![x, y]);
2116 let _xmy = pool.add(vec![x, pool.mul(vec![pool.integer(-1i32), y])]);
2117 let xp1 = pool.add(vec![x, pool.integer(1i32)]);
2118 let f = mp(
2120 pool.add(vec![
2121 pool.pow(x, pool.integer(2i32)),
2122 pool.mul(vec![pool.integer(-1i32), pool.pow(y, pool.integer(2i32))]),
2123 ]),
2124 vec![x, y],
2125 &pool,
2126 );
2127 let g = mp(pool.mul(vec![xpy, xp1]), vec![x, y], &pool);
2129 let h = gcd_sparse_modular(&f, &g, 3, 2, 0).expect("gcd should succeed");
2130 assert_eq!(h.terms.len(), 2, "GCD = x+y should have 2 terms, got {h:?}");
2132 let coeff_x = h.terms.get(&vec![1u32]).cloned();
2135 let coeff_y = h.terms.get(&vec![0u32, 1u32]).cloned();
2136 assert_eq!(
2137 coeff_x,
2138 Some(rug::Integer::from(1)),
2139 "coeff of x should be 1"
2140 );
2141 assert_eq!(
2142 coeff_y,
2143 Some(rug::Integer::from(1)),
2144 "coeff of y should be 1"
2145 );
2146 }
2147}