Skip to main content

lutra_compiler/
bytecoding.rs

1use std::collections::HashMap;
2
3use indexmap::IndexSet;
4use lutra_bin::Encode;
5use lutra_bin::br::*;
6use lutra_bin::bytes::BufMut;
7use lutra_bin::ir;
8
9pub fn compile_program(value: ir::Program) -> Program {
10    let mut b = ByteCoder {
11        externals: Default::default(),
12        include_defs: false,
13
14        defs: &value.defs,
15        def_map: value.defs.iter().map(|def| (&def.name, &def.ty)).collect(),
16    };
17
18    Program {
19        main: b.compile_expr(value.main),
20        externals: b.externals.into_iter().collect(),
21        defs: if b.include_defs { value.defs } else { vec![] },
22    }
23}
24
25struct ByteCoder<'t> {
26    externals: IndexSet<ExternalSymbol>,
27    defs: &'t [ir::TyDef],
28    def_map: HashMap<&'t ir::Path, &'t ir::Ty>,
29
30    // Some externals need defs (read_parquet), but most of the time, skip them.
31    include_defs: bool,
32}
33
34impl<'t> ByteCoder<'t> {
35    fn get_ty_mat<'a: 't>(&self, ty: &'a ir::Ty) -> &'t ir::Ty {
36        match &ty.kind {
37            TyKind::Ident(path) => self.def_map.get(path).unwrap(),
38            _ => ty,
39        }
40    }
41
42    fn compile_expr(&mut self, expr: ir::Expr) -> Expr {
43        let kind = match expr.kind {
44            ir::ExprKind::Pointer(v) => ExprKind::Pointer(self.compile_pointer(v, &expr.ty)),
45            ir::ExprKind::Literal(v) => ExprKind::Literal(self.compile_literal(v)),
46            ir::ExprKind::Call(v) => ExprKind::Call(Box::new(self.compile_call(*v))),
47            ir::ExprKind::Function(v) => ExprKind::Function(Box::new(self.compile_function(*v))),
48            ir::ExprKind::Tuple(v) => ExprKind::Tuple(Box::new(self.compile_tuple(v))),
49            ir::ExprKind::Array(v) => ExprKind::Array(Box::new(self.compile_array(expr.ty, v))),
50            ir::ExprKind::EnumVariant(v) => {
51                ExprKind::EnumVariant(Box::new(self.compile_enum_variant(expr.ty, *v)))
52            }
53            ir::ExprKind::EnumEq(v) => ExprKind::EnumEq(Box::new(self.compile_enum_eq(*v))),
54            ir::ExprKind::EnumUnwrap(v) => return self.compile_enum_unwrap(*v),
55            ir::ExprKind::TupleLookup(v) => return self.compile_tuple_lookup(*v),
56            ir::ExprKind::Binding(v) => ExprKind::Binding(Box::new(self.compile_binding(*v))),
57            ir::ExprKind::Switch(v) => ExprKind::Switch(self.compile_switch(v)),
58        };
59
60        Expr { kind }
61    }
62
63    fn compile_pointer(&mut self, ptr: ir::Pointer, ty: &ir::Ty) -> Sid {
64        match ptr {
65            ir::Pointer::External(e_ptr) => {
66                let ty = self.get_ty_mat(ty);
67                let e_symbol = self.compile_external_symbol(e_ptr.id, ty);
68                let (index, _) = self.externals.insert_full(e_symbol);
69
70                Sid(index as u32).with_tag(SidKind::External)
71            }
72            #[rustfmt::skip]
73            ir::Pointer::Binding(binding_id) => {
74                Sid(binding_id).with_tag(SidKind::Var)
75            },
76            ir::Pointer::Parameter(param_ptr) => {
77                let sid = param_ptr.function_id << 8 | param_ptr.param_position as u32;
78
79                Sid(sid).with_tag(SidKind::FunctionScope)
80            }
81        }
82    }
83
84    fn compile_literal(&mut self, value: ir::Literal) -> Vec<u8> {
85        match value {
86            ir::Literal::bool(v) => v.encode(),
87            ir::Literal::int8(v) => v.encode(),
88            ir::Literal::int16(v) => v.encode(),
89            ir::Literal::int32(v) => v.encode(),
90            ir::Literal::int64(v) => v.encode(),
91            ir::Literal::uint8(v) => v.encode(),
92            ir::Literal::uint16(v) => v.encode(),
93            ir::Literal::uint32(v) => v.encode(),
94            ir::Literal::uint64(v) => v.encode(),
95            ir::Literal::float32(v) => v.encode(),
96            ir::Literal::float64(v) => v.encode(),
97            ir::Literal::text(v) => v.encode(),
98        }
99    }
100
101    fn compile_call(&mut self, value: ir::Call) -> Call {
102        Call {
103            function: self.compile_expr(value.function),
104            args: value
105                .args
106                .into_iter()
107                .map(|x| self.compile_expr(x))
108                .collect(),
109        }
110    }
111
112    fn compile_function(&mut self, value: ir::Function) -> Function {
113        Function {
114            symbol_ns: Sid(value.id << 8).with_tag(SidKind::FunctionScope),
115            body: self.compile_expr(value.body),
116        }
117    }
118
119    fn compile_tuple(&mut self, fields: Vec<ir::TupleField>) -> Tuple {
120        let field_layouts = fields
121            .iter()
122            .flat_map(|f| {
123                if f.unpack {
124                    let ir::TyKind::Tuple(fields) = &self.get_ty_mat(&f.expr.ty).kind else {
125                        panic!();
126                    };
127                    fields.iter().map(|f| &f.ty).collect::<Vec<_>>()
128                } else {
129                    vec![&f.expr.ty]
130                }
131            })
132            .map(|ty| self.compile_ty_layout(ty.layout.clone().unwrap()))
133            .collect();
134
135        let fields = fields
136            .into_iter()
137            .map(|f| {
138                let unpack = if f.unpack {
139                    let ir::TyKind::Tuple(fields) = &self.get_ty_mat(&f.expr.ty).kind else {
140                        panic!();
141                    };
142                    fields.len() as u8
143                } else {
144                    0
145                };
146                let expr = self.compile_expr(f.expr);
147
148                TupleField { expr, unpack }
149            })
150            .collect();
151        Tuple {
152            fields,
153            field_layouts,
154        }
155    }
156
157    fn compile_array(&mut self, ty: ir::Ty, items: Vec<ir::Expr>) -> Array {
158        Array {
159            items: items.into_iter().map(|x| self.compile_expr(x)).collect(),
160            item_layout: self.compile_ty_layout(ty.kind.into_array().unwrap().layout.unwrap()),
161        }
162    }
163
164    fn compile_enum_variant(&mut self, ty: Ty, v: ir::EnumVariant) -> EnumVariant {
165        let ty_mat = self.get_ty_mat(&ty);
166        let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
167            panic!()
168        };
169        let ty_variant = ty_variants.get(v.tag as usize).unwrap();
170        let head_format = lutra_bin::layout::enum_head_format(ty_variants, &ty.variants_recursive);
171        let variant_format = lutra_bin::layout::enum_variant_format(&head_format, &ty_variant.ty);
172
173        EnumVariant {
174            tag: v.tag.to_le_bytes()[0..head_format.tag_bytes as usize].to_vec(),
175            inner_bytes: head_format.inner_bytes as u8,
176            has_ptr: head_format.has_ptr,
177            padding_bytes: variant_format.padding_bytes,
178            inner: self.compile_expr(v.inner),
179        }
180    }
181
182    fn compile_enum_eq(&mut self, v: ir::EnumEq) -> EnumEq {
183        let ty_mat = self.get_ty_mat(&v.subject.ty);
184        let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
185            panic!()
186        };
187        let head_format =
188            lutra_bin::layout::enum_head_format(ty_variants, &ty_mat.variants_recursive);
189
190        let tag = v.tag.to_le_bytes()[0..head_format.tag_bytes as usize].to_vec();
191        EnumEq {
192            tag,
193            expr: self.compile_expr(v.subject),
194        }
195    }
196
197    fn compile_enum_unwrap(&mut self, v: ir::EnumUnwrap) -> Expr {
198        let ty_mat = self.get_ty_mat(&v.subject.ty);
199        let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
200            panic!()
201        };
202
203        let head_format =
204            lutra_bin::layout::enum_head_format(ty_variants, &ty_mat.variants_recursive);
205
206        let mut expr = self.compile_expr(v.subject);
207
208        // offset tag
209        expr = Expr {
210            kind: ExprKind::Offset(Box::new(Offset {
211                base: expr,
212                offset: head_format.tag_bytes,
213            })),
214        };
215
216        // dereference pointer (if there is a pointer)
217        if head_format.has_ptr {
218            expr = Expr {
219                kind: ExprKind::Deref(Box::new(Deref { ptr: expr })),
220            };
221        }
222
223        expr
224    }
225
226    fn compile_tuple_lookup(&mut self, value: ir::TupleLookup) -> Expr {
227        let base_ty = self.get_ty_mat(&value.base.ty);
228        let offset = lutra_bin::layout::tuple_field_offset(base_ty, value.position);
229
230        let kind = ExprKind::Offset(Box::new(Offset {
231            base: self.compile_expr(value.base),
232            offset,
233        }));
234        Expr { kind }
235    }
236
237    fn compile_binding(&mut self, value: ir::Binding) -> Binding {
238        Binding {
239            symbol: Sid(value.id).with_tag(SidKind::Var),
240            expr: self.compile_expr(value.expr),
241            main: self.compile_expr(value.main),
242        }
243    }
244
245    fn compile_switch(&mut self, branches: Vec<ir::SwitchBranch>) -> Vec<SwitchBranch> {
246        branches
247            .into_iter()
248            .map(|b| SwitchBranch {
249                condition: self.compile_expr(b.condition),
250                value: self.compile_expr(b.value),
251            })
252            .collect()
253    }
254
255    fn compile_ty_layout(&self, value: ir::TyLayout) -> TyLayout {
256        TyLayout {
257            head_size: value.head_size,
258            body_ptrs: value.body_ptrs,
259        }
260    }
261
262    fn compile_external_symbol(&mut self, id: String, ty_mat: &ir::Ty) -> ExternalSymbol {
263        let layout_args: Vec<u32> = match id.as_str() {
264            "std::to_int8" | "std::to_int16" | "std::to_int32" | "std::to_int64"
265            | "std::to_uint8" | "std::to_uint16" | "std::to_uint32" | "std::to_uint64"
266            | "std::to_float32" | "std::to_float64" | "std::to_text" | "std::mul" | "std::div"
267            | "std::mod" | "std::add" | "std::sub" | "std::neg" | "std::cmp" | "std::eq"
268            | "std::lt" | "std::lte" | "std::sequence" | "std::math::abs" | "std::math::pow" => {
269                let param_ty = as_ty_of_param(ty_mat);
270                let primitive = param_ty.kind.as_primitive().unwrap();
271
272                vec![encode_prim(primitive)]
273            }
274
275            "std::fold" => {
276                let item_layout = as_layout_of_param_array(ty_mat);
277                vec![
278                    item_layout.head_size.div_ceil(8), // item_head_size
279                ]
280            }
281
282            "std::min"
283            | "std::max"
284            | "std::sum"
285            | "std::mean"
286            | "std::rolling_mean"
287            | "std::rank"
288            | "std::rank_dense"
289            | "std::rank_percentile"
290            | "std::cume_dist" => {
291                let param_ty = as_ty_of_param(ty_mat);
292                let item_ty = self.get_ty_mat(param_ty).kind.as_array().unwrap();
293
294                let item_layout = item_ty.layout.as_ref().unwrap();
295                let item_ty = item_ty.kind.as_primitive().unwrap();
296
297                vec![
298                    item_layout.head_size.div_ceil(8), // item_head_size
299                    encode_prim(item_ty),
300                ]
301            }
302
303            "std::index" => {
304                let item_layout = as_layout_of_param_array(ty_mat);
305
306                let ty_func = ty_mat.kind.as_function().unwrap();
307                let ty_out_variants = ty_func.body.kind.as_enum().unwrap();
308                let ty_out_format = lutra_bin::layout::enum_format(
309                    ty_out_variants,
310                    &ty_func.body.variants_recursive,
311                );
312                let ty_out_format = ty_out_format.encode();
313
314                let mut r = vec![
315                    item_layout.head_size.div_ceil(8), // item_head_size
316                ];
317
318                pack_bytes_to_u32(ty_out_format, &mut r);
319                r
320            }
321
322            "std::filter" | "std::slice" | "std::append" | "std::apply_until_empty" => {
323                let item_layout = as_layout_of_param_array(ty_mat);
324
325                let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
326                r.push(item_layout.head_size.div_ceil(8)); // item_head_size
327                r.extend(as_len_and_items(&item_layout.body_ptrs)); // item_body_ptrs
328                r
329            }
330            "std::sort" => {
331                let item_layout = as_layout_of_param_array(ty_mat);
332
333                let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
334                r.push(item_layout.head_size.div_ceil(8)); // item_head_size
335                r.extend(as_len_and_items(&item_layout.body_ptrs)); // item_body_ptrs
336
337                // ty of key
338                let ty_func = ty_mat.kind.as_function().unwrap();
339                let ty_key_extractor = self.get_ty_mat(&ty_func.params[1]);
340                let ty_key_extractor = ty_key_extractor.kind.as_function().unwrap();
341                let ty_key = self.get_ty_mat(&ty_key_extractor.body);
342                r.push(encode_prim(ty_key.kind.as_primitive().unwrap()));
343
344                r
345            }
346
347            "std::lag" | "std::lead" => {
348                let item_layout = as_layout_of_param_array(ty_mat);
349
350                let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
351                r.push(item_layout.head_size.div_ceil(8)); // item_head_size
352                r.extend(as_len_and_items(&item_layout.body_ptrs)); // item_body_ptrs
353
354                // also encode default value
355                let ty_func = ty_mat.kind.as_function().unwrap();
356                let ty_item = ty_func.body.kind.as_array().unwrap();
357                let default_val = self.construct_default_for_ty(ty_item);
358                let default_val = default_val.encode(ty_item, self.defs).unwrap();
359                pack_bytes_to_u32(default_val, &mut r);
360
361                r
362            }
363
364            "std::map" | "std::flat_map" | "std::scan" => {
365                let input_layout = as_layout_of_param_array(ty_mat);
366                let output_layout = as_layout_of_return_array(ty_mat);
367
368                let mut r = Vec::with_capacity(2 + 1 + output_layout.body_ptrs.len());
369                r.push(input_layout.head_size.div_ceil(8)); // input_item_head
370                r.push(output_layout.head_size.div_ceil(8)); // output_item_head
371                r.extend(as_len_and_items(&output_layout.body_ptrs)); // output_item_body_ptrs
372                r
373            }
374
375            "std::to_columnar" => {
376                let ty_func = ty_mat.kind.as_function().unwrap();
377
378                let input_item = ty_func.params[0].kind.as_array().unwrap();
379                let input_layout = input_item.layout.as_ref().unwrap();
380
381                let mut r = Vec::new();
382                r.push(input_layout.head_size.div_ceil(8)); // item_head_size
383
384                let input_field_offsets = lutra_bin::layout::tuple_field_offsets(input_item);
385                r.extend(as_len_and_items(&input_field_offsets)); // field_offsets
386
387                // fields_head_bytes
388                let fields = input_item.kind.as_tuple().unwrap();
389                r.push(fields.len() as u32);
390                for field in fields {
391                    let field_layout = field.ty.layout.as_ref().unwrap();
392                    r.push(field_layout.head_size.div_ceil(8));
393                }
394
395                // fields_body_ptrs
396                for field in fields {
397                    let field_layout = field.ty.layout.as_ref().unwrap();
398                    r.extend(as_len_and_items(&field_layout.body_ptrs));
399                }
400
401                r
402            }
403            "std::from_columnar" => {
404                let ty_func = ty_mat.kind.as_function().unwrap();
405
406                let output_item = ty_func.body.kind.as_array().unwrap();
407                let output_layout = output_item.layout.as_ref().unwrap();
408
409                let mut r = Vec::new();
410                r.push(output_layout.head_size.div_ceil(8)); // output_head_bytes
411
412                r.extend(as_len_and_items(&output_layout.body_ptrs)); // output_body_ptrs
413
414                // fields_item_head_bytes
415                let fields = output_item.kind.as_tuple().unwrap();
416                r.push(fields.len() as u32);
417                for field in fields {
418                    let field_layout = field.ty.layout.as_ref().unwrap();
419                    r.push(field_layout.head_size.div_ceil(8));
420                }
421
422                // fields_body_ptrs
423                for field in fields {
424                    let field_layout = field.ty.layout.as_ref().unwrap();
425                    r.extend(as_len_and_items(&field_layout.body_ptrs));
426                }
427
428                r
429            }
430
431            "std::zip" => {
432                let ty_func = ty_mat.kind.as_function().unwrap();
433
434                let a_item = self.get_ty_mat(&ty_func.params[0]).kind.as_array().unwrap();
435                let a_layout = a_item.layout.as_ref().unwrap();
436
437                let b_item = self.get_ty_mat(&ty_func.params[1]).kind.as_array().unwrap();
438                let b_layout = b_item.layout.as_ref().unwrap();
439
440                let mut r = Vec::new();
441                r.push(a_layout.head_size.div_ceil(8));
442                r.extend(as_len_and_items(&a_layout.body_ptrs));
443                r.push(b_layout.head_size.div_ceil(8));
444                r.extend(as_len_and_items(&b_layout.body_ptrs));
445                r
446            }
447
448            "std::group" => {
449                let ty_func = ty_mat.kind.as_function().unwrap();
450
451                let input_item = self.get_ty_mat(&ty_func.params[0]).kind.as_array().unwrap();
452                let input_layout = input_item.layout.as_ref().unwrap();
453
454                let output_item = self.get_ty_mat(&ty_func.body).kind.as_array().unwrap();
455                let output_layout = output_item.layout.as_ref().unwrap();
456
457                let key = &self.get_ty_mat(output_item).kind.as_tuple().unwrap()[0].ty;
458                let key_layout = key.layout.as_ref().unwrap();
459
460                let mut r = Vec::new();
461                r.push(input_layout.head_size.div_ceil(8)); // input_head_bytes
462                r.extend(as_len_and_items(&input_layout.body_ptrs)); // input_body_ptrs
463
464                r.push(output_layout.head_size.div_ceil(8)); // output_head_bytes
465                r.extend(as_len_and_items(&output_layout.body_ptrs)); // output_body_ptrs
466
467                // output_field_head_bytes
468                let fields = output_item.kind.as_tuple().unwrap();
469                r.push(fields.len() as u32);
470                for field in fields {
471                    let field_layout = field.ty.layout.as_ref().unwrap();
472                    r.push(field_layout.head_size.div_ceil(8));
473                }
474
475                // output_fields_body_ptrs
476                for field in fields {
477                    let field_layout = field.ty.layout.as_ref().unwrap();
478                    r.extend(as_len_and_items(&field_layout.body_ptrs));
479                }
480
481                r.push(key_layout.head_size.div_ceil(8)); // key_head_bytes
482
483                r
484            }
485
486            "std::fs::read_parquet" => {
487                // Pass the output array item type so interpreter can validate nullability
488                let ty_func = ty_mat.kind.as_function().unwrap();
489                let ty_return = &ty_func.body;
490
491                self.include_defs = true;
492
493                let mut r = Vec::new();
494                pack_bytes_to_u32(ty_return.encode(), &mut r);
495                r
496            }
497            "std::fs::write_parquet" => {
498                let array = self.get_ty_mat(as_ty_of_param(ty_mat));
499                let array_item = array.kind.as_array().unwrap();
500
501                let mut r = Vec::new();
502                pack_bytes_to_u32(array_item.encode(), &mut r);
503                r
504            }
505
506            _ => vec![],
507        };
508        ExternalSymbol { id, layout_args }
509    }
510
511    fn construct_default_for_ty(&self, ty: &ir::Ty) -> lutra_bin::Value {
512        match &self.get_ty_mat(ty).kind {
513            ir::TyKind::Primitive(prim) => match prim {
514                ir::TyPrimitive::bool | ir::TyPrimitive::int8 | ir::TyPrimitive::uint8 => {
515                    lutra_bin::Value::Prim8(0)
516                }
517
518                ir::TyPrimitive::int16 | ir::TyPrimitive::uint16 => lutra_bin::Value::Prim16(0),
519
520                ir::TyPrimitive::int32 | ir::TyPrimitive::uint32 | ir::TyPrimitive::float32 => {
521                    lutra_bin::Value::Prim32(0)
522                }
523
524                ir::TyPrimitive::int64 | ir::TyPrimitive::uint64 | ir::TyPrimitive::float64 => {
525                    lutra_bin::Value::Prim64(0)
526                }
527                ir::TyPrimitive::text => lutra_bin::Value::Text("".into()),
528            },
529            ir::TyKind::Array(_) => lutra_bin::Value::Array(vec![]),
530            ir::TyKind::Tuple(ty_fields) => lutra_bin::Value::Tuple(
531                ty_fields
532                    .iter()
533                    .map(|f| self.construct_default_for_ty(&f.ty))
534                    .collect(),
535            ),
536            ir::TyKind::Enum(ty_enum_variants) => {
537                let variant = ty_enum_variants.iter().next().unwrap();
538                lutra_bin::Value::Enum(0, Box::new(self.construct_default_for_ty(&variant.ty)))
539            }
540
541            ir::TyKind::Function(_) => panic!(),
542            ir::TyKind::Ident(_) => unreachable!(),
543        }
544    }
545}
546
547fn encode_prim(primitive: &ir::TyPrimitive) -> u32 {
548    let mut buf = primitive.encode();
549    buf.put_bytes(0, 3);
550    // padding
551    u32::from_be_bytes(buf[0..4].try_into().unwrap())
552}
553
554fn as_len_and_items(items: &[u32]) -> impl Iterator<Item = u32> + '_ {
555    Some(items.len() as u32)
556        .into_iter()
557        .chain(items.iter().cloned())
558}
559
560fn as_layout_of_param_array(ty: &Ty) -> &ir::TyLayout {
561    let ty_func = ty.kind.as_function().unwrap();
562    let ty_array = ty_func.params[0].kind.as_array().unwrap();
563
564    ty_array.layout.as_ref().unwrap()
565}
566
567fn as_layout_of_return_array(ty: &Ty) -> &ir::TyLayout {
568    let ty_func = ty.kind.as_function().unwrap();
569    let ty_array = ty_func.body.kind.as_array().unwrap();
570
571    ty_array.layout.as_ref().unwrap()
572}
573
574fn as_ty_of_param(ty: &Ty) -> &ir::Ty {
575    let ty_func = ty.kind.as_function().unwrap();
576    &ty_func.params[0]
577}
578
579fn pack_bytes_to_u32(mut input: Vec<u8>, output: &mut Vec<u32>) {
580    let input_len = input.len();
581
582    // pad
583    if !input.len().is_multiple_of(4) {
584        input.put_bytes(0, 4 - input.len() % 4);
585    }
586
587    // cast to Vec<u32> as le bytes
588    output.reserve(2 + input.len() / 4);
589    output.push((input.len() / 4) as u32 + 1);
590    output.push(input_len as u32);
591    for chunk in input.chunks_exact(4) {
592        output.push(u32::from_le_bytes(chunk.try_into().unwrap()));
593    }
594}