Skip to main content

ark_poly/evaluations/multivariate/multilinear/
sparse.rs

1//! multilinear polynomial represented in sparse evaluation form.
2
3use crate::{
4    evaluations::multivariate::multilinear::swap_bits, DenseMultilinearExtension,
5    MultilinearExtension, Polynomial,
6};
7use ark_ff::{Field, Zero};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use ark_std::{
10    cfg_iter,
11    collections::BTreeMap,
12    fmt::{self, Debug, Formatter},
13    ops::{Add, AddAssign, Index, Neg, Sub, SubAssign},
14    rand::Rng,
15    vec,
16    vec::*,
17    UniformRand,
18};
19use hashbrown::HashMap;
20#[cfg(feature = "parallel")]
21use rayon::prelude::*;
22
23use super::DefaultHasher;
24
25/// Stores a multilinear polynomial in sparse evaluation form.
26#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
27pub struct SparseMultilinearExtension<F: Field> {
28    /// tuples of index and value
29    pub evaluations: BTreeMap<usize, F>,
30    /// number of variables
31    pub num_vars: usize,
32    zero: F,
33}
34
35impl<F: Field> SparseMultilinearExtension<F> {
36    pub fn from_evaluations<'a>(
37        num_vars: usize,
38        evaluations: impl IntoIterator<Item = &'a (usize, F)>,
39    ) -> Self {
40        let bit_mask = 1 << num_vars;
41        // check
42        let evaluations = evaluations.into_iter();
43        let evaluations: Vec<_> = evaluations
44            .map(|(i, v): &(usize, F)| {
45                assert!(*i < bit_mask, "index out of range");
46                (*i, *v)
47            })
48            .collect();
49
50        Self {
51            evaluations: tuples_to_treemap(&evaluations),
52            num_vars,
53            zero: F::zero(),
54        }
55    }
56
57    /// Outputs an `l`-variate multilinear extension where value of evaluations
58    /// are sampled uniformly at random. The number of nonzero entries is
59    /// `num_nonzero_entries` and indices of those nonzero entries are
60    /// distributed uniformly at random.
61    ///
62    /// Note that this function uses rejection sampling. As number of nonzero
63    /// entries approach `2 ^ num_vars`, sampling will be very slow due to
64    /// large number of collisions.
65    pub fn rand_with_config<R: Rng>(
66        num_vars: usize,
67        num_nonzero_entries: usize,
68        rng: &mut R,
69    ) -> Self {
70        assert!(num_nonzero_entries <= (1 << num_vars));
71
72        let mut map =
73            HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
74        for _ in 0..num_nonzero_entries {
75            let mut index = usize::rand(rng) & ((1usize << num_vars) - 1);
76            while map.get(&index).is_some() {
77                index = usize::rand(rng) & ((1usize << num_vars) - 1);
78            }
79            map.entry(index).or_insert(F::rand(rng));
80        }
81        let evaluations = hashmap_to_treemap(&map);
82        Self {
83            num_vars,
84            evaluations,
85            zero: F::zero(),
86        }
87    }
88
89    /// Convert the sparse multilinear polynomial to dense form.
90    pub fn to_dense_multilinear_extension(&self) -> DenseMultilinearExtension<F> {
91        let mut evaluations: Vec<_> = (0..(1usize << self.num_vars)).map(|_| F::zero()).collect();
92        for (&i, &v) in &self.evaluations {
93            evaluations[i] = v;
94        }
95        DenseMultilinearExtension::from_evaluations_vec(self.num_vars, evaluations)
96    }
97}
98
99/// utility: precompute f(x) = eq(g,x)
100fn precompute_eq<F: Field>(g: &[F]) -> Vec<F> {
101    let dim = g.len();
102    let mut dp = vec![F::zero(); 1 << dim];
103    dp[0] = F::one() - g[0];
104    dp[1] = g[0];
105    for i in 1..dim {
106        for b in 0..(1 << i) {
107            let prev = dp[b];
108            dp[b + (1 << i)] = prev * g[i];
109            dp[b] = prev - dp[b + (1 << i)];
110        }
111    }
112    dp
113}
114
115impl<F: Field> MultilinearExtension<F> for SparseMultilinearExtension<F> {
116    fn num_vars(&self) -> usize {
117        self.num_vars
118    }
119
120    /// Outputs an `l`-variate multilinear extension where value of evaluations
121    /// are sampled uniformly at random. The number of nonzero entries is
122    /// `sqrt(2^num_vars)` and indices of those nonzero entries are distributed
123    /// uniformly at random.
124    fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
125        Self::rand_with_config(num_vars, 1usize << (num_vars / 2), rng)
126    }
127
128    fn relabel(&self, mut a: usize, mut b: usize, k: usize) -> Self {
129        if a > b {
130            // swap
131            core::mem::swap(&mut a, &mut b);
132        }
133        // sanity check
134        assert!(
135            a + k < self.num_vars && b + k < self.num_vars,
136            "invalid relabel argument"
137        );
138        if a == b || k == 0 {
139            return self.clone();
140        }
141        assert!(a + k <= b, "overlapped swap window is not allowed");
142        let ev: Vec<_> = cfg_iter!(self.evaluations)
143            .map(|(&i, &v)| (swap_bits(i, a, b, k), v))
144            .collect();
145        Self {
146            num_vars: self.num_vars,
147            evaluations: tuples_to_treemap(&ev),
148            zero: F::zero(),
149        }
150    }
151
152    fn fix_variables(&self, partial_point: &[F]) -> Self {
153        let dim = partial_point.len();
154        assert!(dim <= self.num_vars, "invalid partial point dimension");
155
156        let mut window = ark_std::log2(self.evaluations.len()) as usize;
157        if window == 0 {
158            window = 1;
159        }
160        let mut point = partial_point;
161        let mut last = treemap_to_hashmap(&self.evaluations);
162
163        // batch evaluation
164        while !point.is_empty() {
165            let focus_length = if point.len() > window {
166                window
167            } else {
168                point.len()
169            };
170            let focus = &point[..focus_length];
171            point = &point[focus_length..];
172            let pre = precompute_eq(focus);
173            let dim = focus.len();
174            let mut result =
175                HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
176            for src_entry in &last {
177                let old_idx = *src_entry.0;
178                let gz = pre[old_idx & ((1 << dim) - 1)];
179                let new_idx = old_idx >> dim;
180                let dst_entry = result.entry(new_idx).or_insert(F::zero());
181                *dst_entry += gz * src_entry.1;
182            }
183            last = result;
184        }
185        let evaluations = hashmap_to_treemap(&last);
186        Self {
187            num_vars: self.num_vars - dim,
188            evaluations,
189            zero: F::zero(),
190        }
191    }
192
193    fn to_evaluations(&self) -> Vec<F> {
194        let mut evaluations = vec![F::zero(); 1 << self.num_vars];
195        self.evaluations
196            .iter()
197            .for_each(|(&i, &v)| evaluations[i] = v);
198        evaluations
199    }
200}
201
202impl<F: Field> Index<usize> for SparseMultilinearExtension<F> {
203    type Output = F;
204
205    /// Returns the evaluation of the polynomial at a point represented by
206    /// index.
207    ///
208    /// Index represents a vector in {0,1}^`num_vars` in little endian form. For
209    /// example, `0b1011` represents `P(1,1,0,1)`
210    ///
211    /// For Sparse multilinear polynomial, Lookup_evaluation takes log time to
212    /// the size of polynomial.
213    fn index(&self, index: usize) -> &Self::Output {
214        if let Some(v) = self.evaluations.get(&index) {
215            v
216        } else {
217            &self.zero
218        }
219    }
220}
221
222impl<F: Field> Polynomial<F> for SparseMultilinearExtension<F> {
223    type Point = Vec<F>;
224
225    fn degree(&self) -> usize {
226        self.num_vars
227    }
228
229    fn evaluate(&self, point: &Self::Point) -> F {
230        assert!(point.len() == self.num_vars);
231        self.fix_variables(point)[0]
232    }
233}
234
235impl<F: Field> Add for SparseMultilinearExtension<F> {
236    type Output = Self;
237
238    fn add(self, other: Self) -> Self {
239        &self + &other
240    }
241}
242
243impl<'a, F: Field> Add<&'a SparseMultilinearExtension<F>> for &SparseMultilinearExtension<F> {
244    type Output = SparseMultilinearExtension<F>;
245
246    fn add(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
247        // handle zero case
248        if self.is_zero() {
249            return rhs.clone();
250        }
251        if rhs.is_zero() {
252            return self.clone();
253        }
254
255        assert_eq!(
256            rhs.num_vars, self.num_vars,
257            "trying to add non-zero polynomial with different number of variables"
258        );
259        // simply merge the evaluations
260        let mut evaluations =
261            HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
262        for (&i, &v) in self.evaluations.iter().chain(rhs.evaluations.iter()) {
263            *(evaluations.entry(i).or_insert(F::zero())) += v;
264        }
265        let evaluations: Vec<_> = evaluations
266            .into_iter()
267            .filter(|(_, v)| !v.is_zero())
268            .collect();
269
270        Self::Output {
271            evaluations: tuples_to_treemap(&evaluations),
272            num_vars: self.num_vars,
273            zero: F::zero(),
274        }
275    }
276}
277
278impl<F: Field> AddAssign for SparseMultilinearExtension<F> {
279    fn add_assign(&mut self, other: Self) {
280        *self = &*self + &other;
281    }
282}
283
284impl<'a, F: Field> AddAssign<&'a Self> for SparseMultilinearExtension<F> {
285    fn add_assign(&mut self, other: &'a Self) {
286        *self = &*self + other;
287    }
288}
289
290impl<'a, F: Field> AddAssign<(F, &'a Self)> for SparseMultilinearExtension<F> {
291    fn add_assign(&mut self, (f, other): (F, &'a Self)) {
292        if !self.is_zero() && !other.is_zero() {
293            assert_eq!(
294                other.num_vars, self.num_vars,
295                "trying to add non-zero polynomial with different number of variables"
296            );
297        }
298        let ev: Vec<_> = cfg_iter!(other.evaluations)
299            .map(|(i, v)| (*i, f * v))
300            .collect();
301        let other = Self {
302            num_vars: other.num_vars,
303            evaluations: tuples_to_treemap(&ev),
304            zero: F::zero(),
305        };
306        *self += &other;
307    }
308}
309
310impl<F: Field> Neg for SparseMultilinearExtension<F> {
311    type Output = Self;
312
313    fn neg(self) -> Self::Output {
314        let ev: Vec<_> = cfg_iter!(self.evaluations)
315            .map(|(i, v)| (*i, -*v))
316            .collect();
317        Self::Output {
318            num_vars: self.num_vars,
319            evaluations: tuples_to_treemap(&ev),
320            zero: F::zero(),
321        }
322    }
323}
324
325impl<F: Field> Sub for SparseMultilinearExtension<F> {
326    type Output = Self;
327
328    fn sub(self, other: Self) -> Self {
329        &self - &other
330    }
331}
332
333impl<'a, F: Field> Sub<&'a SparseMultilinearExtension<F>> for &SparseMultilinearExtension<F> {
334    type Output = SparseMultilinearExtension<F>;
335
336    fn sub(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
337        self + &rhs.clone().neg()
338    }
339}
340
341impl<F: Field> SubAssign for SparseMultilinearExtension<F> {
342    fn sub_assign(&mut self, other: Self) {
343        *self = &*self - &other;
344    }
345}
346
347impl<'a, F: Field> SubAssign<&'a Self> for SparseMultilinearExtension<F> {
348    fn sub_assign(&mut self, other: &'a Self) {
349        *self = &*self - other;
350    }
351}
352
353impl<F: Field> Zero for SparseMultilinearExtension<F> {
354    fn zero() -> Self {
355        Self {
356            num_vars: 0,
357            evaluations: tuples_to_treemap(&Vec::new()),
358            zero: F::zero(),
359        }
360    }
361
362    fn is_zero(&self) -> bool {
363        self.num_vars == 0 && self.evaluations.is_empty()
364    }
365}
366
367impl<F: Field> Debug for SparseMultilinearExtension<F> {
368    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
369        write!(
370            f,
371            "SparseMultilinearPolynomial(num_vars = {}, evaluations = [",
372            self.num_vars
373        )?;
374        let mut ev_iter = self.evaluations.iter();
375        for _ in 0..ark_std::cmp::min(8, self.evaluations.len()) {
376            write!(f, "{:?}", ev_iter.next())?;
377        }
378        if self.evaluations.len() > 8 {
379            write!(f, "...")?;
380        }
381        write!(f, "])")?;
382        Ok(())
383    }
384}
385
386/// Utility: Convert tuples to hashmap.
387fn tuples_to_treemap<F: Field>(tuples: &[(usize, F)]) -> BTreeMap<usize, F> {
388    tuples.iter().map(|(i, v)| (*i, *v)).collect()
389}
390
391fn treemap_to_hashmap<F: Field>(
392    map: &BTreeMap<usize, F>,
393) -> HashMap<usize, F, core::hash::BuildHasherDefault<DefaultHasher>> {
394    map.iter().map(|(i, v)| (*i, *v)).collect()
395}
396
397fn hashmap_to_treemap<F: Field, S>(map: &HashMap<usize, F, S>) -> BTreeMap<usize, F> {
398    map.iter().map(|(i, v)| (*i, *v)).collect()
399}
400
401#[cfg(test)]
402mod tests {
403    use crate::{
404        evaluations::multivariate::multilinear::MultilinearExtension, Polynomial,
405        SparseMultilinearExtension,
406    };
407    use ark_ff::{One, Zero};
408    use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
409    use ark_std::{ops::Neg, test_rng, vec, vec::*, UniformRand};
410    use ark_test_curves::bls12_381::Fr;
411    /// Some sanity test to ensure random sparse polynomial make sense.
412    #[test]
413    fn random_poly() {
414        const NV: usize = 16;
415
416        let mut rng = test_rng();
417        // two random poly should be different
418        let poly1 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
419        let poly2 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
420        assert_ne!(poly1, poly2);
421        // test sparsity
422        assert!(
423            ((1 << (NV / 2)) >> 1) <= poly1.evaluations.len()
424                && poly1.evaluations.len() <= ((1 << (NV / 2)) << 1),
425            "polynomial size out of range: expected: [{},{}] ,actual: {}",
426            ((1 << (NV / 2)) >> 1),
427            ((1 << (NV / 2)) << 1),
428            poly1.evaluations.len()
429        );
430    }
431
432    #[test]
433    /// Test if sparse multilinear polynomial evaluates correctly.
434    /// This function assumes dense multilinear polynomial functions correctly.
435    fn evaluate() {
436        const NV: usize = 12;
437        let mut rng = test_rng();
438        for _ in 0..20 {
439            let sparse = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
440            let dense = sparse.to_dense_multilinear_extension();
441            let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
442            assert_eq!(sparse.evaluate(&point), dense.evaluate(&point));
443            let sparse_partial = sparse.fix_variables(&point[..3]);
444            let dense_partial = dense.fix_variables(&point[..3]);
445            let point2: Vec<_> = (0..(NV - 3)).map(|_| Fr::rand(&mut rng)).collect();
446            assert_eq!(
447                sparse_partial.evaluate(&point2),
448                dense_partial.evaluate(&point2)
449            );
450        }
451    }
452
453    #[test]
454    fn sparse_to_evaluations_matches_to_dense() {
455        let mut rng = test_rng();
456        const NV: usize = 8; // 2^8 = 256, small and fast
457
458        for _ in 0..25 {
459            // Make a sparse poly with ~sqrt(2^NV) non-zeros at random indices.
460            let sparse = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
461            let dense_via_sparse = sparse.to_dense_multilinear_extension().evaluations;
462            let dense_via_to_evals = sparse.to_evaluations();
463            assert_eq!(
464                dense_via_to_evals, dense_via_sparse,
465                "to_evaluations must reproduce the dense vector exactly"
466            );
467        }
468    }
469
470    #[test]
471    fn evaluate_edge_cases() {
472        // test constant polynomial
473        let mut rng = test_rng();
474        let ev1 = Fr::rand(&mut rng);
475        let poly1 = SparseMultilinearExtension::from_evaluations(0, &vec![(0, ev1)]);
476        assert_eq!(poly1.evaluate(&[].into()), ev1);
477
478        // test single-variate polynomial
479        let ev2 = [Fr::rand(&mut rng), Fr::rand(&mut rng)];
480        let poly2 =
481            SparseMultilinearExtension::from_evaluations(1, &vec![(0, ev2[0]), (1, ev2[1])]);
482
483        let x = Fr::rand(&mut rng);
484        assert_eq!(
485            poly2.evaluate(&[x].into()),
486            x * ev2[1] + (Fr::one() - x) * ev2[0]
487        );
488
489        // test single-variate polynomial with one entry missing
490        let ev3 = Fr::rand(&mut rng);
491        let poly2 = SparseMultilinearExtension::from_evaluations(1, &vec![(1, ev3)]);
492
493        let x = Fr::rand(&mut rng);
494        assert_eq!(poly2.evaluate(&[x].into()), x * ev3);
495    }
496
497    #[test]
498    fn index() {
499        let mut rng = test_rng();
500        let points = vec![
501            (11, Fr::rand(&mut rng)),
502            (117, Fr::rand(&mut rng)),
503            (213, Fr::rand(&mut rng)),
504            (255, Fr::rand(&mut rng)),
505        ];
506        let poly = SparseMultilinearExtension::from_evaluations(8, &points);
507        points
508            .into_iter()
509            .map(|(i, v)| assert_eq!(poly[i], v))
510            .next_back();
511        assert_eq!(poly[0], Fr::zero());
512        assert_eq!(poly[1], Fr::zero());
513    }
514
515    #[test]
516    fn arithmetic() {
517        const NV: usize = 18;
518        let mut rng = test_rng();
519        for _ in 0..20 {
520            let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
521            let poly1 = SparseMultilinearExtension::rand(NV, &mut rng);
522            let poly2 = SparseMultilinearExtension::rand(NV, &mut rng);
523            let v1 = poly1.evaluate(&point);
524            let v2 = poly2.evaluate(&point);
525            // test add
526            assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
527            // test sub
528            assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
529            // test negate
530            assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
531            // test add assign
532            {
533                let mut poly1 = poly1.clone();
534                poly1 += &poly2;
535                assert_eq!(poly1.evaluate(&point), v1 + v2)
536            }
537            // test sub assign
538            {
539                let mut poly1 = poly1.clone();
540                poly1 -= &poly2;
541                assert_eq!(poly1.evaluate(&point), v1 - v2)
542            }
543            // test add assign with scalar
544            {
545                let mut poly1 = poly1.clone();
546                let scalar = Fr::rand(&mut rng);
547                poly1 += (scalar, &poly2);
548                assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
549            }
550            // test additive identity
551            {
552                assert_eq!(&poly1 + &SparseMultilinearExtension::zero(), poly1);
553                assert_eq!(&SparseMultilinearExtension::zero() + &poly1, poly1);
554                {
555                    let mut poly1_cloned = poly1.clone();
556                    poly1_cloned += &SparseMultilinearExtension::zero();
557                    assert_eq!(&poly1_cloned, &poly1);
558                    let mut zero = SparseMultilinearExtension::zero();
559                    let scalar = Fr::rand(&mut rng);
560                    zero += (scalar, &poly1);
561                    assert_eq!(zero.evaluate(&point), scalar * v1);
562                }
563            }
564        }
565    }
566
567    #[test]
568    fn relabel() {
569        let mut rng = test_rng();
570        for _ in 0..20 {
571            let mut poly = SparseMultilinearExtension::rand(10, &mut rng);
572            let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
573
574            let expected = poly.evaluate(&point);
575
576            poly = poly.relabel(2, 2, 1); // should have no effect
577            assert_eq!(expected, poly.evaluate(&point));
578
579            poly = poly.relabel(3, 4, 1); // should switch 3 and 4
580            point.swap(3, 4);
581            assert_eq!(expected, poly.evaluate(&point));
582
583            poly = poly.relabel(7, 5, 1);
584            point.swap(7, 5);
585            assert_eq!(expected, poly.evaluate(&point));
586
587            poly = poly.relabel(2, 5, 3);
588            point.swap(2, 5);
589            point.swap(3, 6);
590            point.swap(4, 7);
591            assert_eq!(expected, poly.evaluate(&point));
592
593            poly = poly.relabel(7, 0, 2);
594            point.swap(0, 7);
595            point.swap(1, 8);
596            assert_eq!(expected, poly.evaluate(&point));
597        }
598    }
599
600    #[test]
601    fn serialize() {
602        let mut rng = test_rng();
603        for _ in 0..20 {
604            let mut buf = Vec::new();
605            let poly = SparseMultilinearExtension::<Fr>::rand(10, &mut rng);
606            let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
607            let expected = poly.evaluate(&point);
608
609            poly.serialize_compressed(&mut buf).unwrap();
610
611            let poly2: SparseMultilinearExtension<Fr> =
612                SparseMultilinearExtension::deserialize_compressed(&buf[..]).unwrap();
613            assert_eq!(poly2.evaluate(&point), expected);
614        }
615    }
616}