use crate::error::Error;
use crate::fft::{EvaluationDomain, Polynomial};
use alloc::vec::Vec;
use core::ops::{Add, Mul};
use dusk_bls12_381::BlsScalar;
use dusk_bytes::{DeserializableSlice, Serializable};
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct MultiSet(pub Vec<BlsScalar>);
impl Default for MultiSet {
fn default() -> Self {
MultiSet::new()
}
}
impl From<&[BlsScalar]> for MultiSet {
fn from(slice: &[BlsScalar]) -> MultiSet {
MultiSet(slice.to_vec())
}
}
impl MultiSet {
pub fn new() -> MultiSet {
MultiSet(vec![])
}
pub fn from_slice(bytes: &[u8]) -> Result<MultiSet, Error> {
let elements = bytes
.chunks(BlsScalar::SIZE)
.map(|chunk| BlsScalar::from_slice(chunk))
.collect::<Result<Vec<BlsScalar>, dusk_bytes::Error>>()?;
Ok(MultiSet(elements))
}
pub fn to_var_bytes(&self) -> Vec<u8> {
self.0
.iter()
.map(|item| item.to_bytes().to_vec())
.flatten()
.collect()
}
pub fn pad(&mut self, n: u32) {
assert!(n.is_power_of_two());
let diff = n - self.len() as u32;
self.0.extend(vec![self.0[0]; diff as usize]);
}
pub fn push(&mut self, value: BlsScalar) {
self.0.push(value)
}
pub fn last(&self) -> Option<&BlsScalar> {
self.0.last()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn position(&self, element: &BlsScalar) -> Option<usize> {
self.0.iter().position(|&x| x == *element)
}
pub fn sorted_concat(&self, f: &MultiSet) -> Result<MultiSet, Error> {
let mut s = self.clone();
s.0.reserve(f.0.len());
for element in f.0.iter() {
let index = s.position(element).ok_or(Error::ElementNotIndexed)?;
s.0.insert(index, *element);
}
Ok(s)
}
pub fn contains_all(&self, other: &MultiSet) -> bool {
other.0.iter().all(|item| self.contains(item))
}
pub fn contains(&self, entry: &BlsScalar) -> bool {
self.0.contains(entry)
}
pub fn halve(&self) -> (MultiSet, MultiSet) {
let length = self.0.len();
let first_half = MultiSet::from(&self.0[0..=length / 2]);
let second_half = MultiSet::from(&self.0[length / 2..]);
(first_half, second_half)
}
pub fn halve_alternating(&self) -> (MultiSet, MultiSet) {
let mut evens = vec![];
let mut odds = vec![];
for i in 0..self.len() {
if i % 2 == 0 {
evens.push(self.0[i]);
} else {
odds.push(self.0[i]);
}
}
(MultiSet(evens), MultiSet(odds))
}
pub(crate) fn to_polynomial(
&self,
domain: &EvaluationDomain,
) -> Polynomial {
Polynomial::from_coefficients_vec(domain.ifft(&self.0))
}
pub fn compress_three_arity(
multisets: [&MultiSet; 3],
alpha: BlsScalar,
) -> MultiSet {
MultiSet(
multisets[0]
.0
.iter()
.zip(multisets[1].0.iter())
.zip(multisets[2].0.iter())
.map(|((a, b), c)| a + b * alpha + c * alpha.square())
.collect::<Vec<BlsScalar>>(),
)
}
pub fn compress_four_arity(
multisets: [&MultiSet; 4],
alpha: BlsScalar,
) -> MultiSet {
MultiSet(
multisets[0]
.0
.iter()
.zip(multisets[1].0.iter())
.zip(multisets[2].0.iter())
.zip(multisets[3].0.iter())
.map(|(((a, b), c), d)| {
a + b * alpha
+ c * alpha.square()
+ d * alpha.pow(&[3u64, 0u64, 0u64, 0u64])
})
.collect::<Vec<BlsScalar>>(),
)
}
}
impl Add for MultiSet {
type Output = MultiSet;
fn add(self, other: MultiSet) -> Self::Output {
let result = self
.0
.into_iter()
.zip(other.0.iter())
.map(|(x, y)| x + y)
.collect();
MultiSet(result)
}
}
impl Mul for MultiSet {
type Output = MultiSet;
fn mul(self, other: MultiSet) -> Self::Output {
let result = self
.0
.into_iter()
.zip(other.0.iter())
.map(|(x, y)| x * y)
.collect();
MultiSet(result)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::fft::EvaluationDomain;
use crate::plonkup::WitnessTable;
#[test]
fn test_halve() {
let mut s = MultiSet::new();
s.push(BlsScalar::from(0));
s.push(BlsScalar::from(1));
s.push(BlsScalar::from(2));
s.push(BlsScalar::from(3));
s.push(BlsScalar::from(4));
s.push(BlsScalar::from(5));
s.push(BlsScalar::from(6));
let (h_1, h_2) = s.halve();
assert_eq!(h_1.len(), 4);
assert_eq!(h_2.len(), 4);
let left_half = MultiSet(vec![
BlsScalar::from(0),
BlsScalar::from(1),
BlsScalar::from(2),
BlsScalar::from(3),
]);
assert_eq!(left_half, h_1);
let right_half = MultiSet(vec![
BlsScalar::from(3),
BlsScalar::from(4),
BlsScalar::from(5),
BlsScalar::from(6),
]);
assert_eq!(right_half, h_2);
assert_eq!(h_1.0.last().unwrap(), &h_2.0[0])
}
#[test]
fn test_to_polynomial() {
let mut s = MultiSet::new();
s.push(BlsScalar::from(1));
s.push(BlsScalar::from(2));
s.push(BlsScalar::from(3));
s.push(BlsScalar::from(4));
s.push(BlsScalar::from(5));
s.push(BlsScalar::from(6));
s.push(BlsScalar::from(7));
let domain = EvaluationDomain::new(s.len() + 1).unwrap();
let s_poly = s.to_polynomial(&domain);
assert_eq!(s_poly.degree(), 7)
}
#[test]
fn test_is_subset() {
let mut t = MultiSet::new();
t.push(BlsScalar::from(1));
t.push(BlsScalar::from(2));
t.push(BlsScalar::from(3));
t.push(BlsScalar::from(4));
t.push(BlsScalar::from(5));
t.push(BlsScalar::from(6));
t.push(BlsScalar::from(7));
let mut f = MultiSet::new();
f.push(BlsScalar::from(1));
f.push(BlsScalar::from(2));
let mut n = MultiSet::new();
n.push(BlsScalar::from(8));
assert!(t.contains_all(&f));
assert!(!t.contains_all(&n));
}
#[test]
fn test_full_compression_into_s() {
let mut t = MultiSet::new();
t.push(BlsScalar::zero());
t.push(BlsScalar::one());
t.push(BlsScalar::from(2));
t.push(BlsScalar::from(3));
t.push(BlsScalar::from(4));
t.push(BlsScalar::from(5));
t.push(BlsScalar::from(6));
t.push(BlsScalar::from(7));
let mut f = MultiSet::new();
f.push(BlsScalar::from(3));
f.push(BlsScalar::from(6));
f.push(BlsScalar::from(0));
f.push(BlsScalar::from(5));
f.push(BlsScalar::from(4));
f.push(BlsScalar::from(3));
f.push(BlsScalar::from(2));
f.push(BlsScalar::from(0));
f.push(BlsScalar::from(0));
f.push(BlsScalar::from(1));
f.push(BlsScalar::from(2));
assert!(t.contains_all(&f));
assert!(t.contains(&BlsScalar::from(2)));
let s = t.sorted_concat(&f);
let concatenated_set = MultiSet(vec![
BlsScalar::zero(),
BlsScalar::zero(),
BlsScalar::zero(),
BlsScalar::zero(),
BlsScalar::one(),
BlsScalar::one(),
BlsScalar::from(2),
BlsScalar::from(2),
BlsScalar::from(2),
BlsScalar::from(3),
BlsScalar::from(3),
BlsScalar::from(3),
BlsScalar::from(4),
BlsScalar::from(4),
BlsScalar::from(5),
BlsScalar::from(5),
BlsScalar::from(6),
BlsScalar::from(6),
BlsScalar::from(7),
]);
assert_eq!(s.unwrap(), concatenated_set);
}
#[test]
fn multiset_compression_input() {
let alpha = BlsScalar::from(2);
let alpha_squared = alpha * alpha;
let mut table = WitnessTable::default();
table.from_wire_values(
BlsScalar::from(1),
BlsScalar::from(2),
BlsScalar::from(3),
BlsScalar::from(4),
);
let compressed_element = MultiSet::compress_three_arity(
[&table.f_1, &table.f_2, &table.f_3],
alpha,
);
let actual_element = BlsScalar::from(1)
+ (BlsScalar::from(2) * alpha)
+ (BlsScalar::from(3) * alpha_squared);
let mut actual_set = MultiSet::new();
actual_set.push(actual_element);
assert_eq!(actual_set, compressed_element);
}
}