weld 0.4.0

Weld is a language and runtime for improving the performance of data-intensive applications.
Documentation
//! Transforms which fuse loops to reduce memory movement and prevent unncessary
//! traversals of data.

use crate::ast::BuilderKind::*;
use crate::ast::ExprKind::*;
use crate::ast::LiteralKind::*;
use crate::ast::Type::*;
use crate::ast::*;
use crate::error::*;

use super::inliner::inline_apply;

use crate::util::SymbolGenerator;

#[cfg(test)]
use crate::tests::*;

/// Fuses for loops over the same vector in a zip into a single for loop which produces a vector of
/// structs directly.
///
/// Some examples:
///
/// for(zip(
///     result(for(a, appender, ...))
///     result(for(a, appender, ...))
/// ), ...)
///
/// will become for(result(for(a, ...))) where the nested for will produce a vector of structs with
/// two elements.
///
/// Caveats:
///     - Like all Zip-based transforms, this function currently assumes that the output of each
///     expression in the Zip is the same length.
///
pub fn fuse_loops_horizontal(expr: &mut Expr) {
    expr.transform(&mut |ref mut expr| {
        let mut sym_gen = SymbolGenerator::from_expression(expr);
        if let For {
            iters: ref all_iters,
            builder: ref outer_bldr,
            func: ref outer_func,
        } = expr.kind
        {
            if all_iters.len() > 1 {
                // Vector of tuples containing the params and expressions of functions in nested lambdas.
                let mut lambdas = vec![];
                let mut common_data = None;
                // Used to check if the same rows of each output are touched by the outer for.
                let first_iter = (&all_iters[0].start, &all_iters[0].end, &all_iters[0].stride);
                // First, check if all the lambdas are over the same vector and have a pattern we can merge.
                // Below, for each iterator in the for loop, we checked if each nested for loop is
                // over the same vector and has the same Iter parameters (i.e., same start, end, stride).
                let iters_same = all_iters.iter().all(|ref iter| {
                    if (&iter.start, &iter.end, &iter.stride) == first_iter {
                        // Make sure each nested for loop follows the ``result(for(a, appender, ...)) pattern.
                        if let Res {
                            builder: ref res_bldr,
                        } = iter.data.kind
                        {
                            if let For {
                                iters: ref iters2,
                                builder: ref bldr2,
                                func: ref lambda,
                            } = res_bldr.kind
                            {
                                if common_data.is_none() {
                                    common_data = Some(iters2.clone());
                                }
                                if iters2 == common_data.as_ref().unwrap() {
                                    if let NewBuilder(_) = bldr2.kind {
                                        if let Builder(ref kind, _) = bldr2.ty {
                                            if let Appender(_) = *kind {
                                                if let Lambda {
                                                    params: ref args,
                                                    ref body,
                                                } = lambda.kind
                                                {
                                                    if let Merge {
                                                        ref builder,
                                                        ref value,
                                                    } = body.kind
                                                    {
                                                        if let Ident(ref n) = builder.kind {
                                                            if *n == args[0].name {
                                                                // Save the arguments and expressions for the function so
                                                                // they can be used for fusion later.
                                                                lambdas.push((
                                                                    args.clone(),
                                                                    value.clone(),
                                                                ));
                                                                return true;
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                    // The pattern doesn't match for some Iter -- abort the transform.
                    false
                });

                if iters_same {
                    // All Iters are over the same range and same vector, with a pattern we can
                    // transform. Produce the new expression by zipping the functions of each
                    // nested for into a single merge into a struct.

                    // Zip the expressions to create an appender whose merge (value) type is a struct.
                    let merge_type = Struct(
                        lambdas
                            .iter()
                            .map(|ref e| e.1.ty.clone())
                            .collect::<Vec<_>>(),
                    );
                    // TODO(Deepak): Fix this to something meaningful.
                    let builder_type =
                        Builder(Appender(Box::new(merge_type.clone())), Annotations::new());
                    // The element type remains unchanged.
                    let func_elem_type = lambdas[0].0[2].ty.clone();

                    // Parameters for the new fused function. Symbols names are generated using symbol
                    // names for the builder and element from an existing function.
                    let new_params = vec![
                        Parameter {
                            ty: builder_type.clone(),
                            name: sym_gen.new_symbol(&lambdas[0].0[0].name.name()),
                        },
                        Parameter {
                            ty: Scalar(ScalarKind::I64),
                            name: sym_gen.new_symbol(&lambdas[0].0[1].name.name()),
                        },
                        Parameter {
                            ty: func_elem_type.clone(),
                            name: sym_gen.new_symbol(&lambdas[0].0[2].name.name()),
                        },
                    ];

                    // Generate Ident expressions for the new symbols and substitute them in the
                    // functions' merge expressions.
                    let new_bldr_expr = Expr {
                        ty: builder_type.clone(),
                        kind: Ident(new_params[0].name.clone()),
                        annotations: Annotations::new(),
                    };
                    let new_index_expr = Expr {
                        ty: Scalar(ScalarKind::I64),
                        kind: Ident(new_params[1].name.clone()),
                        annotations: Annotations::new(),
                    };
                    let new_elem_expr = Expr {
                        ty: func_elem_type,
                        kind: Ident(new_params[2].name.clone()),
                        annotations: Annotations::new(),
                    };
                    for &mut (ref mut args, ref mut expr) in lambdas.iter_mut() {
                        expr.substitute(&args[0].name, &new_bldr_expr);
                        expr.substitute(&args[1].name, &new_index_expr);
                        expr.substitute(&args[2].name, &new_elem_expr);
                    }

                    // Build up the new expression. The new expression merges structs into an
                    // appender, where each struct field is an expression which was merged into an
                    // appender in one of the original functions. For example, if there were two
                    // zipped fors in the original expression with lambdas |b1,e1| merge(b1,
                    // e1+1) and |b2,e2| merge(b2, e2+2), the new expression would be merge(b,
                    // {e+1,e+2}) into a new builder b of type appender[{i32,i32}]. e1, e2, and e
                    // refer to the same element in the expressions above since we check to ensure
                    // each zipped for is over the same input data.
                    let new_merge_expr = Expr {
                        ty: builder_type.clone(),
                        kind: Merge {
                            builder: Box::new(new_bldr_expr),
                            value: Box::new(Expr {
                                ty: merge_type.clone(),
                                kind: MakeStruct {
                                    elems: lambdas
                                        .iter()
                                        .map(|ref lambda| *lambda.1.clone())
                                        .collect::<Vec<_>>(),
                                },
                                annotations: Annotations::new(),
                            }),
                        },
                        annotations: Annotations::new(),
                    };
                    let new_func = Expr {
                        ty: Function(
                            new_params
                                .iter()
                                .map(|ref p| p.ty.clone())
                                .collect::<Vec<_>>(),
                            Box::new(builder_type.clone()),
                        ),
                        kind: Lambda {
                            params: new_params,
                            body: Box::new(new_merge_expr),
                        },
                        annotations: Annotations::new(),
                    };
                    let new_iter_expr = Expr {
                        ty: Vector(Box::new(merge_type)),
                        kind: Res {
                            builder: Box::new(Expr {
                                ty: builder_type.clone(),
                                kind: For {
                                    iters: common_data.unwrap(),
                                    builder: Box::new(Expr {
                                        ty: builder_type,
                                        kind: NewBuilder(None),
                                        annotations: Annotations::new(),
                                    }),
                                    func: Box::new(new_func),
                                },
                                annotations: Annotations::new(),
                            }),
                        },
                        annotations: Annotations::new(),
                    };

                    // TODO(shoumik): Any way to avoid the clones here?
                    return Some(Expr {
                        ty: expr.ty.clone(),
                        kind: For {
                            iters: vec![Iter {
                                data: Box::new(new_iter_expr),
                                start: all_iters[0].start.clone(),
                                end: all_iters[0].end.clone(),
                                stride: all_iters[0].stride.clone(),
                                kind: all_iters[0].kind.clone(),
                                shape: all_iters[0].shape.clone(),
                                strides: all_iters[0].strides.clone(),
                            }],
                            builder: outer_bldr.clone(),
                            func: outer_func.clone(),
                        },
                        annotations: Annotations::new(),
                    });
                }
            }
        }
        None
    });
}

/// Fuses loops where one for loop takes another as it's input, which prevents intermediate results
/// from being materialized.
pub fn fuse_loops_vertical(expr: &mut Expr) {
    expr.transform_and_continue_res(&mut |ref mut expr| {
        let mut sym_gen = SymbolGenerator::from_expression(expr);
        if let For {
            iters: ref all_iters,
            builder: ref bldr1,
            func: ref nested,
        } = expr.kind
        {
            if all_iters.len() == 1 {
                let iter1 = &all_iters[0];
                if let Res {
                    builder: ref res_bldr,
                } = iter1.data.kind
                {
                    if let For {
                        iters: ref iters2,
                        builder: ref bldr2,
                        func: ref lambda,
                    } = res_bldr.kind
                    {
                        if iters2.iter().all(|ref i| consumes_all(&i)) {
                            if let NewBuilder(_) = bldr2.kind {
                                if let Builder(ref kind, _) = bldr2.ty {
                                    if let Appender(_) = *kind {
                                        let mut e = Expr::new_for(
                                            iters2.clone(),
                                            *bldr1.clone(),
                                            replace_builder(lambda, nested, &mut sym_gen)?,
                                        )?;
                                        e.annotations = expr.annotations.clone();
                                        return Ok((Some(e), true));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        Ok((None, true))
    });
}

/// Given an iterator, returns whether the iterator consumes every element of its data vector.
fn consumes_all(iter: &Iter) -> bool {
    if let Iter {
        start: None,
        end: None,
        stride: None,
        ..
    } = *iter
    {
        return true;
    } else if let Iter {
        ref data,
        start: Some(ref start),
        end: Some(ref end),
        stride: Some(ref stride),
        ..
    } = *iter
    {
        // Checks if the stride is 1 and an entire vector represented by a symbol is consumed.
        if let (
            &Literal(I64Literal(1)),
            &Literal(I64Literal(0)),
            &Ident(ref name),
            &Length { data: ref v },
        ) = (&stride.kind, &start.kind, &data.kind, &end.kind)
        {
            if let Ident(ref vsym) = v.kind {
                return vsym == name;
            }
        }
        // Checks if an entire vector literal is consumed.
        if let (&Literal(I64Literal(1)), &Literal(I64Literal(0)), &MakeVector { ref elems }) =
            (&stride.kind, &start.kind, &data.kind)
        {
            let num_elems = elems.len() as i64;
            if let Literal(I64Literal(x)) = end.kind {
                return num_elems == x;
            }
        }
    }
    false
}

/// Given a lambda which takes a builder and an argument, returns a new function which takes a new
/// builder type and calls nested on the values it would've merged into its old builder. This
/// allows us to "compose" merge functions and avoid creating intermediate results.
fn replace_builder(
    lambda: &Expr,
    nested: &Expr,
    sym_gen: &mut SymbolGenerator,
) -> WeldResult<Expr> {
    // Tests whether an identifier and symbol refer to the same value by
    // comparing the symbols.
    fn same_iden(a: &ExprKind, b: &Symbol) -> bool {
        if let Ident(ref symbol) = *a {
            symbol == b
        } else {
            false
        }
    }

    if let Lambda {
        params: ref args,
        ref body,
    } = lambda.kind
    {
        if let Lambda {
            params: ref nested_args,
            ..
        } = nested.kind
        {
            let mut new_body = *body.clone();
            let old_bldr = &args[0];
            let old_index = &args[1];
            let old_arg = &args[2];
            let new_bldr_sym = sym_gen.new_symbol(&old_bldr.name.name());
            let new_index_sym = sym_gen.new_symbol(&old_index.name.name());
            let new_bldr = Expr::new_ident(new_bldr_sym.clone(), nested_args[0].ty.clone())?;
            let new_index = Expr::new_ident(new_index_sym.clone(), nested_args[1].ty.clone())?;

            // Fix expressions to use the new builder.
            new_body.transform_and_continue_res(&mut |ref mut e| match e.kind {
                Merge {
                    ref builder,
                    ref value,
                } if same_iden(&(*builder).kind, &old_bldr.name) => {
                    let params: Vec<Expr> =
                        vec![new_bldr.clone(), new_index.clone(), *value.clone()];
                    let mut expr = Expr::new_apply(nested.clone(), params)?;
                    inline_apply(&mut expr);
                    Ok((Some(expr), true))
                }
                For {
                    iters: ref data,
                    builder: ref bldr,
                    ref func,
                } if same_iden(&(*bldr).kind, &old_bldr.name) => {
                    let expr = Expr::new_for(
                        data.clone(),
                        new_bldr.clone(),
                        replace_builder(func, nested, sym_gen)?,
                    )?;
                    Ok((Some(expr), false))
                }
                Ident(ref mut symbol) if *symbol == old_bldr.name => {
                    Ok((Some(new_bldr.clone()), false))
                }
                Ident(ref mut symbol) if *symbol == old_index.name => {
                    Ok((Some(new_index.clone()), false))
                }
                _ => Ok((None, true)),
            });

            // Fix types to make sure the return type propagates through all subexpressions.
            match_types(&new_bldr.ty, &mut new_body);

            let new_params = vec![
                Parameter {
                    ty: new_bldr.ty,
                    name: new_bldr_sym,
                },
                Parameter {
                    ty: Scalar(ScalarKind::I64),
                    name: new_index_sym,
                },
                Parameter {
                    ty: old_arg.ty.clone(),
                    name: old_arg.name.clone(),
                },
            ];
            return Expr::new_lambda(new_params, new_body);
        }
    }
    return compile_err!("Inconsistency in replace_builder");
}

/// Given a root type, forces each expression to return that type. TODO For now, only supporting
/// expressions which can be builders. We might want to factor this out to be somewhere else.
fn match_types(root_ty: &Type, expr: &mut Expr) {
    expr.ty = root_ty.clone();
    match expr.kind {
        If {
            ref mut on_true,
            ref mut on_false,
            ..
        } => {
            match_types(root_ty, on_true);
            match_types(root_ty, on_false);
        }
        Select {
            ref mut on_true,
            ref mut on_false,
            ..
        } => {
            match_types(root_ty, on_true);
            match_types(root_ty, on_false);
        }
        Let { ref mut body, .. } => {
            match_types(root_ty, body);
        }
        _ => {}
    };
}

#[test]
fn simple_horizontal_loop_fusion() {
    // Two loops.
    let mut e1 = typed_expression(
        "for(zip(
            result(for([1,2,3], appender, |b,i,e| merge(b, e+1))),
            result(for([1,2,3], appender,|b2,i2,e2| merge(b2,e2+1)))
        ), appender, |b,i,e| merge(b, e.$0+1))",
    );
    fuse_loops_horizontal(&mut e1);
    let e2 = typed_expression(
        "for(result(for([1,2,3], appender, |b,i,e| merge(b, {e+1,e+1}))), \
         appender, |b,i,e| merge(b, e.$0+1))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Three loops.
    let mut e1 = typed_expression(
        "for(zip(
            result(for([1,2,3], appender, |b,i,e| merge(b, e+1))),
            result(for([1,2,3], appender,|b2,i2,e2| merge(b2,e2+2))),
            result(for([1,2,3], appender,|b3,i3,e3| merge(b3,e3+3)))
        ), appender, |b,i,e| merge(b, e.$0+1))",
    );
    fuse_loops_horizontal(&mut e1);
    let e2 = typed_expression(
        "for(result(for([1,2,3], appender, |b,i,e| merge(b, \
         {e+1,e+2,e+3}))), appender, |b,i,e| merge(b, e.$0+1))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Iters in inner loop
    let mut e1 = typed_expression(
        "for(zip(
            result(for(iter([1,2,3], 0L, 2L, 1L), appender, |b,i,e| merge(b, e+1))),
            result(for(iter([1,2,3], 0L, 2L, 1L), appender, |b,i,e| merge(b, e+2)))
        ), appender, |b,i,e| merge(b, e.$0+1))",
    );
    fuse_loops_horizontal(&mut e1);
    let e2 = typed_expression(
        "for(result(for(iter([1,2,3], 0L, 2L, 1L), appender, |b,i,e| \
         merge(b, {e+1,e+2}))), appender, |b,i,e| merge(b, e.$0+1))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Iters in outer loop.
    let mut e1 = typed_expression(
        "for(zip(
            iter(result(for([1,2,3], appender, |b,i,e| merge(b, e+1))), 0L, 2L, 1L),
            iter(result(for([1,2,3], appender, |b,i,e| merge(b, e+2))), 0L, 2L, 1L)
        ), appender, |b,i,e| merge(b, e.$0+1))",
    );
    fuse_loops_horizontal(&mut e1);
    let e2 = typed_expression(
        "for(iter(result(for([1,2,3], appender, |b,i,e| merge(b, \
         {e+1,e+2}))), 0L, 2L, 1L), appender, |b,i,e| merge(b, e.$0+1))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Two loops with different vectors; should fail.
    let mut e1 = typed_expression(
        "for(zip(
            result(for([1,2,3], appender, |b,i,e| merge(b, e+1))),
            result(for([1,2,4], appender,|b2,i2,e2| merge(b2,e2+1)))
        ), appender, |b,i,e| merge(b, e.$0+1))",
    );
    let e2 = e1.clone();
    fuse_loops_horizontal(&mut e1);
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());
}

#[test]
fn simple_vertical_loop_fusion() {
    // Two loops.
    let mut e1 = typed_expression(
        "for(result(for([1,2,3], appender, |b,i,e| merge(b,e+2))), \
         appender, |b,h,f| merge(b, f+1))",
    );
    fuse_loops_vertical(&mut e1);
    let e2 = typed_expression("for([1,2,3], appender, |b,i,e| merge(b, (e+2)+1))");
    println!("{}", print_expr_without_indent(&e1));
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Three loops.
    let mut e1 = typed_expression(
        "for(result(for(result(for([1,2,3], appender, |b,i,e| \
         merge(b,e+3))), appender, |b,i,e| merge(b,e+2))), appender, \
         |b,h,f| merge(b, f+1))",
    );
    fuse_loops_vertical(&mut e1);
    let e2 = typed_expression("for([1,2,3], appender, |b,i,e| merge(b, (((e+3)+2)+1)))");
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Merges in other positions, replace builder identifiers.
    let mut e1 = typed_expression(
        "for(result(for([1,2,3], appender, |b,i,e| if(e>5, \
         merge(b,e+2), b))), appender, |b,h,f| merge(b, f+1))",
    );
    fuse_loops_vertical(&mut e1);
    let e2 = typed_expression("for([1,2,3], appender, |b,i,e| if(e>5, merge(b, (e+2)+1), b))");
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Make sure correct builder is chosen.
    let mut e1 = typed_expression(
        "for(result(for([1,2,3], appender[i32], |b,i,e| \
         merge(b,e+2))), appender[f64], |b,h,f| merge(b, 1.0))",
    );
    fuse_loops_vertical(&mut e1);
    let e2 = typed_expression("for([1,2,3], appender[f64], |b,i,e| merge(b, 1.0))");
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Multiple inner loops.
    let mut e1 = typed_expression(
        "for(result(for(zip([1,2,3],[4,5,6]), appender, |b,i,e| \
         merge(b,e.$0+2))), appender, |b,h,f| merge(b, f+1))",
    );
    fuse_loops_vertical(&mut e1);
    let e2 = typed_expression("for(zip([1,2,3],[4,5,6]), appender, |b,i,e| merge(b, (e.$0+2)+1))");
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Iter where inner data consumed fully.
    let mut e1 = typed_expression(
        "let a = [1,2,3]; for(result(for(iter(a, 0L, len(a), 1L), \
         appender, |b,i,e| merge(b,e+2))), appender, |b,h,f| merge(b, \
         f+1))",
    );
    fuse_loops_vertical(&mut e1);
    let e2 = typed_expression(
        "let a = [1,2,3]; for(iter(a,0L,len(a),1L), appender, |b,i,e| \
         merge(b, (e+2)+1))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    // Inner data not consumed fully.
    let mut e1 = typed_expression(
        "for(result(for(iter([1,2,3], 0L, 1L, 1L), appender, |b,i,e| \
         merge(b,e+2))), appender, |b,h,f| merge(b, f+1))",
    );
    fuse_loops_vertical(&mut e1);
    // Loop fusion should fail.
    let e2 = typed_expression(
        "for(result(for(iter([1,2,3], 0L, 1L, 1L), appender, |b,i,e| \
         merge(b,e+2))), appender, |b,h,f| merge(b, f+1))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());
}