ark_linear_sumcheck/ml_sumcheck/
data_structures.rs1use ark_ff::Field;
4use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
5use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
6use ark_std::cmp::max;
7use ark_std::rc::Rc;
8use ark_std::vec::Vec;
9use hashbrown::HashMap;
10#[derive(Clone)]
25pub struct ListOfProductsOfPolynomials<F: Field> {
26 pub max_multiplicands: usize,
28 pub num_variables: usize,
30 pub products: Vec<(F, Vec<usize>)>,
32 pub flattened_ml_extensions: Vec<Rc<DenseMultilinearExtension<F>>>,
34 raw_pointers_lookup_table: HashMap<*const DenseMultilinearExtension<F>, usize>,
35}
36
37impl<F: Field> ListOfProductsOfPolynomials<F> {
38 pub fn info(&self) -> PolynomialInfo {
40 PolynomialInfo {
41 max_multiplicands: self.max_multiplicands,
42 num_variables: self.num_variables,
43 }
44 }
45}
46
47#[derive(CanonicalSerialize, CanonicalDeserialize, Clone)]
48pub struct PolynomialInfo {
51 pub max_multiplicands: usize,
53 pub num_variables: usize,
55}
56
57impl<F: Field> ListOfProductsOfPolynomials<F> {
58 pub fn new(num_variables: usize) -> Self {
60 ListOfProductsOfPolynomials {
61 max_multiplicands: 0,
62 num_variables,
63 products: Vec::new(),
64 flattened_ml_extensions: Vec::new(),
65 raw_pointers_lookup_table: HashMap::new(),
66 }
67 }
68
69 pub fn add_product(
72 &mut self,
73 product: impl IntoIterator<Item = Rc<DenseMultilinearExtension<F>>>,
74 coefficient: F,
75 ) {
76 let product: Vec<Rc<DenseMultilinearExtension<F>>> = product.into_iter().collect();
77 let mut indexed_product = Vec::with_capacity(product.len());
78 assert!(!product.is_empty());
79 self.max_multiplicands = max(self.max_multiplicands, product.len());
80 for m in product {
81 assert_eq!(
82 m.num_vars, self.num_variables,
83 "product has a multiplicand with wrong number of variables"
84 );
85 let m_ptr: *const DenseMultilinearExtension<F> = Rc::as_ptr(&m);
86 if let Some(index) = self.raw_pointers_lookup_table.get(&m_ptr) {
87 indexed_product.push(*index)
88 } else {
89 let curr_index = self.flattened_ml_extensions.len();
90 self.flattened_ml_extensions.push(m.clone());
91 self.raw_pointers_lookup_table.insert(m_ptr, curr_index);
92 indexed_product.push(curr_index);
93 }
94 }
95 self.products.push((coefficient, indexed_product));
96 }
97
98 pub fn evaluate(&self, point: &[F]) -> F {
100 self.products
101 .iter()
102 .map(|(c, p)| {
103 *c * p
104 .iter()
105 .map(|&i| self.flattened_ml_extensions[i].evaluate(point).unwrap())
106 .product::<F>()
107 })
108 .sum()
109 }
110}