use fnv;
use crate::ast::ExprKind::*;
use crate::ast::*;
use fnv::FnvHashMap;
use std::mem;
#[cfg(test)]
use crate::tests::*;
pub fn inline_get_field(expr: &mut Expr) {
expr.transform(&mut |ref mut expr| {
if let GetField { ref expr, index } = expr.kind {
if let MakeStruct { ref elems } = expr.kind {
return Some(elems[index as usize].clone());
}
}
None
});
}
pub fn inline_zips(expr: &mut Expr) {
expr.transform(&mut |ref mut e| {
if let For {
ref mut iters,
ref builder,
ref func,
} = e.kind
{
if iters.len() == 1 {
let first_iter = &iters[0];
if let Zip { ref vectors } = first_iter.data.kind {
let new_iters = vectors
.iter()
.map(|v| Iter {
data: Box::new(v.clone()),
start: None,
end: None,
stride: None,
kind: first_iter.kind.clone(),
shape: None,
strides: None,
})
.collect::<Vec<_>>();
return Some(Expr {
ty: e.ty.clone(),
kind: For {
iters: new_iters,
builder: builder.clone(),
func: func.clone(),
},
annotations: Annotations::new(),
});
}
}
}
None
});
}
pub fn inline_apply(expr: &mut Expr) {
expr.transform(&mut |ref mut expr| {
if let Apply {
ref func,
params: ref args,
} = expr.kind
{
if let Lambda {
ref params,
ref body,
} = func.kind
{
let mut new = *body.clone();
for (param, arg) in params.iter().zip(args) {
new.substitute(¶m.name, &arg);
}
return Some(new);
}
}
None
});
}
pub fn inline_let(expr: &mut Expr) {
expr.uniquify().unwrap();
let usages = &mut FnvHashMap::default();
count_symbols(expr, usages);
trace!("Symbol count: {:?}", usages);
inline_let_helper(expr, usages)
}
#[derive(Debug)]
struct SymbolTracker {
count: i32,
loop_nest: i32,
value: Option<Box<Expr>>,
}
impl Default for SymbolTracker {
fn default() -> SymbolTracker {
SymbolTracker {
count: 0,
loop_nest: 0,
value: None,
}
}
}
fn count_symbols(expr: &Expr, usage: &mut FnvHashMap<Symbol, SymbolTracker>) {
match expr.kind {
For { ref func, .. }
| Iterate {
update_func: ref func,
..
}
| Sort {
cmpfunc: ref func, ..
}
| Apply { ref func, .. } => {
for value in usage.values_mut() {
value.loop_nest += 1;
}
count_symbols(func, usage);
for value in usage.values_mut() {
value.loop_nest -= 1;
}
}
Let { ref name, .. } => {
debug_assert!(!usage.contains_key(name));
let _ = usage.insert(name.clone(), SymbolTracker::default());
}
Ident(ref symbol) => {
if let Some(ref mut tracker) = usage.get_mut(symbol) {
if tracker.loop_nest == 0 {
tracker.count += 1;
} else {
tracker.count += 3;
}
}
}
_ => (),
};
for child in expr.children() {
match child.kind {
Lambda { .. } => (),
_ => count_symbols(child, usage),
}
}
}
fn inline_let_helper(expr: &mut Expr, usages: &mut FnvHashMap<Symbol, SymbolTracker>) {
let mut taken_body = None;
match expr.kind {
Let {
ref mut name,
ref mut value,
ref mut body,
} => {
if let Some(tracker) = usages.get_mut(name) {
if tracker.count <= 1 {
taken_body = Some(body.take());
tracker.value = Some(value.take());
}
}
}
Ident(ref name) => {
if let Some(tracker) = usages.get_mut(name) {
if tracker.count <= 1 {
debug_assert!(tracker.value.is_some());
debug_assert!(!tracker.value.as_ref().unwrap().is_placeholder());
mem::swap(&mut taken_body, &mut tracker.value);
}
}
}
_ => (),
}
if let Some(mut val) = taken_body {
mem::swap(expr, val.as_mut());
inline_let_helper(expr, usages);
} else {
for child in expr.children_mut() {
inline_let_helper(child, usages);
}
}
}
pub fn inline_negate(expr: &mut Expr) {
use crate::ast::LiteralKind::*;
expr.transform(&mut |ref mut expr| {
if let Negate(ref child_expr) = expr.kind {
if let Literal(ref literal_kind) = child_expr.kind {
let res = match *literal_kind {
I8Literal(a) => Some(Expr::new_literal(I8Literal(-a)).unwrap()),
I16Literal(a) => Some(Expr::new_literal(I16Literal(-a)).unwrap()),
I32Literal(a) => Some(Expr::new_literal(I32Literal(-a)).unwrap()),
I64Literal(a) => Some(Expr::new_literal(I64Literal(-a)).unwrap()),
F32Literal(a) => {
Some(Expr::new_literal(F32Literal((-f32::from_bits(a)).to_bits())).unwrap())
}
F64Literal(a) => {
Some(Expr::new_literal(F64Literal((-f64::from_bits(a)).to_bits())).unwrap())
}
_ => None,
};
return res;
}
}
None
});
}
pub fn inline_cast(expr: &mut Expr) {
use crate::ast::LiteralKind::*;
use crate::ast::ScalarKind::*;
use crate::ast::Type::Scalar;
expr.transform(&mut |ref mut expr| {
if let Cast {
kind: ref scalar_kind,
ref child_expr,
} = expr.kind
{
if let Literal(ref literal_kind) = child_expr.kind {
return match (scalar_kind, literal_kind) {
(&F64, &I32Literal(a)) => {
Some(Expr::new_literal(F64Literal((f64::from(a)).to_bits())).unwrap())
}
(&I64, &I32Literal(a)) => {
Some(Expr::new_literal(I64Literal(i64::from(a))).unwrap())
}
(&F64, &I64Literal(a)) => {
Some(Expr::new_literal(F64Literal((a as f64).to_bits())).unwrap())
}
_ => None,
};
}
if let Scalar(ref kind) = child_expr.ty {
if kind == scalar_kind {
return Some(*child_expr.clone());
}
}
}
None
});
}
fn getfield_on_symbol(expr: &Expr, sym: &Symbol) -> Option<u32> {
if let GetField {
ref expr,
ref index,
} = expr.kind
{
if let Ident(ref ident_name) = expr.kind {
if sym == ident_name {
return Some(*index);
}
}
}
None
}
pub fn simplify_branch_conditions(expr: &mut Expr) {
use crate::ast::LiteralKind::BoolLiteral;
expr.uniquify().unwrap();
expr.transform_up(&mut |ref mut expr| {
if let If {
ref mut cond,
ref mut on_true,
ref mut on_false,
} = expr.kind
{
let mut taken = None;
if let BinOp {
ref mut kind,
ref mut left,
ref mut right,
} = cond.kind
{
if *kind == BinOpKind::Equal {
if let Literal(BoolLiteral(false)) = left.kind {
taken = Some(right.take());
} else if let Literal(BoolLiteral(false)) = right.kind {
taken = Some(left.take());
}
}
};
if let Some(ref mut expr) = taken {
mem::swap(cond, expr);
mem::swap(on_true, on_false);
}
}
None
});
}
pub fn unroll_structs(expr: &mut Expr) {
use crate::util::SymbolGenerator;
expr.uniquify().unwrap();
let mut sym_gen = SymbolGenerator::from_expression(expr);
expr.transform_up(&mut |ref mut expr| {
if let Let {
ref name,
ref value,
ref body,
} = expr.kind
{
if let MakeStruct { ref elems } = value.kind {
let mut total_count: i32 = 0;
let mut getstruct_count: i32 = 0;
body.traverse(&mut |ref e| {
if getfield_on_symbol(e, name).is_some() {
getstruct_count += 1;
}
if let Ident(ref ident_name) = e.kind {
if ident_name == name {
total_count += 1;
}
}
});
if total_count != getstruct_count {
return None;
}
let mut new_body = body.as_ref().clone();
let symbols: Vec<_> = elems.iter().map(|_| sym_gen.new_symbol("us")).collect();
new_body.transform(&mut |ref mut expr2| {
if let Some(index) = getfield_on_symbol(expr2, name) {
let sym = symbols[index as usize].clone();
return Some(Expr::new_ident(sym, expr2.ty.clone()).unwrap());
}
None
});
let mut prev = new_body;
for (i, sym) in symbols.into_iter().enumerate().rev() {
prev = Expr::new_let(sym, elems[i].clone(), prev).unwrap();
}
return Some(prev);
}
}
None
});
}
#[test]
fn inline_lets() {
let mut e1 = typed_expression("let a = 1; a + 2");
inline_let(&mut e1);
let e2 = typed_expression("1 + 2");
assert!(e1.compare_ignoring_symbols(&e2).unwrap());
let mut e1 = typed_expression("let a = 1; a + a + 2");
inline_let(&mut e1);
let e2 = typed_expression("let a = 1; a + a + 2");
assert!(e1.compare_ignoring_symbols(&e2).unwrap());
let mut e1 = typed_expression(
"let a = 1L; for([1L,2L,3L], appender, |b,i,e| merge(b, e + a \
+ 2L))",
);
inline_let(&mut e1);
let e2 = typed_expression(
"let a = 1L; for([1L,2L,3L], appender, |b,i,e| merge(b, e + a + \
2L))",
);
assert!(e1.compare_ignoring_symbols(&e2).unwrap());
let mut e1 = typed_expression("let a = 1; let b = 2; let c = 3; a + b + c");
inline_let(&mut e1);
let e2 = typed_expression("1 + 2 + 3");
println!("{}, {}", e1.pretty_print(), e2.pretty_print());
assert!(e1.compare_ignoring_symbols(&e2).unwrap());
let mut e1 = typed_expression(
"|input: vec[i32]|
let b = 1;
result(for(input, merger[i32,+], |b,i,e| let a = 1; merge(b, e + a))) + b",
);
inline_let(&mut e1);
let e2 = typed_expression(
"|input: vec[i32]|
result(for(input, merger[i32,+], |b,i,e| merge(b, e + 1))) + 1",
);
println!("{}, {}", e1.pretty_print(), e2.pretty_print());
assert!(e1.compare_ignoring_symbols(&e2).unwrap());
let mut e1 = typed_expression(
"|input: vec[i32]|
let b = 1;
result(for(input, merger[i32,+], |b,i,e| let a = 1; merge(b, e + a + a))) + b",
);
inline_let(&mut e1);
let e2 = typed_expression(
"|input: vec[i32]|
result(for(input, merger[i32,+], |b,i,e| let a = 1; merge(b, e + a + a))) + 1",
);
println!("{}, {}", e1.pretty_print(), e2.pretty_print());
assert!(e1.compare_ignoring_symbols(&e2).unwrap());
}