use crate::field::PrimeField;
use crate::bigint::BigUint;
#[must_use]
pub fn horner(field: &PrimeField, coeffs: &[BigUint], x: &BigUint) -> BigUint {
let mut acc = BigUint::zero();
for c in coeffs.iter().rev() {
acc = field.add(&field.mul(&acc, x), c);
}
acc
}
#[must_use]
pub fn lagrange_eval(
field: &PrimeField,
points: &[(BigUint, BigUint)],
x_eval: &BigUint,
) -> Option<BigUint> {
let n = points.len();
let reduced: Vec<BigUint> = points.iter().map(|(x, _)| field.reduce(x)).collect();
for r in &reduced {
if r.is_zero() {
return None;
}
}
for i in 0..n {
for j in (i + 1)..n {
if reduced[i] == reduced[j] {
return None;
}
}
}
Some(lagrange_eval_unchecked(field, points, x_eval))
}
#[must_use]
pub fn lagrange_eval_unchecked(
field: &PrimeField,
points: &[(BigUint, BigUint)],
x_eval: &BigUint,
) -> BigUint {
let n = points.len();
if n == 0 {
return BigUint::zero();
}
let mut sum = BigUint::zero();
for j in 0..n {
let (xj, yj) = &points[j];
let mut num = BigUint::one();
let mut den = BigUint::one();
for (i, (xi, _)) in points.iter().enumerate() {
if i == j {
continue;
}
num = field.mul(&num, &field.sub(x_eval, xi));
den = field.mul(&den, &field.sub(xj, xi));
}
let den_inv = field
.inv(&den)
.expect("Lagrange denominator nonzero given distinct x");
let term = field.mul(yj, &field.mul(&num, &den_inv));
sum = field.add(&sum, &term);
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
fn f257() -> PrimeField {
PrimeField::new(BigUint::from_u64(257))
}
#[test]
fn horner_matches_manual_eval() {
let f = f257();
let coeffs = vec![
BigUint::from_u64(5),
BigUint::from_u64(3),
BigUint::from_u64(2),
];
let v = horner(&f, &coeffs, &BigUint::from_u64(4));
assert_eq!(v, BigUint::from_u64(49));
}
#[test]
fn lagrange_recovers_polynomial() {
let f = f257();
let coeffs = vec![
BigUint::from_u64(7),
BigUint::from_u64(11),
BigUint::from_u64(5),
];
let pts: Vec<(BigUint, BigUint)> = (1..=3)
.map(|i| {
let x = BigUint::from_u64(i);
let y = horner(&f, &coeffs, &x);
(x, y)
})
.collect();
assert_eq!(
lagrange_eval(&f, &pts, &BigUint::zero()),
Some(BigUint::from_u64(7))
);
assert_eq!(
lagrange_eval(&f, &pts, &BigUint::from_u64(4)),
Some(BigUint::from_u64(131))
);
}
#[test]
fn lagrange_rejects_duplicate_x() {
let f = f257();
let pts = vec![
(BigUint::from_u64(1), BigUint::from_u64(7)),
(BigUint::from_u64(1), BigUint::from_u64(8)),
];
assert!(lagrange_eval(&f, &pts, &BigUint::zero()).is_none());
}
}