use alloc::vec::Vec;
use p3_field::Field;
pub(crate) fn lagrange_weights_01inf<F: Field>(r: F) -> [F; 3] {
let r_minus_one = r - F::ONE;
[F::ONE - r, r, r * r_minus_one]
}
pub fn lagrange_weights_01inf_multi<F: Field>(rs: &[F]) -> Vec<F> {
let total = 3usize.pow(rs.len() as u32);
let mut current = Vec::with_capacity(total);
let mut next = Vec::with_capacity(total);
current.push(F::ONE);
for &r in rs {
let uni = lagrange_weights_01inf(r);
next.clear();
for &li in &uni {
for &w in ¤t {
next.push(w * li);
}
}
core::mem::swap(&mut current, &mut next);
}
current
}
pub fn extrapolate_01inf<F: Field>(e0: F, e1: F, e_inf: F, r: F) -> F {
let [w0, w1, w_inf] = lagrange_weights_01inf(r);
e0 * w0 + e1 * w1 + e_inf * w_inf
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use proptest::prelude::*;
use super::*;
type F = BabyBear;
#[test]
fn test_lagrange_weights_at_finite_points() {
let [l0, l1, l_inf] = lagrange_weights_01inf(F::ZERO);
assert_eq!(l0, F::ONE);
assert_eq!(l1, F::ZERO);
assert_eq!(l_inf, F::ZERO);
let [l0, l1, l_inf] = lagrange_weights_01inf(F::ONE);
assert_eq!(l0, F::ZERO);
assert_eq!(l1, F::ONE);
assert_eq!(l_inf, F::ZERO);
}
#[test]
fn test_lagrange_weights_multi_k1() {
for i in 0..5 {
let r = F::from_u64(i);
let single = lagrange_weights_01inf(r);
let multi = lagrange_weights_01inf_multi(&[r]);
assert_eq!(multi.len(), 3);
assert_eq!(multi[0], single[0]);
assert_eq!(multi[1], single[1]);
assert_eq!(multi[2], single[2]);
}
}
#[test]
fn test_lagrange_weights_multi_k2() {
let r0 = F::from_u64(5);
let r1 = F::from_u64(7);
let weights = lagrange_weights_01inf_multi(&[r0, r1]);
assert_eq!(weights.len(), 9);
let w0 = lagrange_weights_01inf(r0);
let w1 = lagrange_weights_01inf(r1);
for i in 0..3 {
for j in 0..3 {
assert_eq!(weights[3 * j + i], w0[i] * w1[j]);
}
}
}
#[test]
fn test_extrapolate_at_finite_points() {
let e0 = F::from_u64(7);
let e1 = F::from_u64(13);
let e_inf = F::from_u64(3);
assert_eq!(extrapolate_01inf(e0, e1, e_inf, F::ZERO), e0);
assert_eq!(extrapolate_01inf(e0, e1, e_inf, F::ONE), e1);
}
#[test]
fn test_extrapolate_known_quadratic() {
let e0 = F::from_u64(1);
let e1 = F::from_u64(2);
let e_inf = F::from_u64(1);
assert_eq!(
extrapolate_01inf(e0, e1, e_inf, F::from_u64(3)),
F::from_u64(10)
);
assert_eq!(
extrapolate_01inf(e0, e1, e_inf, F::from_u64(4)),
F::from_u64(17)
);
}
proptest! {
#[test]
fn prop_extrapolate_identity(
e0 in 0u32..1_000_000,
e1 in 0u32..1_000_000,
e_inf in 0u32..1_000_000,
) {
let e0 = F::from_u32(e0);
let e1 = F::from_u32(e1);
let e_inf = F::from_u32(e_inf);
prop_assert_eq!(extrapolate_01inf(e0, e1, e_inf, F::ZERO), e0);
prop_assert_eq!(extrapolate_01inf(e0, e1, e_inf, F::ONE), e1);
}
#[test]
fn prop_extrapolate_matches_monomial(
a_val in 0u32..1_000_000,
b_val in 0u32..1_000_000,
c_val in 0u32..1_000_000,
r_val in 0u32..1_000_000,
) {
let a = F::from_u32(a_val);
let b = F::from_u32(b_val);
let c = F::from_u32(c_val);
let r = F::from_u32(r_val);
let e0 = c;
let e1 = a + b + c;
let e_inf = a;
let monomial = c + r * (b + r * a);
prop_assert_eq!(extrapolate_01inf(e0, e1, e_inf, r), monomial);
}
}
}