use std::collections::HashMap;
use indexmap::IndexSet;
use lutra_bin::Encode;
use lutra_bin::br::*;
use lutra_bin::bytes::BufMut;
use lutra_bin::ir;
pub fn compile_program(value: ir::Program) -> Program {
let mut b = ByteCoder {
externals: Default::default(),
include_defs: false,
defs: &value.defs,
def_map: value.defs.iter().map(|def| (&def.name, &def.ty)).collect(),
next_wrapper_id: 0xFF00,
};
let program = Program {
main: b.compile_expr(value.main),
externals: b.externals.into_iter().collect(),
defs: if b.include_defs { value.defs } else { vec![] },
};
tracing::debug!("br:\n{program:#?}");
program
}
struct ByteCoder<'t> {
externals: IndexSet<ExternalSymbol>,
defs: &'t [ir::TyDef],
def_map: HashMap<&'t ir::Path, &'t ir::Ty>,
include_defs: bool,
next_wrapper_id: u32,
}
impl<'t> ByteCoder<'t> {
fn get_ty_mat<'a: 't>(&self, ty: &'a ir::Ty) -> &'t ir::Ty {
let mut ty = ty;
while let TyKind::Ident(path) = &ty.kind {
ty = self.def_map.get(path).unwrap();
}
ty
}
fn get_ty_mat_or_std<'a: 't>(&self, ty: &'a ir::Ty) -> &'t ir::Ty {
let mut ty = ty;
while let TyKind::Ident(path) = &ty.kind {
if ir::TyStd::try_new(path).is_some() {
return ty;
}
ty = self.def_map.get(path).unwrap();
}
ty
}
fn compile_expr(&mut self, expr: ir::Expr) -> Expr {
let kind = match expr.kind {
ir::ExprKind::Pointer(v) => self.compile_pointer(v, &expr.ty),
ir::ExprKind::Literal(v) => ExprKind::Literal(self.compile_literal(v)),
ir::ExprKind::Call(v) => ExprKind::Call(Box::new(self.compile_call(*v))),
ir::ExprKind::Function(v) => ExprKind::Function(Box::new(self.compile_function(*v))),
ir::ExprKind::Tuple(v) => ExprKind::Tuple(Box::new(self.compile_tuple(v))),
ir::ExprKind::Array(v) => ExprKind::Array(Box::new(self.compile_array(expr.ty, v))),
ir::ExprKind::EnumVariant(v) => {
ExprKind::EnumVariant(Box::new(self.compile_enum_variant(expr.ty, *v)))
}
ir::ExprKind::EnumTag(v) => self.compile_expr(v.subject).kind,
ir::ExprKind::EnumUnwrap(v) => return self.compile_enum_unwrap(*v),
ir::ExprKind::TupleLookup(v) => return self.compile_tuple_lookup(*v),
ir::ExprKind::Binding(v) => ExprKind::Binding(Box::new(self.compile_binding(*v))),
ir::ExprKind::Switch(v) => ExprKind::Switch(self.compile_switch(v)),
};
Expr { kind }
}
fn compile_pointer(&mut self, ptr: ir::Pointer, ty: &ir::Ty) -> ExprKind {
match ptr {
ir::Pointer::External(e_ptr) => {
let ty = self.get_ty_mat(ty);
self.compile_external_symbol(e_ptr.id, ty)
}
#[rustfmt::skip]
ir::Pointer::Binding(binding_id) => {
ExprKind::Pointer(Sid(binding_id).with_tag(SidKind::Var))
},
ir::Pointer::Parameter(param_ptr) => {
let sid = param_ptr.function_id << 8 | param_ptr.param_position as u32;
ExprKind::Pointer(Sid(sid).with_tag(SidKind::FunctionScope))
}
}
}
fn compile_literal(&mut self, value: ir::Literal) -> Vec<u8> {
match value {
ir::Literal::Prim8(v) => v.encode(),
ir::Literal::Prim16(v) => v.encode(),
ir::Literal::Prim32(v) => v.encode(),
ir::Literal::Prim64(v) => v.encode(),
ir::Literal::Text(v) => v.encode(),
}
}
fn compile_call(&mut self, value: ir::Call) -> Call {
Call {
function: self.compile_expr(value.function),
args: value
.args
.into_iter()
.map(|x| self.compile_expr(x))
.collect(),
}
}
fn compile_function(&mut self, value: ir::Function) -> Function {
Function {
symbol_ns: Sid(value.id << 8).with_tag(SidKind::FunctionScope),
body: self.compile_expr(value.body),
}
}
fn compile_tuple(&mut self, fields: Vec<ir::TupleField>) -> Tuple {
let field_layouts = fields
.iter()
.flat_map(|f| {
if f.unpack {
let ir::TyKind::Tuple(fields) = &self.get_ty_mat(&f.expr.ty).kind else {
panic!();
};
fields.iter().map(|f| &f.ty).collect::<Vec<_>>()
} else {
vec![&f.expr.ty]
}
})
.map(|ty| self.compile_ty_layout(ty.layout.clone().unwrap()))
.collect();
let fields = fields
.into_iter()
.map(|f| {
let unpack = if f.unpack {
let ir::TyKind::Tuple(fields) = &self.get_ty_mat(&f.expr.ty).kind else {
panic!();
};
fields.len() as u8
} else {
0
};
let expr = self.compile_expr(f.expr);
TupleField { expr, unpack }
})
.collect();
Tuple {
fields,
field_layouts,
}
}
fn compile_array(&mut self, ty: ir::Ty, items: Vec<ir::Expr>) -> Array {
let ty_item = self.get_ty_mat(&ty).kind.as_array().unwrap();
Array {
items: items.into_iter().map(|x| self.compile_expr(x)).collect(),
item_layout: self.compile_ty_layout(ty_item.layout.clone().unwrap()),
}
}
fn compile_enum_variant(&mut self, ty: Ty, v: ir::EnumVariant) -> EnumVariant {
let ty_mat = self.get_ty_mat(&ty);
let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
panic!()
};
let ty_variant = ty_variants.get(v.tag as usize).unwrap();
let head_format = lutra_bin::layout::enum_head_format(ty_variants, &ty.variants_recursive);
let variant_format = lutra_bin::layout::enum_variant_format(&head_format, &ty_variant.ty);
EnumVariant {
tag: v.tag.to_le_bytes()[0..head_format.tag_bytes as usize].to_vec(),
inner_bytes: head_format.inner_bytes as u8,
has_ptr: head_format.has_ptr,
padding_bytes: variant_format.padding_bytes,
inner: self.compile_expr(v.inner),
}
}
fn compile_enum_unwrap(&mut self, v: ir::EnumUnwrap) -> Expr {
let ty_mat = self.get_ty_mat(&v.subject.ty);
let ir::TyKind::Enum(ty_variants) = &ty_mat.kind else {
panic!()
};
let head_format =
lutra_bin::layout::enum_head_format(ty_variants, &ty_mat.variants_recursive);
let mut expr = self.compile_expr(v.subject);
expr = Expr {
kind: ExprKind::Offset(Box::new(Offset {
base: expr,
offset: head_format.tag_bytes,
})),
};
if head_format.has_ptr {
expr = Expr {
kind: ExprKind::Deref(Box::new(Deref { ptr: expr })),
};
}
expr
}
fn compile_tuple_lookup(&mut self, value: ir::TupleLookup) -> Expr {
let base_ty = self.get_ty_mat(&value.base.ty);
let offset = lutra_bin::layout::tuple_field_offset(base_ty, value.position);
let kind = ExprKind::Offset(Box::new(Offset {
base: self.compile_expr(value.base),
offset,
}));
Expr { kind }
}
fn compile_binding(&mut self, value: ir::Binding) -> Binding {
Binding {
symbol: Sid(value.id).with_tag(SidKind::Var),
expr: self.compile_expr(value.expr),
main: self.compile_expr(value.main),
}
}
fn compile_switch(&mut self, branches: Vec<ir::SwitchBranch>) -> Vec<SwitchBranch> {
branches
.into_iter()
.map(|b| SwitchBranch {
condition: self.compile_expr(b.condition),
value: self.compile_expr(b.value),
})
.collect()
}
fn compile_ty_layout(&self, value: ir::TyLayout) -> TyLayout {
TyLayout {
head_size: value.head_size,
body_ptrs: value.body_ptrs,
}
}
fn compile_external_symbol(&mut self, id: String, ty_mat: &ir::Ty) -> ExprKind {
let layout_args: Vec<u32> = match id.as_str() {
"std::ops::add"
| "std::ops::sub"
| "std::ops::mul"
| "std::ops::div"
| "std::ops::mod"
| "std::ops::neg"
| "std::ops::cmp"
| "std::ops::eq"
| "std::ops::lt"
| "std::ops::lte"
| "std::convert::to_int8"
| "std::convert::to_int16"
| "std::convert::to_int32"
| "std::convert::to_int64"
| "std::convert::to_uint8"
| "std::convert::to_uint16"
| "std::convert::to_uint32"
| "std::convert::to_uint64"
| "std::convert::to_float32"
| "std::convert::to_float64"
| "std::convert::to_text"
| "std::math::abs"
| "std::math::pow"
| "std::array::sequence" => {
let param_ty = as_ty_of_param(ty_mat);
let ty_name = self.as_std_ty_suffix(param_ty);
return self.make_external(format!("{id}_{ty_name}"), vec![]);
}
"std::array::fold" => {
let item_layout = as_layout_of_param_array(ty_mat);
vec![
item_layout.head_size.div_ceil(8), ]
}
"std::array::min"
| "std::array::max"
| "std::array::rank"
| "std::array::rank_dense"
| "std::array::rank_percentile"
| "std::array::cume_dist" => {
let param_ty = as_ty_of_param(ty_mat);
let item_ty = self.get_ty_mat(param_ty).kind.as_array().unwrap();
let item_layout = item_ty.layout.as_ref().unwrap();
let layout_args = vec![item_layout.head_size.div_ceil(8)];
let n_params = ty_mat.kind.as_function().unwrap().params.len();
let cmp_id = format!("std::ops::cmp_{}", self.as_std_ty_suffix(item_ty));
let cmp = self.make_external(cmp_id, vec![]);
return self.wrap_external_with_extra_args(id, layout_args, n_params, vec![cmp]);
}
"std::array::sort" => {
let item_layout = as_layout_of_param_array(ty_mat);
let mut layout_args = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
layout_args.push(item_layout.head_size.div_ceil(8));
layout_args.extend(as_len_and_items(&item_layout.body_ptrs));
let ty_func = ty_mat.kind.as_function().unwrap();
let key_extractor_ty = self.get_ty_mat(&ty_func.params[1]);
let key_ty = &key_extractor_ty.kind.as_function().unwrap().body;
let n_params = ty_func.params.len();
let cmp_id = format!("std::ops::cmp_{}", self.as_std_ty_suffix(key_ty));
let cmp = self.make_external(cmp_id, vec![]);
return self.wrap_external_with_extra_args(id, layout_args, n_params, vec![cmp]);
}
"std::array::sum" | "std::array::mean" | "std::array::rolling_mean" => {
let param_ty = as_ty_of_param(ty_mat);
let item_ty = self.get_ty_mat(param_ty).kind.as_array().unwrap();
let item_layout = item_ty.layout.as_ref().unwrap();
let layout_args = vec![item_layout.head_size.div_ceil(8)];
let ty_name = self.as_std_ty_suffix(item_ty);
return self.make_external(format!("{id}_{ty_name}"), layout_args);
}
"std::array::index" => {
let item_layout = as_layout_of_param_array(ty_mat);
let ty_func = ty_mat.kind.as_function().unwrap();
let ty_out_variants = self.get_ty_mat(&ty_func.body).kind.as_enum().unwrap();
let ty_out_format = lutra_bin::layout::enum_format(
ty_out_variants,
&ty_func.body.variants_recursive,
);
let ty_out_format = ty_out_format.encode();
let mut r = vec![
item_layout.head_size.div_ceil(8), ];
pack_bytes_to_u32(ty_out_format, &mut r);
r
}
"std::array::filter"
| "std::array::slice"
| "std::array::append"
| "std::array::loop_until_empty" => {
let item_layout = as_layout_of_param_array(ty_mat);
let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
r.push(item_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&item_layout.body_ptrs)); r
}
"std::array::lag" | "std::array::lead" => {
let item_layout = as_layout_of_param_array(ty_mat);
let mut r = Vec::with_capacity(1 + 1 + item_layout.body_ptrs.len());
r.push(item_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&item_layout.body_ptrs));
let ty_func = ty_mat.kind.as_function().unwrap();
let ty_item = ty_func.body.kind.as_array().unwrap();
let default_val = self.construct_default_for_ty(ty_item);
pack_bytes_to_u32(default_val, &mut r);
r
}
"std::array::map" | "std::array::flat_map" | "std::array::scan" => {
let input_layout = as_layout_of_param_array(ty_mat);
let output_layout = as_layout_of_return_array(ty_mat);
let mut r = Vec::with_capacity(2 + 1 + output_layout.body_ptrs.len());
r.push(input_layout.head_size.div_ceil(8)); r.push(output_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&output_layout.body_ptrs)); r
}
"std::array::to_columnar" => {
let ty_func = ty_mat.kind.as_function().unwrap();
let input_item = ty_func.params[0].kind.as_array().unwrap();
let input_layout = input_item.layout.as_ref().unwrap();
let mut r = Vec::new();
r.push(input_layout.head_size.div_ceil(8));
let input_field_offsets = lutra_bin::layout::tuple_field_offsets(input_item);
r.extend(as_len_and_items(&input_field_offsets));
let fields = input_item.kind.as_tuple().unwrap();
r.push(fields.len() as u32);
for field in fields {
let field_layout = field.ty.layout.as_ref().unwrap();
r.push(field_layout.head_size.div_ceil(8));
}
for field in fields {
let field_layout = field.ty.layout.as_ref().unwrap();
r.extend(as_len_and_items(&field_layout.body_ptrs));
}
r
}
"std::array::from_columnar" => {
let ty_func = ty_mat.kind.as_function().unwrap();
let output_item = ty_func.body.kind.as_array().unwrap();
let output_layout = output_item.layout.as_ref().unwrap();
let mut r = Vec::new();
r.push(output_layout.head_size.div_ceil(8));
r.extend(as_len_and_items(&output_layout.body_ptrs));
let fields = output_item.kind.as_tuple().unwrap();
r.push(fields.len() as u32);
for field in fields {
let field_layout = field.ty.layout.as_ref().unwrap();
r.push(field_layout.head_size.div_ceil(8));
}
for field in fields {
let field_layout = field.ty.layout.as_ref().unwrap();
r.extend(as_len_and_items(&field_layout.body_ptrs));
}
r
}
"std::array::zip" => {
let ty_func = ty_mat.kind.as_function().unwrap();
let a_item = self.get_ty_mat(&ty_func.params[0]).kind.as_array().unwrap();
let a_layout = a_item.layout.as_ref().unwrap();
let b_item = self.get_ty_mat(&ty_func.params[1]).kind.as_array().unwrap();
let b_layout = b_item.layout.as_ref().unwrap();
let mut r = Vec::new();
r.push(a_layout.head_size.div_ceil(8));
r.extend(as_len_and_items(&a_layout.body_ptrs));
r.push(b_layout.head_size.div_ceil(8));
r.extend(as_len_and_items(&b_layout.body_ptrs));
r
}
"std::array::group" => {
let ty_func = ty_mat.kind.as_function().unwrap();
let input_item = self.get_ty_mat(&ty_func.params[0]).kind.as_array().unwrap();
let input_layout = input_item.layout.as_ref().unwrap();
let output_item = self.get_ty_mat(&ty_func.body).kind.as_array().unwrap();
let output_layout = output_item.layout.as_ref().unwrap();
let key = &self.get_ty_mat(output_item).kind.as_tuple().unwrap()[0].ty;
let key_layout = key.layout.as_ref().unwrap();
let mut r = Vec::new();
r.push(input_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&input_layout.body_ptrs));
r.push(output_layout.head_size.div_ceil(8)); r.extend(as_len_and_items(&output_layout.body_ptrs));
let fields = output_item.kind.as_tuple().unwrap();
r.push(fields.len() as u32);
for field in fields {
let field_layout = field.ty.layout.as_ref().unwrap();
r.push(field_layout.head_size.div_ceil(8));
}
for field in fields {
let field_layout = field.ty.layout.as_ref().unwrap();
r.extend(as_len_and_items(&field_layout.body_ptrs));
}
r.push(key_layout.head_size.div_ceil(8));
r
}
"std::fs::read_parquet" => {
let ty_func = ty_mat.kind.as_function().unwrap();
let ty_data = &ty_func.body;
self.include_defs = true;
let mut r = Vec::new();
pack_bytes_to_u32(ty_data.encode(), &mut r);
r
}
"std::fs::write_parquet" => {
let ty_data = as_ty_of_param(ty_mat);
self.include_defs = true;
let mut r = Vec::new();
pack_bytes_to_u32(ty_data.encode(), &mut r);
r
}
_ => vec![],
};
let (index, _) = self
.externals
.insert_full(ExternalSymbol { id, layout_args });
ExprKind::Pointer(Sid(index as u32).with_tag(SidKind::External))
}
fn as_std_ty_suffix(&self, ty: &ir::Ty) -> String {
match &self.get_ty_mat_or_std(ty).kind {
ir::TyKind::Ident(path) => path.0.last().unwrap().to_ascii_lowercase(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim8) => "uint8".into(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim16) => "uint16".into(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim32) => "uint32".into(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim64) => "uint64".into(),
_ => {
panic!("std specialization not supported for {}", ir::print_ty(ty));
}
}
}
fn make_external(&mut self, id: String, layout_args: Vec<u32>) -> ExprKind {
let (index, _) = self
.externals
.insert_full(ExternalSymbol { id, layout_args });
ExprKind::Pointer(Sid(index as u32).with_tag(SidKind::External))
}
fn wrap_external_with_extra_args(
&mut self,
id: String,
layout_args: Vec<u32>,
n_params: usize,
extra_args: Vec<ExprKind>,
) -> ExprKind {
let (ext_index, _) = self
.externals
.insert_full(ExternalSymbol { id, layout_args });
let ext_sid = Sid(ext_index as u32).with_tag(SidKind::External);
let wrapper_id = self.next_wrapper_id;
self.next_wrapper_id += 1;
let wrapper_ns = Sid(wrapper_id << 8).with_tag(SidKind::FunctionScope);
let mut args: Vec<Expr> = (0..n_params)
.map(|i| {
let sid = Sid(wrapper_id << 8 | i as u32).with_tag(SidKind::FunctionScope);
Expr {
kind: ExprKind::Pointer(sid),
}
})
.collect();
args.extend(extra_args.into_iter().map(|kind| Expr { kind }));
ExprKind::Function(Box::new(Function {
symbol_ns: wrapper_ns,
body: Expr {
kind: ExprKind::Call(Box::new(Call {
function: Expr {
kind: ExprKind::Pointer(ext_sid),
},
args,
})),
},
}))
}
fn construct_default_for_ty(&self, ty: &ir::Ty) -> Vec<u8> {
self.construct_default_for_ty_re(ty)
.encode(ty, self.defs)
.unwrap()
}
fn construct_default_for_ty_re(&self, ty: &ir::Ty) -> lutra_bin::Value {
match &self.get_ty_mat(ty).kind {
ir::TyKind::Primitive(prim) => match prim {
ir::TyPrimitive::Prim8 => lutra_bin::Value::Prim8(0),
ir::TyPrimitive::Prim16 => lutra_bin::Value::Prim16(0),
ir::TyPrimitive::Prim32 => lutra_bin::Value::Prim32(0),
ir::TyPrimitive::Prim64 => lutra_bin::Value::Prim64(0),
},
ir::TyKind::Array(_) => lutra_bin::Value::Array(vec![]),
ir::TyKind::Tuple(ty_fields) => lutra_bin::Value::Tuple(
ty_fields
.iter()
.map(|f| self.construct_default_for_ty_re(&f.ty))
.collect(),
),
ir::TyKind::Enum(ty_enum_variants) => {
let variant = ty_enum_variants.iter().next().unwrap();
lutra_bin::Value::Enum(0, Box::new(self.construct_default_for_ty_re(&variant.ty)))
}
ir::TyKind::Function(_) => panic!(),
ir::TyKind::Ident(_) => unreachable!(),
}
}
}
fn as_len_and_items(items: &[u32]) -> impl Iterator<Item = u32> + '_ {
Some(items.len() as u32)
.into_iter()
.chain(items.iter().cloned())
}
fn as_layout_of_param_array(ty: &Ty) -> &ir::TyLayout {
let ty_func = ty.kind.as_function().unwrap();
let ty_array = ty_func.params[0].kind.as_array().unwrap();
ty_array.layout.as_ref().unwrap()
}
fn as_layout_of_return_array(ty: &Ty) -> &ir::TyLayout {
let ty_func = ty.kind.as_function().unwrap();
let ty_array = ty_func.body.kind.as_array().unwrap();
ty_array.layout.as_ref().unwrap()
}
fn as_ty_of_param(ty: &Ty) -> &ir::Ty {
let ty_func = ty.kind.as_function().unwrap();
&ty_func.params[0]
}
fn pack_bytes_to_u32(mut input: Vec<u8>, output: &mut Vec<u32>) {
let input_len = input.len();
if !input.len().is_multiple_of(4) {
input.put_bytes(0, 4 - input.len() % 4);
}
output.reserve(2 + input.len() / 4);
output.push((input.len() / 4) as u32 + 1);
output.push(input_len as u32);
for chunk in input.chunks_exact(4) {
output.push(u32::from_le_bytes(chunk.try_into().unwrap()));
}
}