weld 0.4.0

Weld is a language and runtime for improving the performance of data-intensive applications.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
//! Various inlining transforms.
//!
//! These transforms take a set of nested expressions and fuse them into a single one.

use fnv;

use crate::ast::ExprKind::*;
use crate::ast::*;

use fnv::FnvHashMap;

use std::mem;

#[cfg(test)]
use crate::tests::*;

/// Inlines GetField(MakeStruct(*)) expressions, which can occur during loop fusion when some
/// of the loops are zipping together multiple column vectors.
pub fn inline_get_field(expr: &mut Expr) {
    expr.transform(&mut |ref mut expr| {
        if let GetField { ref expr, index } = expr.kind {
            if let MakeStruct { ref elems } = expr.kind {
                return Some(elems[index as usize].clone());
            }
        }
        None
    });
}

/// Inlines Zip expressions as collections of iters. Using Zips outside of a For loop is currently
/// unsupported behavior. This transform handles the simple case of converting Zips in macros
/// such as map and filter into Iters in For loops.
///
/// TODO(shoumik): Perhaps Zip should just be a macro? Then macros need to be ordered.
pub fn inline_zips(expr: &mut Expr) {
    expr.transform(&mut |ref mut e| {
        if let For {
            ref mut iters,
            ref builder,
            ref func,
        } = e.kind
        {
            if iters.len() == 1 {
                let first_iter = &iters[0];
                if let Zip { ref vectors } = first_iter.data.kind {
                    let new_iters = vectors
                        .iter()
                        .map(|v| Iter {
                            data: Box::new(v.clone()),
                            start: None,
                            end: None,
                            stride: None,
                            kind: first_iter.kind.clone(),
                            shape: None,
                            strides: None,
                        })
                        .collect::<Vec<_>>();
                    return Some(Expr {
                        ty: e.ty.clone(),
                        kind: For {
                            iters: new_iters,
                            builder: builder.clone(),
                            func: func.clone(),
                        },
                        annotations: Annotations::new(),
                    });
                }
            }
        }
        None
    });
}

/// Inlines Apply nodes whose argument is a Lambda expression. These often arise during macro
/// expansion but it's simpler to inline them before doing type inference.
/// Unlike many of the other transformations, we make this one independent of types so that
/// we can apply it before type inference.
///
/// Caveats:
/// - Functions that reuse a parameter twice have its expansion appear twice, instead of assigning
///   it to a temporary as would happen with function application.
/// - Does not complete inlining if some of the functions take functions as arguments (in that
///   case, the expressions after inlining may lead to more inlining).
pub fn inline_apply(expr: &mut Expr) {
    expr.transform(&mut |ref mut expr| {
        if let Apply {
            ref func,
            params: ref args,
        } = expr.kind
        {
            if let Lambda {
                ref params,
                ref body,
            } = func.kind
            {
                let mut new = *body.clone();
                for (param, arg) in params.iter().zip(args) {
                    new.substitute(&param.name, &arg);
                }
                return Some(new);
            }
        }
        None
    });
}

pub fn inline_let(expr: &mut Expr) {
    expr.uniquify().unwrap();
    let usages = &mut FnvHashMap::default();
    count_symbols(expr, usages);
    trace!("Symbol count: {:?}", usages);
    inline_let_helper(expr, usages)
}

#[derive(Debug)]
struct SymbolTracker {
    count: i32,
    loop_nest: i32,
    value: Option<Box<Expr>>,
}

impl Default for SymbolTracker {
    fn default() -> SymbolTracker {
        SymbolTracker {
            count: 0,
            loop_nest: 0,
            value: None,
        }
    }
}

/// Count the occurances of each symbol defined by a `Let` statement.
fn count_symbols(expr: &Expr, usage: &mut FnvHashMap<Symbol, SymbolTracker>) {
    match expr.kind {
        For { ref func, .. }
        | Iterate {
            update_func: ref func,
            ..
        }
        | Sort {
            cmpfunc: ref func, ..
        }
        | Apply { ref func, .. } => {
            // Mark all symbols seen so far as "in a loop"
            for value in usage.values_mut() {
                value.loop_nest += 1;
            }

            count_symbols(func, usage);

            for value in usage.values_mut() {
                value.loop_nest -= 1;
            }
        }
        Let { ref name, .. } => {
            debug_assert!(!usage.contains_key(name));
            let _ = usage.insert(name.clone(), SymbolTracker::default());
        }
        Ident(ref symbol) => {
            if let Some(ref mut tracker) = usage.get_mut(symbol) {
                if tracker.loop_nest == 0 {
                    tracker.count += 1;
                } else {
                    // Used in a loop!
                    tracker.count += 3;
                }
            }
        }
        _ => (),
    };

    // Recurse into children - skip functions that expressions may call repeatedly. We handled
    // those already.
    for child in expr.children() {
        match child.kind {
            Lambda { .. } => (),
            _ => count_symbols(child, usage),
        }
    }
}

