use crate::{
errors::AlgebraError,
prelude::*,
rand::{CryptoRng, RngCore},
};
use ark_ff::FftField;
use ark_std::fmt::Debug;
use digest::{generic_array::typenum::U64, Digest};
use num_bigint::BigUint;
use serde::{Deserialize, Serialize};
pub trait Scalar:
Copy
+ Default
+ Debug
+ PartialEq
+ Eq
+ Serialize
+ for<'de> Deserialize<'de>
+ Into<BigUint>
+ for<'a> From<&'a BigUint>
+ Clone
+ One
+ Zero
+ Sized
+ Add<Self, Output = Self>
+ Mul<Self, Output = Self>
+ Sum<Self>
+ for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ for<'a> Mul<&'a Self, Output = Self>
+ for<'a> MulAssign<&'a Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ for<'a> Sum<&'a Self>
+ From<u32>
+ From<u64>
+ Neg<Output = Self>
+ Sync
+ Send
{
fn random<R: CryptoRng + RngCore>(rng: &mut R) -> Self;
fn from_hash<D>(hash: D) -> Self
where
D: Digest<OutputSize = U64> + Default;
fn multiplicative_generator() -> Self;
fn capacity() -> usize;
fn get_field_size_le_bytes() -> Vec<u8>;
fn get_field_size_biguint() -> BigUint;
fn field_size_minus_one_half() -> Vec<u8> {
let mut q_minus_1_half_le = Self::get_field_size_le_bytes();
shift_u8_vec(&mut q_minus_1_half_le);
q_minus_1_half_le
}
fn get_little_endian_u64(&self) -> Vec<u64>;
fn bytes_len() -> usize;
fn to_bytes(&self) -> Vec<u8>;
fn from_bytes(bytes: &[u8]) -> Result<Self>;
fn inv(&self) -> Result<Self>;
fn square(&self) -> Self;
fn pow(&self, exponent: &[u64]) -> Self {
let mut base = self.clone();
let mut result = Self::one();
for exp_u64 in exponent {
let mut e = *exp_u64;
for _ in 0..64 {
if e % 2 == 1 {
result.mul_assign(&base);
}
base = base.mul(&base);
e >>= 1;
}
}
result
}
}
pub trait Domain: Scalar {
type Field: FftField;
fn get_field(&self) -> Self::Field;
fn from_field(field: Self::Field) -> Self;
}
pub trait Group:
Debug
+ Default
+ Copy
+ Sized
+ PartialEq
+ Eq
+ Clone
+ for<'a> Add<&'a Self, Output = Self>
+ for<'a> Mul<&'a Self::ScalarType, Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ for<'a> SubAssign<&'a Self>
+ Serialize
+ Neg
+ for<'de> Deserialize<'de>
{
type ScalarType: Scalar;
const COMPRESSED_LEN: usize;
fn double(&self) -> Self;
fn get_identity() -> Self;
fn get_base() -> Self;
fn random<R: CryptoRng + RngCore>(rng: &mut R) -> Self;
fn to_compressed_bytes(&self) -> Vec<u8>;
fn from_compressed_bytes(bytes: &[u8]) -> Result<Self>;
fn to_unchecked_bytes(&self) -> Vec<u8>;
fn from_unchecked_bytes(bytes: &[u8]) -> Result<Self>;
fn unchecked_size() -> usize;
fn from_hash<D>(hash: D) -> Self
where
D: Digest<OutputSize = U64> + Default;
#[inline]
fn multi_exp(scalars: &[&Self::ScalarType], points: &[&Self]) -> Self {
if scalars.is_empty() {
Self::get_identity()
} else {
pippenger(scalars, points).unwrap()
}
}
}
pub trait PedersenCommitment<G: Group>: Default {
fn generator(&self) -> G;
fn blinding_generator(&self) -> G;
fn commit(&self, value: G::ScalarType, blinding: G::ScalarType) -> G;
}
pub trait Pairing {
type ScalarField: Scalar;
type G1: Group<ScalarType = Self::ScalarField>;
type G2: Group<ScalarType = Self::ScalarField>;
type Gt: Group<ScalarType = Self::ScalarField>;
fn pairing(a: &Self::G1, b: &Self::G2) -> Self::Gt;
fn product_of_pairings(a: &[Self::G1], b: &[Self::G2]) -> Self::Gt;
}
pub fn scalar_to_radix_2_power_w<S: Scalar>(scalar: &S, w: usize) -> Vec<i8> {
assert!(w <= 7);
if *scalar == S::from(0u32) {
return vec![0i8];
}
let scalar64 = scalar.get_little_endian_u64();
let radix: u64 = 1 << (w as u64);
let window_mask: u64 = radix - 1;
let mut carry = 0u64;
let mut digits = vec![];
let mut i = 0;
loop {
let bit_offset = i * w;
let u64_idx = bit_offset / 64;
let bit_idx = bit_offset % 64;
if u64_idx >= scalar64.len() {
digits.push(carry as i8);
break;
}
let is_last = u64_idx == scalar64.len() - 1;
let bit_buf = if bit_idx < 64 - w || is_last {
scalar64[u64_idx] >> (bit_idx as u64)
} else {
(scalar64[u64_idx] >> bit_idx) | (scalar64[1 + u64_idx] << (64 - bit_idx))
};
let coef = carry + (bit_buf & window_mask);
carry = (coef + (radix / 2)) >> w;
digits.push(((coef as i64) - (carry << w) as i64) as i8);
i += 1;
}
while digits.len() > 1 && *digits.last().unwrap() == 0i8 {
digits.pop();
}
digits
}
pub fn pippenger<G: Group>(scalars: &[&G::ScalarType], elems: &[&G]) -> Result<G> {
let size = scalars.len();
if size == 0 {
return Err(eg!(AlgebraError::ParameterError));
}
let w = if size < 500 {
6
} else if size < 800 {
7
} else {
8
};
let two_power_w: usize = 1 << w;
let digits_vec: Vec<Vec<i8>> = scalars
.iter()
.map(|s| scalar_to_radix_2_power_w::<G::ScalarType>(s, w))
.collect();
let mut digits_count = 0;
for digits in digits_vec.iter() {
if digits.len() > digits_count {
digits_count = digits.len();
}
}
let mut buckets: Vec<_> = (0..two_power_w / 2).map(|_| G::get_identity()).collect();
let mut cols = (0..digits_count).rev().map(|index| {
for b in buckets.iter_mut() {
*b = G::get_identity();
}
for (digits, elem) in digits_vec.iter().zip(elems) {
if index >= digits.len() {
continue;
}
let digit = digits[index];
if digit > 0 {
let b_index = (digit - 1) as usize;
buckets[b_index].add_assign(elem);
}
if digit < 0 {
let b_index = (-(digit + 1)) as usize;
buckets[b_index].sub_assign(elem);
}
}
let mut intermediate_sum = buckets[buckets.len() - 1].clone();
let mut sum = buckets[buckets.len() - 1].clone();
for i in (0..buckets.len() - 1).rev() {
intermediate_sum = intermediate_sum.add(&buckets[i]);
sum = sum.add(&intermediate_sum);
}
sum
});
let two_power_w_int = G::ScalarType::from(two_power_w as u64);
let hi_col = cols.next().unwrap();
let res = cols.fold(hi_col, |total, p| total.mul(&two_power_w_int).add(&p));
Ok(res)
}
#[cfg(test)]
pub(crate) mod group_tests {
use crate::traits::{scalar_to_radix_2_power_w, Scalar};
pub(crate) fn test_scalar_operations<S: Scalar>() {
let a = S::from(40u32);
let b = S::from(60u32);
let c = a.add(&b);
let d = S::from(100u32);
assert_eq!(c, d);
let mut x = S::from(0u32);
x.add_assign(&a);
x.add_assign(&b);
assert_eq!(x, d);
let a = S::from(10u32);
let b = S::from(40u32);
let c = a.mul(&b);
let d = S::from(400u32);
assert_eq!(c, d);
let mut x = S::from(1u32);
x.mul_assign(&a);
x.mul_assign(&b);
assert_eq!(x, d);
let a = S::from(0xFFFFFFFFu32);
let b = S::from(1u32);
let c = a.add(&b);
let d = S::from(0x100000000u64);
assert_eq!(c, d);
let a = S::from(0xFFFFFFFFu32);
let b = S::from(1u32);
let c = a.mul(&b);
let d = S::from(0xFFFFFFFFu32);
assert_eq!(c, d);
let a = S::from(40u32);
let b = S::from(60u32);
let c = b.sub(&a);
let d = S::from(20u32);
assert_eq!(c, d);
let mut x = S::from(120u32);
x.sub_assign(&b);
x.sub_assign(&a);
assert_eq!(x, d);
let a = S::from(40u32);
let b = a.neg();
let c = b.add(&a);
let d = S::from(0u32);
assert_eq!(c, d);
let a = S::from(40u32);
let b = a.inv().unwrap();
let c = b.mul(&a);
let d = S::from(1u32);
assert_eq!(c, d);
let a = S::from(3u32);
let b = vec![20];
let c = a.pow(&b[..]);
let d = S::from(3486784401u64);
assert_eq!(c, d);
let v = S::get_field_size_biguint().to_bytes_le();
assert_eq!(v, S::get_field_size_le_bytes());
}
pub(crate) fn test_scalar_serialization<S: Scalar>() {
let a = S::from(100u32);
let bytes = a.to_bytes();
let b = S::from_bytes(bytes.as_slice()).unwrap();
assert_eq!(a, b);
}
pub(crate) fn test_to_radix<S: Scalar>() {
let int = S::from(41u32);
let w = 2;
let r = scalar_to_radix_2_power_w(&int, w);
let expected = [1i8, -2, -1, 1]; assert_eq!(r.as_slice(), expected.as_ref());
let int = S::from(0u32);
let w = 2;
let r = scalar_to_radix_2_power_w(&int, w);
let expected = [0i8];
assert_eq!(expected.as_ref(), r.as_slice());
let int = S::from(1000u32);
let w = 6;
let r = scalar_to_radix_2_power_w(&int, w);
let expected = [-24, 16];
assert_eq!(expected.as_ref(), r.as_slice());
}
}
#[cfg(test)]
mod multi_exp_tests {
use crate::bls12_381::BLSGt;
use crate::bls12_381::BLSG1;
use crate::bls12_381::BLSG2;
use crate::ristretto::RistrettoPoint;
use crate::traits::Group;
#[test]
fn test_multiexp_ristretto() {
run_multiexp_test::<RistrettoPoint>();
}
#[test]
fn test_multiexp_blsg1() {
run_multiexp_test::<BLSG1>();
}
#[test]
fn test_multiexp_blsg2() {
run_multiexp_test::<BLSG2>();
}
#[test]
fn test_multiexp_blsgt() {
run_multiexp_test::<BLSGt>();
}
fn run_multiexp_test<G: Group>() {
let g = G::multi_exp(&[], &[]);
assert_eq!(g, G::get_identity());
let g1 = G::get_base();
let zero = G::ScalarType::from(0u32);
let g = G::multi_exp(&[&zero], &[&g1]);
assert_eq!(g, G::get_identity());
let g1 = G::get_base();
let one = G::ScalarType::from(1u32);
let g = G::multi_exp(&[&one], &[&g1]);
assert_eq!(g, G::get_base());
let g1 = G::get_base();
let g1p = G::get_base();
let one = G::ScalarType::from(1u32);
let zero = G::ScalarType::from(0u32);
let g = G::multi_exp(&[&one, &zero], &[&g1, &g1p]);
assert_eq!(g, G::get_base());
let g1 = G::get_base();
let g2 = g1.add(&g1);
let g3 = g1.mul(&G::ScalarType::from(500u32));
let thousand = G::ScalarType::from(1000u32);
let two = G::ScalarType::from(2u32);
let three = G::ScalarType::from(3u32);
let g = G::multi_exp(&[&thousand, &two, &three], &[&g1, &g2, &g3]);
let expected = G::get_base().mul(&G::ScalarType::from((1000 + 4 + 1500) as u32));
assert_eq!(g, expected);
}
}