use fxhash::FxHashMap;
use crate::{
error::Handler,
expr::{evaluate::Evaluator, BinaryOp, Expr, Immediate, UnaryOp},
predicate::{Contract, ExprKey},
span::{empty_span, Spanned},
types::Type,
};
pub(crate) fn const_folding(handler: &Handler, contract: &mut Contract) {
for loop_count in 0.. {
if !fold_consts(contract) && !fold_identities(contract) {
break;
}
if loop_count > 10_000 {
handler.emit_internal_err("Infinite loop in const_folding", empty_span());
break;
}
}
}
pub(crate) fn fold_consts(contract: &mut Contract) -> bool {
let mut did_fold_const: bool = false;
for pred_key in contract.preds.keys().collect::<Vec<_>>() {
let mut expr_keys_to_replace: FxHashMap<ExprKey, (Expr, Type)> = FxHashMap::default();
let mut scope_values: FxHashMap<String, Immediate> = FxHashMap::default();
if let Some(pred) = contract.preds.get(pred_key) {
for (_, var) in pred.variables() {
if let Expr::Immediate { value, .. } = var.expr.get(contract) {
scope_values.insert(var.name.clone(), value.clone());
}
}
}
let evaluator = Evaluator::from_values(contract, scope_values.clone());
for expr_key in contract.exprs(pred_key) {
if expr_key.get(contract).is_immediate() {
continue;
}
if let Ok(imm) = evaluator.evaluate(expr_key, &Handler::default(), contract) {
let simplified_expr = Expr::Immediate {
value: imm.clone(),
span: expr_key.get(contract).span().clone(),
};
expr_keys_to_replace.insert(
expr_key,
(simplified_expr, expr_key.get_ty(contract).clone()),
);
}
}
for (old_expr_key, (simplified_expr, simplified_type)) in &expr_keys_to_replace {
let simplified_expr_key = contract
.exprs
.insert(simplified_expr.clone(), simplified_type.clone());
contract.replace_exprs(Some(pred_key), *old_expr_key, simplified_expr_key);
}
if !expr_keys_to_replace.is_empty() {
did_fold_const = true;
}
}
did_fold_const
}
pub(crate) fn fold_identities(contract: &mut Contract) -> bool {
let mut expr_keys_to_replace: Vec<(ExprKey, ExprKey)> = vec![];
let mut new_exprs_to_replace: Vec<(ExprKey, Expr, Type)> = vec![];
let mut did_fold_const: bool = false;
for pred_key in contract.preds.keys().collect::<Vec<_>>() {
for expr_key in contract.exprs(pred_key) {
if let Expr::BinaryOp { op, lhs, rhs, span } = expr_key.get(contract) {
let lhs_imm = if let Expr::Immediate { value, .. } = lhs.get(contract) {
Some(value)
} else {
None
};
let rhs_imm = if let Expr::Immediate { value, .. } = rhs.get(contract) {
Some(value)
} else {
None
};
let replacement_expr_key = match (op, lhs_imm, rhs_imm) {
(BinaryOp::LogicalAnd, Some(Immediate::Bool(true)), _) => Some(*rhs),
(BinaryOp::LogicalAnd, _, Some(Immediate::Bool(true))) => Some(*lhs),
(BinaryOp::LogicalAnd, Some(Immediate::Bool(false)), _) => Some(*lhs),
(BinaryOp::LogicalAnd, _, Some(Immediate::Bool(false))) => Some(*rhs),
(BinaryOp::LogicalOr, Some(Immediate::Bool(true)), _) => Some(*lhs),
(BinaryOp::LogicalOr, _, Some(Immediate::Bool(true))) => Some(*rhs),
(BinaryOp::LogicalOr, Some(Immediate::Bool(false)), _) => Some(*rhs),
(BinaryOp::LogicalOr, _, Some(Immediate::Bool(false))) => Some(*lhs),
(BinaryOp::Add, Some(Immediate::Int(0)), _) => Some(*rhs),
(BinaryOp::Add, _, Some(Immediate::Int(0))) => Some(*lhs),
(BinaryOp::Add, Some(Immediate::Real(0.0)), _) => Some(*rhs),
(BinaryOp::Add, _, Some(Immediate::Real(0.0))) => Some(*lhs),
(BinaryOp::Sub, _, Some(Immediate::Int(0))) => Some(*lhs),
(BinaryOp::Sub, _, Some(Immediate::Real(0.0))) => Some(*lhs),
(BinaryOp::Mul, Some(Immediate::Int(0)), _) => Some(*lhs),
(BinaryOp::Mul, _, Some(Immediate::Int(0))) => Some(*rhs),
(BinaryOp::Mul, Some(Immediate::Real(0.0)), _) => Some(*lhs),
(BinaryOp::Mul, _, Some(Immediate::Real(0.0))) => Some(*rhs),
(BinaryOp::Div, Some(Immediate::Int(0)), _) => Some(*lhs),
(BinaryOp::Div, _, Some(Immediate::Int(1))) => Some(*lhs),
(BinaryOp::Div, Some(Immediate::Real(0.0)), _) => Some(*lhs),
(BinaryOp::Div, _, Some(Immediate::Real(1.0))) => Some(*lhs),
(BinaryOp::Mod, _, Some(Immediate::Int(1))) => Some(*lhs),
(BinaryOp::Mod, Some(Immediate::Int(1)), _) => Some(*lhs),
_ => None,
};
if let Some(replacement_expr_key) = replacement_expr_key {
expr_keys_to_replace.push((expr_key, replacement_expr_key));
}
let replacement_expr = match (op, lhs_imm, rhs_imm) {
(BinaryOp::Sub, Some(Immediate::Int(0)), _) => {
let new_expr = Expr::UnaryOp {
op: UnaryOp::Neg,
expr: *rhs,
span: span.clone(),
};
Some((new_expr, lhs.get_ty(contract).clone()))
}
(BinaryOp::Sub, Some(Immediate::Real(0.0)), _) => {
let new_expr = Expr::UnaryOp {
op: UnaryOp::Neg,
expr: *rhs,
span: span.clone(),
};
Some((new_expr, lhs.get_ty(contract).clone()))
}
_ => None,
};
if let Some(replacement_expr) = replacement_expr {
new_exprs_to_replace.push((expr_key, replacement_expr.0, replacement_expr.1))
}
}
}
for (old_expr_key, new_expr, new_expr_ty) in &new_exprs_to_replace {
let new_expr_key = contract.exprs.insert(new_expr.clone(), new_expr_ty.clone());
expr_keys_to_replace.push((*old_expr_key, new_expr_key));
}
for (old_expr_key, new_expr_key) in &expr_keys_to_replace {
contract.replace_exprs(Some(pred_key), *old_expr_key, *new_expr_key);
}
if !expr_keys_to_replace.is_empty() {
did_fold_const = true;
}
expr_keys_to_replace = vec![];
}
did_fold_const
}