use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::boolean::utils::{equal, CircuitType},
compile::Compiler,
compile_passes::{Optimizer, BUILD_FN_SEQ},
expressions::{
bit_expr::{BitExpr, BitInputInfo, GenBitExpr, RandomBitId},
conversion_expr::ConversionExpr,
curve_expr::CurveExpr,
expr::{
EvalValue,
Expr::{self, *},
UndefinedBehavior,
},
field_expr::{FieldExpr, InputInfo},
random_expr::{ExprGenHelper, ExprGenerator},
InputKind,
},
global_value::{global_expr_store::with_local_expr_store_as_global, value::FieldValue},
ir::IntermediateRepresentation,
ir_builder::IRBuilder,
},
traits::GetBit,
utils::{
curve_point::CurvePoint,
field::{BaseField, ScalarField},
number::Number,
used_field::UsedField,
},
};
use ff::{Field, PrimeField};
use num_bigint::{BigInt, BigUint, ToBigUint};
use num_traits::FromBytes;
use rand::Rng;
use rustc_hash::{FxHashMap, FxHashSet};
use std::rc::Rc;
pub struct RNGIds {
scalar_min: ScalarField,
scalar_max: ScalarField,
base_field_min: BaseField,
base_field_max: BaseField,
bool_ids: Vec<usize>,
scalar_cond_ids: Vec<usize>,
scalar_ids: Vec<usize>,
scalar_pos_ids: Vec<usize>,
scalar_eda_ids: Vec<usize>,
base_field_cond_ids: Vec<usize>,
base_field_ids: Vec<usize>,
base_field_pos_ids: Vec<usize>,
base_field_eda_ids: Vec<usize>,
curve_ids: Vec<usize>,
}
impl RNGIds {
pub fn new(min: &Number, max: &Number) -> RNGIds {
RNGIds {
scalar_min: min.clone().into(),
scalar_max: max.clone().into(),
bool_ids: Vec::new(),
scalar_cond_ids: Vec::new(),
scalar_ids: Vec::new(),
scalar_pos_ids: Vec::new(),
scalar_eda_ids: Vec::new(),
base_field_min: min.clone().into(),
base_field_max: max.clone().into(),
base_field_cond_ids: Vec::new(),
base_field_ids: Vec::new(),
base_field_pos_ids: Vec::new(),
base_field_eda_ids: Vec::new(),
curve_ids: Vec::new(),
}
}
pub fn add_expr(&mut self, ir_builder: &mut IRBuilder, expr: Expr<usize>) -> usize {
let n = ir_builder.len();
let is_bool = expr.is_boolean();
let is_scalar = !is_bool && matches!(expr, ScalarConversion(..) | Scalar(..));
let is_base = !is_bool && matches!(expr, BaseConversion(..) | Base(..));
let is_eda = matches!(expr, ScalarConversion(ConversionExpr::EdaBit(..)));
let is_curve = !is_bool && matches!(expr, Curve(..));
let expr_id = ir_builder.new_expr(expr);
let expr_bounds = ir_builder.get_bounds(expr_id);
let is_arith_bool = expr_bounds.is_arithmetic_boolean();
let is_arith_and_positive = !is_bool && !expr_bounds.contains_field_zero();
if expr_id == n {
if is_bool {
self.bool_ids.push(n)
} else if is_scalar {
if !is_eda {
self.scalar_ids.push(n);
}
if is_arith_bool {
self.scalar_cond_ids.push(n);
}
if is_arith_and_positive {
self.scalar_pos_ids.push(n);
}
if is_eda {
self.scalar_eda_ids.push(n);
}
} else if is_base {
if !is_eda {
self.base_field_ids.push(n);
}
if is_arith_bool {
self.base_field_cond_ids.push(n);
}
if is_arith_and_positive {
self.base_field_pos_ids.push(n);
}
if is_eda {
self.base_field_eda_ids.push(n);
}
} else if is_curve {
self.curve_ids.push(n);
}
}
expr_id
}
fn gen_bool(&self, r: usize) -> usize {
self.bool_ids[r % self.bool_ids.len()]
}
fn gen_scalar_cond(&self, r: usize) -> usize {
self.scalar_cond_ids[r % self.scalar_cond_ids.len()]
}
fn gen_scalar(&self, r: usize) -> usize {
self.scalar_ids[r % self.scalar_ids.len()]
}
fn gen_scalar_pos(&self, r: usize) -> usize {
self.scalar_pos_ids[r % self.scalar_pos_ids.len()]
}
fn gen_scalar_eda(&self, _: usize) -> usize {
0
}
fn gen_base_field_cond(&self, r: usize) -> usize {
self.base_field_cond_ids[r % self.base_field_cond_ids.len()]
}
fn gen_base_field(&self, r: usize) -> usize {
self.base_field_ids[r % self.base_field_ids.len()]
}
fn gen_base_field_pos(&self, r: usize) -> usize {
self.base_field_pos_ids[r % self.base_field_pos_ids.len()]
}
fn gen_base_field_eda(&self, _: usize) -> usize {
0
}
fn gen_curve(&self, r: usize) -> usize {
self.curve_ids[r % self.curve_ids.len()]
}
fn throw_expr(&self, expr: &Expr<usize>) -> bool {
fn throw_conversion_expr<F: ActuallyUsedField>(expr: &ConversionExpr<F, usize>) -> bool {
matches!(expr, ConversionExpr::ScalarFromPlaintextBit(..)) || matches!(expr, ConversionExpr::BitFromEdaBit(..)) || matches!(expr, ConversionExpr::ScalarFromEdaBit(..))
}
fn uses_bounds<F: ActuallyUsedField>(expr: &FieldExpr<F, usize>) -> bool {
matches!(expr, FieldExpr::Bounds(..))
}
fn uses_keccak(expr: &BitExpr<usize>) -> bool {
matches!(expr, BitExpr::KeccakF1600(..)) }
fn uses_key_recovery_compute_errors<F: ActuallyUsedField>(
expr: &FieldExpr<F, usize>,
) -> bool {
matches!(expr, FieldExpr::KeyRecoveryComputeErrors(..))
}
!expr.is_eval_deterministic_fn_from_deps()
|| expr.get_input().is_some()
|| match expr {
Scalar(e) => uses_bounds(e) || uses_key_recovery_compute_errors(e),
ScalarConversion(e) => throw_conversion_expr(e),
Base(e) => uses_bounds(e) || uses_key_recovery_compute_errors(e),
BaseConversion(e) => throw_conversion_expr(e),
Bit(e) => uses_keccak(e),
_ => false,
}
}
}
impl ExprGenHelper for RNGIds {
type ScalarType = usize;
type BitType = usize;
type BaseType = usize;
type CurveType = usize;
fn scalar<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_scalar(r)
}
fn bit<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_bool(r)
}
fn base<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BaseType {
let r: usize = rng.r#gen();
self.gen_base_field(r)
}
fn curve_point<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::CurveType {
let r: usize = rng.r#gen();
self.gen_curve(r)
}
fn curve_val<R: Rng + ?Sized>(&self, rng: &mut R) -> CurvePoint {
R::gen(rng)
}
fn scalar_cond<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_scalar_cond(r)
}
fn scalar_pos<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_scalar_pos(r)
}
fn scalar_eda<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_scalar_eda(r)
}
fn scalar_int<R: Rng + ?Sized>(&self, rng: &mut R) -> ScalarField {
ScalarField::gen_inclusive_range(rng, self.scalar_min, self.scalar_max)
}
fn base_field_cond<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_base_field_cond(r)
}
fn base_field_pos<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_base_field_pos(r)
}
fn base_field_eda<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let r: usize = rng.r#gen();
self.gen_base_field_eda(r)
}
fn base_field_int<R: Rng + ?Sized>(&self, rng: &mut R) -> BaseField {
BaseField::gen_inclusive_range(rng, self.base_field_min, self.base_field_max)
}
}
fn test_compilation<R: Rng + ?Sized>(rng: &mut R, unopt_ir: IntermediateRepresentation) {
let opt_ir = Compiler::optimize_into_circuitable(unopt_ir.clone());
let async_circuit = opt_ir.to_async_mpc_circuit();
for _ in 0..4 {
let mut input_vals = FxHashMap::<usize, _>::default();
let unopt_vals = unopt_ir.eval(rng, &mut input_vals);
if unopt_vals.is_err() {
continue;
}
let opt_vals = opt_ir.eval(rng, &mut input_vals);
if opt_vals != unopt_vals {
println!("unopt: {}", unopt_ir);
let _ = unopt_ir.eval_with_log(
rng,
&mut input_vals,
false,
true,
false,
std::iter::empty(),
);
println!("opt: {}", opt_ir);
let _ =
opt_ir.eval_with_log(rng, &mut input_vals, false, true, false, std::iter::empty());
let mut ir = unopt_ir.clone();
fn test_ir<R: Rng + ?Sized>(
rng: &mut R,
ir: &IntermediateRepresentation,
unopt_vals: &Result<Vec<EvalValue>, UndefinedBehavior>,
input_vals: &mut FxHashMap<usize, EvalValue>,
) {
for _ in 0..16 {
let opt_vals = ir.eval(rng, input_vals);
if *unopt_vals != opt_vals {
println!("opt: {}", ir);
let _ = ir.eval_with_log(
rng,
input_vals,
false,
true,
false,
std::iter::empty(),
);
assert_eq!(opt_vals, *unopt_vals);
}
}
}
for build_fn in BUILD_FN_SEQ {
ir = Optimizer::optimize(ir);
test_ir(rng, &ir, &unopt_vals, &mut input_vals);
ir = build_fn(ir);
test_ir(rng, &ir, &unopt_vals, &mut input_vals);
}
assert_eq!(unopt_vals, opt_vals)
}
let vec_inputs: Vec<_> = (0..async_circuit.circuit.input_indices().len())
.map(|i| {
input_vals
.get(&i)
.map(|x| {
BigInt::from(match x {
EvalValue::Scalar(a) => a.to_unsigned_number(),
EvalValue::Base(a) => a.to_unsigned_number(),
EvalValue::Bit(a) => a.into(),
EvalValue::Curve(_) => panic!("curve inputs not supported here"),
})
.to_biguint()
.expect("unreachable, big_int to big_uint conversion always succeeds")
})
.unwrap_or(BigUint::from(0u32))
})
.collect();
let async_circuit_vals = async_circuit.circuit.mock_eval_big_uint(vec_inputs, rng);
let async_circuit_vals = async_circuit_vals
.into_iter()
.map(|x| EvalValue::Base(BaseField::from(&x.to_biguint().unwrap())))
.collect();
assert_eq!(unopt_vals, Ok(async_circuit_vals));
}
}
#[test]
fn compile_test() {
let rng = &mut crate::utils::test_rng::get();
for bound in [Number::from(1), 4.into(), 16.into(), 65536.into()] {
for _ in 0..4096 {
let mut ir_builder = IRBuilder::new(true);
let mut rng_ids = RNGIds::new(&Number::from(0), &bound);
let min = Number::from(0);
let max = bound.clone();
let base_field_bool_input_info = Rc::new(InputInfo {
min: 0.into(),
max: 1.into(),
kind: InputKind::SecretFromPlayer(0),
..InputInfo::default()
});
rng_ids.add_expr(
&mut ir_builder,
Base(FieldExpr::Input(0, base_field_bool_input_info.clone())),
);
rng_ids.add_expr(
&mut ir_builder,
Base(FieldExpr::Input(1, base_field_bool_input_info)),
);
assert!(!rng_ids.base_field_cond_ids.is_empty());
let input_kinds = [
InputKind::Secret,
InputKind::Plaintext,
InputKind::SecretFromPlayer(2),
];
let bool_id =
rng_ids.add_expr(&mut ir_builder, Bit(BitExpr::Input(2, Default::default())));
assert!(!rng_ids.bool_ids.is_empty());
for (i, kind) in input_kinds.into_iter().enumerate() {
let expr = Base(FieldExpr::Input(
i + 3, Rc::new(InputInfo {
min: min.clone().into(),
max: max.clone().into(),
kind,
..InputInfo::default()
}),
));
rng_ids.add_expr(&mut ir_builder, expr);
}
let base_field_one = rng_ids.add_expr(&mut ir_builder, Base(FieldExpr::Val(1.into())));
rng_ids.add_expr(&mut ir_builder, Base(FieldExpr::Add(1, base_field_one)));
assert!(!rng_ids.base_field_pos_ids.is_empty());
let scalar_id = rng_ids.add_expr(
&mut ir_builder,
ScalarConversion(ConversionExpr::BitToBitNum(vec![bool_id], false)),
);
let scalar_one = rng_ids.add_expr(&mut ir_builder, Scalar(FieldExpr::Val(1.into())));
rng_ids.add_expr(
&mut ir_builder,
Scalar(FieldExpr::Add(scalar_id, scalar_one)),
);
rng_ids.add_expr(
&mut ir_builder,
Curve(CurveExpr::Val(CurvePoint::identity())),
);
for _ in 0..16 {
let expr = rng_ids.expr(rng);
if rng_ids.throw_expr(&expr) {
continue;
}
rng_ids.add_expr(&mut ir_builder, expr);
}
assert!(rng_ids.base_field_eda_ids.is_empty());
let n = rng_ids.base_field_ids.len();
let unopt_ir = ir_builder.into_ir(rng_ids.base_field_ids[2..n].to_owned());
test_compilation(rng, unopt_ir);
}
}
}
#[test]
fn test_reveal_comparison() {
let mut ir_builder = IRBuilder::new(true);
let val_0 = ir_builder.new_expr(Base(FieldExpr::Val(0.into())));
let input = ir_builder.new_expr(Base(FieldExpr::Input(
0,
Rc::new(InputInfo {
kind: InputKind::Secret,
min: 0.into(),
max: 1.into(),
..InputInfo::default()
}),
)));
let neg_input = ir_builder.new_expr(Base(FieldExpr::Neg(input)));
let revealed_neg_input = ir_builder.new_expr(Base(FieldExpr::Reveal(neg_input)));
let output_bis = ir_builder.new_expr(Base(FieldExpr::Equal(revealed_neg_input, val_0)));
let ir = ir_builder.into_ir(vec![neg_input, output_bis]);
let rng = &mut crate::utils::test_rng::get();
test_compilation(rng, ir);
}
#[test]
fn optimize_modulo() {
for i in 0..=128 {
let mut ir_builder = IRBuilder::new(true);
let input_bits: Vec<_> = (0..128)
.map(|_| ir_builder.new_expr(Bit(BitExpr::Random(RandomBitId::default()))))
.collect();
let input = ir_builder.new_expr(BaseConversion(ConversionExpr::BitToBitNum(
input_bits, false,
)));
let power_of_two = ir_builder.new_expr(Base(FieldExpr::Val(BaseField::power_of_two(i))));
let modulo = ir_builder.new_expr(Base(FieldExpr::Rem(input, power_of_two)));
let ir = ir_builder.into_ir(vec![modulo]);
let compiled_ir = Compiler::optimize_into_circuitable(ir);
let len = compiled_ir.get_exprs().len();
let max_allowed_len = (15 * i).max(2) - 1;
assert!(len <= max_allowed_len);
}
}
#[test]
fn test_full_conversion() {
let rng = &mut crate::utils::test_rng::get();
for signed in [false, true] {
let mut expr_store = IRBuilder::new(true);
let outputs = with_local_expr_store_as_global(
|| {
let x = FieldValue::new(FieldExpr::Input(
0,
FieldBounds::<BaseField>::All.as_input_info(InputKind::Secret),
));
(0..BaseField::NUM_BITS as usize)
.map(|i| x.get_bit(i, signed).get_id())
.collect::<Vec<usize>>()
},
&mut expr_store,
);
let ir = expr_store.into_ir(outputs);
let compiled_circuit = Compiler::optimize_into_circuitable(ir).to_async_mpc_circuit();
let mut input_vals = FxHashMap::<usize, EvalValue>::default();
for x in [
Some(BaseField::ZERO),
Some(BaseField::ONE),
Some(-BaseField::ONE),
Some(BaseField::TWO_INV),
Some(-BaseField::TWO_INV),
]
.into_iter()
.chain(std::iter::repeat_n(None, 4))
{
let x = x.unwrap_or(BaseField::random(&mut *rng));
input_vals.insert(0, EvalValue::Base(x));
let result = compiled_circuit
.circuit
.mock_eval_big_uint(vec![BigUint::from_le_bytes(&x.to_le_bytes())], rng)
.into_iter()
.map(|bit| bit == BigUint::from(1u8))
.collect::<Vec<bool>>();
let expected = (0..BaseField::NUM_BITS as usize)
.map(|i| x.get_bit(i, signed))
.collect::<Vec<bool>>();
assert_eq!(result, expected);
}
}
}
fn gen_bit_expr<R: Rng + ?Sized>(r#gen: &mut RNGIds, rng: &mut R) -> BitExpr<usize> {
(r#gen, rng).r#gen()
}
fn count_bit_and(ir: &IntermediateRepresentation) -> usize {
ir.get_exprs()
.iter()
.filter(|e| matches!(e, Bit(BitExpr::And(..))))
.count()
}
#[test]
fn test_boolean_optimize() {
let rng = &mut crate::utils::test_rng::get();
for i in 0..16 {
let all_are_outputs = i % 2 == 1; let mut expr_store = IRBuilder::new(true);
let input_info = Rc::new(BitInputInfo {
kind: InputKind::Secret,
..BitInputInfo::default()
});
let mut expr_gen = RNGIds::new(&Number::from(0), &Number::from(1));
let input_0 =
expr_gen.add_expr(&mut expr_store, Bit(BitExpr::Input(0, input_info.clone())));
expr_gen.add_expr(&mut expr_store, Bit(BitExpr::Input(1, input_info)));
let mut last_bool_expr_id = input_0;
for _ in 0..1024 {
let bool_expr = gen_bit_expr(&mut expr_gen, rng);
if !bool_expr.is_eval_deterministic_fn_from_deps() {
continue;
}
if matches!(bool_expr, BitExpr::KeccakF1600(..)) {
continue;
}
let bool_expr_id = expr_gen.add_expr(&mut expr_store, Bit(bool_expr));
last_bool_expr_id = bool_expr_id;
}
let outputs = if all_are_outputs {
let n_expr = expr_store.len();
(input_0..n_expr)
.map(|i| {
expr_store.new_expr(ScalarConversion(ConversionExpr::BitToBitNum(
vec![i],
false,
)))
})
.collect()
} else {
let output = expr_store.new_expr(ScalarConversion(ConversionExpr::BitToBitNum(
vec![last_bool_expr_id],
false,
)));
vec![output]
};
let ir = expr_store.into_ir(outputs);
let opt_ir = Optimizer::optimize(ir.clone());
let opt_ir = Optimizer::optimize(opt_ir);
assert!(count_bit_and(&opt_ir) <= 1);
let n_different = FxHashSet::from_iter(opt_ir.get_outputs().iter().cloned()).len();
assert!(
n_different <= 16,
"n_different is {}, but it should be equal or below 16.\n ir is {opt_ir}",
n_different
);
assert!(opt_ir.get_exprs().len() <= if all_are_outputs { 202 } else { 37 });
for f_0 in [EvalValue::Bit(false), EvalValue::Bit(true)] {
for f_1 in [EvalValue::Bit(false), EvalValue::Bit(true)] {
let mut input_values = FxHashMap::from_iter([(0, f_0), (1, f_1)].into_iter());
let unopt_val = ir.eval(rng, &mut input_values);
let opt_val = opt_ir.eval(rng, &mut input_values);
assert_eq!(unopt_val, opt_val);
}
}
}
}
#[test]
fn test_boolean_optimize_example() {
fn build_example(with_xor: bool) -> IntermediateRepresentation {
let mut expr_store = IRBuilder::new(true);
let mut outputs = vec![];
let input_info = Rc::new(InputInfo {
kind: InputKind::Secret,
min: ScalarField::ZERO,
max: ScalarField::ONE,
..InputInfo::default()
});
let mut found = expr_store.new_expr(Bit(BitExpr::Val(false)));
for i in 0usize..4 {
let order = expr_store.new_expr(Scalar(FieldExpr::Input(i, input_info.clone())));
let eq = with_local_expr_store_as_global(
|| {
FieldValue::<ScalarField>::from(equal(
FieldValue::<ScalarField>::from_id(order),
FieldValue::<ScalarField>::from(0),
true,
CircuitType::default(),
))
.expr()
},
&mut expr_store,
);
let ScalarConversion(ConversionExpr::BitToBitNum(eq, false)) = eq else {
panic!("this test needs a rewrite")
};
let eq = eq[0];
let not_found = expr_store.new_expr(Bit(BitExpr::Not(found)));
let overwrite = expr_store.new_expr(Bit(BitExpr::And(eq, not_found)));
let output = expr_store.new_expr(ScalarConversion(ConversionExpr::BitToBitNum(
vec![overwrite],
false,
)));
outputs.push(output);
if with_xor {
found = expr_store.new_expr(Bit(BitExpr::Xor(overwrite, found)));
} else {
let not_a = expr_store.new_expr(Bit(BitExpr::Not(overwrite)));
let not_b = expr_store.new_expr(Bit(BitExpr::Not(found)));
let and = expr_store.new_expr(Bit(BitExpr::And(not_a, not_b)));
found = expr_store.new_expr(Bit(BitExpr::Not(and)));
}
}
let output = expr_store.new_expr(ScalarConversion(ConversionExpr::BitToBitNum(
vec![found],
false,
)));
outputs.push(output);
expr_store.into_ir(outputs)
}
let or_example = build_example(false);
let xor_example = build_example(true);
assert!(count_bit_and(&or_example) > count_bit_and(&xor_example));
let opt_or_example = Optimizer::optimize(or_example);
let opt_xor_example = Optimizer::optimize(xor_example);
assert_eq!(
count_bit_and(&opt_or_example),
count_bit_and(&opt_xor_example)
);
}