use std::collections::{BTreeMap, BTreeSet};
use crate::{
parser::NUM_RESERVED,
types::{UnrolledExpr, UnrolledStatement, VarId},
};
pub fn let_useonce(input: UnrolledExpr) -> UnrolledExpr {
let mut output = let_useonce_once(input);
for pass in 1.. {
log::debug!("pass {}...", pass);
let new_output = let_useonce_once(output.clone());
if new_output == output {
break;
}
output = new_output;
}
output
}
fn let_useonce_once(input: UnrolledExpr) -> UnrolledExpr {
const INLINE_THRESHOLD: usize = 1;
let mut varid_counts: BTreeMap<VarId, usize> = BTreeMap::new();
let mut mutated_varids: BTreeSet<VarId> = BTreeSet::new();
let mut varid_bindings: BTreeMap<VarId, UnrolledExpr> = BTreeMap::new();
let input = input.structural_map(
&mut |expr| {
match &expr {
UnrolledExpr::Var(varid) => {
if varid > &NUM_RESERVED {
*varid_counts.entry(*varid).or_default() += 1
}
}
UnrolledExpr::Let(bindings, _, _) => {
for (k, v) in bindings {
varid_bindings.insert(*k, v.clone());
}
}
_ => {}
}
expr
},
&mut |stmt| {
if let UnrolledStatement::Set(vid, _) = &stmt {
mutated_varids.insert(*vid);
}
stmt
},
);
log::trace!("varid usages: {:#?}", varid_counts);
input.structural_map(
&mut |expr| match expr {
UnrolledExpr::Let(mut bindings, stmt, expr) => {
bindings.retain(|(vid, _)|
varid_counts.get(vid).copied().unwrap_or_default()
+ mutated_varids.get(vid).copied().unwrap_or_default() as usize > 0);
if bindings.is_empty() && stmt.is_empty() {
*expr
} else {
UnrolledExpr::Let(bindings, stmt, expr)
}
}
UnrolledExpr::Var(vid) => {
if varid_counts
.get(&vid)
.map(|cnt| *cnt <= INLINE_THRESHOLD)
.unwrap_or_default()
&& !mutated_varids.contains(&vid)
{
varid_bindings
.get(&vid)
.cloned()
.unwrap_or(UnrolledExpr::Var(vid))
} else {
UnrolledExpr::Var(vid)
}
}
expr => expr,
},
&mut |stmt| stmt,
)
}