/// Inlines Let calls if the symbol defined by the Let statement is used
/// never or only one time.
fn inline_let_helper(expr: &mut Expr, usages: &mut FnvHashMap<Symbol, SymbolTracker>) {
    let mut taken_body = None;
    match expr.kind {
        Let {
            ref mut name,
            ref mut value,
            ref mut body,
        } => {
            // Check whether the symbol is used one or fewer times.
            if let Some(tracker) = usages.get_mut(name) {
                if tracker.count <= 1 {
                    taken_body = Some(body.take());
                    tracker.value = Some(value.take());
                }
            }
        }
        Ident(ref name) => {
            // Check if the identifier maps to one that should be inlined.
            if let Some(tracker) = usages.get_mut(name) {
                if tracker.count <= 1 {
                    // Value should have been set by a preceding Let.
                    debug_assert!(tracker.value.is_some());
                    // Value should only be swapped once.
                    debug_assert!(!tracker.value.as_ref().unwrap().is_placeholder());
                    mem::swap(&mut taken_body, &mut tracker.value);
                }
            }
        }
        _ => (),
    }

    // Set the body to this expression.
    if let Some(mut val) = taken_body {
        mem::swap(expr, val.as_mut());
        inline_let_helper(expr, usages);
    } else {
        for child in expr.children_mut() {
            inline_let_helper(child, usages);
        }
    }
}

/// Changes negations of literal values to be literal negated values.
pub fn inline_negate(expr: &mut Expr) {
    use crate::ast::LiteralKind::*;
    expr.transform(&mut |ref mut expr| {
        if let Negate(ref child_expr) = expr.kind {
            if let Literal(ref literal_kind) = child_expr.kind {
                let res = match *literal_kind {
                    I8Literal(a) => Some(Expr::new_literal(I8Literal(-a)).unwrap()),
                    I16Literal(a) => Some(Expr::new_literal(I16Literal(-a)).unwrap()),
                    I32Literal(a) => Some(Expr::new_literal(I32Literal(-a)).unwrap()),
                    I64Literal(a) => Some(Expr::new_literal(I64Literal(-a)).unwrap()),
                    F32Literal(a) => {
                        Some(Expr::new_literal(F32Literal((-f32::from_bits(a)).to_bits())).unwrap())
                    }
                    F64Literal(a) => {
                        Some(Expr::new_literal(F64Literal((-f64::from_bits(a)).to_bits())).unwrap())
                    }
                    _ => None,
                };
                return res;
            }
        }
        None
    });
}

/// Inline casts.
///
/// This changes casts of literal values to be literal values of the casted type. It additionally
/// removes "self casts" (e.g., `i64(x: i64)` becomes `x`).
pub fn inline_cast(expr: &mut Expr) {
    use crate::ast::LiteralKind::*;
    use crate::ast::ScalarKind::*;
    use crate::ast::Type::Scalar;
    expr.transform(&mut |ref mut expr| {
        if let Cast {
            kind: ref scalar_kind,
            ref child_expr,
        } = expr.kind
        {
            if let Literal(ref literal_kind) = child_expr.kind {
                return match (scalar_kind, literal_kind) {
                    (&F64, &I32Literal(a)) => {
                        Some(Expr::new_literal(F64Literal((f64::from(a)).to_bits())).unwrap())
                    }
                    (&I64, &I32Literal(a)) => {
                        Some(Expr::new_literal(I64Literal(i64::from(a))).unwrap())
                    }
                    (&F64, &I64Literal(a)) => {
                        Some(Expr::new_literal(F64Literal((a as f64).to_bits())).unwrap())
                    }
                    _ => None,
                };
            }
            if let Scalar(ref kind) = child_expr.ty {
                if kind == scalar_kind {
                    // XXX Tombstone and mem::swap here!!
                    return Some(*child_expr.clone());
                }
            }
        }
        None
    });
}

/// Checks if `expr` is a `GetField` on an identifier with name `sym`. If so,
/// returns the field index being accessed.
fn getfield_on_symbol(expr: &Expr, sym: &Symbol) -> Option<u32> {
    if let GetField {
        ref expr,
        ref index,
    } = expr.kind
    {
        if let Ident(ref ident_name) = expr.kind {
            if sym == ident_name {
                return Some(*index);
            }
        }
    }
    None
}

/// Simplifies branches with `<expr> == False` to just be over <expr>`
///
/// This switches the true condition and the false condition.
pub fn simplify_branch_conditions(expr: &mut Expr) {
    use crate::ast::LiteralKind::BoolLiteral;
    expr.uniquify().unwrap();
    expr.transform_up(&mut |ref mut expr| {
        if let If {
            ref mut cond,
            ref mut on_true,
            ref mut on_false,
        } = expr.kind
        {
            let mut taken = None;
            if let BinOp {
                ref mut kind,
                ref mut left,
                ref mut right,
            } = cond.kind
            {
                if *kind == BinOpKind::Equal {
                    if let Literal(BoolLiteral(false)) = left.kind {
                        taken = Some(right.take());
                    } else if let Literal(BoolLiteral(false)) = right.kind {
                        taken = Some(left.take());
                    }
                }
            };

            if let Some(ref mut expr) = taken {
                mem::swap(cond, expr);
                mem::swap(on_true, on_false);
            }
        }
        // We just updated the expression in place instead of replacing it.
        None
    });
}

