use commonware_parallel::Strategy as ParStrategy;
use core::{
fmt::Debug,
iter,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
use rand_core::CryptoRngCore;
fn yield_bits_le(x: u64) -> impl Iterator<Item = bool> {
(0..64).map(move |i| (x >> i) & 1 != 0)
}
fn yield_bits_le_until_zeroes(x: u64) -> impl Iterator<Item = bool> {
(0..64 - x.leading_zeros()).map(move |i| (x >> i) & 1 != 0)
}
fn yield_bits_le_arr(xs: &[u64]) -> impl Iterator<Item = bool> + use<'_> {
let (&last, start) = xs.split_last().unwrap_or((&0, &[]));
start
.iter()
.copied()
.flat_map(yield_bits_le)
.chain(yield_bits_le_until_zeroes(last))
}
fn monoid_exp<T: Clone>(
zero: T,
op: impl Fn(&mut T, &T),
self_op: impl Fn(&mut T),
x: &T,
bits_le: &[u64],
) -> T {
let mut acc = zero;
let mut w = x.clone();
for b in yield_bits_le_arr(bits_le) {
if b {
op(&mut acc, &w);
}
self_op(&mut w)
}
acc
}
pub fn powers<R: Ring>(shift: R, base: &R) -> impl Iterator<Item = R> + '_ {
iter::successors(Some(shift), move |state: &R| Some(state.clone() * base))
}
pub trait Object: Clone + Debug + PartialEq + Eq + Send + Sync {}
pub trait Additive:
Object
+ for<'a> AddAssign<&'a Self>
+ for<'a> Add<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ Neg<Output = Self>
{
fn zero() -> Self;
fn double(&mut self) {
*self += &self.clone();
}
fn scale(&self, bits_le: &[u64]) -> Self {
monoid_exp(Self::zero(), |a, b| *a += b, |a| a.double(), self, bits_le)
}
}
pub trait Multiplicative:
Object + for<'a> MulAssign<&'a Self> + for<'a> Mul<&'a Self, Output = Self>
{
fn square(&mut self) {
*self *= &self.clone();
}
}
pub trait Space<R>:
Additive + for<'a> MulAssign<&'a R> + for<'a> Mul<&'a R, Output = Self>
{
fn msm(points: &[Self], scalars: &[R], _strategy: &impl ParStrategy) -> Self {
msm_naive(points, scalars)
}
}
pub fn msm_naive<R, K: Space<R>>(points: &[K], scalars: &[R]) -> K {
let mut out = K::zero();
for (s, p) in scalars.iter().zip(points.iter()) {
out += &(p.clone() * s);
}
out
}
impl<R: Additive + Multiplicative> Space<R> for R {}
pub trait Ring: Additive + Multiplicative {
fn one() -> Self;
fn exp(&self, bits_le: &[u64]) -> Self {
monoid_exp(Self::one(), |a, b| *a *= b, |a| a.square(), self, bits_le)
}
}
pub trait Field: Ring {
fn inv(&self) -> Self;
}
pub trait FieldNTT: Field {
const MAX_LG_ROOT_ORDER: u8;
fn root_of_unity(lg: u8) -> Option<Self>;
fn coset_shift() -> Self;
fn coset_shift_inv() -> Self {
Self::coset_shift().inv()
}
fn div_2(&self) -> Self {
(Self::one() + &Self::one()).inv() * self
}
}
pub trait CryptoGroup: Space<Self::Scalar> {
type Scalar: Field;
fn generator() -> Self;
}
pub trait HashToGroup: CryptoGroup {
fn hash_to_group(domain_separator: &[u8], message: &[u8]) -> Self;
fn rand_to_group(mut rng: impl CryptoRngCore) -> Self {
let mut bytes = [0u8; 32];
rng.fill_bytes(&mut bytes);
Self::hash_to_group(&[], &bytes)
}
}
pub trait Random {
fn random(rng: impl CryptoRngCore) -> Self;
}
#[cfg(any(test, feature = "arbitrary"))]
pub mod test_suites {
use super::*;
use arbitrary::{Arbitrary, Unstructured};
fn check_add_assign<T: Additive>(a: T, b: T) {
let mut acc = a.clone();
acc += &b;
assert_eq!(acc, a + &b, "+= does not match +");
}
fn check_add_commutes<T: Additive>(a: T, b: T) {
assert_eq!(a.clone() + &b, b + &a, "+ not commutative");
}
fn check_add_associates<T: Additive>(a: T, b: T, c: T) {
assert_eq!((a.clone() + &b) + &c, a + &(b + &c), "+ not associative");
}
fn check_add_zero<T: Additive>(a: T) {
assert_eq!(T::zero() + &a, a, "a + 0 != a");
}
fn check_add_neg_self<T: Additive>(a: T) {
let neg_a = -a.clone();
assert_eq!(T::zero(), a + &neg_a, "a - a != 0");
}
fn check_sub_vs_add_neg<T: Additive>(a: T, b: T) {
assert_eq!(a.clone() - &b, a + &-b, "a - b != a + (-b)");
}
fn check_sub_assign<T: Additive>(a: T, b: T) {
let mut acc = a.clone();
acc -= &b;
assert_eq!(acc, a - &b, "-= different from -");
}
pub fn fuzz_additive<T: Additive + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
let a: T = u.arbitrary()?;
let b: T = u.arbitrary()?;
let c: T = u.arbitrary()?;
check_add_assign(a.clone(), b.clone());
check_add_commutes(a.clone(), b.clone());
check_add_associates(a.clone(), b.clone(), c);
check_add_zero(a.clone());
check_add_neg_self(a.clone());
check_sub_vs_add_neg(a.clone(), b.clone());
check_sub_assign(a, b);
Ok(())
}
fn check_mul_assign<T: Multiplicative>(a: T, b: T) {
let mut acc = a.clone();
acc *= &b;
assert_eq!(acc, a * &b, "*= different from *");
}
fn check_mul_commutes<T: Multiplicative>(a: T, b: T) {
assert_eq!(a.clone() * &b, b * &a, "* not commutative");
}
fn check_mul_associative<T: Multiplicative>(a: T, b: T, c: T) {
assert_eq!((a.clone() * &b) * &c, a * &(b * &c), "* not associative");
}
pub fn fuzz_multiplicative<T: Multiplicative + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
let a: T = u.arbitrary()?;
let b: T = u.arbitrary()?;
let c: T = u.arbitrary()?;
check_mul_assign(a.clone(), b.clone());
check_mul_commutes(a.clone(), b.clone());
check_mul_associative(a, b, c);
Ok(())
}
fn check_mul_one<T: Ring>(a: T) {
assert_eq!(T::one() * &a, a, "a * 1 != a");
}
fn check_mul_distributes<T: Ring>(a: T, b: T, c: T) {
assert_eq!(
(a.clone() + &b) * &c,
a * &c + &(b * &c),
"(a + b) * c != a * c + b * c"
);
}
pub fn fuzz_ring<T: Ring + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
fuzz_additive::<T>(u)?;
fuzz_multiplicative::<T>(u)?;
let a: T = u.arbitrary()?;
let b: T = u.arbitrary()?;
let c: T = u.arbitrary()?;
check_mul_one(a.clone());
check_mul_distributes(a, b, c);
Ok(())
}
fn check_inv<T: Field>(a: T) {
if a == T::zero() {
assert_eq!(T::zero(), a.inv(), "0.inv() != 0");
} else {
assert_eq!(a.inv() * &a, T::one(), "a * a.inv() != 1");
}
}
pub fn fuzz_field<T: Field + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
fuzz_ring::<T>(u)?;
let a: T = u.arbitrary()?;
check_inv(a);
Ok(())
}
fn check_scale_distributes<R, K: Space<R>>(a: K, b: K, x: R) {
assert_eq!((a.clone() + &b) * &x, a * &x + &(b * &x));
}
fn check_scale_assign<R, K: Space<R>>(a: K, b: R) {
let mut acc = a.clone();
acc *= &b;
assert_eq!(acc, a * &b);
}
fn check_msm_eq_naive<R, K: Space<R>>(points: &[K], scalars: &[R]) {
use commonware_parallel::Sequential;
assert_eq!(
msm_naive(points, scalars),
K::msm(points, scalars, &Sequential)
);
}
pub fn fuzz_space<R: Debug + for<'a> Arbitrary<'a>, K: Space<R> + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
let a: K = u.arbitrary()?;
let b: K = u.arbitrary()?;
let x: R = u.arbitrary()?;
check_scale_distributes(a.clone(), b, x);
let c: R = u.arbitrary()?;
check_scale_assign(a, c);
let len: usize = u.int_in_range(0..=16)?;
let points: Vec<K> = (0..len)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let scalars: Vec<R> = (0..len)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
check_msm_eq_naive(&points, &scalars);
Ok(())
}
fn check_scale_compat<R: Multiplicative, K: Space<R>>(a: K, b: R, c: R) {
assert_eq!((a.clone() * &b) * &c, a * &(b * &c));
}
pub fn fuzz_space_multiplicative<
R: Multiplicative + for<'a> Arbitrary<'a>,
K: Space<R> + for<'a> Arbitrary<'a>,
>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
fuzz_space::<R, K>(u)?;
let a: K = u.arbitrary()?;
let b: R = u.arbitrary()?;
let c: R = u.arbitrary()?;
check_scale_compat(a, b, c);
Ok(())
}
fn check_scale_one<R: Ring, K: Space<R>>(a: K) {
assert_eq!(a.clone(), a * &R::one());
}
fn check_scale_zero<R: Ring, K: Space<R>>(a: K) {
assert_eq!(K::zero(), a * &R::zero());
}
pub fn fuzz_space_ring<R: Ring + for<'a> Arbitrary<'a>, K: Space<R> + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
fuzz_space_multiplicative::<R, K>(u)?;
let a: K = u.arbitrary()?;
check_scale_one::<R, K>(a.clone());
check_scale_zero::<R, K>(a);
Ok(())
}
fn check_hash_to_group<G: HashToGroup>(data: [[u8; 4]; 4]) {
let (dst0, m0, dst1, m1) = (&data[0], &data[1], &data[2], &data[3]);
assert_eq!(
(dst0, m0) == (dst1, m1),
G::hash_to_group(dst0, m0) == G::hash_to_group(dst1, m1)
);
}
pub fn fuzz_hash_to_group<G: HashToGroup>(u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
let data: [[u8; 4]; 4] = u.arbitrary()?;
check_hash_to_group::<G>(data);
Ok(())
}
fn check_root_of_unity_order<T: FieldNTT>(lg: u8) {
if lg > T::MAX_LG_ROOT_ORDER {
assert!(
T::root_of_unity(lg).is_none(),
"root_of_unity should be None for lg > MAX"
);
return;
}
let root = T::root_of_unity(lg).expect("root_of_unity should be Some for lg <= MAX");
let mut order = Vec::new();
let mut remaining = lg;
while remaining >= 64 {
order.push(0u64);
remaining -= 64;
}
order.push(1u64 << remaining);
assert_eq!(root.exp(&order), T::one(), "root^(2^lg) should equal 1");
if lg > 0 {
let last = order.len() - 1;
order[0] = order[0].wrapping_sub(1);
for i in 0..last {
if order[i] == u64::MAX {
order[i + 1] = order[i + 1].wrapping_sub(1);
}
}
assert_ne!(
root.exp(&order),
T::one(),
"root^(2^lg - 1) should not equal 1"
);
}
}
fn check_div_2<T: FieldNTT>(a: T) {
let two = T::one() + &T::one();
assert_eq!(a.div_2() * &two, a, "div_2(a) * 2 should equal a");
}
fn check_coset_shift_inv<T: FieldNTT>() {
assert_eq!(
T::coset_shift() * &T::coset_shift_inv(),
T::one(),
"coset_shift * coset_shift_inv should equal 1"
);
}
pub fn fuzz_field_ntt<T: FieldNTT + for<'a> Arbitrary<'a>>(
u: &mut Unstructured<'_>,
) -> arbitrary::Result<()> {
fuzz_field::<T>(u)?;
match u.int_in_range(0u8..=9)? {
0 => {
check_coset_shift_inv::<T>();
}
1 => {
let lg = u.int_in_range(0..=T::MAX_LG_ROOT_ORDER + 1)?;
check_root_of_unity_order::<T>(lg);
}
_ => {
check_div_2(T::arbitrary(u)?);
}
}
Ok(())
}
}
commonware_macros::stability_scope!(ALPHA {
#[cfg(any(test, feature = "fuzz"))]
pub mod fuzz {
use super::*;
use crate::fields::goldilocks::F;
use arbitrary::{Arbitrary, Unstructured};
use commonware_parallel::Sequential;
#[derive(Debug, Arbitrary)]
pub enum Plan {
ExpOne(F),
ExpZero(F),
Exp(F, u32, u32),
PowersMatchesExp(F, F, u16),
ScaleOne(F),
ScaleZero(F),
Scale(F, u32, u32),
Msm2([F; 2], [F; 2]),
}
impl Plan {
pub fn run(self, _u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
match self {
Self::ExpOne(x) => {
assert_eq!(x.exp(&[1]), x);
}
Self::ExpZero(x) => {
assert_eq!(x.exp(&[]), F::one());
}
Self::Exp(x, a, b) => {
let a = u64::from(a);
let b = u64::from(b);
assert_eq!(x.exp(&[a + b]), x.exp(&[a]) * x.exp(&[b]));
}
Self::PowersMatchesExp(shift, base, index) => {
let pow_i = powers(shift, &base)
.take(usize::from(index) + 1)
.last()
.expect("len=index+1 guarantees at least one item");
assert_eq!(pow_i, shift * base.exp(&[u64::from(index)]));
}
Self::ScaleOne(x) => {
assert_eq!(x.scale(&[1]), x);
}
Self::ScaleZero(x) => {
assert_eq!(x.scale(&[]), F::zero());
}
Self::Scale(x, a, b) => {
let a = u64::from(a);
let b = u64::from(b);
assert_eq!(x.scale(&[a + b]), x.scale(&[a]) + x.scale(&[b]));
}
Self::Msm2(a, b) => {
assert_eq!(F::msm(&a, &b, &Sequential), a[0] * b[0] + a[1] * b[1]);
}
}
Ok(())
}
}
#[test]
fn test_fuzz() {
commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
}
}
});