use crate::{
core::{
bounds::FieldBounds,
expressions::expr::EvalFailure,
global_value::{global_expr_store::with_local_expr_store_as_global, value::FieldValue},
ir_builder::IRBuilder,
},
utils::field::{BaseField, ScalarField},
};
use std::fmt::Debug;
#[allow(dead_code)]
pub trait GeneralCircuit: Debug {
fn eval(
&self,
scalars: Vec<ScalarField>,
bases: Vec<BaseField>,
) -> Result<(Vec<ScalarField>, Vec<BaseField>), EvalFailure>;
fn bounds(
&self,
scalar_bounds: Vec<FieldBounds<ScalarField>>,
base_bounds: Vec<FieldBounds<BaseField>>,
) -> (Vec<FieldBounds<ScalarField>>, Vec<FieldBounds<BaseField>>);
fn run(
&self,
scalar_vals: Vec<FieldValue<ScalarField>>,
base_vals: Vec<FieldValue<BaseField>>,
) -> (Vec<FieldValue<ScalarField>>, Vec<FieldValue<BaseField>>);
fn run_usize(
&self,
scalar_vals: &[usize],
base_vals: &[usize],
expr_store: &mut IRBuilder,
) -> (Vec<usize>, Vec<usize>) {
let (scalar_res, base_res) = with_local_expr_store_as_global(
|| {
self.run(
scalar_vals
.iter()
.map(|id| FieldValue::from_id(*id))
.collect(),
base_vals
.iter()
.map(|id| FieldValue::from_id(*id))
.collect(),
)
},
expr_store,
);
(
scalar_res.iter().map(FieldValue::get_id).collect(),
base_res.iter().map(FieldValue::get_id).collect(),
)
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::IsBounds,
expressions::{domain::Domain, expr::EvalValue, field_expr::FieldExpr, InputKind},
ir_builder::{ExprStore, IRBuilder},
},
utils::used_field::UsedField,
};
use rand::Rng;
use rustc_hash::FxHashMap;
pub trait TestedGeneralCircuit: GeneralCircuit + Clone + 'static {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;
fn gen_n_scalars<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;
fn gen_n_bases<R: Rng + ?Sized>(&self, rng: &mut R, n_scalars: usize) -> usize;
#[allow(unused_variables)]
fn extra_checks(
&self,
scalar_inputs: Vec<ScalarField>,
base_inputs: Vec<BaseField>,
scalar_outputs: Vec<ScalarField>,
base_outputs: Vec<BaseField>,
) {
}
fn gen_input_bounds<F: UsedField, R: Rng + ?Sized>(rng: &mut R) -> FieldBounds<F> {
if rng.gen_bool(0.125) {
return FieldBounds::All;
}
let signed = rng.gen_bool(0.5);
let size = (rng.next_u32() % F::NUM_BITS) as usize;
let two_power_size = F::power_of_two(size);
let bounds_bounds = if signed {
FieldBounds::new(-two_power_size, two_power_size)
} else {
FieldBounds::new(F::ZERO, two_power_size)
};
FieldBounds::gen_bounds(rng, bounds_bounds)
}
fn test(n_desc: usize, n_runs_per_desc: usize) {
fn gen_input_values_and_expr<F: ActuallyUsedField, R: Rng + ?Sized>(
rng: &mut R,
n_inputs: usize,
start_input_id: usize,
inputs: &mut FxHashMap<usize, EvalValue>,
expr_store: &mut impl ExprStore<F>,
mut gen_bounds: impl FnMut(&mut R) -> FieldBounds<F>,
) -> (Vec<usize>, Vec<F>, Vec<FieldBounds<F>>) {
let mut bounds: Vec<_> = (0..n_inputs).map(|_| gen_bounds(rng)).collect();
let input_vals: Vec<F> = bounds.iter().map(|bound| bound.sample(rng)).collect();
let input_ids: Vec<usize> = bounds
.iter_mut()
.enumerate()
.map(|(i, bound)| {
if rng.gen_bool(0.125) {
let val = input_vals[i];
*bound = FieldBounds::new(val, val);
expr_store.push_field(FieldExpr::Val(val))
} else {
expr_store.push_field(FieldExpr::Input(
start_input_id + i,
bound.as_input_info(InputKind::Secret),
))
}
})
.collect();
input_vals.iter().enumerate().for_each(|(i, val)| {
inputs.insert(start_input_id + i, F::field_to_eval_value(*val));
});
(input_ids, input_vals, bounds)
}
fn check_bounds<F: UsedField>(vals: &[F], bounds: &[FieldBounds<F>]) {
assert_eq!(vals.len(), bounds.len());
vals.iter().zip(bounds).for_each(|(val, bound)| {
assert!(bound.contains(*val));
})
}
let rng = &mut crate::utils::test_rng::get();
for _ in 0..n_desc {
let desc = Self::gen_desc(rng);
for _ in 0..n_runs_per_desc {
let mut expr_store = IRBuilder::new(false);
let mut input_values: FxHashMap<usize, EvalValue> = FxHashMap::default();
let n_scalars = desc.gen_n_scalars(rng);
let n_bases = desc.gen_n_bases(rng, n_scalars);
let (scalar_ids, scalar_vals, scalar_bounds) =
gen_input_values_and_expr::<ScalarField, _>(
rng,
n_scalars,
0,
&mut input_values,
&mut expr_store,
Self::gen_input_bounds,
);
let (base_ids, base_vals, base_bounds) =
gen_input_values_and_expr::<BaseField, _>(
rng,
n_bases,
scalar_vals.len(),
&mut input_values,
&mut expr_store,
Self::gen_input_bounds,
);
let ctrl_eval_result = desc.eval(scalar_vals.clone(), base_vals.clone());
let Ok(ctrl_eval_result) = ctrl_eval_result else {
continue;
};
let (scalar_res, base_res) =
desc.run_usize(&scalar_ids, &base_ids, &mut expr_store);
let scalar_base_border = scalar_res.len();
let output_ids = [scalar_res, base_res].concat();
let n_outputs = output_ids.len();
let run_result = expr_store.into_ir(output_ids).eval(rng, &mut input_values);
let Ok(run_result) = run_result else {
panic!("run failed: {:?}", run_result);
};
let scalar_res = (0..scalar_base_border)
.map(|i| ScalarField::unwrap(run_result[i]))
.collect::<Vec<_>>();
let base_res = (scalar_base_border..n_outputs)
.map(|i| BaseField::unwrap(run_result[i]))
.collect::<Vec<_>>();
let (scalar_bounds_res, base_bounds_res) =
desc.bounds(scalar_bounds, base_bounds);
assert_eq!(ctrl_eval_result.0, scalar_res);
assert_eq!(ctrl_eval_result.1, base_res);
check_bounds(&scalar_res, &scalar_bounds_res);
check_bounds(&base_res, &base_bounds_res);
desc.extra_checks(scalar_vals, base_vals, scalar_res, base_res);
}
}
}
}
}