use crate::base::{map::IndexMap, scalar::Scalar};
use alloc::{rc::Rc, vec::Vec};
use core::cmp::max;
#[cfg(test)]
use core::iter;
#[cfg(test)]
use itertools::Itertools;
#[derive(Clone, Debug)]
pub struct CompositePolynomial<S: Scalar> {
pub max_multiplicands: usize,
pub num_variables: usize,
pub products: Vec<(S, Vec<usize>)>,
pub flattened_ml_extensions: Vec<Rc<Vec<S>>>,
raw_pointers_lookup_table: IndexMap<*const Vec<S>, usize>,
}
impl<S: Scalar> CompositePolynomial<S> {
pub fn new(num_variables: usize) -> Self {
CompositePolynomial {
max_multiplicands: 0,
num_variables,
products: Vec::new(),
flattened_ml_extensions: Vec::new(),
raw_pointers_lookup_table: IndexMap::default(),
}
}
#[expect(clippy::missing_panics_doc)]
pub fn add_product(&mut self, product: impl IntoIterator<Item = Rc<Vec<S>>>, coefficient: S) {
let product: Vec<Rc<Vec<S>>> = product.into_iter().collect();
let mut indexed_product = Vec::with_capacity(product.len());
assert!(!product.is_empty());
self.max_multiplicands = max(self.max_multiplicands, product.len());
for m in product {
let m_ptr: *const Vec<S> = Rc::as_ptr(&m);
if let Some(index) = self.raw_pointers_lookup_table.get(&m_ptr) {
indexed_product.push(*index);
} else {
let curr_index = self.flattened_ml_extensions.len();
self.flattened_ml_extensions.push(m.clone());
self.raw_pointers_lookup_table.insert(m_ptr, curr_index);
indexed_product.push(curr_index);
}
}
self.products.push((coefficient, indexed_product));
}
#[cfg(test)]
pub fn rand(
num_variables: usize,
max_multiplicands: usize,
multiplicands_length: impl IntoIterator<Item = usize>,
products: impl IntoIterator<Item = impl IntoIterator<Item = usize>>,
rng: &mut (impl ark_std::rand::Rng + ?Sized),
) -> Self {
let mut result = CompositePolynomial::new(num_variables);
result.max_multiplicands = max_multiplicands;
result.products = products
.into_iter()
.map(|p| (S::rand(rng), p.into_iter().collect()))
.collect();
result.flattened_ml_extensions = multiplicands_length
.into_iter()
.map(|length| Rc::new(iter::repeat_with(|| S::rand(rng)).take(length).collect()))
.collect();
result
}
#[cfg(test)]
fn term_product(&self, terms: &[usize], i: usize) -> S {
terms
.iter()
.map(|&j| *self.flattened_ml_extensions[j].get(i).unwrap_or(&S::ZERO))
.product::<S>()
}
#[cfg(test)]
pub fn hypercube_sum(&self, length: usize) -> S {
(0..length)
.cartesian_product(&self.products)
.map(|(i, (coeff, terms))| *coeff * self.term_product(terms, i))
.sum::<S>()
}
#[cfg(test)]
pub fn evaluate(&self, point: &[S]) -> S {
let mut evaluation_vector = vec![S::default(); 1 << self.num_variables];
super::evaluation_vector::compute_evaluation_vector(&mut evaluation_vector, point);
let result = self
.products
.iter()
.map(|(c, p)| {
*c * p
.iter()
.map(|&i| {
crate::base::slice_ops::inner_product(
&evaluation_vector,
&self.flattened_ml_extensions[i],
)
})
.product::<S>()
})
.sum();
result
}
#[tracing::instrument(
name = "CompositePolynomial::annotate_trace",
level = "debug",
skip_all
)]
pub fn annotate_trace(&self) {
for i in 0..self.products.len() {
tracing::info!(
"Product #{:?}: {:#} * {:?}",
i,
self.products[i].0,
self.products[i].1
);
}
}
}