use crate::ast::ExprKind::*;
use crate::ast::*;
use crate::util::SymbolGenerator;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
pub fn common_subexpression_elimination(expr: &mut Expr) {
use super::inliner;
if let Lambda { .. } = expr.kind {
expr.uniquify().unwrap();
Cse::apply(expr);
inliner::inline_let(expr);
}
}
#[derive(Debug)]
struct Cse {
sym_gen: SymbolGenerator,
counter: i32,
}
trait UseCse {
fn use_cse(&self) -> bool;
}
impl UseCse for Expr {
fn use_cse(&self) -> bool {
if self.ty.contains_builder() {
return false;
}
match self.kind {
Let { .. } | Literal(_) | Ident(_) | Lambda { .. } => false,
_ => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct Site {
site: Vec<i32>,
first_seen: i32,
}
impl Site {
fn new() -> Site {
Site {
site: vec![],
first_seen: i32::max_value(),
}
}
fn push(&mut self, index: i32) {
self.site.push(index);
}
fn pop(&mut self) {
self.site.pop();
}
fn depth(&self) -> usize {
self.site.len()
}
fn contains(&self, other: &Site) -> bool {
if self.site.len() > other.site.len() {
return false;
}
self.site.iter().zip(other.site.iter()).all(|(a, b)| a == b)
}
}
#[derive(Debug)]
struct SiteList {
sites: Vec<Site>,
}
impl SiteList {
fn new() -> SiteList {
SiteList { sites: vec![] }
}
fn contains(&self, site: &Site) -> bool {
self.get_with_index(site).is_some()
}
fn get_with_index(&self, site: &Site) -> Option<(usize, &Site)> {
self.sites
.iter()
.enumerate()
.find(|(_, s)| s.contains(&site))
}
fn delete_index(&mut self, site: usize) {
self.sites.swap_remove(site);
}
fn add_site(&mut self, mut new: Site, counter: i32) -> Option<i32> {
if !self.sites.iter().any(|s| s.contains(&new)) {
let mut first_seen = i32::max_value();
let mut i = 0;
while i != self.sites.len() {
if new.contains(&self.sites[i]) {
let previous_site = self.sites.swap_remove(i);
first_seen = ::std::cmp::min(previous_site.first_seen, first_seen);
} else {
i += 1;
}
}
if first_seen == i32::max_value() {
first_seen = counter;
}
new.first_seen = first_seen;
self.sites.push(new);
Some(first_seen)
} else {
None
}
}
}
type Binding = (Symbol, Expr);
type SiteMap = HashMap<Symbol, SiteList>;
impl Cse {
pub fn apply(expr: &mut Expr) {
let mut cse = Cse {
sym_gen: SymbolGenerator::from_expression(expr),
counter: 0,
};
let bindings = &mut HashMap::new();
cse.remove_common_subexpressions(expr, bindings);
let bindings = &mut bindings
.drain()
.map(|(k, v)| (v, k))
.collect::<HashMap<Symbol, Expr>>();
cse.counter = 0;
let sites = &mut cse.build_site_map(expr, bindings);
let generated = &mut HashSet::new();
let stack = &mut vec![];
cse.counter = 0;
cse.generate_bindings(expr, bindings, generated, &mut Site::new(), stack, sites);
}
fn remove_common_subexpressions(
&mut self,
expr: &mut Expr,
bindings: &mut HashMap<Expr, Symbol>,
) {
expr.transform_up(&mut |ref mut e| {
if !e.use_cse() {
return None;
}
let e = e.take();
let ty = e.ty.clone();
let name = bindings
.entry(e)
.or_insert_with(&mut || self.sym_gen.new_symbol("cse"))
.clone();
let replacement = Expr {
ty,
kind: Ident(name),
annotations: Annotations::new(),
};
Some(replacement)
});
}
fn build_site_map(&mut self, expr: &Expr, bindings: &HashMap<Symbol, Expr>) -> SiteMap {
let mut sites = SiteMap::new();
self.build_site_map_helper(expr, bindings, &mut sites, &mut Site::new());
sites
}
fn build_site_map_helper(
&mut self,
expr: &Expr,
bindings: &HashMap<Symbol, Expr>,
site_map: &mut SiteMap,
current_site: &mut Site,
) {
let mut handled = true;
match expr.kind {
If {
ref cond,
ref on_true,
ref on_false,
} => {
self.build_site_map_helper(cond, bindings, site_map, current_site);
current_site.push(self.counter);
self.counter += 1;
self.build_site_map_helper(on_true, bindings, site_map, current_site);
current_site.pop();
current_site.push(self.counter);
self.counter += 1;
self.build_site_map_helper(on_false, bindings, site_map, current_site);
current_site.pop();
}
Lambda { ref body, .. } => {
current_site.push(self.counter);
self.counter += 1;
self.build_site_map_helper(body, bindings, site_map, current_site);
current_site.pop();
}
Let {
ref value,
ref body,
..
} => {
self.build_site_map_helper(value, bindings, site_map, current_site);
current_site.push(self.counter);
self.counter += 1;
self.build_site_map_helper(body, bindings, site_map, current_site);
current_site.pop();
}
Ident(ref name) if bindings.contains_key(name) => {
let resolved = match site_map.entry(name.clone()) {
Entry::Vacant(ent) => {
let site_list = ent.insert(SiteList::new());
let result = site_list.add_site(current_site.clone(), self.counter);
debug_assert!(result == Some(self.counter));
result
}
Entry::Occupied(ref mut ent) => {
let site_list = ent.get_mut();
site_list.add_site(current_site.clone(), self.counter)
}
};
if let Some(first_seen) = resolved {
let expr = bindings.get(name).unwrap();
let new = first_seen == self.counter;
let current_counter = self.counter;
self.counter = first_seen;
self.build_site_map_helper(expr, bindings, site_map, current_site);
if !new {
self.counter = current_counter;
}
}
}
_ => {
handled = false;
}
}
if !handled {
for child in expr.children() {
self.build_site_map_helper(child, bindings, site_map, current_site);
}
}
}
fn generate_bindings(
&mut self,
expr: &mut Expr,
bindings: &mut HashMap<Symbol, Expr>,
generated: &mut HashSet<Symbol>,
current_site: &mut Site,
stack: &mut Vec<Vec<Binding>>,
sites: &mut SiteMap,
) {
let handled = match expr.kind {
Lambda { ref mut body, .. } => {
self.generate_bindings_scoped(
body,
bindings,
generated,
current_site,
stack,
sites,
);
true
}
Let {
ref mut value,
ref mut body,
..
} => {
self.generate_bindings(value, bindings, generated, current_site, stack, sites);
self.generate_bindings_scoped(
body,
bindings,
generated,
current_site,
stack,
sites,
);
true
}
If {
ref mut cond,
ref mut on_true,
ref mut on_false,
} => {
self.generate_bindings(cond, bindings, generated, current_site, stack, sites);
self.generate_bindings_scoped(
on_true,
bindings,
generated,
current_site,
stack,
sites,
);
self.generate_bindings_scoped(
on_false,
bindings,
generated,
current_site,
stack,
sites,
);
true
}
Ident(ref mut sym) if bindings.contains_key(sym) => {
if sites.get(&sym).unwrap().contains(¤t_site) {
let mut value = bindings.get(sym).cloned().unwrap();
self.generate_bindings(
&mut value,
bindings,
generated,
current_site,
stack,
sites,
);
let site_list = sites.get_mut(&sym).unwrap();
let delete_index =
if let Some(ref result) = site_list.get_with_index(¤t_site) {
let index_in_list = result.0;
let site_with_expr = result.1;
let index = site_with_expr.depth() - 1;
let binding_list = &mut stack[index];
binding_list.push((sym.clone(), value));
index_in_list
} else {
unreachable!()
};
site_list.delete_index(delete_index);
}
true
}
_ => false,
};
if !handled {
for child in expr.children_mut() {
self.generate_bindings(child, bindings, generated, current_site, stack, sites);
}
}
}
fn generate_bindings_scoped(
&mut self,
expr: &mut Expr,
bindings: &mut HashMap<Symbol, Expr>,
generated: &mut HashSet<Symbol>,
current_site: &mut Site,
stack: &mut Vec<Vec<(Symbol, Expr)>>,
sites: &mut SiteMap,
) {
current_site.push(self.counter);
self.counter += 1;
stack.push(vec![]);
self.generate_bindings(expr, bindings, generated, current_site, stack, sites);
let binding_list = stack.pop().unwrap();
let mut prev = expr.take();
for (sym, expr) in binding_list.into_iter().rev() {
generated.remove(&sym);
prev = Expr::new_let(sym, expr, prev).unwrap();
}
current_site.pop();
*expr = prev;
}
}
#[cfg(test)]
fn check_cse(input: &str, expect: &str) {
use crate::tests::check_transform;
check_transform(input, expect, common_subexpression_elimination);
}
#[test]
fn basic_test() {
let input = "|| (1+2) + (1+2)";
let expect = "|| let cse = (1+2); cse + cse";
check_cse(input, expect);
}
#[test]
fn many_subexprs_test() {
let input = "|| (1+2) + (1+2) + (3+4) + (3+4)";
let expect = "|| let cse1 = (1+2); let cse2 = (3+4); cse1 + cse1 + cse2 + cse2";
check_cse(input, expect);
}
#[test]
fn nesting_test() {
let input = "|| ((1+2) + (1+2)) + ((1+2) + (1+2))";
let expect = "|| let cse1 = (1+2); let cse2 = (cse1+cse1); cse2 + cse2";
check_cse(input, expect);
}
#[test]
fn if_test() {
let input = "|x: i32, v:vec[i32]| if(x>0, lookup(v,0L), lookup(v,0L))";
let expect = input;
check_cse(input, expect);
}
#[test]
fn if_test_2() {
let input = "|x:i32| if(x>0, (1+2) + (1+2), (1+2)+(1+2))";
let expect = "|x:i32| if(x>0,
let cse1 = (1+2); cse1 + cse1,
let cse2 = (1+2); cse2 + cse2)";
check_cse(input, expect);
}
#[test]
fn if_test_3() {
let input = "|x:i32| (1+2) + if(x>0, (1+2) + (1+2), (1+2)+(1+2))";
let expect = "|x:i32| let cse = (1+2);
cse + if(x>0, cse+cse, cse+cse)";
check_cse(input, expect);
}
#[test]
fn if_test_4() {
let input = "|x:i32| if(x>0, (1+2) + (1+2), (1+2)+(1+2)) + (1+2)";
let expect = "|x:i32| let cse = (1+2);
if(x>0, cse+cse, cse+cse) + cse";
check_cse(input, expect);
}
#[test]
fn if_test_5() {
let input = "|x:i32|
if(x > 0,
if (x > 1,
(1+2),
(2+3)
),
(1+2) + (1+2)
)";
let expect = "|x:i32| if (x > 0,
if (x > 1, (1+2), (2+3)),
let cse = (1+2); cse + cse)";
check_cse(input, expect);
}
#[test]
fn if_test_6() {
let input = "|x:i32|
if(x > 0,
(1 + 2) + (1+2),
(2 + 3)
) +
if(x > 1,
(1 + 2) + (1+2),
(2 + 3)
)";
let expect = "|x:i32|
if(x > 0,
let cse = (1+2); cse + cse,
(2+3)
) +
if(x > 1,
let cse = (1+2); cse + cse,
(2+3)
)";
check_cse(input, expect);
}
#[test]
fn if_test_7() {
let input = "|x:i32|
let c = (if (x>0, (1+2), 0));
if (c > 0,
c + 5,
(if (x>1, (2+3), 0))
)";
let expect = input;
check_cse(input, expect);
}
#[test]
fn if_test_8() {
let input = "|x: i32|
let cse4 = (1+2);
let cse2 = if (x > 2, 1, 2);
let cse3 = if (x > 3, cse2, 2) + cse2;
let cse1 = if (x > 1, cse2, cse3);
if (x > 0, cse1, cse4)";
let expect = "|x: i32|
let cse2 = if (x > 2, 1, 2);
if (x > 0,
if (x > 1,
cse2,
if (x > 3, cse2, 2) + cse2
),
(1+2)
)";
check_cse(input, expect);
}
#[test]
fn for_test() {
let input = "|| result(for([(1+2)], merger[i32,+], |b,i,e| merge(b, e + (1+2))))";
let expect = "|| let cse = (1+2); result(for([cse], merger[i32,+], |b,i,e| merge(b, e + cse)))";
check_cse(input, expect);
}
#[test]
fn for_test_2() {
let input = "|| result(for([1], merger[i32,+], |b,i,e| merge(b, e + (1+2)))) + (1+2)";
let expect =
"|| let cse = (1+2); result(for([1], merger[i32,+], |b,i,e| merge(b, e + cse))) + cse";
check_cse(input, expect);
}
#[test]
fn for_test_3() {
let input = "|| result(for([1], merger[i32,+], |b,i,e| merge(b, e + (1+2) + (1+2))))";
let expect =
"|| result(for([1], merger[i32,+], |b,i,e| let cse = (1+2); merge(b, e + cse + cse)))";
check_cse(input, expect);
}
#[test]
fn for_test_4() {
let input = "|| let x = (1+2); result(for([1], merger[i32,+], |b,i,e| merge(b, e + x + x)))";
let expect = input;
check_cse(input, expect);
}
#[test]
fn builder_test() {
let input = "|| {appender[i32], appender[i32]}";
let expect = input;
check_cse(input, expect);
}
#[test]
fn builder_test_2() {
let input = "|| {result(appender[i32]), result(appender[i32])}";
let expect = "|| let cse = result(appender[i32]); {cse, cse}";
check_cse(input, expect);
}
#[test]
fn builder_test_3() {
let input = "|| {{appender[i32], appender[i32], 1}, {appender[i32], appender[i32], 1}}";
let expect = input;
check_cse(input, expect);
}
#[test]
fn builder_test_4() {
let input = "|| {{appender[i32], appender[i32], (1+2)}, {appender[i32], appender[i32], (1+2)}}";
let expect = "|| let cse = (1+2); {{appender[i32], appender[i32], cse}, {appender[i32], appender[i32], cse}}";
check_cse(input, expect);
}
#[test]
fn alias_test() {
let input = "|x:i32| let a = x; let b = x; a + a + b + b";
let expect = input;
check_cse(input, expect);
}
#[test]
fn let_test() {
let input = "|x:i32| let a = (1+2); (1+2) + a + (1+2)";
let expect = "|x:i32| let cse = (1+2); cse + cse + cse";
check_cse(input, expect);
}