use std::error::Error;
use crate::ast::*;
use crate::ast::BuilderKind::*;
use crate::ast::ExprKind::*;
use crate::ast::Type::*;
use crate::error::*;
pub const UNROLL_LIMIT: u64 = 8;
struct UnrollPattern<'a> {
loop_size: u64,
iters: &'a Vec<Iter>,
builder_kind: &'a BuilderKind,
merge_params: &'a Vec<Parameter>,
merge_value: &'a Expr,
}
trait LoopSizeAnnotation {
fn loopsize(&self) -> Option<u64>;
}
impl LoopSizeAnnotation for Expr {
fn loopsize(&self) -> Option<u64> {
println!("{}", self.annotations);
self.annotations
.get("loopsize")
.and_then(|value| value.parse::<u64>().ok())
}
}
impl<'a> UnrollPattern<'a> {
fn extract(expr: &'a Expr) -> Option<UnrollPattern<'_>> {
if let Res { ref builder } = expr.kind {
if let Some(loopsize) = builder.loopsize() {
if loopsize <= UNROLL_LIMIT {
if let For {
ref iters,
ref builder,
ref func,
} = builder.kind
{
if let Builder(ref bk, _) = builder.ty {
if let Lambda {
ref params,
ref body,
} = func.kind
{
if let Merge {
builder: ref builder2,
ref value,
} = body.kind
{
match builder2.kind {
Ident(ref name) if *name == params[0].name => {
return Some(UnrollPattern {
loop_size: loopsize,
iters,
builder_kind: bk,
merge_params: params,
merge_value: value,
});
}
_ => {
return None;
}
}
}
}
}
}
}
}
}
None
}
}
pub fn unroll_static_loop(expr: &mut Expr) {
use crate::util::SymbolGenerator;
if expr.uniquify().is_err() {
return;
}
let mut sym_gen = SymbolGenerator::from_expression(expr);
expr.transform_up(&mut |ref mut expr| {
if let Some(pat) = UnrollPattern::extract(expr) {
let symbols: Vec<_> = (0..pat.iters.len())
.map(|_| sym_gen.new_symbol("tmp"))
.collect();
let idents: Vec<_> = symbols
.iter()
.zip(pat.iters.iter())
.map(|ref t| Expr::new_ident(t.0.clone(), t.1.data.ty.clone()).unwrap())
.collect();
let vals = unroll_values(pat.merge_params, pat.merge_value, &idents, pat.loop_size);
if let Err(err) = vals {
trace!("Unroller error: {}", err.description());
return None;
}
let vals = vals.unwrap();
let combined_expr = combine_unrolled_values(pat.builder_kind.clone(), vals);
if let Err(err) = combined_expr {
trace!("Unroller error: {}", err.description());
return None;
}
let mut prev = combined_expr.unwrap();
for (ref sym, ref iter) in symbols.into_iter().rev().zip(pat.iters.iter().rev()) {
prev = Expr::new_let(sym.clone(), iter.data.as_ref().clone(), prev).unwrap();
}
Some(prev)
} else {
None
}
});
}
fn is_same_ident(expr: &Expr, other: &Expr) -> bool {
if let Ident(ref name) = other.kind {
if let Ident(ref name2) = expr.kind {
return name == name2 && expr.ty == other.ty;
}
}
false
}
fn unroll_values(
parameters: &[Parameter],
value: &Expr,
vectors: &[Expr],
loopsize: u64,
) -> WeldResult<Vec<Expr>> {
if parameters.len() != 3 {
return compile_err!("Expected three parameters to Merge function");
}
let index_symbol = ¶meters[1].name;
let elem_symbol = ¶meters[2].name;
let elem_ident = &Expr::new_ident(elem_symbol.clone(), parameters[2].ty.clone())?;
let mut expressions = vec![];
for i in 0..loopsize {
let mut unrolled_value = value.clone();
unrolled_value.transform(&mut |ref mut e| {
match e.kind {
Ident(ref name) if name == index_symbol => {
Some(Expr::new_literal(LiteralKind::I64Literal(i as i64)).unwrap())
}
Ident(ref name) if name == elem_symbol && vectors.len() == 1 => {
Some(
Expr::new_lookup(
vectors[0].clone(),
Expr::new_literal(LiteralKind::I64Literal(i as i64)).unwrap(),
)
.unwrap(),
)
}
GetField {
ref expr,
ref index,
} if is_same_ident(expr, elem_ident) && vectors.len() > 1 => {
let data_expr = vectors[*index as usize].clone();
Some(
Expr::new_lookup(
data_expr,
Expr::new_literal(LiteralKind::I64Literal(i as i64)).unwrap(),
)
.unwrap(),
)
}
_ => None,
}
});
expressions.push(unrolled_value);
}
Ok(expressions)
}
fn combine_unrolled_values(bk: BuilderKind, values: Vec<Expr>) -> WeldResult<Expr> {
if values.is_empty() {
return compile_err!("Need at least one value to combine in unroller");
}
match bk {
Merger(ref ty, ref binop) => {
if values.iter().any(|ref expr| expr.ty != *ty.as_ref()) {
return compile_err!("Mismatched types in Merger and unrolled values.");
}
let mut prev = None;
for value in values.into_iter() {
if prev.is_none() {
prev = Some(value);
} else {
prev = Some(Expr::new_bin_op(*binop, prev.unwrap(), value)?);
}
}
Ok(prev.unwrap())
}
Appender(ref ty) => {
if values.iter().any(|ref expr| expr.ty != *ty.as_ref()) {
return compile_err!("Mismatched types in Appender and unrolled values.");
}
Expr::new_make_vector(values)
}
ref bk => compile_err!(
"Unroller transform does not support loops with builder of kind {:?}",
bk
),
}
}
#[cfg(test)]
use crate::tests::*;
#[test]
fn simple_merger_loop() {
let mut e = typed_expression(
"|v:vec[i32]| result(
@(loopsize:2L)
for(v, merger[i32,+],
|b,i,e| merge(b, e)))",
);
unroll_static_loop(&mut e);
let expect = &typed_expression("|v:vec[i32]| let t0 = v; lookup(t0, 0L) + lookup(t0, 1L)");
assert!(e.compare_ignoring_symbols(expect).unwrap());
}
#[test]
fn zipped_merger_loop() {
let mut e = typed_expression(
"|v:vec[i32], w: vec[i32]| result(
@(loopsize:2L)
for(zip(v, w), merger[i32,+],
|b,i,e| merge(b, e.$0 * e.$1)))",
);
unroll_static_loop(&mut e);
let expect = &typed_expression(
"|v:vec[i32], w:vec[i32]| let t0 = v; let t1 = w;
lookup(t0, 0L) * lookup(t1, 0L) +
lookup(t0, 1L) * lookup(t1, 1L)",
);
assert!(e.compare_ignoring_symbols(expect).unwrap());
}
#[test]
fn simple_appender_loop() {
let mut e = typed_expression(
"|v:vec[i32]| result(
@(loopsize:2L)
for(v, appender,
|b,i,e| merge(b, e)))",
);
unroll_static_loop(&mut e);
let expect = &typed_expression("|v:vec[i32]| let t0 = v; [lookup(t0, 0L), lookup(t0, 1L)]");
assert!(e.compare_ignoring_symbols(expect).unwrap());
}
#[test]
fn zipped_appender_loop() {
let mut e = typed_expression(
"|v:vec[i32], w: vec[i32]| result(
@(loopsize:2L)
for(zip(v, w), appender,
|b,i,e| merge(b, e.$0 * e.$1)))",
);
unroll_static_loop(&mut e);
let expect = &typed_expression(
"|v:vec[i32], w:vec[i32]| let t0 = v; let t1 = w;
[lookup(t0, 0L) * lookup(t1, 0L),
lookup(t0, 1L) * lookup(t1, 1L)]",
);
assert!(e.compare_ignoring_symbols(expect).unwrap());
}
#[test]
fn large_merger_loop() {
let mut e = typed_expression(
format!(
"|v:vec[i32]| result(
@(loopsize:{}L)
for(v, merger[i32,+],
|b,i,e| merge(b, e)))",
UNROLL_LIMIT + 1
)
.as_ref(),
);
let expect = &e.clone();
unroll_static_loop(&mut e);
assert!(e.compare_ignoring_symbols(expect).unwrap());
}