use crate::{
core::{
actually_used_field::ActuallyUsedField,
circuits::{
arithmetic::{PowCircuit, SqrtCircuit},
traits::arithmetic_circuit::ArithmeticCircuit,
},
compile_passes::compilation_pass::LocalCompilationPass,
expressions::{
circuit::{ArithmeticCircuitId, BaseCircuitId, GeneralCircuitId},
expr::Expr,
field_expr::FieldExpr,
other_expr::OtherExpr,
},
ir_builder::IRBuilder,
},
utils::field::{BaseField, ScalarField},
};
use ff::Field;
use rustc_hash::FxHashMap;
type ArithmeticCircuitRunCache = FxHashMap<(Vec<usize>, ArithmeticCircuitId), Vec<usize>>;
type BaseCircuitRunCache = FxHashMap<(Vec<usize>, BaseCircuitId), Vec<usize>>;
type GeneralCircuitRunCache =
FxHashMap<(Vec<usize>, Vec<usize>, GeneralCircuitId), (Vec<usize>, Vec<usize>)>;
#[derive(Default)]
pub struct ArithmeticCircuitBuilder {
expr_store: IRBuilder,
arithmetic_circuit_cache: ArithmeticCircuitRunCache,
base_circuit_cache: BaseCircuitRunCache,
general_circuit_cache: GeneralCircuitRunCache,
}
impl ArithmeticCircuitBuilder {
fn build_arithmetic_circuits<F: ActuallyUsedField>(
&mut self,
expr: FieldExpr<F, usize>,
) -> Expr<usize> {
match expr {
FieldExpr::SubCircuit(v, c, idx) => {
let res = match self.arithmetic_circuit_cache.get(&(v.clone(), c)) {
None => {
let circuit_outputs =
c.to_circuit::<F>().run_usize(&v, &mut self.expr_store);
let res = circuit_outputs.get(idx).cloned();
self.arithmetic_circuit_cache
.insert((v.clone(), c), circuit_outputs);
res
}
Some(circuit_outputs) => circuit_outputs.get(idx).cloned(),
};
res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
.unwrap_or(F::field_expr_to_expr(FieldExpr::Val(F::ZERO)))
}
FieldExpr::Sqrt(v) if !self.expr_store.get_is_plaintext(v) => {
let circuit_outputs = <SqrtCircuit as ArithmeticCircuit<F>>::run_usize(
&SqrtCircuit,
&[v],
&mut self.expr_store,
);
let res = circuit_outputs[0];
self.expr_store.get_expr(res).clone()
}
FieldExpr::Pow(v, e, is_expected_non_zero) if !self.expr_store.get_is_plaintext(v) => {
let circuit_outputs = <PowCircuit as ArithmeticCircuit<F>>::run_usize(
&PowCircuit {
exponent: e,
is_expected_non_zero,
},
&[v],
&mut self.expr_store,
);
let res = circuit_outputs[0];
self.expr_store.get_expr(res).clone()
}
_ => F::field_expr_to_expr(expr),
}
}
fn build_general_circuits(&mut self, expr: OtherExpr<usize>) -> Expr<usize> {
match expr {
OtherExpr::ScalarGeneralCircuit(s, b, c, idx) => {
let res = match self.general_circuit_cache.get(&(s.clone(), b.clone(), c)) {
None => {
let result = c.to_circuit().run_usize(&s, &b, &mut self.expr_store);
self.general_circuit_cache
.insert((s.clone(), b.clone(), c), result.clone());
result.0.get(idx).cloned()
}
Some((scalar_res, _)) => scalar_res.get(idx).cloned(),
};
res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
.unwrap_or(Expr::Scalar(FieldExpr::Val(ScalarField::ZERO)))
}
OtherExpr::BaseGeneralCircuit(s, b, c, idx) => {
let res = match self.general_circuit_cache.get(&(s.clone(), b.clone(), c)) {
None => {
let result = c.to_circuit().run_usize(&s, &b, &mut self.expr_store);
self.general_circuit_cache
.insert((s.clone(), b.clone(), c), result.clone());
result.1.get(idx).cloned()
}
Some((_, base_res)) => base_res.get(idx).cloned(),
};
res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
.unwrap_or(Expr::Base(FieldExpr::Val(BaseField::ZERO)))
}
OtherExpr::BaseArithmeticCircuit(v, c, idx) => {
let res = match self.base_circuit_cache.get(&(v.clone(), c)) {
None => {
let circuit_outputs = c.to_circuit().run_usize(&v, &mut self.expr_store);
let res = circuit_outputs.get(idx).cloned();
self.base_circuit_cache
.insert((v.clone(), c), circuit_outputs);
res
}
Some(circuit_outputs) => circuit_outputs.get(idx).cloned(),
};
res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
.unwrap_or(Expr::Base(FieldExpr::Val(BaseField::ZERO)))
}
_ => Expr::Other(expr),
}
}
}
impl LocalCompilationPass for ArithmeticCircuitBuilder {
fn expr_store(&mut self) -> &mut IRBuilder {
&mut self.expr_store
}
fn transform(&mut self, expr: Expr<usize>, _is_plaintext: bool) -> Expr<usize> {
match expr {
Expr::Scalar(e) => self.build_arithmetic_circuits(e),
Expr::Base(e) => self.build_arithmetic_circuits(e),
Expr::Other(e) => self.build_general_circuits(e),
_ => expr,
}
}
}