use crate::errors::AlkahestError;
use crate::kernel::ExprId;
use crate::poly::MultiPoly;
use rug::Integer;
use std::collections::BTreeMap;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct MultiPolyFp {
pub vars: Vec<ExprId>,
pub modulus: u64,
pub terms: BTreeMap<Vec<u32>, u64>,
}
impl MultiPolyFp {
pub fn zero(vars: Vec<ExprId>, modulus: u64) -> Self {
MultiPolyFp {
vars,
modulus,
terms: BTreeMap::new(),
}
}
pub fn is_zero(&self) -> bool {
self.terms.is_empty()
}
pub fn total_degree(&self) -> u32 {
self.terms
.keys()
.map(|e| e.iter().sum::<u32>())
.max()
.unwrap_or(0)
}
pub fn compatible_with(&self, other: &Self) -> bool {
self.vars == other.vars && self.modulus == other.modulus
}
}
impl std::fmt::Display for MultiPolyFp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_zero() {
return write!(f, "0 (mod {})", self.modulus);
}
let mut first = true;
for (exp, coeff) in &self.terms {
if !first {
write!(f, " + ")?;
}
first = false;
write!(f, "{coeff}")?;
for (i, &e) in exp.iter().enumerate() {
if e == 0 {
continue;
}
if e == 1 {
write!(f, "*x{i}")?;
} else {
write!(f, "*x{i}^{e}")?;
}
}
}
write!(f, " (mod {})", self.modulus)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModularValue {
pub value: u64,
pub modulus: u64,
}
impl ModularValue {
pub fn new(value: u64, modulus: u64) -> Self {
debug_assert!(
value < modulus,
"ModularValue: value must be in [0, modulus)"
);
ModularValue { value, modulus }
}
pub fn zero(modulus: u64) -> Self {
ModularValue { value: 0, modulus }
}
pub fn one(modulus: u64) -> Self {
ModularValue {
value: if modulus > 1 { 1 } else { 0 },
modulus,
}
}
pub fn add(&self, other: &Self) -> Self {
debug_assert_eq!(
self.modulus, other.modulus,
"ModularValue: mismatched moduli"
);
let v = ((self.value as u128 + other.value as u128) % self.modulus as u128) as u64;
ModularValue::new(v, self.modulus)
}
pub fn sub(&self, other: &Self) -> Self {
debug_assert_eq!(
self.modulus, other.modulus,
"ModularValue: mismatched moduli"
);
let v = (self.value + self.modulus - other.value % self.modulus) % self.modulus;
ModularValue::new(v, self.modulus)
}
pub fn mul(&self, other: &Self) -> Self {
debug_assert_eq!(
self.modulus, other.modulus,
"ModularValue: mismatched moduli"
);
let v = ((self.value as u128 * other.value as u128) % self.modulus as u128) as u64;
ModularValue::new(v, self.modulus)
}
pub fn neg(&self) -> Self {
if self.value == 0 {
self.clone()
} else {
ModularValue::new(self.modulus - self.value, self.modulus)
}
}
pub fn inverse(&self) -> Option<Self> {
if self.value == 0 {
return None;
}
Some(ModularValue::new(
mod_inverse_u64(self.value, self.modulus),
self.modulus,
))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ModularError {
InvalidModulus(u64),
IncompatiblePolynomials,
EmptyImageList,
ReconstructionFailed,
}
impl std::fmt::Display for ModularError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModularError::InvalidModulus(p) => {
write!(f, "invalid modulus {p}: must be prime ≥ 2")
}
ModularError::IncompatiblePolynomials => {
write!(f, "polynomials have incompatible variable lists or moduli")
}
ModularError::EmptyImageList => {
write!(f, "CRT lifting requires at least one modular image")
}
ModularError::ReconstructionFailed => write!(
f,
"rational reconstruction failed: no a/b ≤ ⌊√(M/2)⌋ with a/b ≡ n (mod M)"
),
}
}
}
impl std::error::Error for ModularError {}
impl AlkahestError for ModularError {
fn code(&self) -> &'static str {
match self {
ModularError::InvalidModulus(_) => "E-MOD-001",
ModularError::IncompatiblePolynomials => "E-MOD-002",
ModularError::EmptyImageList => "E-MOD-003",
ModularError::ReconstructionFailed => "E-MOD-004",
}
}
fn remediation(&self) -> Option<&'static str> {
match self {
ModularError::InvalidModulus(_) => {
Some("use a prime modulus p ≥ 2, e.g. 101, 1009, 32749")
}
ModularError::IncompatiblePolynomials => {
Some("ensure all images share the same variable ordering and modulus")
}
ModularError::EmptyImageList => Some("provide at least one (MultiPolyFp, prime) pair"),
ModularError::ReconstructionFailed => {
Some("provide more modular images so the prime product M exceeds 2 * max_coeff²")
}
}
}
}
pub fn reduce_mod(poly: &MultiPoly, p: u64) -> Result<MultiPolyFp, ModularError> {
if !is_prime(p) {
return Err(ModularError::InvalidModulus(p));
}
let mut terms = BTreeMap::new();
for (exp, coeff) in &poly.terms {
let c_mod = rug_mod_u64(coeff, p);
if c_mod != 0 {
terms.insert(exp.clone(), c_mod);
}
}
Ok(MultiPolyFp {
vars: poly.vars.clone(),
modulus: p,
terms,
})
}
pub fn lift_crt(images: &[(MultiPolyFp, u64)]) -> Result<MultiPoly, ModularError> {
if images.is_empty() {
return Err(ModularError::EmptyImageList);
}
let vars = images[0].0.vars.clone();
for (img, _) in images {
if img.vars != vars {
return Err(ModularError::IncompatiblePolynomials);
}
}
let mut all_exps: std::collections::BTreeSet<Vec<u32>> = std::collections::BTreeSet::new();
for (img, _) in images {
for exp in img.terms.keys() {
all_exps.insert(exp.clone());
}
}
let mut terms: BTreeMap<Vec<u32>, Integer> = BTreeMap::new();
for exp in &all_exps {
let residues: Vec<(u64, u64)> = images
.iter()
.map(|(img, p)| (img.terms.get(exp).copied().unwrap_or(0), *p))
.collect();
let (combined, m) = crt_combine(&residues);
let centered = center_mod(&combined, &m);
if centered != 0 {
terms.insert(exp.clone(), centered);
}
}
Ok(MultiPoly { vars, terms })
}
pub fn rational_reconstruction(n: &Integer, m: &Integer) -> Option<(Integer, Integer)> {
if *m <= 1 {
return None;
}
let n_mod = {
let r = n.clone() % m.clone();
if r < 0 {
r + m
} else {
r
}
};
if n_mod == 0 {
return Some((Integer::from(0), Integer::from(1)));
}
let half_m = m.clone() >> 1u32;
let t = half_m.sqrt();
let mut r_prev = m.clone();
let mut r_curr = n_mod;
let mut s_prev = Integer::from(0);
let mut s_curr = Integer::from(1);
while r_curr > t {
if r_curr == 0 {
return None;
}
let q = r_prev.clone() / r_curr.clone();
let r_next = r_prev.clone() - q.clone() * r_curr.clone();
let s_next = s_prev.clone() - q * s_curr.clone();
r_prev = r_curr;
r_curr = r_next;
s_prev = s_curr;
s_curr = s_next;
}
if r_curr == 0 {
return None;
}
let b_abs = s_curr.clone().abs();
if b_abs == 0 || b_abs > t {
return None;
}
if r_curr.clone().abs() > t {
return None;
}
let (a, b) = if s_curr < 0 {
(-r_curr, -s_curr)
} else {
(r_curr, s_curr)
};
Some((a, b))
}
pub fn mignotte_bound(poly: &MultiPoly) -> Integer {
if poly.is_zero() {
return Integer::from(1);
}
let l1: Integer = poly
.terms
.values()
.map(|c| Integer::from(c.abs_ref()))
.fold(Integer::from(0), |acc, x| acc + x);
let d = poly.total_degree();
let scale = Integer::from(1) << d;
l1 * scale
}
pub fn select_lucky_prime(avoid_divisor: &Integer, used: &[u64]) -> u64 {
let mut candidate = 2u64;
loop {
if is_prime(candidate) && !used.contains(&candidate) {
let lucky = if *avoid_divisor == 0 {
true
} else {
let p_int = Integer::from(candidate);
let rem = avoid_divisor.clone() % p_int.clone();
let rem = if rem < 0 { rem + p_int } else { rem };
rem != 0
};
if lucky {
return candidate;
}
}
candidate += 1;
if candidate > 1_000_000 {
panic!("select_lucky_prime: no suitable prime found below 1_000_000");
}
}
}
fn crt_combine(pairs: &[(u64, u64)]) -> (Integer, Integer) {
if pairs.is_empty() {
return (Integer::from(0), Integer::from(1));
}
let (a0, p0) = pairs[0];
let mut a = Integer::from(a0); let mut m = Integer::from(p0);
for &(ai, pi) in &pairs[1..] {
let a_mod_pi = rug_mod_u64(&a, pi);
let diff = ((ai as u128 + pi as u128 - a_mod_pi as u128) % pi as u128) as u64;
let m_mod_pi = rug_mod_u64(&m, pi);
let m_inv = mod_inverse_u64(m_mod_pi, pi);
let t = ((diff as u128 * m_inv as u128) % pi as u128) as u64;
a += m.clone() * t;
m *= Integer::from(pi);
}
(a, m)
}
fn center_mod(a: &Integer, m: &Integer) -> Integer {
let half = m.clone() >> 1u32; if *a > half {
a.clone() - m
} else {
a.clone()
}
}
fn rug_mod_u64(a: &Integer, p: u64) -> u64 {
let p_big = Integer::from(p);
let r = a.clone() % p_big.clone();
let r = if r < 0 { r + p_big } else { r };
r.to_u64().expect("modular result fits in u64")
}
fn mod_inverse_u64(a: u64, m: u64) -> u64 {
if m == 1 {
return 0;
}
let mut old_r = a as i128;
let mut r = m as i128;
let mut old_s: i128 = 1;
let mut s: i128 = 0;
while r != 0 {
let q = old_r / r;
let tmp_r = r;
r = old_r - q * r;
old_r = tmp_r;
let tmp_s = s;
s = old_s - q * s;
old_s = tmp_s;
}
((old_s % m as i128 + m as i128) % m as i128) as u64
}
pub fn is_prime(n: u64) -> bool {
match n {
0 | 1 => return false,
2 | 3 | 5 | 7 => return true,
_ if n % 2 == 0 || n % 3 == 0 || n % 5 == 0 => return false,
_ => {}
}
let mut d = n - 1;
let mut r = 0u32;
while d % 2 == 0 {
d >>= 1;
r += 1;
}
let witnesses: &[u64] = if n < 3_215_031_751 {
&[2, 3, 5, 7]
} else {
&[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]
};
'outer: for &a in witnesses {
if a >= n {
continue;
}
let mut x = pow_mod(a, d, n);
if x == 1 || x == n - 1 {
continue;
}
for _ in 0..r - 1 {
x = mul_mod(x, x, n);
if x == n - 1 {
continue 'outer;
}
}
return false;
}
true
}
fn pow_mod(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
let mut result = 1u64;
base %= modulus;
while exp > 0 {
if exp & 1 == 1 {
result = mul_mod(result, base, modulus);
}
base = mul_mod(base, base, modulus);
exp >>= 1;
}
result
}
#[inline]
fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
((a as u128 * b as u128) % m as u128) as u64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn pool_xy() -> (ExprPool, ExprId, ExprId) {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let y = p.symbol("y", Domain::Real);
(p, x, y)
}
#[test]
fn prime_small() {
for &(n, exp) in &[
(0u64, false),
(1, false),
(2, true),
(3, true),
(4, false),
(5, true),
(9, false),
(97, true),
(100, false),
(101, true),
] {
assert_eq!(is_prime(n), exp, "is_prime({n})");
}
}
#[test]
fn prime_large() {
assert!(is_prime(999_983));
assert!(!is_prime(1_000_000));
assert!(is_prime(1_000_003));
assert!(is_prime(2_147_483_647));
}
#[test]
fn mod_inverse_basic() {
assert_eq!(mod_inverse_u64(3, 7), 5); assert_eq!(mod_inverse_u64(2, 101), 51); assert_eq!(mod_inverse_u64(1, 7), 1);
}
#[test]
fn reduce_mod_basic() {
let (pool, x, y) = pool_xy();
let expr = pool.add(vec![
pool.mul(vec![pool.integer(6_i32), x]),
pool.integer(4_i32),
]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
let fp = reduce_mod(&poly, 5).unwrap();
assert_eq!(fp.modulus, 5);
assert_eq!(*fp.terms.get(&vec![1]).unwrap(), 1u64); assert_eq!(*fp.terms.get(&vec![]).unwrap(), 4u64); }
#[test]
fn reduce_mod_negative_coeff() {
let (pool, x, y) = pool_xy();
let expr = pool.mul(vec![pool.integer(-3_i32), x]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
let fp = reduce_mod(&poly, 7).unwrap();
assert_eq!(*fp.terms.get(&vec![1]).unwrap(), 4u64); }
#[test]
fn reduce_mod_vanishing_term() {
let (pool, x, y) = pool_xy();
let expr = pool.add(vec![
pool.mul(vec![pool.integer(5_i32), x]),
pool.integer(7_i32),
]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
let fp = reduce_mod(&poly, 5).unwrap();
assert!(!fp.terms.contains_key(&vec![1]));
assert_eq!(*fp.terms.get(&vec![]).unwrap(), 2u64);
}
#[test]
fn reduce_mod_invalid() {
let (pool, x, y) = pool_xy();
let poly = MultiPoly::from_symbolic(x, vec![x, y], &pool).unwrap();
for bad in [0, 1, 4, 6, 9] {
assert!(
matches!(reduce_mod(&poly, bad), Err(ModularError::InvalidModulus(_))),
"expected InvalidModulus for {bad}"
);
}
}
#[test]
fn crt_combine_single() {
let (a, m) = crt_combine(&[(3, 5)]);
assert_eq!(a, Integer::from(3));
assert_eq!(m, Integer::from(5));
}
#[test]
fn crt_combine_two() {
let (a, m) = crt_combine(&[(2, 3), (3, 5)]);
assert_eq!(m, Integer::from(15));
assert_eq!(a, Integer::from(8));
assert_eq!(8u64 % 3, 2);
assert_eq!(8u64 % 5, 3);
}
#[test]
fn crt_combine_three() {
let (a, m) = crt_combine(&[(1, 2), (2, 3), (3, 5)]);
assert_eq!(m, Integer::from(30));
assert_eq!(a, Integer::from(23));
assert_eq!(23u64 % 2, 1);
assert_eq!(23u64 % 3, 2);
assert_eq!(23u64 % 5, 3);
}
#[test]
fn lift_crt_roundtrip_positive() {
let (pool, x, y) = pool_xy();
let x2 = pool.pow(x, pool.integer(2_i32));
let expr = pool.add(vec![
pool.mul(vec![pool.integer(3_i32), x2]),
pool.mul(vec![pool.integer(2_i32), x]),
pool.integer(1_i32),
]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
let p1 = 101u64;
let p2 = 103u64;
let fp1 = reduce_mod(&poly, p1).unwrap();
let fp2 = reduce_mod(&poly, p2).unwrap();
let lifted = lift_crt(&[(fp1, p1), (fp2, p2)]).unwrap();
assert_eq!(lifted, poly);
}
#[test]
fn lift_crt_negative_coeff() {
let (pool, x, y) = pool_xy();
let expr = pool.add(vec![x, pool.integer(-50_i32)]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
let p1 = 101u64;
let p2 = 103u64; let lifted = lift_crt(&[
(reduce_mod(&poly, p1).unwrap(), p1),
(reduce_mod(&poly, p2).unwrap(), p2),
])
.unwrap();
assert_eq!(lifted, poly);
}
#[test]
fn lift_crt_bivariate() {
let (pool, x, y) = pool_xy();
let expr = pool.add(vec![pool.mul(vec![x, y]), pool.integer(3_i32)]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
let p = 7u64;
let q = 11u64;
let lifted = lift_crt(&[
(reduce_mod(&poly, p).unwrap(), p),
(reduce_mod(&poly, q).unwrap(), q),
])
.unwrap();
assert_eq!(lifted, poly);
}
#[test]
fn lift_crt_empty_error() {
assert!(matches!(lift_crt(&[]), Err(ModularError::EmptyImageList)));
}
#[test]
fn rat_recon_one_half() {
let result = rational_reconstruction(&Integer::from(51), &Integer::from(101));
assert!(result.is_some());
let (a, b) = result.unwrap();
assert_eq!(a, Integer::from(1));
assert_eq!(b, Integer::from(2));
}
#[test]
fn rat_recon_negative_numerator() {
let result = rational_reconstruction(&Integer::from(50), &Integer::from(101));
assert!(result.is_some());
let (a, b) = result.unwrap();
assert_eq!(a, Integer::from(-1));
assert_eq!(b, Integer::from(2));
}
#[test]
fn rat_recon_zero() {
let result = rational_reconstruction(&Integer::from(0), &Integer::from(101));
assert!(result.is_some());
let (a, b) = result.unwrap();
assert_eq!(a, Integer::from(0));
assert_eq!(b, Integer::from(1));
}
#[test]
fn rat_recon_integer() {
let result = rational_reconstruction(&Integer::from(5), &Integer::from(101));
assert!(result.is_some());
let (a, b) = result.unwrap();
assert_eq!(b, Integer::from(1));
assert_eq!(a, Integer::from(5));
}
#[test]
fn rat_recon_m_too_small() {
let result = rational_reconstruction(&Integer::from(2), &Integer::from(7));
assert!(result.is_none());
}
#[test]
fn mignotte_constant() {
let (pool, x, y) = pool_xy();
let poly = MultiPoly::from_symbolic(pool.integer(5_i32), vec![x, y], &pool).unwrap();
assert_eq!(mignotte_bound(&poly), Integer::from(5));
}
#[test]
fn mignotte_linear() {
let (pool, x, y) = pool_xy();
let expr = pool.add(vec![
pool.mul(vec![pool.integer(3_i32), x]),
pool.integer(2_i32),
]);
let poly = MultiPoly::from_symbolic(expr, vec![x, y], &pool).unwrap();
assert_eq!(mignotte_bound(&poly), Integer::from(10));
}
#[test]
fn mignotte_zero_poly() {
let (_, x, y) = pool_xy();
let z = MultiPoly::zero(vec![x, y]);
assert_eq!(mignotte_bound(&z), Integer::from(1));
}
#[test]
fn lucky_prime_no_constraint() {
let p = select_lucky_prime(&Integer::from(0), &[]);
assert!(is_prime(p));
assert_eq!(p, 2);
}
#[test]
fn lucky_prime_avoids_divisors() {
let p = select_lucky_prime(&Integer::from(6), &[]);
assert!(is_prime(p));
assert_ne!(6 % p, 0);
assert_eq!(p, 5); }
#[test]
fn lucky_prime_skips_used() {
let p = select_lucky_prime(&Integer::from(0), &[2, 3, 5]);
assert_eq!(p, 7);
}
#[test]
fn lucky_prime_combined() {
let p = select_lucky_prime(&Integer::from(30), &[7]);
assert!(is_prime(p));
assert_ne!(30 % p, 0);
assert_ne!(p, 7);
}
#[test]
fn modular_value_add() {
let a = ModularValue::new(3, 7);
let b = ModularValue::new(5, 7);
assert_eq!(a.add(&b), ModularValue::new(1, 7)); }
#[test]
fn modular_value_sub() {
let a = ModularValue::new(3, 7);
let b = ModularValue::new(5, 7);
assert_eq!(a.sub(&b), ModularValue::new(5, 7)); }
#[test]
fn modular_value_mul() {
let a = ModularValue::new(3, 7);
let b = ModularValue::new(5, 7);
assert_eq!(a.mul(&b), ModularValue::new(1, 7)); }
#[test]
fn modular_value_neg() {
assert_eq!(ModularValue::new(3, 7).neg(), ModularValue::new(4, 7));
assert_eq!(ModularValue::new(0, 7).neg(), ModularValue::new(0, 7));
}
#[test]
fn modular_value_inverse() {
assert_eq!(
ModularValue::new(3, 7).inverse().unwrap(),
ModularValue::new(5, 7)
);
assert!(ModularValue::new(0, 7).inverse().is_none());
}
#[test]
fn error_codes() {
assert_eq!(ModularError::InvalidModulus(4).code(), "E-MOD-001");
assert_eq!(ModularError::IncompatiblePolynomials.code(), "E-MOD-002");
assert_eq!(ModularError::EmptyImageList.code(), "E-MOD-003");
assert_eq!(ModularError::ReconstructionFailed.code(), "E-MOD-004");
}
}