use std::collections::HashMap;
use anyhow::Result;
use disjoint::DisjointSetVec;
use crate::pcl::{
Module,
expr::traits::ConstraintExpr,
opt::{MutOptResult, MutOptimizer, OptResult, PassError},
stmt::traits::{ConstraintLike as _, ExprArgs as _},
vars::{VarKind, VarStr},
};
#[derive(Default, Debug)]
pub struct ConsolidateVarNamesPass;
type RenameSet = HashMap<VarStr, VarStr>;
struct PassImpl<'a, K: VarKind> {
module: &'a mut Module<K>,
}
#[derive(Debug, thiserror::Error)]
enum RenamePassError {
#[error("Variable name not in environment: {var:?}")]
VarNotFound { var: Option<VarStr> },
}
impl<'m, K: VarKind + Copy> PassImpl<'m, K> {
fn find_eq_constraint_exprs(&self) -> impl Iterator<Item = &dyn ConstraintExpr> {
self.module
.stmts()
.iter()
.filter_map(|stmt| stmt.constraint_expr())
.filter(|c| c.is_eq())
}
fn compute_eqv_classes(&self) -> DisjointSetVec<VarStr> {
let (set, _) = self
.find_eq_constraint_exprs()
.filter_map(|c| Some((c.lhs().var_name()?.clone(), c.rhs().var_name()?.clone())))
.fold(
(
DisjointSetVec::<VarStr>::new(),
HashMap::<VarStr, usize>::new(), ),
|(mut set, mut seen), (lhs, rhs)| {
let lhs = *seen.entry(lhs.clone()).or_insert_with(|| set.push(lhs));
let rhs = *seen.entry(rhs.clone()).or_insert_with(|| set.push(rhs));
set.join(lhs, rhs);
(set, seen)
},
);
set
}
fn find_vars<'a>(
&self,
class: &[usize],
set: &'a DisjointSetVec<VarStr>,
) -> OptResult<Vec<(K, &'a VarStr)>> {
class
.iter()
.map(|idx| set.get(*idx))
.map(|var| {
var.and_then(|var| self.module.vars().lookup_key(var).map(|k| (*k, var)))
.ok_or_else(|| {
PassError::new(RenamePassError::VarNotFound { var: var.cloned() })
})
})
.collect::<Result<Vec<_>, _>>()
}
fn collect_temps<'a>(&self, vars: &[(K, &'a VarStr)]) -> Vec<(K, &'a VarStr)> {
vars.iter()
.copied()
.filter(|(k, _)| k.is_temp())
.collect::<Vec<_>>()
}
fn select_leader<'a>(&self, vars: &[(K, &'a VarStr)]) -> Option<(K, &'a VarStr)> {
vars.iter()
.filter(|(k, _)| !k.is_temp())
.copied()
.fold(None, |acc, (k, v)| {
debug_assert!(!k.is_temp());
if acc.is_none() {
return Some((k, v));
}
let (acc_k, _) = acc.unwrap();
if acc_k.is_output() && k.is_input() {
return Some((k, v));
}
acc
})
}
fn handle_eqv_class(
&self,
class: Vec<usize>,
set: &DisjointSetVec<VarStr>,
) -> OptResult<impl Iterator<Item = (VarStr, VarStr)>> {
assert!(!class.is_empty());
let vars = self.find_vars(&class, set)?;
let temps = self.collect_temps(&vars);
let leader = self.select_leader(&vars).or_else(|| temps.first().copied());
Ok(temps
.into_iter()
.zip(leader.into_iter().cycle())
.filter(|(to_rename, leader)| to_rename != leader)
.map(|((_, renamed), (_, with))| (renamed.clone(), with.clone()))
.inspect(|(renamed, with)| log::debug!("Variable {renamed} will be renamed to {with}")))
}
fn compute_rename_set(&self) -> OptResult<RenameSet> {
let ec = self.compute_eqv_classes();
Ok(ec
.indices()
.sets()
.into_iter()
.map(|class| self.handle_eqv_class(class, &ec))
.collect::<OptResult<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<RenameSet>())
}
fn rename_stmts(&mut self, rename_set: &RenameSet) -> MutOptResult {
for stmt in self.module.stmts_mut() {
for (idx, new_arg) in stmt
.args()
.iter()
.map(|expr| expr.renamed(rename_set).unwrap_or(expr.clone()))
.enumerate()
{
stmt.replace_arg(idx, new_arg)?;
}
}
Ok(())
}
fn remove_tautos(&mut self) {
fn is_tauto(expr: &dyn ConstraintExpr) -> bool {
match (expr.lhs().var_name(), expr.rhs().var_name()) {
(Some(lhs), Some(rhs)) => lhs == rhs,
_ => false,
}
}
self.module.remove_stmt_if(|stmt| {
stmt.constraint_expr()
.map(is_tauto)
.inspect(|remove| {
if *remove {
log::debug!("Removing {stmt:?}")
}
})
.unwrap_or_default()
})
}
}
impl<K: VarKind + Copy> MutOptimizer<Module<K>> for ConsolidateVarNamesPass {
fn optimize(&mut self, module: &mut Module<K>) -> MutOptResult {
let mut pass = PassImpl { module };
let rename_set = pass.compute_rename_set()?;
pass.rename_stmts(&rename_set)?;
pass.remove_tautos();
Ok(())
}
}