use crate::core::{
actually_used_field::ActuallyUsedField,
bounds::IsBounds,
compile_passes::optimizer::Optimizer,
expressions::{
circuit::ArithmeticCircuitId,
expr::Expr,
field_expr::{
expr_lincomb,
FieldExpr,
FieldExpr::{LinComb, Val},
},
},
ir_builder::ExprStore,
};
use std::collections::BTreeMap;
impl Optimizer {
fn is_constant<F: ActuallyUsedField>(&self, e: usize) -> Option<F> {
self.expr_store.bounds(e).as_constant()
}
fn merge_lincomb<F: ActuallyUsedField>(
&mut self,
v: Vec<(usize, F)>,
c: F,
is_result_plaintext: bool,
) -> Expr<usize> {
let mut coeffs: BTreeMap<usize, F> = BTreeMap::new();
let mut new_c = c;
fn add_expr<F: ActuallyUsedField>(coeffs: &mut BTreeMap<usize, F>, e: usize, factor: F) {
match coeffs.get_mut(&e) {
None => {
coeffs.insert(e, factor);
}
Some(mut_factor) => {
*mut_factor += factor;
}
}
}
for (e, factor) in &v {
match F::expr_to_field_expr(self.expr_store.get_expr(*e).clone()) {
Some(LinComb(v2, c2)) if v2.len() < 16 => {
for (e2, factor2) in v2 {
add_expr(&mut coeffs, e2, *factor * factor2)
}
new_c += *factor * c2;
}
Some(Val(c2)) => new_c += *factor * c2,
_ => add_expr(&mut coeffs, *e, *factor),
}
}
let coeffs: Vec<_> = coeffs
.iter()
.filter(|(_, x)| **x != F::ZERO)
.map(|(expr1, x)| (*expr1, *x))
.collect();
if coeffs.is_empty() {
F::field_expr_to_expr(Val(new_c))
} else if new_c == F::ZERO && coeffs.len() == 1 && coeffs[0].1 == F::ONE {
if is_result_plaintext {
self.expr_store.reveal(coeffs[0].0)
}
self.expr_store.get_expr(coeffs[0].0).clone()
} else {
if is_result_plaintext {
let n_non_plaintext = coeffs
.iter()
.filter(|(e, _)| !self.expr_store.get_is_plaintext(*e))
.count();
if n_non_plaintext == 1 {
for (e, _) in &coeffs {
self.expr_store.reveal(*e)
}
}
}
F::field_expr_to_expr(LinComb(coeffs, new_c))
}
}
pub fn optimize_field_expr<F: ActuallyUsedField>(
&mut self,
expr: FieldExpr<F, usize>,
is_plaintext: bool,
) -> Expr<usize> {
use FieldExpr::*;
match expr {
LinComb(v, c) => self.merge_lincomb(v, c, is_plaintext),
Mul(e1, e2) => {
if is_plaintext
&& self.expr_store.get_is_plaintext(e1)
&& !F::bounds_to_field_bounds(*self.expr_store.get_bounds(e1)).contains(F::ZERO)
{
self.expr_store.reveal(e2);
}
if is_plaintext
&& self.expr_store.get_is_plaintext(e2)
&& !F::bounds_to_field_bounds(*self.expr_store.get_bounds(e2)).contains(F::ZERO)
{
self.expr_store.reveal(e1);
}
if let Some(c) = self.is_constant(e1) {
self.merge_lincomb(vec![(e2, c)], F::ZERO, is_plaintext)
} else if let Some(c) = self.is_constant(e2) {
self.merge_lincomb(vec![(e1, c)], F::ZERO, is_plaintext)
} else if e1 > e2 {
F::field_expr_to_expr(Mul(e2, e1))
} else {
F::field_expr_to_expr(expr)
}
}
Equal(e1, e2) => {
let (e1, e2) = if self.is_constant(e2) == Some(F::ZERO) {
(e1, e2)
} else {
(e1.min(e2), e1.max(e2))
};
let sub = self.merge_lincomb(expr_lincomb!((e1, 1), (e2, -1)), F::ZERO, false);
let sub_field_expr = F::expr_to_field_expr(sub.clone());
if let Some(Val(val)) = sub_field_expr.as_ref() {
F::field_expr_to_expr(Val((*val == F::ZERO).into()))
} else {
let e_sub = if let Some(Mul(f1, f2)) = sub_field_expr.as_ref() {
let f1_nonzero =
!F::bounds_to_field_bounds(*self.expr_store.get_bounds(*f1))
.contains(F::ZERO);
let f2_nonzero =
!F::bounds_to_field_bounds(*self.expr_store.get_bounds(*f2))
.contains(F::ZERO);
if f1_nonzero && f2_nonzero {
return F::field_expr_to_expr(Val(F::ZERO));
} else if f1_nonzero {
Some(*f2)
} else if f2_nonzero {
Some(*f1)
} else {
None
}
} else if let Some(LinComb(v, c)) = sub_field_expr.as_ref() {
if *c == F::ZERO && v.len() == 1 {
Some(v[0].0)
} else {
None
}
} else {
None
};
if is_plaintext || e_sub.is_some() {
let e_sub = e_sub.unwrap_or_else(|| self.expr_store.new_expr(sub));
let e_zero = self
.expr_store
.new_expr(F::field_expr_to_expr(Val(F::ZERO)));
F::field_expr_to_expr(Equal(e_sub, e_zero))
} else {
F::field_expr_to_expr(expr)
}
}
}
Gt(e1, e2, _) => {
if e1 == e2 {
F::field_expr_to_expr(Val(F::ZERO))
} else {
F::field_expr_to_expr(expr)
}
}
Ge(e1, e2, _) => {
if e1 == e2 {
F::field_expr_to_expr(Val(1.into()))
} else {
F::field_expr_to_expr(expr)
}
}
Add(e1, e2) => {
self.merge_lincomb(expr_lincomb!((e1, 1), (e2, 1)), F::ZERO, is_plaintext)
}
Sub(e1, e2) => {
self.merge_lincomb(expr_lincomb!((e1, 1), (e2, -1)), F::ZERO, is_plaintext)
}
Neg(e) => self.merge_lincomb(expr_lincomb!((e, -1)), F::ZERO, is_plaintext),
Reveal(e) => {
let expr = self.expr_store.get_expr(e);
expr.clone()
}
Bounds(e, _) => {
let expr = self.expr_store.get_expr(e);
expr.clone()
}
Where(e1, e2, e3) => {
let sub = self.merge_lincomb(expr_lincomb!((e2, 1), (e3, -1)), F::ZERO, false);
let e_sub = self.expr_store.new_expr(sub);
let e_prod = self.expr_store.push_field(Mul::<F, _>(e1, e_sub));
self.merge_lincomb(expr_lincomb!((e_prod, 1), (e3, 1)), F::ZERO, is_plaintext)
}
KeepLsBits(e, c, signed_output) => {
let bounds = F::bounds_to_field_bounds(*self.expr_store.get_bounds(e));
if bounds == KeepLsBits(bounds, c, signed_output).bounds() {
self.expr_store.get_expr(e).clone()
} else {
F::field_expr_to_expr(expr)
}
}
SubCircuit(v, ArithmeticCircuitId::Div, 0) => {
assert_eq!(v.len(), 2);
let e_a = v[0];
let e_b = v[1];
let expr_b = self.expr_store.get_expr(e_b).clone();
let res = if let Some(SubCircuit(v2, ArithmeticCircuitId::Sqrt, 0)) =
F::expr_to_field_expr(expr_b)
{
SubCircuit(vec![e_a, v2[0]], ArithmeticCircuitId::DivSqrt, 0)
} else {
SubCircuit(v, ArithmeticCircuitId::Div, 0)
};
F::field_expr_to_expr(res)
}
y => F::field_expr_to_expr(y),
}
}
}