use crate::{
core::{
expressions::expr::EvalFailure,
global_value::{
curve_value::CurveValue,
global_expr_store::with_local_expr_store_as_global,
value::FieldValue,
},
ir_builder::IRBuilder,
},
utils::{curve_point::CurvePoint, field::ScalarField},
};
use std::fmt::Debug;
pub trait CurveCircuit: Debug {
#[allow(dead_code)]
fn eval(
&self,
curve_points: Vec<CurvePoint>,
scalars: Vec<ScalarField>,
) -> Result<Vec<CurvePoint>, EvalFailure>;
#[allow(dead_code)]
fn run(
&self,
curve_vals: Vec<CurveValue>,
scalar_vals: Vec<FieldValue<ScalarField>>,
) -> Vec<CurveValue>;
#[allow(dead_code)]
fn run_usize(
&self,
curve_vals: &[usize],
scalar_vals: &[usize],
expr_store: &mut IRBuilder,
) -> Vec<usize> {
let res = with_local_expr_store_as_global(
|| {
self.run(
curve_vals.iter().map(|id| CurveValue::new(*id)).collect(),
scalar_vals
.iter()
.map(|id| FieldValue::from_id(*id))
.collect(),
)
},
expr_store,
);
res.iter().map(CurveValue::get_id).collect::<Vec<usize>>()
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::{
core::{
bounds::FieldBounds,
expressions::{
curve_expr::{self, CurveExpr},
domain::Domain,
expr::EvalValue,
field_expr::FieldExpr,
InputKind,
},
ir_builder::{ExprStore, IRBuilder},
},
utils::used_field::UsedField,
};
use ff::PrimeField;
use rand::Rng;
use rustc_hash::FxHashMap;
use std::rc::Rc;
pub trait TestedCurveCircuit: CurveCircuit + Clone + 'static {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;
fn gen_n_points<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;
fn gen_n_scalars<R: Rng + ?Sized>(&self, rng: &mut R, n_points: usize) -> usize;
#[allow(unused_variables)]
fn extra_checks(
&self,
curve_inputs: Vec<CurvePoint>,
scalar_inputs: Vec<ScalarField>,
curve_outputs: Vec<CurvePoint>,
) {
}
fn gen_input_bounds<R: Rng + ?Sized>(rng: &mut R) -> FieldBounds<ScalarField> {
if rng.gen_bool(0.125) {
return FieldBounds::All;
}
let signed = rng.gen_bool(0.5);
let size = (rng.next_u32() % ScalarField::NUM_BITS) as usize;
let two_power_size = ScalarField::power_of_two(size);
let bounds_bounds = if signed {
FieldBounds::new(-two_power_size, two_power_size)
} else {
FieldBounds::new(ScalarField::from(0), two_power_size)
};
FieldBounds::gen_bounds(rng, bounds_bounds)
}
fn test(n_desc: usize, n_runs_per_desc: usize) {
fn gen_input_points_and_expr<R: Rng + ?Sized>(
rng: &mut R,
n_inputs: usize,
start_input_id: usize,
inputs: &mut FxHashMap<usize, EvalValue>,
expr_store: &mut impl ExprStore<ScalarField>,
) -> (Vec<usize>, Vec<CurvePoint>) {
let input_points = (0..n_inputs)
.map(|_| R::gen(rng))
.collect::<Vec<CurvePoint>>();
let input_ids = input_points
.iter()
.enumerate()
.map(|(i, point)| {
if rng.gen_bool(0.125) {
expr_store.push_curve(CurveExpr::Val(*point))
} else {
expr_store.push_curve(CurveExpr::Input(
start_input_id + i,
Rc::new(curve_expr::InputInfo::from(InputKind::Secret)),
))
}
})
.collect::<Vec<usize>>();
input_points.iter().enumerate().for_each(|(i, point)| {
inputs.insert(start_input_id + i, EvalValue::Curve(*point));
});
(input_ids, input_points)
}
fn gen_input_scalars_and_expr<R: Rng + ?Sized>(
rng: &mut R,
n_inputs: usize,
start_input_id: usize,
inputs: &mut FxHashMap<usize, EvalValue>,
expr_store: &mut impl ExprStore<ScalarField>,
mut gen_bounds: impl FnMut(&mut R) -> FieldBounds<ScalarField>,
) -> (Vec<usize>, Vec<ScalarField>) {
let mut bounds = (0..n_inputs)
.map(|_| gen_bounds(rng))
.collect::<Vec<FieldBounds<ScalarField>>>();
let input_scalars = bounds
.iter()
.map(|bound| bound.sample(rng))
.collect::<Vec<ScalarField>>();
let input_ids = bounds
.iter_mut()
.enumerate()
.map(|(i, bound)| {
if rng.gen_bool(0.125) {
let val = input_scalars[i];
*bound = FieldBounds::new(val, val);
expr_store.push_field(FieldExpr::Val(val))
} else {
expr_store.push_field(FieldExpr::Input(
start_input_id + i,
bound.as_input_info(InputKind::Secret),
))
}
})
.collect::<Vec<usize>>();
input_scalars.iter().enumerate().for_each(|(i, val)| {
inputs.insert(start_input_id + i, EvalValue::Scalar(*val));
});
(input_ids, input_scalars)
}
let rng = &mut crate::utils::test_rng::get();
for _ in 0..n_desc {
let desc = Self::gen_desc(rng);
for _ in 0..n_runs_per_desc {
let mut expr_store = IRBuilder::new(false);
let mut input_values: FxHashMap<usize, _> = FxHashMap::default();
let n_points = desc.gen_n_points(rng);
let n_scalars = desc.gen_n_scalars(rng, n_points);
let (input_point_ids, input_points) = gen_input_points_and_expr(
rng,
n_points,
0,
&mut input_values,
&mut expr_store,
);
let (input_scalar_ids, input_scalars) = gen_input_scalars_and_expr(
rng,
n_scalars,
input_points.len(),
&mut input_values,
&mut expr_store,
Self::gen_input_bounds,
);
let ctrl_eval_result = desc.eval(input_points.clone(), input_scalars.clone());
let Ok(ctrl_eval_result) = ctrl_eval_result else {
continue;
};
let output_ids =
desc.run_usize(&input_point_ids, &input_scalar_ids, &mut expr_store);
let n_outputs = output_ids.len();
let run_result = expr_store.into_ir(output_ids).eval(rng, &mut input_values);
let Ok(run_result) = run_result else {
panic!("run failed: {:?}", run_result);
};
let res = (0..n_outputs)
.map(|i| CurvePoint::unwrap(run_result[i]))
.collect::<Vec<CurvePoint>>();
assert_eq!(ctrl_eval_result, res);
desc.extra_checks(input_points, input_scalars, res);
}
}
}
}
}