use crate::core::{
actually_used_field::ActuallyUsedField,
compile_passes::compilation_pass::LocalCompilationPass,
expressions::{
bit_expr::BitExpr,
conversion_expr::{ConversionExpr, EdaBitId},
expr::Expr,
field_expr::FieldExpr,
},
ir_builder::{ExprStore, IRBuilder},
};
use rustc_hash::FxHashMap;
use std::marker::PhantomData;
#[derive(Default)]
pub struct ComplexExprExpander {
expr_store: IRBuilder,
}
impl ComplexExprExpander {
fn expand_lincombs<F: ActuallyUsedField>(&mut self, expr: FieldExpr<F, usize>) -> Expr<usize> {
expand_lincombs(expr, &mut self.expr_store)
}
fn expand_binop(&mut self, expr: BitExpr<usize>) -> BitExpr<usize> {
use BitExpr::*;
use Expr::Bit;
let expr_store = &mut self.expr_store;
if let Binop(e1, e2, truth_table) = expr {
debug_assert!(!is_bit_expr_not(e1, expr_store), "a NOT has escaped");
debug_assert!(!is_bit_expr_not(e2, expr_store), "a NOT has escaped");
debug_assert!(
!truth_table[0],
"a truth table is true for (false, false). This is unacceptable"
);
let n_true_values = truth_table.iter().filter(|x| **x).count();
if n_true_values == 0 || n_true_values == 4 {
panic!("unexpected constant binary operation");
} else if n_true_values == 2 {
if truth_table[0] {
panic!("No Nots");
} else {
bit_xor(e1, e2)
}
} else if n_true_values == 3 {
let and = expr_store.new_expr(Bit(bit_and(e1, e2)));
let sum = expr_store.new_expr(Bit(bit_xor(e1, and)));
bit_xor(e2, sum)
} else if truth_table[1] {
let and = expr_store.new_expr(Bit(bit_and(e1, e2)));
bit_xor(e2, and)
} else if truth_table[2] {
let and = expr_store.new_expr(Bit(bit_and(e1, e2)));
bit_xor(e1, and)
} else {
debug_assert!(truth_table[3], "this should be and");
bit_and(e1, e2)
}
} else {
expr
}
}
}
fn expand_lincombs<F: ActuallyUsedField>(
expr: FieldExpr<F, usize>,
expr_store: &mut impl ExprStore<F>,
) -> Expr<usize> {
use FieldExpr::*;
let expr = if let LinComb(v, c) = expr {
match v.len() {
0 => Val(c),
1 => {
let (e, f) = &v[0];
if c == F::ZERO {
if *f == F::ONE {
return expr_store.get_expr(*e).clone();
} else if *f == -F::ONE {
Neg(*e)
} else {
let f = expr_store.push_field(Val(*f));
Mul(*e, f)
}
} else {
let c = expr_store.push_field(Val(c));
if *f == F::ONE {
Add(*e, c)
} else if *f == -F::ONE {
let new_e = expr_store.push_field(Neg(*e));
Add(new_e, c)
} else {
let f = expr_store.push_field(Val(*f));
let new_e = expr_store.push_field(Mul(*e, f));
Add(new_e, c)
}
}
}
_ => {
let new_v: Vec<usize> = v
.iter()
.map(|(e, f)| {
if *f == F::ONE {
*e
} else if *f == -F::ONE {
expr_store.push_field(Neg(*e))
} else {
let f = expr_store.push_field(Val(*f));
expr_store.push_field(Mul(*e, f))
}
})
.collect();
let mut expr = Add(new_v[0], new_v[1]);
for item in new_v.iter().skip(2) {
let e = expr_store.push_field(expr);
expr = Add(e, *item);
}
if c != F::ZERO {
let e = expr_store.push_field(expr);
let c = expr_store.push_field(Val(c));
expr = Add(e, c);
}
expr
}
}
} else {
expr
};
F::field_expr_to_expr(expr)
}
fn bit_xor(e1: usize, e2: usize) -> BitExpr<usize> {
if e2 < e1 {
BitExpr::Xor(e2, e1)
} else {
BitExpr::Xor(e1, e2)
}
}
fn bit_and(e1: usize, e2: usize) -> BitExpr<usize> {
if e2 < e1 {
BitExpr::And(e2, e1)
} else {
BitExpr::And(e1, e2)
}
}
fn is_bit_expr_not(e: usize, expr_store: &mut IRBuilder) -> bool {
matches!(expr_store.get_expr(e), Expr::Bit(BitExpr::Not(_)))
}
impl LocalCompilationPass for ComplexExprExpander {
fn expr_store(&mut self) -> &mut IRBuilder {
&mut self.expr_store
}
fn transform(&mut self, expr: Expr<usize>, _is_plaintext: bool) -> Expr<usize> {
match expr {
Expr::Scalar(e) => self.expand_lincombs(e),
Expr::Bit(e) => Expr::Bit(self.expand_binop(e)),
Expr::Base(e) => self.expand_lincombs(e),
_ => expr,
}
}
}
#[derive(Default)]
pub struct ComplexExprExpanderTestnet {
inner: ComplexExprExpander,
eda_bit_id_to_da_bit_expr_ids_map: FxHashMap<EdaBitId, Vec<usize>>,
}
impl ComplexExprExpanderTestnet {
fn eda_bit_expr_id_to_da_bit_expr_ids<F: ActuallyUsedField>(
&self,
eda_bit_expr_id: usize,
) -> &Vec<usize> {
let Some(ConversionExpr::EdaBit(eda_bit_id, _, _)) =
F::expr_to_conversion_expr(self.inner.expr_store.get_expr(eda_bit_expr_id).clone())
else {
panic!("cannot expand eda_bits_to_da_bits");
};
let Some(v) = self.eda_bit_id_to_da_bit_expr_ids_map.get(&eda_bit_id) else {
panic!("cannot expand eda_bits_to_da_bits");
};
v
}
fn expand_eda_bits_to_da_bits<F: ActuallyUsedField>(
&mut self,
expr: ConversionExpr<F, usize>,
) -> Expr<usize> {
match expr {
ConversionExpr::EdaBit(eda_bit_id, width, _) => {
let v = (0..width)
.map(|_| {
self.inner.expr_store.new_expr(F::conversion_expr_to_expr(
ConversionExpr::EdaBit(EdaBitId::new(), 1, PhantomData),
))
})
.collect::<Vec<_>>();
self.eda_bit_id_to_da_bit_expr_ids_map.insert(eda_bit_id, v);
F::conversion_expr_to_expr(ConversionExpr::EdaBit(eda_bit_id, width, PhantomData))
}
ConversionExpr::BitFromEdaBit(eda_bit_expr_id, bit_idx) => {
let v = self.eda_bit_expr_id_to_da_bit_expr_ids::<F>(eda_bit_expr_id);
assert!(bit_idx < v.len());
F::conversion_expr_to_expr(ConversionExpr::BitFromEdaBit(v[bit_idx], 0))
}
ConversionExpr::ScalarFromEdaBit(eda_bit_expr_id) => {
let v = self
.eda_bit_expr_id_to_da_bit_expr_ids::<F>(eda_bit_expr_id)
.clone();
let v = v
.into_iter()
.enumerate()
.map(|(i, x)| {
let expr_id = self.inner.expr_store.new_expr(F::conversion_expr_to_expr(
ConversionExpr::ScalarFromEdaBit(x),
));
(expr_id, F::power_of_two(i))
})
.collect::<Vec<_>>();
self.inner.expand_lincombs(FieldExpr::LinComb(v, F::ZERO))
}
_ => F::conversion_expr_to_expr(expr),
}
}
}
impl LocalCompilationPass for ComplexExprExpanderTestnet {
fn expr_store(&mut self) -> &mut IRBuilder {
&mut self.inner.expr_store
}
fn transform(&mut self, expr: Expr<usize>, _is_plaintext: bool) -> Expr<usize> {
match expr {
Expr::Scalar(e) => self.inner.expand_lincombs(e),
Expr::Bit(e) => Expr::Bit(self.inner.expand_binop(e)),
Expr::Base(e) => self.inner.expand_lincombs(e),
Expr::BaseConversion(e) => self.expand_eda_bits_to_da_bits(e),
Expr::ScalarConversion(e) => self.expand_eda_bits_to_da_bits(e),
_ => expr,
}
}
}