1use crate::{
4 evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension},
5 Polynomial,
6};
7use ark_ff::{Field, Zero};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use ark_std::{
10 cfg_iter,
11 fmt::{self, Formatter},
12 iter::IntoIterator,
13 log2,
14 ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign},
15 rand::Rng,
16 slice::{Iter, IterMut},
17 vec,
18 vec::*,
19};
20#[cfg(feature = "parallel")]
21use rayon::prelude::*;
22
23#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
25pub struct DenseMultilinearExtension<F: Field> {
26 pub evaluations: Vec<F>,
28 pub num_vars: usize,
30}
31
32impl<F: Field> DenseMultilinearExtension<F> {
33 pub fn from_evaluations_slice(num_vars: usize, evaluations: &[F]) -> Self {
37 Self::from_evaluations_vec(num_vars, evaluations.to_vec())
38 }
39
40 pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec<F>) -> Self {
59 assert_eq!(
61 evaluations.len(),
62 1 << num_vars,
63 "The size of evaluations should be 2^num_vars."
64 );
65
66 Self {
67 num_vars,
68 evaluations,
69 }
70 }
71 pub fn relabel_in_place(&mut self, mut a: usize, mut b: usize, k: usize) {
77 if a > b {
79 ark_std::mem::swap(&mut a, &mut b);
80 }
81 if a == b || k == 0 {
82 return;
83 }
84 assert!(b + k <= self.num_vars, "invalid relabel argument");
85 assert!(a + k <= b, "overlapped swap window is not allowed");
86 for i in 0..self.evaluations.len() {
87 let j = swap_bits(i, a, b, k);
88 if i < j {
89 self.evaluations.swap(i, j);
90 }
91 }
92 }
93
94 pub fn iter(&self) -> Iter<'_, F> {
96 self.evaluations.iter()
97 }
98
99 pub fn iter_mut(&mut self) -> IterMut<'_, F> {
101 self.evaluations.iter_mut()
102 }
103
104 pub fn concat(polys: impl IntoIterator<Item = impl AsRef<Self>> + Clone) -> Self {
134 let polys_iter_cloned = polys.clone().into_iter();
137
138 let total_len: usize = polys
139 .into_iter()
140 .map(|poly| poly.as_ref().evaluations.len())
141 .sum();
142
143 let next_pow_of_two = total_len.next_power_of_two();
144 let num_vars = log2(next_pow_of_two);
145 let mut evaluations: Vec<F> = Vec::with_capacity(next_pow_of_two);
146
147 for poly in polys_iter_cloned {
148 evaluations.extend_from_slice(poly.as_ref().evaluations.as_slice());
149 }
150
151 evaluations.resize(next_pow_of_two, F::zero());
152
153 Self::from_evaluations_slice(num_vars as usize, &evaluations)
154 }
155}
156
157impl<'a, F: Field> IntoIterator for &'a DenseMultilinearExtension<F> {
158 type IntoIter = ark_std::slice::Iter<'a, F>;
159 type Item = &'a F;
160
161 fn into_iter(self) -> Self::IntoIter {
162 self.iter()
163 }
164}
165
166impl<'a, F: Field> IntoIterator for &'a mut DenseMultilinearExtension<F> {
167 type IntoIter = ark_std::slice::IterMut<'a, F>;
168 type Item = &'a mut F;
169
170 fn into_iter(self) -> Self::IntoIter {
171 self.iter_mut()
172 }
173}
174
175impl<F: Field> AsRef<Self> for DenseMultilinearExtension<F> {
176 fn as_ref(&self) -> &Self {
177 self
178 }
179}
180
181impl<F: Field> MultilinearExtension<F> for DenseMultilinearExtension<F> {
182 fn num_vars(&self) -> usize {
183 self.num_vars
184 }
185
186 fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
187 Self::from_evaluations_vec(
188 num_vars,
189 (0..(1usize << num_vars)).map(|_| F::rand(rng)).collect(),
190 )
191 }
192
193 fn relabel(&self, a: usize, b: usize, k: usize) -> Self {
194 let mut copied = self.clone();
195 copied.relabel_in_place(a, b, k);
196 copied
197 }
198
199 fn fix_variables(&self, partial_point: &[F]) -> Self {
223 assert!(
224 partial_point.len() <= self.num_vars,
225 "invalid size of partial point"
226 );
227 let mut poly = self.evaluations.clone();
228 let nv = self.num_vars;
229 let dim = partial_point.len();
230 for i in 1..dim + 1 {
232 let r = partial_point[i - 1];
233 for b in 0..(1 << (nv - i)) {
234 let left = poly[b << 1];
235 let right = poly[(b << 1) + 1];
236 poly[b] = left + r * (right - left);
237 }
238 }
239 Self::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
240 }
241
242 fn to_evaluations(&self) -> Vec<F> {
243 self.evaluations.clone()
244 }
245}
246
247impl<F: Field> Index<usize> for DenseMultilinearExtension<F> {
248 type Output = F;
249
250 fn index(&self, index: usize) -> &Self::Output {
257 &self.evaluations[index]
258 }
259}
260
261impl<F: Field> Add for DenseMultilinearExtension<F> {
262 type Output = Self;
263
264 fn add(self, other: Self) -> Self {
265 &self + &other
266 }
267}
268
269impl<'a, F: Field> Add<&'a DenseMultilinearExtension<F>> for &DenseMultilinearExtension<F> {
270 type Output = DenseMultilinearExtension<F>;
271
272 fn add(self, rhs: &'a DenseMultilinearExtension<F>) -> Self::Output {
273 if rhs.is_zero() {
275 return self.clone();
276 }
277 if self.is_zero() {
278 return rhs.clone();
279 }
280 assert_eq!(self.num_vars, rhs.num_vars);
281 let result: Vec<F> = cfg_iter!(self.evaluations)
282 .zip(&rhs.evaluations)
283 .map(|(a, b)| *a + *b)
284 .collect();
285
286 Self::Output::from_evaluations_vec(self.num_vars, result)
287 }
288}
289
290impl<F: Field> AddAssign for DenseMultilinearExtension<F> {
291 fn add_assign(&mut self, other: Self) {
292 *self = &*self + &other;
293 }
294}
295
296impl<'a, F: Field> AddAssign<&'a Self> for DenseMultilinearExtension<F> {
297 fn add_assign(&mut self, other: &'a Self) {
298 *self = &*self + other;
299 }
300}
301
302impl<'a, F: Field> AddAssign<(F, &'a Self)> for DenseMultilinearExtension<F> {
303 fn add_assign(&mut self, (f, other): (F, &'a Self)) {
304 let other = Self {
305 num_vars: other.num_vars,
306 evaluations: cfg_iter!(other.evaluations).map(|x| f * x).collect(),
307 };
308 *self = &*self + &other;
309 }
310}
311
312impl<F: Field> Neg for DenseMultilinearExtension<F> {
313 type Output = Self;
314
315 fn neg(self) -> Self::Output {
316 Self::Output {
317 num_vars: self.num_vars,
318 evaluations: cfg_iter!(self.evaluations).map(|x| -*x).collect(),
319 }
320 }
321}
322
323impl<F: Field> Sub for DenseMultilinearExtension<F> {
324 type Output = Self;
325
326 fn sub(self, other: Self) -> Self {
327 &self - &other
328 }
329}
330
331impl<'a, F: Field> Sub<&'a DenseMultilinearExtension<F>> for &DenseMultilinearExtension<F> {
332 type Output = DenseMultilinearExtension<F>;
333
334 fn sub(self, rhs: &'a DenseMultilinearExtension<F>) -> Self::Output {
335 self + &rhs.clone().neg()
336 }
337}
338
339impl<F: Field> SubAssign for DenseMultilinearExtension<F> {
340 fn sub_assign(&mut self, other: Self) {
341 *self = &*self - &other;
342 }
343}
344
345impl<'a, F: Field> SubAssign<&'a Self> for DenseMultilinearExtension<F> {
346 fn sub_assign(&mut self, other: &'a Self) {
347 *self = &*self - other;
348 }
349}
350
351impl<F: Field> Mul<F> for DenseMultilinearExtension<F> {
352 type Output = Self;
353
354 fn mul(self, scalar: F) -> Self::Output {
355 &self * &scalar
356 }
357}
358
359impl<'a, F: Field> Mul<&'a F> for &DenseMultilinearExtension<F> {
360 type Output = DenseMultilinearExtension<F>;
361
362 fn mul(self, scalar: &'a F) -> Self::Output {
363 if scalar.is_zero() {
364 return DenseMultilinearExtension::zero();
365 } else if scalar.is_one() {
366 return self.clone();
367 }
368 let result: Vec<F> = self.evaluations.iter().map(|&x| x * scalar).collect();
369
370 DenseMultilinearExtension {
371 num_vars: self.num_vars,
372 evaluations: result,
373 }
374 }
375}
376
377impl<F: Field> MulAssign<F> for DenseMultilinearExtension<F> {
378 fn mul_assign(&mut self, scalar: F) {
379 *self = &*self * &scalar
380 }
381}
382
383impl<'a, F: Field> MulAssign<&'a F> for DenseMultilinearExtension<F> {
384 fn mul_assign(&mut self, scalar: &'a F) {
385 *self = &*self * scalar
386 }
387}
388
389impl<F: Field> fmt::Debug for DenseMultilinearExtension<F> {
390 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
391 write!(f, "DenseML(nv = {}, evaluations = [", self.num_vars)?;
392 for i in 0..ark_std::cmp::min(4, self.evaluations.len()) {
393 write!(f, "{:?} ", self.evaluations[i])?;
394 }
395 if self.evaluations.len() < 4 {
396 write!(f, "])")?;
397 } else {
398 write!(f, "...])")?;
399 }
400 Ok(())
401 }
402}
403
404impl<F: Field> Zero for DenseMultilinearExtension<F> {
405 fn zero() -> Self {
406 Self {
407 num_vars: 0,
408 evaluations: vec![F::zero()],
409 }
410 }
411
412 fn is_zero(&self) -> bool {
413 self.num_vars == 0 && self.evaluations[0].is_zero()
414 }
415}
416
417impl<F: Field> Polynomial<F> for DenseMultilinearExtension<F> {
418 type Point = Vec<F>;
419
420 fn degree(&self) -> usize {
421 self.num_vars
422 }
423
424 fn evaluate(&self, point: &Self::Point) -> F {
444 assert!(point.len() == self.num_vars);
445 self.fix_variables(point)[0]
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use crate::{DenseMultilinearExtension, MultilinearExtension, Polynomial};
452 use ark_ff::{Field, One, Zero};
453 use ark_std::{ops::Neg, test_rng, vec::*, UniformRand};
454 use ark_test_curves::bls12_381::Fr;
455
456 fn evaluate_data_array<F: Field>(data: &[F], point: &[F]) -> F {
458 assert!(
459 data.len() == (1 << point.len()),
460 "Data size mismatch with number of variables. "
461 );
462
463 let nv = point.len();
464 let mut a = data.to_vec();
465
466 for i in 1..nv + 1 {
467 let r = point[i - 1];
468 for b in 0..(1 << (nv - i)) {
469 a[b] = a[b << 1] * (F::one() - r) + a[(b << 1) + 1] * r;
470 }
471 }
472 a[0]
473 }
474
475 #[test]
476 fn evaluate_at_a_point() {
477 let mut rng = test_rng();
478 let poly = DenseMultilinearExtension::rand(10, &mut rng);
479 for _ in 0..10 {
480 let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
481 assert_eq!(
482 evaluate_data_array(&poly.evaluations, &point),
483 poly.evaluate(&point)
484 )
485 }
486 }
487
488 #[test]
489 fn relabel_polynomial() {
490 let mut rng = test_rng();
491 for _ in 0..20 {
492 let mut poly = DenseMultilinearExtension::rand(10, &mut rng);
493 let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
494
495 let expected = poly.evaluate(&point);
496
497 poly.relabel_in_place(2, 2, 1); assert_eq!(expected, poly.evaluate(&point));
499
500 poly.relabel_in_place(3, 4, 1); point.swap(3, 4);
502 assert_eq!(expected, poly.evaluate(&point));
503
504 poly.relabel_in_place(7, 5, 1);
505 point.swap(7, 5);
506 assert_eq!(expected, poly.evaluate(&point));
507
508 poly.relabel_in_place(2, 5, 3);
509 point.swap(2, 5);
510 point.swap(3, 6);
511 point.swap(4, 7);
512 assert_eq!(expected, poly.evaluate(&point));
513
514 poly.relabel_in_place(7, 0, 2);
515 point.swap(0, 7);
516 point.swap(1, 8);
517 assert_eq!(expected, poly.evaluate(&point));
518
519 poly.relabel_in_place(0, 9, 1);
520 point.swap(0, 9);
521 assert_eq!(expected, poly.evaluate(&point));
522 }
523 }
524
525 #[test]
526 fn arithmetic() {
527 const NV: usize = 10;
528 let mut rng = test_rng();
529 for _ in 0..20 {
530 let scalar = Fr::rand(&mut rng);
531 let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
532 let poly1 = DenseMultilinearExtension::rand(NV, &mut rng);
533 let poly2 = DenseMultilinearExtension::rand(NV, &mut rng);
534 let v1 = poly1.evaluate(&point);
535 let v2 = poly2.evaluate(&point);
536 assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
538 assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
540 assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
542 assert_eq!((&poly1 * &scalar).evaluate(&point), v1 * scalar);
544 {
546 let mut poly1 = poly1.clone();
547 poly1 += &poly2;
548 assert_eq!(poly1.evaluate(&point), v1 + v2)
549 }
550 {
552 let mut poly1 = poly1.clone();
553 poly1 -= &poly2;
554 assert_eq!(poly1.evaluate(&point), v1 - v2)
555 }
556 {
558 let mut poly1 = poly1.clone();
559 let scalar = Fr::rand(&mut rng);
560 poly1 += (scalar, &poly2);
561 assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
562 }
563 {
565 assert_eq!(&poly1 + &DenseMultilinearExtension::zero(), poly1);
566 assert_eq!(&DenseMultilinearExtension::zero() + &poly1, poly1);
567 {
568 let mut poly1_cloned = poly1.clone();
569 poly1_cloned += &DenseMultilinearExtension::zero();
570 assert_eq!(&poly1_cloned, &poly1);
571 let mut zero = DenseMultilinearExtension::zero();
572 let scalar = Fr::rand(&mut rng);
573 zero += (scalar, &poly1);
574 assert_eq!(zero.evaluate(&point), scalar * v1);
575 }
576 }
577 {
579 let mut poly1_cloned = poly1.clone();
580 poly1_cloned *= Fr::one();
581 assert_eq!(poly1_cloned.evaluate(&point), v1);
582 poly1_cloned *= scalar;
583 assert_eq!(poly1_cloned.evaluate(&point), v1 * scalar);
584 poly1_cloned *= Fr::zero();
585 assert_eq!(poly1_cloned, DenseMultilinearExtension::zero());
586 }
587 }
588 }
589
590 #[test]
591 fn concat_two_equal_polys() {
592 let mut rng = test_rng();
593 let degree = 10;
594
595 let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
596 let poly_r = DenseMultilinearExtension::rand(degree, &mut rng);
597
598 let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
599 for _ in 0..10 {
600 let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
601
602 let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
603 + point[10] * poly_r.evaluate(&point[..10].to_vec());
604 assert_eq!(expected, merged.evaluate(&point));
605 }
606 }
607
608 #[test]
609 fn concat_unequal_polys() {
610 let mut rng = test_rng();
611 let degree = 10;
612 let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
613 let poly_r = DenseMultilinearExtension::rand(degree - 1, &mut rng);
615
616 let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
617
618 for _ in 0..10 {
619 let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
620
621 let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
624 + point[10] * ((Fr::ONE - point[9]) * poly_r.evaluate(&point[..9].to_vec()));
625 assert_eq!(expected, merged.evaluate(&point));
626 }
627 }
628
629 #[test]
630 fn concat_two_iterators() {
631 let mut rng = test_rng();
632 let degree = 10;
633
634 let polys_l: Vec<_> = (0..2)
636 .map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
637 .collect();
638 let polys_r: Vec<_> = (0..2)
639 .map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
640 .collect();
641
642 let merged = DenseMultilinearExtension::<Fr>::concat(polys_l.iter().chain(polys_r.iter()));
643
644 for _ in 0..10 {
645 let point: Vec<_> = (0..(degree)).map(|_| Fr::rand(&mut rng)).collect();
646
647 let expected = (Fr::ONE - point[9])
648 * ((Fr::ONE - point[8]) * polys_l[0].evaluate(&point[..8].to_vec())
649 + point[8] * polys_l[1].evaluate(&point[..8].to_vec()))
650 + point[9]
651 * ((Fr::ONE - point[8]) * polys_r[0].evaluate(&point[..8].to_vec())
652 + point[8] * polys_r[1].evaluate(&point[..8].to_vec()));
653
654 assert_eq!(expected, merged.evaluate(&point));
655 }
656 }
657}