use std::collections::{HashMap, HashSet};
use haloumi_ir::Felt;
use crate::pcl::{
Module,
expr::{Expr, traits::ConstraintExpr},
opt::{MutOptResult, MutOptimizer},
stmt::traits::{CallLike as _, ConstraintLike as _, ExprArgs as _, MaybeCallLike as _},
vars::{VarKind, VarStr},
};
#[derive(Default, Debug)]
pub struct ReplaceKnownConstsPass;
type ReplacementSet = HashMap<VarStr, Felt>;
struct PassImpl<'a, K: VarKind> {
module: &'a mut Module<K>,
}
impl<'m, K: VarKind + Copy> PassImpl<'m, K> {
fn find_eq_constraint_exprs(&self) -> impl Iterator<Item = &dyn ConstraintExpr> {
self.module
.stmts()
.iter()
.filter_map(|stmt| stmt.constraint_expr())
.filter(|c| c.is_eq())
}
fn find_call_output_vars(&self) -> HashSet<VarStr> {
self.module
.stmts()
.iter()
.flat_map(|stmt| stmt.as_call())
.flat_map(|c| c.outputs().to_vec())
.collect()
}
fn try_find_pattern(
&self,
lhs: Expr,
rhs: Expr,
call_output_vars: &HashSet<VarStr>,
) -> Option<(VarStr, Felt)> {
let f = rhs.as_const()?;
lhs.var_name()
.filter(|var| {
self.module
.vars()
.lookup_key(var)
.is_some_and(|k| k.is_temp())
&& !call_output_vars.contains(var)
})
.map(|var| (var.clone(), f))
}
fn collect_replacements(&self) -> ReplacementSet {
let mut set: HashMap<VarStr, HashSet<Felt>> = HashMap::new();
let output_vars = self.find_call_output_vars();
self.find_eq_constraint_exprs()
.filter_map(|c| {
self.try_find_pattern(c.lhs(), c.rhs(), &output_vars)
.or_else(|| self.try_find_pattern(c.rhs(), c.lhs(), &output_vars))
})
.for_each(|(var, felt)| {
set.entry(var).or_default().insert(felt);
});
set.retain(|_, values| values.len() == 1);
set.into_iter()
.map(|(k, values)| (k, values.into_iter().next().unwrap()))
.collect()
}
fn replace_stmts(&mut self, replacement_set: &ReplacementSet) -> MutOptResult {
for stmt in self.module.stmts_mut() {
stmt.args()
.iter()
.enumerate()
.filter_map(|(idx, expr)| {
expr.replaced_by_const(replacement_set)
.map(|expr| (idx, expr))
})
.try_for_each(|(idx, new_arg)| stmt.replace_arg(idx, new_arg))?;
}
Ok(())
}
fn remove_tautos(&mut self) {
fn is_tauto(expr: &dyn ConstraintExpr) -> bool {
match (expr.lhs().as_const(), expr.rhs().as_const()) {
(Some(lhs), Some(rhs)) => lhs == rhs,
_ => false,
}
}
self.module.remove_stmt_if(|stmt| {
stmt.constraint_expr()
.map(is_tauto)
.inspect(|remove| {
if *remove {
log::debug!("Removing {stmt:?}")
}
})
.unwrap_or_default()
})
}
}
impl<K: VarKind + Copy> MutOptimizer<Module<K>> for ReplaceKnownConstsPass {
fn optimize(&mut self, module: &mut Module<K>) -> MutOptResult {
let mut pass = PassImpl { module };
let replacement_set = pass.collect_replacements();
pass.replace_stmts(&replacement_set)?;
pass.remove_tautos();
Ok(())
}
}