use core::ops::{Add, Mul, Sub};
use num_traits::{One, Zero};
pub fn compute_truncated_lagrange_basis_inner_product<F>(length: usize, a: &[F], b: &[F]) -> F
where
F: One + Zero + Mul<Output = F> + Add<Output = F> + Sub<Output = F> + Copy,
{
compute_truncated_lagrange_basis_inner_product_impl(length, a, b).0
}
fn compute_truncated_lagrange_basis_inner_product_impl<F>(
part_length: usize,
a: &[F],
b: &[F],
) -> (F, F)
where
F: One + Zero + Mul<Output = F> + Add<Output = F> + Sub<Output = F> + Copy,
{
let nu = a.len();
assert_eq!(nu, b.len());
if nu == 0 {
assert!(part_length <= 1);
if part_length == 1 {
(F::one(), F::one())
} else {
(F::zero(), F::one())
}
} else {
let first_half_term = (F::one() - a[nu - 1]) * (F::one() - b[nu - 1]);
let second_half_term = a[nu - 1] * b[nu - 1];
let half_full_length = 1 << (nu - 1);
let sub_part_length = if part_length >= half_full_length {
part_length - half_full_length
} else {
part_length
};
let (sub_part, sub_full) = compute_truncated_lagrange_basis_inner_product_impl(
sub_part_length,
&a[..nu - 1],
&b[..nu - 1],
);
let part = if part_length >= half_full_length {
sub_full * first_half_term + sub_part * second_half_term
} else {
sub_part * first_half_term
};
let full = sub_full * (first_half_term + second_half_term);
(part, full)
}
}
fn next_chi_accumulator<F>(length: usize, previous_accumulator: F, i: usize, alpha: F) -> F
where
F: One + Zero + Mul<Output = F> + Sub<Output = F> + Copy,
{
if (length >> i) & 1 == 0 {
previous_accumulator * (F::one() - alpha)
} else {
F::one() - (F::one() - previous_accumulator) * alpha
}
}
fn next_rho_accumulator<F>(
length: usize,
previous_accumulator: F,
i: usize,
alpha: F,
previous_chi: F,
rho_of_power_of_two: F,
power_of_two: F,
) -> F
where
F: One + Zero + Mul<Output = F> + Sub<Output = F> + Copy,
{
if (length >> i) & 1 == 0 {
previous_accumulator * (F::one() - alpha)
} else {
(F::one() - alpha) * rho_of_power_of_two
+ alpha * (previous_accumulator + power_of_two * previous_chi)
}
}
pub fn compute_truncated_lagrange_basis_sum<F>(length: usize, point: &[F]) -> F
where
F: One + Zero + Mul<Output = F> + Sub<Output = F> + Copy,
{
if length >= 1 << point.len() {
F::one()
} else {
point
.iter()
.enumerate()
.fold(F::zero(), |chi, (i, &alpha)| {
next_chi_accumulator(length, chi, i, alpha)
})
}
}
pub fn compute_rho_eval<F>(length: usize, point: &[F]) -> F
where
F: One + Zero + Mul<Output = F> + Sub<Output = F> + Copy,
{
if length == 1 << point.len() {
let res = point.iter().fold(
(F::zero(), F::one()),
|(acc, current_power_of_two), &alpha| {
(
acc + current_power_of_two * alpha,
current_power_of_two + current_power_of_two,
)
},
);
res.0
} else {
let (rho, _, _, _) = point.iter().enumerate().fold(
(F::zero(), F::zero(), F::zero(), F::one()),
|(previous_rho, previous_chi, current_rho_of_power_of_two, current_power_of_two),
(i, &alpha)| {
let next_rho_of_power_of_two =
current_power_of_two * alpha + current_rho_of_power_of_two;
let next_power_of_two = current_power_of_two + current_power_of_two;
(
next_rho_accumulator(
length,
previous_rho,
i,
alpha,
previous_chi,
current_rho_of_power_of_two,
current_power_of_two,
),
next_chi_accumulator(length, previous_chi, i, alpha),
next_rho_of_power_of_two,
next_power_of_two,
)
},
);
rho
}
}