1use std::collections::HashMap;
2
3use indexmap::IndexSet;
4use lutra_bin::Encode;
5use lutra_bin::br::*;
6use lutra_bin::bytes::BufMut;
7use lutra_bin::ir;
8
9pub fn compile_program(value: ir::Program) -> Program {
10 let mut b = ByteCoder {
11 externals: Default::default(),
12 include_defs: false,
13
14 defs: &value.defs,
15 def_map: value.defs.iter().map(|def| (&def.name, &def.ty)).collect(),
16 };
17
18 Program {
19 main: b.compile_expr(value.main),
20 externals: b.externals.into_iter().collect(),
21 defs: if b.include_defs { value.defs } else { vec![] },
22 }
23}
24
25struct ByteCoder<'t> {
26 externals: IndexSet<ExternalSymbol>,
27 defs: &'t [ir::TyDef],
28 def_map: HashMap<&'t ir::Path, &'t ir::Ty>,
29
30 include_defs: bool,
32}
33
34impl<'t> ByteCoder<'t> {
35 fn get_ty_mat<'a: 't>(&self, ty: &'a ir::Ty) -> &'t ir::Ty {
36 match &ty.kind {
37 TyKind::Ident(path) => self.def_map.get(path).unwrap(),
38 _ => ty,
39 }
40 }
41
42 fn compile_expr(&mut self, expr: ir::Expr) -> Expr {
43 let kind = match expr.kind {
44 ir::ExprKind::Pointer(v) => ExprKind::Pointer(self.compile_pointer(v, &expr.ty)),
45 ir::ExprKind::Literal(v) => ExprKind::Literal(self.compile_literal(v)),
46 ir::ExprKind::Call(v) => ExprKind::Call(Box::new(self.compile_call(*v))),
47 ir::ExprKind::Function(v) => ExprKind::Function(Box::new(self.compile_function(*v))),
48 ir::ExprKind::Tuple(v) => ExprKind::Tuple(Box::new(self.compile_tuple(v))),
49 ir::ExprKind::Array(v) => ExprKind::Array(Box::new(self.compile_array(expr.ty, v))),
50 ir::ExprKind::EnumVariant(v) => {
51 ExprKind::EnumVariant(Box::new(self.compile_enum_variant(expr.ty, *v)))
52 }
53 ir::ExprKind::EnumEq(v) => ExprKind::EnumEq(Box::new(self.compile_enum_eq(*v))),
54 ir::ExprKind::EnumUnwrap(v) => return self.compile_enum_unwrap(*v),
55 ir::ExprKind::TupleLookup(v) => return self.compile_tuple_lookup(*v),
56 ir::ExprKind::Binding(v) => ExprKind::Binding(Box::new(self.compile_binding(*v))),
57 ir::ExprKind::Switch(v) => ExprKind::Switch(self.compile_switch(v)),
58 };
59
60 Expr { kind }
61 }
62
63 fn compile_pointer(&mut self, ptr: ir::Pointer, ty: &ir::Ty) -> Sid {
64 match ptr {
65 ir::Pointer::External(e_ptr) => {
66 let ty = self.get_ty_mat(ty);
67 let e_symbol = self.compile_external_symbol(e_ptr.id, ty);
68 let (index, _) = self.externals.insert_full(e_symbol);
69
70 Sid(index as u32).with_tag(SidKind::External)
71 }
72 #[rustfmt::skip]
73 ir::Pointer::Binding(binding_id) => {
74 Sid(binding_id).with_tag(SidKind::Var)
75 },
76 ir::Pointer::Parameter(param_ptr) => {
77 let sid = param_ptr.function_id << 8 | param_ptr.param_position as u32;
78
79 Sid(sid).with_tag(SidKind::FunctionScope)
80 }
81 }
82 }
83
84 fn compile_literal(&mut self, value: ir::Literal) -> Vec<u8> {
85 match value {
86 ir::Literal::bool(v) => v.encode(),
87 ir::Literal::int8(v) => v.encode(),
88 ir::Literal::int16(v) => v.encode(),
89 ir::Literal::int32(v) => v.encode(),
90 ir::Literal::int64(v) => v.encode(),
91 ir::Literal::uint8(v) => v.encode(),
92 ir::Literal::uint16(v) => v.encode(),
93 ir::Literal::uint32(v) => v.encode(),
94 ir::Literal::uint64(v) => v.encode(),
95 ir::Literal::float32(v) => v.encode(),
96 ir::Literal::float64(v) => v.encode(),
97 ir::Literal::text(v) => v.encode(),
98 }
99 }
100
101 fn compile_call(&mut self, value: ir::Call) -> Call {
102 Call {
103 function: self.compile_expr(value.function),
104 args: value
105 .args
106 .into_iter()
107 .map(|x| self.compile_expr(x))
108 .collect(),
109 }
110 }
111
112 fn compile_function(&mut self, value: ir::Function) -> Function {
113 Function {
114 symbol_ns: Sid(value.id << 8).with_tag(SidKind::FunctionScope),
115 body: self.compile_expr(value.body),
116 }
117 }
118
119 fn compile_tuple(&mut self, fields: Vec<ir::TupleField>) -> Tuple {
120 let field_layouts = fields
121 .iter()
122 .flat_map(|f| {
123 if f.unpack {
124 let ir::TyKind::Tuple(fields) = &self.get_ty_mat(&f.expr.ty).kind else {
125 panic!();
126 };
127 fields.iter().map(|f| &f.ty).collect::<Vec<_>>()
128 } else {
129 vec![&f.expr.ty]
130 }
131 })
132 .map(|ty| self.compile_ty_layout(ty.layout.clone().unwrap()))
133 .collect();
134
135 let fields = fields
136 .into_iter()
137 .map(|f| {
138 let unpack = if f.unpack {
139 let ir::TyKind::Tuple(fields) = &self.get_ty_mat(&f.expr.ty).kind else {
140 panic!();
141 };
142 fields.len() as u8
143 } else {
144 0
145 };
146 let expr = self.compile_expr(f.expr);
147
148 TupleField { expr, unpack }
149 })
150 .collect();
151 Tuple {
152 fields,
153 field_layouts,
154 }
155 }
156
157 fn compile_array(&mut self, ty: ir::Ty, items: Vec<ir::Expr>) -> Array {
158 Array {
159 items: items.into_iter().map(|x| self.compile_expr(x)).collect(),
160 item_layout: self.compile_ty_layout(ty.kind.into_array().unwrap().layout.unwrap()),
161 }
162 }
163
164 fn compile_enum_variant(&mut self, ty: Ty, v: ir::EnumVariant) -> EnumVariant {
165 let ty_mat = self.get_ty_mat(&ty);
166 let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
167 panic!()
168 };
169 let ty_variant = ty_variants.get(v.tag as usize).unwrap();
170 let head_format = lutra_bin::layout::enum_head_format(ty_variants, &ty.variants_recursive);
171 let variant_format = lutra_bin::layout::enum_variant_format(&head_format, &ty_variant.ty);
172
173 EnumVariant {
174 tag: v.tag.to_le_bytes()[0..head_format.tag_bytes as usize].to_vec(),
175 inner_bytes: head_format.inner_bytes as u8,
176 has_ptr: head_format.has_ptr,
177 padding_bytes: variant_format.padding_bytes,
178 inner: self.compile_expr(v.inner),
179 }
180 }
181
182 fn compile_enum_eq(&mut self, v: ir::EnumEq) -> EnumEq {
183 let ty_mat = self.get_ty_mat(&v.subject.ty);
184 let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
185 panic!()
186 };
187 let head_format =
188 lutra_bin::layout::enum_head_format(ty_variants, &ty_mat.variants_recursive);
189
190 let tag = v.tag.to_le_bytes()[0..head_format.tag_bytes as usize].to_vec();
191 EnumEq {
192 tag,
193 expr: self.compile_expr(v.subject),
194 }
195 }
196
197 fn compile_enum_unwrap(&mut self, v: ir::EnumUnwrap) -> Expr {
198 let ty_mat = self.get_ty_mat(&v.subject.ty);
199 let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
200 panic!()
201 };
202
203 let head_format =
204 lutra_bin::layout::enum_head_format(ty_variants, &ty_mat.variants_recursive);
205
206 let mut expr = self.compile_expr(v.subject);
207
208 expr = Expr {
210 kind: ExprKind::Offset(Box::new(Offset {
211 base: expr,
212 offset: head_format.tag_bytes,
213 })),
214 };
215
216 if head_format.has_ptr {
218 expr = Expr {
219 kind: ExprKind::Deref(Box::new(Deref { ptr: expr })),
220 };
221 }
222
223 expr
224 }
225
226 fn compile_tuple_lookup(&mut self, value: ir::TupleLookup) -> Expr {
227 let base_ty = self.get_ty_mat(&value.base.ty);
228 let offset = lutra_bin::layout::tuple_field_offset(base_ty, value.position);
229
230 let kind = ExprKind::Offset(Box::new(Offset {
231 base: self.compile_expr(value.base),
232 offset,
233 }));
234 Expr { kind }
235 }
236
237 fn compile_binding(&mut self, value: ir::Binding) -> Binding {
238 Binding {
239 symbol: Sid(value.id).with_tag(SidKind::Var),
240 expr: self.compile_expr(value.expr),
241 main: self.compile_expr(value.main),
242 }
243 }
244
245 fn compile_switch(&mut self, branches: Vec<ir::SwitchBranch>) -> Vec<SwitchBranch> {
246 branches
247 .into_iter()
248 .map(|b| SwitchBranch {
249 condition: self.compile_expr(b.condition),
250 value: self.compile_expr(b.value),
251 })
252 .collect()
253 }
254
255 fn compile_ty_layout(&self, value: ir::TyLayout) -> TyLayout {
256 TyLayout {
257 head_size: value.head_size,
258 body_ptrs: value.body_ptrs,
259 }
260 }
261
262 fn compile_external_symbol(&mut self, id: String, ty_mat: &ir::Ty) -> ExternalSymbol {
263 let layout_args: Vec<u32> = match id.as_str() {
264 "std::to_int8" | "std::to_int16" | "std::to_int32" | "std::to_int64"
265 | "std::to_uint8" | "std::to_uint16" | "std::to_uint32" | "std::to_uint64"
266 | "std::to_float32" | "std::to_float64" | "std::to_text" | "std::mul" | "std::div"
267 | "std::mod" | "std::add" | "std::sub" | "std::neg" | "std::cmp" | "std::eq"
268 | "std::lt" | "std::lte" | "std::sequence" | "std::math::abs" | "std::math::pow" => {
269 let param_ty = as_ty_of_param(ty_mat);
270 let primitive = param_ty.kind.as_primitive().unwrap();
271
272 vec![encode_prim(primitive)]
273 }
274
275 "std::fold" => {
276 let item_layout = as_layout_of_param_array(ty_mat);
277 vec![
278 item_layout.head_size.div_ceil(8), ]
280 }
281
282 "std::min"
283 | "std::max"
284 | "std::sum"
285 | "std::mean"
286 | "std::rolling_mean"
287 | "std::rank"
288 | "std::rank_dense"
289 | "std::rank_percentile"
290 | "std::cume_dist" => {
291 let param_ty = as_ty_of_param(ty_mat);
292 let item_ty = self.get_ty_mat(param_ty).kind.as_array().unwrap();
293
294 let item_layout = item_ty.layout.as_ref().unwrap();
295 let item_ty = item_ty.kind.as_primitive().unwrap();
296
297 vec![
298 item_layout.head_size.div_ceil(8), encode_prim(item_ty),
300 ]
301 }
302
303 "std::index" => {
304 let item_layout = as_layout_of_param_array(ty_mat);
305
306 let ty_func = ty_mat.kind.as_function().unwrap();
307 let ty_out_variants = ty_func.body.kind.as_enum().unwrap();
308 let ty_out_format = lutra_bin::layout::enum_format(
309 ty_out_variants,
310 &ty_func.body.variants_recursive,
311 );
312 let ty_out_format = ty_out_format.encode();
313
314 let mut r = vec![
315 item_layout.head_size.div_ceil(8), ];
317
318 pack_bytes_to_u32(ty_out_format, &mut r);
319 r
320 }
321
322 "std::filter" | "std::slice" | "std::append" | "std::apply_until_empty" => {
323 let item_layout = as_layout_of_param_array(ty_mat);
324
325 let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
326 r.push(item_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&item_layout.body_ptrs)); r
329 }
330 "std::sort" => {
331 let item_layout = as_layout_of_param_array(ty_mat);
332
333 let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
334 r.push(item_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&item_layout.body_ptrs)); let ty_func = ty_mat.kind.as_function().unwrap();
339 let ty_key_extractor = self.get_ty_mat(&ty_func.params[1]);
340 let ty_key_extractor = ty_key_extractor.kind.as_function().unwrap();
341 let ty_key = self.get_ty_mat(&ty_key_extractor.body);
342 r.push(encode_prim(ty_key.kind.as_primitive().unwrap()));
343
344 r
345 }
346
347 "std::lag" | "std::lead" => {
348 let item_layout = as_layout_of_param_array(ty_mat);
349
350 let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
351 r.push(item_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&item_layout.body_ptrs)); let ty_func = ty_mat.kind.as_function().unwrap();
356 let ty_item = ty_func.body.kind.as_array().unwrap();
357 let default_val = self.construct_default_for_ty(ty_item);
358 let default_val = default_val.encode(ty_item, self.defs).unwrap();
359 pack_bytes_to_u32(default_val, &mut r);
360
361 r
362 }
363
364 "std::map" | "std::flat_map" | "std::scan" => {
365 let input_layout = as_layout_of_param_array(ty_mat);
366 let output_layout = as_layout_of_return_array(ty_mat);
367
368 let mut r = Vec::with_capacity(2 + 1 + output_layout.body_ptrs.len());
369 r.push(input_layout.head_size.div_ceil(8)); r.push(output_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&output_layout.body_ptrs)); r
373 }
374
375 "std::to_columnar" => {
376 let ty_func = ty_mat.kind.as_function().unwrap();
377
378 let input_item = ty_func.params[0].kind.as_array().unwrap();
379 let input_layout = input_item.layout.as_ref().unwrap();
380
381 let mut r = Vec::new();
382 r.push(input_layout.head_size.div_ceil(8)); let input_field_offsets = lutra_bin::layout::tuple_field_offsets(input_item);
385 r.extend(as_len_and_items(&input_field_offsets)); let fields = input_item.kind.as_tuple().unwrap();
389 r.push(fields.len() as u32);
390 for field in fields {
391 let field_layout = field.ty.layout.as_ref().unwrap();
392 r.push(field_layout.head_size.div_ceil(8));
393 }
394
395 for field in fields {
397 let field_layout = field.ty.layout.as_ref().unwrap();
398 r.extend(as_len_and_items(&field_layout.body_ptrs));
399 }
400
401 r
402 }
403 "std::from_columnar" => {
404 let ty_func = ty_mat.kind.as_function().unwrap();
405
406 let output_item = ty_func.body.kind.as_array().unwrap();
407 let output_layout = output_item.layout.as_ref().unwrap();
408
409 let mut r = Vec::new();
410 r.push(output_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&output_layout.body_ptrs)); let fields = output_item.kind.as_tuple().unwrap();
416 r.push(fields.len() as u32);
417 for field in fields {
418 let field_layout = field.ty.layout.as_ref().unwrap();
419 r.push(field_layout.head_size.div_ceil(8));
420 }
421
422 for field in fields {
424 let field_layout = field.ty.layout.as_ref().unwrap();
425 r.extend(as_len_and_items(&field_layout.body_ptrs));
426 }
427
428 r
429 }
430
431 "std::zip" => {
432 let ty_func = ty_mat.kind.as_function().unwrap();
433
434 let a_item = self.get_ty_mat(&ty_func.params[0]).kind.as_array().unwrap();
435 let a_layout = a_item.layout.as_ref().unwrap();
436
437 let b_item = self.get_ty_mat(&ty_func.params[1]).kind.as_array().unwrap();
438 let b_layout = b_item.layout.as_ref().unwrap();
439
440 let mut r = Vec::new();
441 r.push(a_layout.head_size.div_ceil(8));
442 r.extend(as_len_and_items(&a_layout.body_ptrs));
443 r.push(b_layout.head_size.div_ceil(8));
444 r.extend(as_len_and_items(&b_layout.body_ptrs));
445 r
446 }
447
448 "std::group" => {
449 let ty_func = ty_mat.kind.as_function().unwrap();
450
451 let input_item = self.get_ty_mat(&ty_func.params[0]).kind.as_array().unwrap();
452 let input_layout = input_item.layout.as_ref().unwrap();
453
454 let output_item = self.get_ty_mat(&ty_func.body).kind.as_array().unwrap();
455 let output_layout = output_item.layout.as_ref().unwrap();
456
457 let key = &self.get_ty_mat(output_item).kind.as_tuple().unwrap()[0].ty;
458 let key_layout = key.layout.as_ref().unwrap();
459
460 let mut r = Vec::new();
461 r.push(input_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&input_layout.body_ptrs)); r.push(output_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&output_layout.body_ptrs)); let fields = output_item.kind.as_tuple().unwrap();
469 r.push(fields.len() as u32);
470 for field in fields {
471 let field_layout = field.ty.layout.as_ref().unwrap();
472 r.push(field_layout.head_size.div_ceil(8));
473 }
474
475 for field in fields {
477 let field_layout = field.ty.layout.as_ref().unwrap();
478 r.extend(as_len_and_items(&field_layout.body_ptrs));
479 }
480
481 r.push(key_layout.head_size.div_ceil(8)); r
484 }
485
486 "std::fs::read_parquet" => {
487 let ty_func = ty_mat.kind.as_function().unwrap();
489 let ty_return = &ty_func.body;
490
491 self.include_defs = true;
492
493 let mut r = Vec::new();
494 pack_bytes_to_u32(ty_return.encode(), &mut r);
495 r
496 }
497 "std::fs::write_parquet" => {
498 let array = self.get_ty_mat(as_ty_of_param(ty_mat));
499 let array_item = array.kind.as_array().unwrap();
500
501 let mut r = Vec::new();
502 pack_bytes_to_u32(array_item.encode(), &mut r);
503 r
504 }
505
506 _ => vec![],
507 };
508 ExternalSymbol { id, layout_args }
509 }
510
511 fn construct_default_for_ty(&self, ty: &ir::Ty) -> lutra_bin::Value {
512 match &self.get_ty_mat(ty).kind {
513 ir::TyKind::Primitive(prim) => match prim {
514 ir::TyPrimitive::bool | ir::TyPrimitive::int8 | ir::TyPrimitive::uint8 => {
515 lutra_bin::Value::Prim8(0)
516 }
517
518 ir::TyPrimitive::int16 | ir::TyPrimitive::uint16 => lutra_bin::Value::Prim16(0),
519
520 ir::TyPrimitive::int32 | ir::TyPrimitive::uint32 | ir::TyPrimitive::float32 => {
521 lutra_bin::Value::Prim32(0)
522 }
523
524 ir::TyPrimitive::int64 | ir::TyPrimitive::uint64 | ir::TyPrimitive::float64 => {
525 lutra_bin::Value::Prim64(0)
526 }
527 ir::TyPrimitive::text => lutra_bin::Value::Text("".into()),
528 },
529 ir::TyKind::Array(_) => lutra_bin::Value::Array(vec![]),
530 ir::TyKind::Tuple(ty_fields) => lutra_bin::Value::Tuple(
531 ty_fields
532 .iter()
533 .map(|f| self.construct_default_for_ty(&f.ty))
534 .collect(),
535 ),
536 ir::TyKind::Enum(ty_enum_variants) => {
537 let variant = ty_enum_variants.iter().next().unwrap();
538 lutra_bin::Value::Enum(0, Box::new(self.construct_default_for_ty(&variant.ty)))
539 }
540
541 ir::TyKind::Function(_) => panic!(),
542 ir::TyKind::Ident(_) => unreachable!(),
543 }
544 }
545}
546
547fn encode_prim(primitive: &ir::TyPrimitive) -> u32 {
548 let mut buf = primitive.encode();
549 buf.put_bytes(0, 3);
550 u32::from_be_bytes(buf[0..4].try_into().unwrap())
552}
553
554fn as_len_and_items(items: &[u32]) -> impl Iterator<Item = u32> + '_ {
555 Some(items.len() as u32)
556 .into_iter()
557 .chain(items.iter().cloned())
558}
559
560fn as_layout_of_param_array(ty: &Ty) -> &ir::TyLayout {
561 let ty_func = ty.kind.as_function().unwrap();
562 let ty_array = ty_func.params[0].kind.as_array().unwrap();
563
564 ty_array.layout.as_ref().unwrap()
565}
566
567fn as_layout_of_return_array(ty: &Ty) -> &ir::TyLayout {
568 let ty_func = ty.kind.as_function().unwrap();
569 let ty_array = ty_func.body.kind.as_array().unwrap();
570
571 ty_array.layout.as_ref().unwrap()
572}
573
574fn as_ty_of_param(ty: &Ty) -> &ir::Ty {
575 let ty_func = ty.kind.as_function().unwrap();
576 &ty_func.params[0]
577}
578
579fn pack_bytes_to_u32(mut input: Vec<u8>, output: &mut Vec<u32>) {
580 let input_len = input.len();
581
582 if !input.len().is_multiple_of(4) {
584 input.put_bytes(0, 4 - input.len() % 4);
585 }
586
587 output.reserve(2 + input.len() / 4);
589 output.push((input.len() / 4) as u32 + 1);
590 output.push(input_len as u32);
591 for chunk in input.chunks_exact(4) {
592 output.push(u32::from_le_bytes(chunk.try_into().unwrap()));
593 }
594}