use crate::{
core::{
bounds::FieldBounds,
expressions::{
bit_expr::{BitInputInfo, GenBitExpr, RandomBitId},
circuit::{ArithmeticCircuitId, BaseCircuitId, GeneralCircuitId},
conversion_expr::{EdaBitId, GenConversionExpr},
curve_expr::{self, GenCurveExpr},
expr::Expr,
field_expr::{InputId, InputInfo, *},
other_expr::GenOtherExpr,
InputKind,
},
mxe_input::{MxeFieldInput, MxeInput},
},
utils::{
curve_point::CurvePoint,
field::{BaseField, ScalarField},
number::Number,
used_field::UsedField,
},
};
use rand::Rng;
use std::{marker::PhantomData, rc::Rc};
pub trait ExprGenHelper {
type ScalarType;
type BitType;
type BaseType;
type CurveType;
fn scalar<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::ScalarType;
fn bit<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BitType;
fn base<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BaseType;
fn scalar_cond<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::ScalarType;
fn scalar_pos<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::ScalarType;
fn scalar_eda<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::ScalarType;
fn scalar_int<R: Rng + ?Sized>(&self, rng: &mut R) -> ScalarField;
fn base_field_cond<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BaseType;
fn base_field_pos<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BaseType;
fn base_field_eda<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BaseType;
fn base_field_int<R: Rng + ?Sized>(&self, rng: &mut R) -> BaseField;
fn curve_point<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::CurveType;
fn curve_val<R: Rng + ?Sized>(&self, rng: &mut R) -> CurvePoint;
}
fn gen_usize<R: Rng + ?Sized>(rng: &mut R) -> usize {
let mut result = 0;
while rng.gen_bool(0.875) {
result += 1;
}
result
}
impl<T: Clone, Gen, R: Rng + ?Sized> GenFieldExpr<ScalarField, T> for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<ScalarType = T>,
{
fn input_id(&mut self) -> InputId {
0
}
fn rc_input_info_f(&mut self) -> Rc<InputInfo<ScalarField>> {
let (r#gen, rng) = self;
let (min, max) = r#gen.scalar_int(*rng).sort_pair(r#gen.scalar_int(*rng));
InputInfo::generate(
*rng,
&(min.to_unsigned_number()),
&(max.to_unsigned_number()),
)
}
fn t(&mut self) -> T {
self.0.scalar(self.1)
}
fn vec_t_f(&mut self) -> Vec<(T, ScalarField)> {
let len = self.usize();
(0..len).map(|_| (self.t(), self.f())).collect()
}
fn f(&mut self) -> ScalarField {
self.0.scalar_int(self.1)
}
fn p(&mut self) -> T {
self.0.scalar_pos(self.1)
}
fn c(&mut self) -> T {
self.0.scalar_cond(self.1)
}
fn usize(&mut self) -> usize {
gen_usize(self.1)
}
fn field_bounds_f(&mut self) -> FieldBounds<ScalarField> {
FieldBounds::new(self.f(), self.f())
}
fn vec_t(&mut self) -> Vec<T> {
let len = self.usize();
(0..len).map(|_| self.t()).collect()
}
fn arithmetic_circuit_id(&mut self) -> ArithmeticCircuitId {
ArithmeticCircuitId::Min
}
fn random_val_id(&mut self) -> RandomValId {
RandomValId::new()
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
fn number(&mut self) -> Number {
Number::from(self.f().to_le_bytes())
}
fn bool(&mut self) -> bool {
self.1.r#gen()
}
}
impl<T: Clone, Gen, R: Rng + ?Sized> GenFieldExpr<BaseField, T> for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<BaseType = T>,
{
fn input_id(&mut self) -> InputId {
0
}
fn rc_input_info_f(&mut self) -> Rc<InputInfo<BaseField>> {
let (r#gen, rng) = self;
let (min, max) = r#gen
.base_field_int(*rng)
.sort_pair(r#gen.base_field_int(*rng));
InputInfo::generate(
*rng,
&(min.to_unsigned_number()),
&(max.to_unsigned_number()),
)
}
fn t(&mut self) -> T {
self.0.base(self.1)
}
fn vec_t_f(&mut self) -> Vec<(T, BaseField)> {
let len = self.usize();
(0..len).map(|_| (self.t(), self.f())).collect()
}
fn f(&mut self) -> BaseField {
self.0.base_field_int(self.1)
}
fn p(&mut self) -> T {
self.0.base_field_pos(self.1)
}
fn c(&mut self) -> T {
self.0.base_field_cond(self.1)
}
fn usize(&mut self) -> usize {
gen_usize(self.1)
}
fn field_bounds_f(&mut self) -> FieldBounds<BaseField> {
FieldBounds::new(self.f(), self.f())
}
fn vec_t(&mut self) -> Vec<T> {
let len = self.usize();
(0..len).map(|_| self.t()).collect()
}
fn arithmetic_circuit_id(&mut self) -> ArithmeticCircuitId {
ArithmeticCircuitId::Min
}
fn random_val_id(&mut self) -> RandomValId {
RandomValId::new()
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
fn number(&mut self) -> Number {
Number::from(self.f().to_le_bytes())
}
fn bool(&mut self) -> bool {
self.1.r#gen()
}
}
impl<B: Clone, Gen, R: Rng + ?Sized> GenBitExpr<B> for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<BitType = B>,
{
fn b(&mut self) -> B {
self.0.bit(self.1)
}
fn bool_4(&mut self) -> [bool; 4] {
[self.bool(), self.bool(), self.bool(), self.bool()]
}
fn bool(&mut self) -> bool {
self.1.gen_bool(0.5)
}
fn random_bit_id(&mut self) -> RandomBitId {
RandomBitId::new()
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
fn vec_b(&mut self) -> Vec<B> {
(0..1600).map(|_| self.0.bit(self.1)).collect::<Vec<B>>()
}
fn usize(&mut self) -> usize {
gen_usize(self.1)
}
fn input_id(&mut self) -> InputId {
0
}
fn rc_bit_input_info(&mut self) -> Rc<BitInputInfo> {
BitInputInfo::default().into()
}
}
impl<T: Clone, B: Clone, Gen, R: Rng + ?Sized> GenConversionExpr<ScalarField, T, B>
for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<ScalarType = T, BitType = B>,
{
fn t(&mut self) -> T {
self.0.scalar(self.1)
}
fn usize(&mut self) -> usize {
gen_usize(self.1)
}
fn bool(&mut self) -> bool {
self.1.gen_bool(0.5)
}
fn vec_b(&mut self) -> Vec<B> {
let len = gen_usize(self.1);
(0..len).map(|_| self.0.bit(self.1)).collect()
}
fn eda_bit_id(&mut self) -> EdaBitId {
EdaBitId::new()
}
fn phantom_data_f(&mut self) -> PhantomData<ScalarField> {
PhantomData
}
fn e(&mut self) -> T {
self.0.scalar_eda(self.1)
}
fn b(&mut self) -> B {
self.0.bit(self.1)
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
}
impl<T: Clone, B: Clone, Gen, R: Rng + ?Sized> GenConversionExpr<BaseField, T, B>
for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<BaseType = T, BitType = B>,
{
fn t(&mut self) -> T {
self.0.base(self.1)
}
fn usize(&mut self) -> usize {
gen_usize(self.1)
}
fn bool(&mut self) -> bool {
self.1.gen_bool(0.5)
}
fn vec_b(&mut self) -> Vec<B> {
let len = gen_usize(self.1);
(0..len).map(|_| self.0.bit(self.1)).collect()
}
fn eda_bit_id(&mut self) -> EdaBitId {
EdaBitId::new()
}
fn phantom_data_f(&mut self) -> PhantomData<BaseField> {
PhantomData
}
fn e(&mut self) -> T {
self.0.base_field_eda(self.1)
}
fn b(&mut self) -> B {
self.0.bit(self.1)
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
}
impl<C: Clone, S: Clone, Gen, R: Rng + ?Sized> GenCurveExpr<C, S> for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<CurveType = C, ScalarType = S>,
{
fn input_id(&mut self) -> InputId {
0
}
fn rc_input_info(&mut self) -> Rc<curve_expr::InputInfo> {
Rc::new(curve_expr::InputInfo {
kind: InputKind::Secret,
..curve_expr::InputInfo::default()
})
}
fn c(&mut self) -> C {
self.0.curve_point(self.1)
}
fn s(&mut self) -> S {
self.0.scalar(self.1)
}
fn curve_point(&mut self) -> CurvePoint {
self.0.curve_val(self.1)
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
}
impl<T: Clone, B: Clone, C: Clone, Gen, R: Rng + ?Sized> GenOtherExpr<T, B, C>
for (&mut Gen, &mut R)
where
Gen: ExprGenHelper<BaseType = B, CurveType = C, ScalarType = T>,
{
fn vec_t(&mut self) -> Vec<T> {
let len = gen_usize(self.1);
(0..len).map(|_| self.0.scalar(self.1)).collect()
}
fn vec_b(&mut self) -> Vec<B> {
let len = gen_usize(self.1);
(0..len).map(|_| self.0.base(self.1)).collect()
}
fn usize(&mut self) -> usize {
gen_usize(self.1)
}
fn t(&mut self) -> T {
self.0.scalar(self.1)
}
fn b(&mut self) -> B {
self.0.base(self.1)
}
fn c(&mut self) -> C {
self.0.curve_point(self.1)
}
fn base_circuit_id(&mut self) -> BaseCircuitId {
BaseCircuitId::Arith(ArithmeticCircuitId::Min)
}
fn general_circuit_id(&mut self) -> GeneralCircuitId {
GeneralCircuitId::Conversion
}
fn mxe_input(&mut self) -> MxeInput {
MxeInput::Base(MxeFieldInput::Ed25519VerifyingKey(0))
}
fn choose_variant(&mut self, n_variants: usize) -> usize {
let val: usize = self.1.r#gen();
val % n_variants
}
}
pub trait ExprGenerator {
type ScalarType: Clone;
type BitType: Clone;
type BaseType: Clone;
type CurveType: Clone;
fn expr<R: Rng + ?Sized>(
&mut self,
rng: &mut R,
) -> Expr<Self::ScalarType, Self::BitType, Self::BaseType, Self::CurveType>;
}
impl<Scalar: Clone, Bit: Clone, Base: Clone, Curve: Clone, Gen> ExprGenerator for Gen
where
Gen: ExprGenHelper<ScalarType = Scalar, BitType = Bit, BaseType = Base, CurveType = Curve>,
{
type ScalarType = Scalar;
type BitType = Bit;
type BaseType = Base;
type CurveType = Curve;
fn expr<R: Rng + ?Sized>(
&mut self,
rng: &mut R,
) -> Expr<Self::ScalarType, Self::BitType, Self::BaseType, Self::CurveType> {
let variant = {
let val: usize = rng.r#gen();
val % 7
};
match variant {
0 => Expr::Scalar(GenFieldExpr::r#gen(&mut (self, rng))),
1 => Expr::Bit(GenBitExpr::r#gen(&mut (self, rng))),
2 => Expr::ScalarConversion(GenConversionExpr::r#gen(&mut (self, rng))),
3 => Expr::Base(GenFieldExpr::r#gen(&mut (self, rng))),
4 => Expr::BaseConversion(GenConversionExpr::r#gen(&mut (self, rng))),
5 => Expr::Curve(GenCurveExpr::r#gen(&mut (self, rng))),
6 => Expr::Other(GenOtherExpr::r#gen(&mut (self, rng))),
_ => unreachable!("Modulo cannot reach >=6."),
}
}
}