use crate::errors;
use crate::poly::variant::Error;
use alloc::vec::Vec;
use hekate_math::{Flat, HardwareField};
use zeroize::Zeroize;
#[derive(Clone, Debug, Zeroize)]
pub struct TensorProduct<F: HardwareField> {
pub r_coords: Vec<Flat<F>>,
pub current_scale: Flat<F>,
}
impl<F: HardwareField> TensorProduct<F> {
pub fn new(r: Vec<Flat<F>>) -> Self {
Self {
r_coords: r,
current_scale: Flat::from_raw(F::ONE),
}
}
pub fn num_vars(&self) -> usize {
self.r_coords.len()
}
pub fn evaluate_extension(&self, x: &[Flat<F>]) -> errors::Result<Flat<F>> {
let expected_len = self.r_coords.len();
let got_len = x.len();
if got_len != expected_len {
return Err(Error::PointDimensionMismatch {
expected_len,
got_len,
}
.into());
}
let mut res = self.current_scale;
let one = Flat::from_raw(F::ONE);
for (&xi, &ri) in x.iter().zip(self.r_coords.iter()) {
let term_0 = (one - xi) * (one - ri);
let term_1 = xi * ri;
let term = term_0 + term_1;
res *= term;
}
Ok(res)
}
#[inline(always)]
pub fn evaluate_eq_slice(r: &[Flat<F>], x: &[Flat<F>]) -> Flat<F> {
debug_assert_eq!(r.len(), x.len());
let mut res = Flat::from_raw(F::ONE);
let one = Flat::from_raw(F::ONE);
for (&ri, &xi) in r.iter().zip(x.iter()) {
let term_0 = (one - ri) * (one - xi);
let term_1 = ri * xi;
res *= term_0 + term_1;
}
res
}
#[inline(always)]
pub fn evaluate_at_index(&self, index: usize) -> Flat<F> {
let mut val = self.current_scale;
for (i, &r_val) in self.r_coords.iter().enumerate() {
let bit_is_set = (index >> i) & 1 == 1;
if bit_is_set {
val *= r_val;
} else {
val *= Flat::from_raw(F::ONE) - r_val;
}
}
val
}
pub fn fold(&self, u: Flat<F>) -> Self {
if self.r_coords.is_empty() {
return self.clone();
}
let r_0 = self.r_coords[0];
let one = F::ONE.to_hardware();
let term_0 = (one - u) * (one - r_0);
let term_1 = u * r_0;
let factor = term_0 + term_1;
let mut new_r = self.r_coords.clone();
new_r.remove(0);
Self {
r_coords: new_r,
current_scale: self.current_scale * factor,
}
}
}