use std::collections::HashSet;
use crate::ast::ExprKind::*;
use crate::ast::Type::*;
use crate::ast::*;
use crate::error::*;
use crate::util::SymbolGenerator;
#[cfg(test)]
use crate::tests::*;
pub trait ShouldPredicate {
fn should_predicate(&self) -> bool;
}
impl ShouldPredicate for Expr {
fn should_predicate(&self) -> bool {
if let Some(ref value) = self.annotations.get("predicate") {
return value.to_lowercase() == "true";
}
false
}
}
pub fn vectorize(expr: &mut Expr) {
let mut vectorized = false;
let mut sym_gen = SymbolGenerator::from_expression(expr);
expr.transform_and_continue_res(&mut |ref mut expr| {
if let Some(ref broadcast_idens) = vectorizable(expr) {
info!("Vectorizing For loop!");
if let For {
ref iters,
builder: ref init_builder,
ref func,
} = expr.kind
{
if let Lambda {
ref params,
ref body,
} = func.kind
{
let mut vectorized_body = body.clone();
vectorized_body.transform_and_continue(&mut |ref mut e| {
let cont = vectorize_expr(e, broadcast_idens).unwrap();
(None, cont)
});
let mut vectorized_params = params.clone();
vectorized_params[2].ty = vectorized_params[2].ty.simd_type()?;
let vec_func = Expr::new_lambda(vectorized_params, *vectorized_body)?;
let data_names = iters
.iter()
.map(|_| sym_gen.new_symbol("a"))
.collect::<Vec<_>>();
let mut vec_iters = vec![];
for (e, n) in iters.iter().zip(&data_names) {
vec_iters.push(Iter {
data: Box::new(Expr::new_ident(n.clone(), e.data.ty.clone())?),
start: e.start.clone(),
end: e.end.clone(),
stride: e.stride.clone(),
kind: IterKind::SimdIter,
shape: e.shape.clone(),
strides: e.strides.clone(),
});
}
let fringe_iters = vec_iters
.iter_mut()
.map(|i| {
let mut i = i.clone();
i.kind = IterKind::FringeIter;
i
})
.collect();
let vectorized_loop =
Expr::new_for(vec_iters, *init_builder.clone(), vec_func)?;
let scalar_loop = Expr::new_for(fringe_iters, vectorized_loop, *func.clone())?;
let mut prev_expr = scalar_loop;
for (iter, name) in iters.iter().zip(data_names).rev() {
prev_expr = Expr::new_let(name.clone(), *iter.data.clone(), prev_expr)?;
}
vectorized = true;
return Ok((Some(prev_expr), false));
}
}
}
Ok((None, true))
});
}
pub fn predicate_merge_expr(e: &mut Expr) {
e.transform_and_continue_res(&mut |ref mut e| {
if !e.should_predicate() {
return Ok((None, true));
}
if let If {
ref cond,
ref on_true,
ref on_false,
} = e.kind
{
if let Merge {
ref builder,
ref value,
} = on_true.kind
{
if let Ident(ref name) = on_false.kind {
if let Ident(ref name2) = builder.kind {
if name == name2 {
if let Builder(ref bk, _) = builder.ty {
let (ty, op) = match *bk {
BuilderKind::Merger(ref ty, ref op) => (ty, op),
BuilderKind::DictMerger(_, ref ty2, ref op) => (ty2, op),
BuilderKind::VecMerger(ref ty, ref op) => (ty, op),
_ => {
return Ok((None, true));
}
};
let identity = get_id_element(ty.as_ref(), *op)?;
match identity {
Some(x) => {
match *bk {
BuilderKind::Merger(_, _)
| BuilderKind::VecMerger(_, _) => {
let expr = Expr::new_merge(
*builder.clone(),
Expr::new_select(
*cond.clone(),
*value.clone(),
x,
)?,
)?;
return Ok((Some(expr), true));
}
BuilderKind::DictMerger(_, _, _) => {
let sel_expr = make_select_for_kv(
*cond.clone(),
*value.clone(),
x,
)?;
return Ok((sel_expr, true));
}
_ => {
return Ok((None, true));
}
}
}
None => {
return Ok((None, true));
}
};
}
}
}
}
}
}
Ok((None, true))
});
}
fn is_simple(e: &Expr) -> bool {
match e.kind {
Ident(_) | Literal(_) => true,
GetField { ref expr, .. } => is_simple(expr),
_ => false,
}
}
pub fn predicate_simple_expr(e: &mut Expr) {
e.transform_and_continue_res(&mut |ref mut e| {
if let If {
ref cond,
ref on_true,
ref on_false,
} = e.kind
{
let mut safe = true;
on_true.traverse(&mut |ref sub_expr| {
if sub_expr.kind.is_builder_expr() {
safe = false;
}
});
on_false.traverse(&mut |ref sub_expr| {
if sub_expr.kind.is_builder_expr() {
safe = false;
}
});
if !safe {
return Ok((None, true));
}
if !(is_simple(on_true) && is_simple(on_false)) {
return Ok((None, true));
}
if let Scalar(_) = on_true.ty {
if let Scalar(_) = on_false.ty {
let expr =
Expr::new_select(*cond.clone(), *on_true.clone(), *on_false.clone())?;
return Ok((Some(expr), true));
}
}
}
Ok((None, true))
});
}
fn vectorizable_iters(iters: &[Iter]) -> bool {
iters.iter().all(|ref iter| {
iter.start.is_none()
&& iter.end.is_none()
&& iter.stride.is_none()
&& match iter.data.ty {
Vector(ref elem) if elem.is_scalar() => true,
_ => false,
}
})
}
fn vectorize_expr(e: &mut Expr, broadcast_idens: &HashSet<Symbol>) -> WeldResult<bool> {
let mut new_expr = None;
let mut cont = true;
match e.kind {
Literal(_) => {
e.ty = e.ty.simd_type()?;
}
Ident(ref name) => {
if let Scalar(_) = e.ty {
if broadcast_idens.contains(&name) {
new_expr = Some(Expr::new_broadcast(e.clone())?);
cont = false;
} else {
e.ty = e.ty.simd_type()?;
}
} else if let Struct(_) = e.ty {
e.ty = e.ty.simd_type()?;
}
}
GetField { .. } => {
e.ty = e.ty.simd_type()?;
}
UnaryOp { .. } => {
e.ty = e.ty.simd_type()?;
}
BinOp { .. } => {
e.ty = e.ty.simd_type()?;
}
Select { .. } => {
e.ty = e.ty.simd_type()?;
}
MakeStruct { .. } => {
e.ty = e.ty.simd_type()?;
}
_ => {}
}
if let Some(val) = new_expr {
*e = val;
}
Ok(cont)
}
fn vectorizable_builder(expr: &Expr) -> Option<bool> {
use crate::ast::BuilderKind::*;
match expr.kind {
Ident(_) | NewBuilder(_) => {
if let Builder(ref bk, _) = expr.ty {
match *bk {
Appender(ref elem) | Merger(ref elem, _) => Some(elem.is_scalar()),
_ => Some(false),
}
} else {
None
}
}
MakeStruct { ref elems } => {
let mut vectorizable = true;
for elem in elems.iter() {
match vectorizable_builder(elem) {
Some(val) => vectorizable &= val,
None => return None,
}
}
Some(vectorizable)
}
_ => None,
}
}
fn vectorizable(for_loop: &Expr) -> Option<HashSet<Symbol>> {
if let For {
ref iters,
builder: ref init_builder,
ref func,
} = for_loop.kind
{
if vectorizable_iters(&iters) {
if let Some(true) = vectorizable_builder(init_builder) {
if let Lambda {
ref params,
ref body,
} = func.kind
{
let mut passed = true;
let mut defined_in_loop = HashSet::new();
for param in params.iter() {
defined_in_loop.insert(param.name.clone());
}
body.traverse(&mut |f| {
if passed {
match f.kind {
Literal(_) => {}
Ident(ref name) => {
if f.ty == params[1].ty && *name == params[1].name {
passed = false;
}
}
UnaryOp { .. } => {}
BinOp { .. } => {}
Let { ref name, .. } => {
defined_in_loop.insert(name.clone());
}
GetField { .. } => {}
MakeStruct { .. } => {}
Merge { .. } => {}
Select { .. } => {}
_ => {
passed = false;
}
}
}
});
if !passed {
trace!("Vectorization failed due to unsupported expression in loop body");
return None;
}
let mut check_arg_ty = false;
if let Scalar(_) = params[2].ty {
check_arg_ty = true;
} else if let Struct(ref field_tys) = params[2].ty {
if field_tys.iter().all(|t| match *t {
Scalar(_) => true,
_ => false,
}) {
check_arg_ty = true;
}
}
if !check_arg_ty {
trace!("Vectorization failed due to unsupported type");
return None;
}
let mut idens = HashSet::new();
let mut passed = true;
body.traverse(&mut |e| match e.kind {
Ident(ref name) if !defined_in_loop.contains(name) => {
if let Scalar(_) = e.ty {
idens.insert(name.clone());
} else {
passed = false;
}
}
_ => {}
});
if !passed {
trace!("Unsupported pattern: non-scalar identifier that must be broadcast");
return None;
}
return Some(idens);
}
}
}
}
trace!("Vectorization failed due to unsupported pattern");
None
}
fn get_id_element(ty: &Type, op: BinOpKind) -> WeldResult<Option<Expr>> {
let sk = &match *ty {
Scalar(sk) => sk,
_ => {
return Ok(None);
}
};
let identity = match op {
BinOpKind::Add => match *sk {
ScalarKind::I8 => Expr::new_literal(LiteralKind::I8Literal(0))?,
ScalarKind::I32 => Expr::new_literal(LiteralKind::I32Literal(0))?,
ScalarKind::I64 => Expr::new_literal(LiteralKind::I64Literal(0))?,
ScalarKind::F32 => Expr::new_literal(LiteralKind::F32Literal(0f32.to_bits()))?,
ScalarKind::F64 => Expr::new_literal(LiteralKind::F64Literal(0f64.to_bits()))?,
_ => {
return Ok(None);
}
},
BinOpKind::Multiply => match *sk {
ScalarKind::I8 => Expr::new_literal(LiteralKind::I8Literal(1))?,
ScalarKind::I32 => Expr::new_literal(LiteralKind::I32Literal(1))?,
ScalarKind::I64 => Expr::new_literal(LiteralKind::I64Literal(1))?,
ScalarKind::F32 => Expr::new_literal(LiteralKind::F32Literal(1f32.to_bits()))?,
ScalarKind::F64 => Expr::new_literal(LiteralKind::F64Literal(1f64.to_bits()))?,
_ => {
return Ok(None);
}
},
_ => {
return Ok(None);
}
};
Ok(Some(identity))
}
fn make_select_for_kv(cond: Expr, kv: Expr, ident: Expr) -> WeldResult<Option<Expr>> {
let mut sym_gen = SymbolGenerator::from_expression(&kv);
let name = sym_gen.new_symbol("k");
let kv_struct = Expr::new_ident(name.clone(), kv.ty.clone())?;
let kv_ident = Expr::new_make_struct(vec![Expr::new_get_field(kv_struct.clone(), 0)?, ident])?;
let sel = Expr::new_select(cond, kv_struct, kv_ident)?;
let le = Expr::new_let(name, kv, sel)?;
Ok(Some(le))
}
#[cfg(test)]
fn has_vectorized_merge(expr: &Expr) -> bool {
let mut found = false;
expr.traverse(&mut |ref e| {
if let Merge { ref value, .. } = e.kind {
found |= value.ty.is_simd();
}
});
found
}
#[test]
fn simple_merger() {
let mut e =
typed_expression("|v:vec[i32]| result(for(v, merger[i32,+], |b,i,e| merge(b,e+1)))");
vectorize(&mut e);
assert!(has_vectorized_merge(&e));
}
#[test]
fn predicated_merger() {
let mut e = typed_expression("|v:vec[i32]| result(for(v, merger[i32,+], |b,i,e| @(predicate:true)if(e>0, merge(b,e), b)))");
predicate_merge_expr(&mut e);
vectorize(&mut e);
assert!(has_vectorized_merge(&e));
}
#[test]
fn unpredicated_merger() {
let mut e = typed_expression(
"|v:vec[i32]| result(for(v, merger[i32,+], |b,i,e| if(e>0, merge(b,e), b)))",
);
vectorize(&mut e);
assert!(!has_vectorized_merge(&e));
}
#[test]
fn simple_appender() {
let mut e =
typed_expression("|v:vec[i32]| result(for(v, appender[i32], |b,i,e| merge(b,e+1)))");
vectorize(&mut e);
assert!(has_vectorized_merge(&e));
}
#[test]
fn predicated_appender() {
let mut e = typed_expression("|v:vec[i32]| result(for(v, appender[i32], |b,i,e| @(predicate:true)if(e>0, merge(b,e), b)))");
predicate_merge_expr(&mut e);
vectorize(&mut e);
assert!(!has_vectorized_merge(&e));
}
#[test]
fn non_vectorizable_type() {
let mut e =
typed_expression("|v:vec[i32]| result(for(v, appender[vec[i32]], |b,i,e| merge(b,v)))");
vectorize(&mut e);
assert!(!has_vectorized_merge(&e));
}
#[test]
fn non_vectorizable_expr() {
let mut e = typed_expression(
"|v:vec[i32]| result(for(v, appender[i32], |b,i,e| merge(b,lookup(v,i))))",
);
vectorize(&mut e);
assert!(!has_vectorized_merge(&e));
}
#[test]
fn zipped_input() {
let mut e = typed_expression(
"|v:vec[i32]| result(for(zip(v,v), appender[i32], |b,i,e| merge(b,e.$0+e.$1)))",
);
vectorize(&mut e);
assert!(has_vectorized_merge(&e));
}