/// Changes struct definitions assigned to a name and only used in `GetField` operations
/// to `Let` definitions over the struct elements themselves.
///
/// This transformation is similar to a simple SROA (scalar replacement of aggregates) transform in
/// other compilers.
///
/// ## Example
///
/// let a = {1, 2, 3, 4};
/// a.$0 + a.$1 + a.$2
///
/// Becomes
///
/// let us = 1;
/// let us#1 = 2;
/// let us#1 = 3;
/// let us#1 = 4;
/// us + us#1 + us#2
///
pub fn unroll_structs(expr: &mut Expr) {
    use crate::util::SymbolGenerator;

    expr.uniquify().unwrap();
    let mut sym_gen = SymbolGenerator::from_expression(expr);
    expr.transform_up(&mut |ref mut expr| {
        if let Let {
            ref name,
            ref value,
            ref body,
        } = expr.kind
        {
            if let MakeStruct { ref elems } = value.kind {
                // First, ensure that the name is not used anywhere but a `GetField`.
                let mut total_count: i32 = 0;
                let mut getstruct_count: i32 = 0;
                body.traverse(&mut |ref e| {
                    if getfield_on_symbol(e, name).is_some() {
                        getstruct_count += 1;
                    }
                    if let Ident(ref ident_name) = e.kind {
                        if ident_name == name {
                            total_count += 1;
                        }
                    }
                });

                // We used the struct somewhere else, so we can't safely get rid of it.
                if total_count != getstruct_count {
                    return None;
                }

                let mut new_body = body.as_ref().clone();
                let symbols: Vec<_> = elems.iter().map(|_| sym_gen.new_symbol("us")).collect();
                // Replace the new_body with the symbol we assigned the struct element to.
                new_body.transform(&mut |ref mut expr2| {
                    if let Some(index) = getfield_on_symbol(expr2, name) {
                        let sym = symbols[index as usize].clone();
                        return Some(Expr::new_ident(sym, expr2.ty.clone()).unwrap());
                    }
                    None
                });

                // Unroll the struct elements by assigning each one to a name.
                let mut prev = new_body;
                for (i, sym) in symbols.into_iter().enumerate().rev() {
                    prev = Expr::new_let(sym, elems[i].clone(), prev).unwrap();
                }
                return Some(prev);
            }
        }
        None
    });
}

#[test]
fn inline_lets() {
    let mut e1 = typed_expression("let a = 1; a + 2");
    inline_let(&mut e1);
    let e2 = typed_expression("1 + 2");
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    let mut e1 = typed_expression("let a = 1; a + a + 2");
    // The transform should fail since the identifier is used more than once.
    inline_let(&mut e1);
    let e2 = typed_expression("let a = 1; a + a + 2");
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    let mut e1 = typed_expression(
        "let a = 1L; for([1L,2L,3L], appender, |b,i,e| merge(b, e + a \
         + 2L))",
    );
    inline_let(&mut e1);
    // The transform should fail since the identifier is used in a loop.
    let e2 = typed_expression(
        "let a = 1L; for([1L,2L,3L], appender, |b,i,e| merge(b, e + a + \
         2L))",
    );
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    let mut e1 = typed_expression("let a = 1; let b = 2; let c = 3; a + b + c");
    inline_let(&mut e1);
    let e2 = typed_expression("1 + 2 + 3");
    println!("{}, {}", e1.pretty_print(), e2.pretty_print());
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    let mut e1 = typed_expression(
        "|input: vec[i32]|
        let b = 1;
        result(for(input, merger[i32,+], |b,i,e| let a = 1; merge(b, e + a))) + b",
    );
    inline_let(&mut e1);

    let e2 = typed_expression(
        "|input: vec[i32]|
        result(for(input, merger[i32,+], |b,i,e| merge(b, e + 1))) + 1",
    );
    println!("{}, {}", e1.pretty_print(), e2.pretty_print());
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());

    let mut e1 = typed_expression(
        "|input: vec[i32]|
        let b = 1;
        result(for(input, merger[i32,+], |b,i,e| let a = 1; merge(b, e + a + a))) + b",
    );
    inline_let(&mut e1);

    let e2 = typed_expression(
        "|input: vec[i32]|
        result(for(input, merger[i32,+], |b,i,e| let a = 1; merge(b, e + a + a))) + 1",
    );
    println!("{}, {}", e1.pretty_print(), e2.pretty_print());
    assert!(e1.compare_ignoring_symbols(&e2).unwrap());
}