1use crate::{
4 evaluations::multivariate::multilinear::swap_bits, DenseMultilinearExtension,
5 MultilinearExtension, Polynomial,
6};
7use ark_ff::{Field, Zero};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use ark_std::{
10 cfg_iter,
11 collections::BTreeMap,
12 fmt::{self, Debug, Formatter},
13 ops::{Add, AddAssign, Index, Neg, Sub, SubAssign},
14 rand::Rng,
15 vec,
16 vec::*,
17 UniformRand,
18};
19use hashbrown::HashMap;
20#[cfg(feature = "parallel")]
21use rayon::prelude::*;
22
23use super::DefaultHasher;
24
25#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
27pub struct SparseMultilinearExtension<F: Field> {
28 pub evaluations: BTreeMap<usize, F>,
30 pub num_vars: usize,
32 zero: F,
33}
34
35impl<F: Field> SparseMultilinearExtension<F> {
36 pub fn from_evaluations<'a>(
37 num_vars: usize,
38 evaluations: impl IntoIterator<Item = &'a (usize, F)>,
39 ) -> Self {
40 let bit_mask = 1 << num_vars;
41 let evaluations = evaluations.into_iter();
43 let evaluations: Vec<_> = evaluations
44 .map(|(i, v): &(usize, F)| {
45 assert!(*i < bit_mask, "index out of range");
46 (*i, *v)
47 })
48 .collect();
49
50 Self {
51 evaluations: tuples_to_treemap(&evaluations),
52 num_vars,
53 zero: F::zero(),
54 }
55 }
56
57 pub fn rand_with_config<R: Rng>(
66 num_vars: usize,
67 num_nonzero_entries: usize,
68 rng: &mut R,
69 ) -> Self {
70 assert!(num_nonzero_entries <= (1 << num_vars));
71
72 let mut map =
73 HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
74 for _ in 0..num_nonzero_entries {
75 let mut index = usize::rand(rng) & ((1usize << num_vars) - 1);
76 while map.get(&index).is_some() {
77 index = usize::rand(rng) & ((1usize << num_vars) - 1);
78 }
79 map.entry(index).or_insert(F::rand(rng));
80 }
81 let evaluations = hashmap_to_treemap(&map);
82 Self {
83 num_vars,
84 evaluations,
85 zero: F::zero(),
86 }
87 }
88
89 pub fn to_dense_multilinear_extension(&self) -> DenseMultilinearExtension<F> {
91 let mut evaluations: Vec<_> = (0..(1usize << self.num_vars)).map(|_| F::zero()).collect();
92 for (&i, &v) in &self.evaluations {
93 evaluations[i] = v;
94 }
95 DenseMultilinearExtension::from_evaluations_vec(self.num_vars, evaluations)
96 }
97}
98
99fn precompute_eq<F: Field>(g: &[F]) -> Vec<F> {
101 let dim = g.len();
102 let mut dp = vec![F::zero(); 1 << dim];
103 dp[0] = F::one() - g[0];
104 dp[1] = g[0];
105 for i in 1..dim {
106 for b in 0..(1 << i) {
107 let prev = dp[b];
108 dp[b + (1 << i)] = prev * g[i];
109 dp[b] = prev - dp[b + (1 << i)];
110 }
111 }
112 dp
113}
114
115impl<F: Field> MultilinearExtension<F> for SparseMultilinearExtension<F> {
116 fn num_vars(&self) -> usize {
117 self.num_vars
118 }
119
120 fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
125 Self::rand_with_config(num_vars, 1usize << (num_vars / 2), rng)
126 }
127
128 fn relabel(&self, mut a: usize, mut b: usize, k: usize) -> Self {
129 if a > b {
130 core::mem::swap(&mut a, &mut b);
132 }
133 assert!(
135 a + k < self.num_vars && b + k < self.num_vars,
136 "invalid relabel argument"
137 );
138 if a == b || k == 0 {
139 return self.clone();
140 }
141 assert!(a + k <= b, "overlapped swap window is not allowed");
142 let ev: Vec<_> = cfg_iter!(self.evaluations)
143 .map(|(&i, &v)| (swap_bits(i, a, b, k), v))
144 .collect();
145 Self {
146 num_vars: self.num_vars,
147 evaluations: tuples_to_treemap(&ev),
148 zero: F::zero(),
149 }
150 }
151
152 fn fix_variables(&self, partial_point: &[F]) -> Self {
153 let dim = partial_point.len();
154 assert!(dim <= self.num_vars, "invalid partial point dimension");
155
156 let mut window = ark_std::log2(self.evaluations.len()) as usize;
157 if window == 0 {
158 window = 1;
159 }
160 let mut point = partial_point;
161 let mut last = treemap_to_hashmap(&self.evaluations);
162
163 while !point.is_empty() {
165 let focus_length = if point.len() > window {
166 window
167 } else {
168 point.len()
169 };
170 let focus = &point[..focus_length];
171 point = &point[focus_length..];
172 let pre = precompute_eq(focus);
173 let dim = focus.len();
174 let mut result =
175 HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
176 for src_entry in &last {
177 let old_idx = *src_entry.0;
178 let gz = pre[old_idx & ((1 << dim) - 1)];
179 let new_idx = old_idx >> dim;
180 let dst_entry = result.entry(new_idx).or_insert(F::zero());
181 *dst_entry += gz * src_entry.1;
182 }
183 last = result;
184 }
185 let evaluations = hashmap_to_treemap(&last);
186 Self {
187 num_vars: self.num_vars - dim,
188 evaluations,
189 zero: F::zero(),
190 }
191 }
192
193 fn to_evaluations(&self) -> Vec<F> {
194 let mut evaluations = vec![F::zero(); 1 << self.num_vars];
195 self.evaluations
196 .iter()
197 .for_each(|(&i, &v)| evaluations[i] = v);
198 evaluations
199 }
200}
201
202impl<F: Field> Index<usize> for SparseMultilinearExtension<F> {
203 type Output = F;
204
205 fn index(&self, index: usize) -> &Self::Output {
214 if let Some(v) = self.evaluations.get(&index) {
215 v
216 } else {
217 &self.zero
218 }
219 }
220}
221
222impl<F: Field> Polynomial<F> for SparseMultilinearExtension<F> {
223 type Point = Vec<F>;
224
225 fn degree(&self) -> usize {
226 self.num_vars
227 }
228
229 fn evaluate(&self, point: &Self::Point) -> F {
230 assert!(point.len() == self.num_vars);
231 self.fix_variables(point)[0]
232 }
233}
234
235impl<F: Field> Add for SparseMultilinearExtension<F> {
236 type Output = Self;
237
238 fn add(self, other: Self) -> Self {
239 &self + &other
240 }
241}
242
243impl<'a, F: Field> Add<&'a SparseMultilinearExtension<F>> for &SparseMultilinearExtension<F> {
244 type Output = SparseMultilinearExtension<F>;
245
246 fn add(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
247 if self.is_zero() {
249 return rhs.clone();
250 }
251 if rhs.is_zero() {
252 return self.clone();
253 }
254
255 assert_eq!(
256 rhs.num_vars, self.num_vars,
257 "trying to add non-zero polynomial with different number of variables"
258 );
259 let mut evaluations =
261 HashMap::with_hasher(core::hash::BuildHasherDefault::<DefaultHasher>::default());
262 for (&i, &v) in self.evaluations.iter().chain(rhs.evaluations.iter()) {
263 *(evaluations.entry(i).or_insert(F::zero())) += v;
264 }
265 let evaluations: Vec<_> = evaluations
266 .into_iter()
267 .filter(|(_, v)| !v.is_zero())
268 .collect();
269
270 Self::Output {
271 evaluations: tuples_to_treemap(&evaluations),
272 num_vars: self.num_vars,
273 zero: F::zero(),
274 }
275 }
276}
277
278impl<F: Field> AddAssign for SparseMultilinearExtension<F> {
279 fn add_assign(&mut self, other: Self) {
280 *self = &*self + &other;
281 }
282}
283
284impl<'a, F: Field> AddAssign<&'a Self> for SparseMultilinearExtension<F> {
285 fn add_assign(&mut self, other: &'a Self) {
286 *self = &*self + other;
287 }
288}
289
290impl<'a, F: Field> AddAssign<(F, &'a Self)> for SparseMultilinearExtension<F> {
291 fn add_assign(&mut self, (f, other): (F, &'a Self)) {
292 if !self.is_zero() && !other.is_zero() {
293 assert_eq!(
294 other.num_vars, self.num_vars,
295 "trying to add non-zero polynomial with different number of variables"
296 );
297 }
298 let ev: Vec<_> = cfg_iter!(other.evaluations)
299 .map(|(i, v)| (*i, f * v))
300 .collect();
301 let other = Self {
302 num_vars: other.num_vars,
303 evaluations: tuples_to_treemap(&ev),
304 zero: F::zero(),
305 };
306 *self += &other;
307 }
308}
309
310impl<F: Field> Neg for SparseMultilinearExtension<F> {
311 type Output = Self;
312
313 fn neg(self) -> Self::Output {
314 let ev: Vec<_> = cfg_iter!(self.evaluations)
315 .map(|(i, v)| (*i, -*v))
316 .collect();
317 Self::Output {
318 num_vars: self.num_vars,
319 evaluations: tuples_to_treemap(&ev),
320 zero: F::zero(),
321 }
322 }
323}
324
325impl<F: Field> Sub for SparseMultilinearExtension<F> {
326 type Output = Self;
327
328 fn sub(self, other: Self) -> Self {
329 &self - &other
330 }
331}
332
333impl<'a, F: Field> Sub<&'a SparseMultilinearExtension<F>> for &SparseMultilinearExtension<F> {
334 type Output = SparseMultilinearExtension<F>;
335
336 fn sub(self, rhs: &'a SparseMultilinearExtension<F>) -> Self::Output {
337 self + &rhs.clone().neg()
338 }
339}
340
341impl<F: Field> SubAssign for SparseMultilinearExtension<F> {
342 fn sub_assign(&mut self, other: Self) {
343 *self = &*self - &other;
344 }
345}
346
347impl<'a, F: Field> SubAssign<&'a Self> for SparseMultilinearExtension<F> {
348 fn sub_assign(&mut self, other: &'a Self) {
349 *self = &*self - other;
350 }
351}
352
353impl<F: Field> Zero for SparseMultilinearExtension<F> {
354 fn zero() -> Self {
355 Self {
356 num_vars: 0,
357 evaluations: tuples_to_treemap(&Vec::new()),
358 zero: F::zero(),
359 }
360 }
361
362 fn is_zero(&self) -> bool {
363 self.num_vars == 0 && self.evaluations.is_empty()
364 }
365}
366
367impl<F: Field> Debug for SparseMultilinearExtension<F> {
368 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
369 write!(
370 f,
371 "SparseMultilinearPolynomial(num_vars = {}, evaluations = [",
372 self.num_vars
373 )?;
374 let mut ev_iter = self.evaluations.iter();
375 for _ in 0..ark_std::cmp::min(8, self.evaluations.len()) {
376 write!(f, "{:?}", ev_iter.next())?;
377 }
378 if self.evaluations.len() > 8 {
379 write!(f, "...")?;
380 }
381 write!(f, "])")?;
382 Ok(())
383 }
384}
385
386fn tuples_to_treemap<F: Field>(tuples: &[(usize, F)]) -> BTreeMap<usize, F> {
388 tuples.iter().map(|(i, v)| (*i, *v)).collect()
389}
390
391fn treemap_to_hashmap<F: Field>(
392 map: &BTreeMap<usize, F>,
393) -> HashMap<usize, F, core::hash::BuildHasherDefault<DefaultHasher>> {
394 map.iter().map(|(i, v)| (*i, *v)).collect()
395}
396
397fn hashmap_to_treemap<F: Field, S>(map: &HashMap<usize, F, S>) -> BTreeMap<usize, F> {
398 map.iter().map(|(i, v)| (*i, *v)).collect()
399}
400
401#[cfg(test)]
402mod tests {
403 use crate::{
404 evaluations::multivariate::multilinear::MultilinearExtension, Polynomial,
405 SparseMultilinearExtension,
406 };
407 use ark_ff::{One, Zero};
408 use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
409 use ark_std::{ops::Neg, test_rng, vec, vec::*, UniformRand};
410 use ark_test_curves::bls12_381::Fr;
411 #[test]
413 fn random_poly() {
414 const NV: usize = 16;
415
416 let mut rng = test_rng();
417 let poly1 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
419 let poly2 = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
420 assert_ne!(poly1, poly2);
421 assert!(
423 ((1 << (NV / 2)) >> 1) <= poly1.evaluations.len()
424 && poly1.evaluations.len() <= ((1 << (NV / 2)) << 1),
425 "polynomial size out of range: expected: [{},{}] ,actual: {}",
426 ((1 << (NV / 2)) >> 1),
427 ((1 << (NV / 2)) << 1),
428 poly1.evaluations.len()
429 );
430 }
431
432 #[test]
433 fn evaluate() {
436 const NV: usize = 12;
437 let mut rng = test_rng();
438 for _ in 0..20 {
439 let sparse = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
440 let dense = sparse.to_dense_multilinear_extension();
441 let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
442 assert_eq!(sparse.evaluate(&point), dense.evaluate(&point));
443 let sparse_partial = sparse.fix_variables(&point[..3]);
444 let dense_partial = dense.fix_variables(&point[..3]);
445 let point2: Vec<_> = (0..(NV - 3)).map(|_| Fr::rand(&mut rng)).collect();
446 assert_eq!(
447 sparse_partial.evaluate(&point2),
448 dense_partial.evaluate(&point2)
449 );
450 }
451 }
452
453 #[test]
454 fn sparse_to_evaluations_matches_to_dense() {
455 let mut rng = test_rng();
456 const NV: usize = 8; for _ in 0..25 {
459 let sparse = SparseMultilinearExtension::<Fr>::rand(NV, &mut rng);
461 let dense_via_sparse = sparse.to_dense_multilinear_extension().evaluations;
462 let dense_via_to_evals = sparse.to_evaluations();
463 assert_eq!(
464 dense_via_to_evals, dense_via_sparse,
465 "to_evaluations must reproduce the dense vector exactly"
466 );
467 }
468 }
469
470 #[test]
471 fn evaluate_edge_cases() {
472 let mut rng = test_rng();
474 let ev1 = Fr::rand(&mut rng);
475 let poly1 = SparseMultilinearExtension::from_evaluations(0, &vec![(0, ev1)]);
476 assert_eq!(poly1.evaluate(&[].into()), ev1);
477
478 let ev2 = [Fr::rand(&mut rng), Fr::rand(&mut rng)];
480 let poly2 =
481 SparseMultilinearExtension::from_evaluations(1, &vec![(0, ev2[0]), (1, ev2[1])]);
482
483 let x = Fr::rand(&mut rng);
484 assert_eq!(
485 poly2.evaluate(&[x].into()),
486 x * ev2[1] + (Fr::one() - x) * ev2[0]
487 );
488
489 let ev3 = Fr::rand(&mut rng);
491 let poly2 = SparseMultilinearExtension::from_evaluations(1, &vec![(1, ev3)]);
492
493 let x = Fr::rand(&mut rng);
494 assert_eq!(poly2.evaluate(&[x].into()), x * ev3);
495 }
496
497 #[test]
498 fn index() {
499 let mut rng = test_rng();
500 let points = vec![
501 (11, Fr::rand(&mut rng)),
502 (117, Fr::rand(&mut rng)),
503 (213, Fr::rand(&mut rng)),
504 (255, Fr::rand(&mut rng)),
505 ];
506 let poly = SparseMultilinearExtension::from_evaluations(8, &points);
507 points
508 .into_iter()
509 .map(|(i, v)| assert_eq!(poly[i], v))
510 .next_back();
511 assert_eq!(poly[0], Fr::zero());
512 assert_eq!(poly[1], Fr::zero());
513 }
514
515 #[test]
516 fn arithmetic() {
517 const NV: usize = 18;
518 let mut rng = test_rng();
519 for _ in 0..20 {
520 let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
521 let poly1 = SparseMultilinearExtension::rand(NV, &mut rng);
522 let poly2 = SparseMultilinearExtension::rand(NV, &mut rng);
523 let v1 = poly1.evaluate(&point);
524 let v2 = poly2.evaluate(&point);
525 assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
527 assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
529 assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
531 {
533 let mut poly1 = poly1.clone();
534 poly1 += &poly2;
535 assert_eq!(poly1.evaluate(&point), v1 + v2)
536 }
537 {
539 let mut poly1 = poly1.clone();
540 poly1 -= &poly2;
541 assert_eq!(poly1.evaluate(&point), v1 - v2)
542 }
543 {
545 let mut poly1 = poly1.clone();
546 let scalar = Fr::rand(&mut rng);
547 poly1 += (scalar, &poly2);
548 assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
549 }
550 {
552 assert_eq!(&poly1 + &SparseMultilinearExtension::zero(), poly1);
553 assert_eq!(&SparseMultilinearExtension::zero() + &poly1, poly1);
554 {
555 let mut poly1_cloned = poly1.clone();
556 poly1_cloned += &SparseMultilinearExtension::zero();
557 assert_eq!(&poly1_cloned, &poly1);
558 let mut zero = SparseMultilinearExtension::zero();
559 let scalar = Fr::rand(&mut rng);
560 zero += (scalar, &poly1);
561 assert_eq!(zero.evaluate(&point), scalar * v1);
562 }
563 }
564 }
565 }
566
567 #[test]
568 fn relabel() {
569 let mut rng = test_rng();
570 for _ in 0..20 {
571 let mut poly = SparseMultilinearExtension::rand(10, &mut rng);
572 let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
573
574 let expected = poly.evaluate(&point);
575
576 poly = poly.relabel(2, 2, 1); assert_eq!(expected, poly.evaluate(&point));
578
579 poly = poly.relabel(3, 4, 1); point.swap(3, 4);
581 assert_eq!(expected, poly.evaluate(&point));
582
583 poly = poly.relabel(7, 5, 1);
584 point.swap(7, 5);
585 assert_eq!(expected, poly.evaluate(&point));
586
587 poly = poly.relabel(2, 5, 3);
588 point.swap(2, 5);
589 point.swap(3, 6);
590 point.swap(4, 7);
591 assert_eq!(expected, poly.evaluate(&point));
592
593 poly = poly.relabel(7, 0, 2);
594 point.swap(0, 7);
595 point.swap(1, 8);
596 assert_eq!(expected, poly.evaluate(&point));
597 }
598 }
599
600 #[test]
601 fn serialize() {
602 let mut rng = test_rng();
603 for _ in 0..20 {
604 let mut buf = Vec::new();
605 let poly = SparseMultilinearExtension::<Fr>::rand(10, &mut rng);
606 let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
607 let expected = poly.evaluate(&point);
608
609 poly.serialize_compressed(&mut buf).unwrap();
610
611 let poly2: SparseMultilinearExtension<Fr> =
612 SparseMultilinearExtension::deserialize_compressed(&buf[..]).unwrap();
613 assert_eq!(poly2.evaluate(&point), expected);
614 }
615 }
616}