use std::collections::{HashMap, HashSet, VecDeque};
use std::path::PathBuf;
use std::{borrow::Cow, fmt::Write};
use lutra_bin::{Encode, ir, layout};
use lutra_compiler::{ProgramRepr, Project, pr};
use crate::camel_to_snake;
#[derive(Debug)]
struct Context<'a> {
current_rust_mod: Vec<String>,
def_buffer: VecDeque<ir::Ty>,
tys_written: HashSet<String>,
options: &'a super::GenerateOptions,
ty_defs: &'a HashMap<ir::Path, &'a ir::Ty>,
project: &'a Project,
#[allow(dead_code)]
out_dir: PathBuf,
}
impl<'a> Context<'a> {
fn is_done(&self) -> bool {
self.def_buffer.is_empty()
}
#[allow(dead_code)]
fn get_ty_mat<'t: 'a>(&'t self, ty: &'t ir::Ty) -> &'t ir::Ty {
if let ir::TyKind::Ident(path) = &ty.kind {
self.ty_defs.get(path).unwrap()
} else {
ty
}
}
}
fn scalar_std_type_ref(ident: &ir::Path) -> Option<&'static str> {
if ident.is(&["std", "Bool"]) {
Some("bool")
} else if ident.is(&["std", "Int8"])
|| ident.is(&["std", "Int16"])
|| ident.is(&["std", "Int32"])
|| ident.is(&["std", "Int64"])
|| ident.is(&["std", "Uint8"])
|| ident.is(&["std", "Uint16"])
|| ident.is(&["std", "Uint32"])
|| ident.is(&["std", "Uint64"])
{
Some("int")
} else if ident.is(&["std", "Float32"]) || ident.is(&["std", "Float64"]) {
Some("float")
} else if ident.is(&["std", "Text"]) {
Some("str")
} else {
None
}
}
fn scalar_std_codec(ident: &ir::Path) -> Option<&'static str> {
if ident.is(&["std", "Bool"]) {
Some("lutra_bin.BoolCodec()")
} else if ident.is(&["std", "Int8"]) {
Some("lutra_bin.Int8Codec()")
} else if ident.is(&["std", "Int16"]) {
Some("lutra_bin.Int16Codec()")
} else if ident.is(&["std", "Int32"]) {
Some("lutra_bin.Int32Codec()")
} else if ident.is(&["std", "Int64"]) {
Some("lutra_bin.Int64Codec()")
} else if ident.is(&["std", "Uint8"]) {
Some("lutra_bin.Uint8Codec()")
} else if ident.is(&["std", "Uint16"]) {
Some("lutra_bin.Uint16Codec()")
} else if ident.is(&["std", "Uint32"]) {
Some("lutra_bin.Uint32Codec()")
} else if ident.is(&["std", "Uint64"]) {
Some("lutra_bin.Uint64Codec()")
} else if ident.is(&["std", "Float32"]) {
Some("lutra_bin.Float32Codec()")
} else if ident.is(&["std", "Float64"]) {
Some("lutra_bin.Float64Codec()")
} else if ident.is(&["std", "Text"]) {
Some("lutra_bin.TextCodec()")
} else {
None
}
}
pub(crate) fn run(
project: &Project,
options: &super::GenerateOptions,
out_dir: PathBuf,
) -> Result<String, std::fmt::Error> {
use std::fmt::Write;
let module = lutra_compiler::project_to_types(project);
let ty_defs = module.iter_types_re().collect();
let mut w = String::new();
writeln!(w, "# Generated by lutra-codegen\n")?;
writeln!(w, "import base64 # noqa: F401")?;
writeln!(w, "import dataclasses # noqa: F401")?;
writeln!(w, "import functools # noqa: F401")?;
writeln!(w, "import typing # noqa: F401")?;
writeln!(w)?;
writeln!(w, "import lutra_bin")?;
writeln!(w)?;
let mut ctx = Context {
current_rust_mod: vec![],
def_buffer: VecDeque::new(),
tys_written: Default::default(),
options,
ty_defs: &ty_defs,
project,
out_dir,
};
let module_path = vec![];
codegen_module(&mut w, &module, module_path, &mut ctx)?;
Ok(w)
}
fn codegen_module(
w: &mut impl std::fmt::Write,
module: &ir::Module,
module_path: Vec<String>,
ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
let mut tys = Vec::new();
let mut functions = Vec::new();
let mut sub_modules = Vec::new();
let root_mod = &ctx.project.root_module;
let pr_mod = root_mod.get_module(&module_path).unwrap();
for (name, pr_def) in &pr_mod.defs {
let Some(decl) = module.decls.iter().find(|d| &d.name == name) else {
continue;
};
match &decl.decl {
ir::Decl::Mod(module) => {
sub_modules.push((name, module));
}
ir::Decl::Ty(ty) => {
let mut ty = ty.clone();
super::infer_names(name, &mut ty);
tys.push((ty, pr_def.annotations.as_slice()));
}
ir::Decl::Var(ty) => {
let mut ty = ty.clone();
super::infer_names(name, &mut ty);
if let ir::TyKind::Function(func) = ty.kind {
functions.push((name, *func));
}
}
}
}
ctx.current_rust_mod = module_path.clone();
let mut tys = write_tys(w, tys, ctx)?;
let module_path_str = module_path.as_slice().join("::");
if let Some(format) = ctx.options.included_program_repr(&module_path_str) {
write_programs(w, &functions, format, ctx)?;
tys.extend(write_tys_in_buffer(w, ctx)?);
}
for (name, sub_mod) in sub_modules {
let mut path = module_path.clone();
path.push(name.clone());
codegen_module(w, sub_mod, path, ctx)?;
}
for ty in &tys {
write_ty_def_codec(w, ty, ctx)?;
}
assert!(ctx.is_done(), "{ctx:?}");
Ok(())
}
fn write_tys(
w: &mut impl Write,
tys: Vec<(ir::Ty, &[pr::Anno])>,
ctx: &mut Context,
) -> Result<Vec<ir::Ty>, std::fmt::Error> {
let mut all_tys = Vec::new();
for (ty, annotations) in tys {
write_ty_def(w, &ty, annotations, ctx)?;
all_tys.push(ty);
all_tys.extend(write_tys_in_buffer(w, ctx)?);
}
Ok(all_tys)
}
fn write_tys_in_buffer(
w: &mut impl Write,
ctx: &mut Context<'_>,
) -> Result<Vec<ir::Ty>, std::fmt::Error> {
let mut all_tys = Vec::new();
while let Some(ty) = ctx.def_buffer.pop_front() {
if ctx.tys_written.contains(ty.name.as_ref().unwrap()) {
continue;
}
let annotations = vec![];
write_ty_def(w, &ty, &annotations, ctx)?;
all_tys.push(ty);
}
Ok(all_tys)
}
#[rustfmt::skip::macros(writeln)]
#[rustfmt::skip::macros(write)]
fn write_ty_def(
w: &mut impl Write,
ty: &ir::Ty,
_annotations: &[pr::Anno],
ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
let name = ty.name.as_ref().unwrap();
let codec_name = format!("{name}Codec");
writeln!(w)?;
match &ty.kind {
ir::TyKind::Primitive(_)
| ir::TyKind::Ident(_)
| ir::TyKind::Array(_)
| ir::TyKind::Tuple(_) => {
writeln!(w, "@dataclasses.dataclass")?;
}
ir::TyKind::Enum(_) => {
writeln!(w, "@dataclasses.dataclass(kw_only=True, repr=False)")?;
}
_ => unimplemented!(),
}
writeln!(w, "class {name}(lutra_bin.Encodable):")?;
match &ty.kind {
ir::TyKind::Primitive(_) | ir::TyKind::Ident(_) | ir::TyKind::Array(_) => {
writeln!(w, " value: {}", ty_ref(ty, true, ctx))?;
}
ir::TyKind::Tuple(fields) => {
for (index, field) in fields.iter().enumerate() {
let name = tuple_field_name(&field.name, index);
writeln!(w, " {name}: {}", ty_ref(&field.ty, true, ctx))?;
}
}
ir::TyKind::Enum(variants) => {
for variant in variants {
let name = camel_to_snake(&variant.name);
let ty = if variant.ty.is_unit() {
Cow::Borrowed("typing.Literal[True] | None")
} else {
let ty = ty_ref(&variant.ty, true, ctx);
if ty.contains('\'') {
format!("typing.Optional[{ty}]").into()
} else {
format!("{ty} | None").into()
}
};
writeln!(w, " {name}: {ty} = None")?;
}
writeln!(w)?;
writeln!(w, " def __repr__(self) -> str:")?;
for variant in variants {
let variant_name = camel_to_snake(&variant.name);
writeln!(w, " if self.{variant_name} is not None:")?;
write!(w, r" return ")?;
if variant.ty.is_unit() {
writeln!(w, r#""{name}({variant_name})""#)?;
} else {
writeln!(w, r#"f"{name}({variant_name}={{self.{variant_name}!r}})""#)?;
}
}
writeln!(w, " raise AssertionError()")?;
}
_ => unimplemented!(),
}
writeln!(w)?;
writeln!(w, " @classmethod")?;
writeln!(w, " def codec(cls) -> '{codec_name}':")?;
writeln!(w, " return {codec_name}()")?;
writeln!(w)?;
ctx.tys_written.insert(name.clone());
Ok(())
}
#[rustfmt::skip::macros(writeln)]
#[rustfmt::skip::macros(write)]
fn write_ty_def_codec(
w: &mut impl Write,
ty: &ir::Ty,
ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
let name = ty.name.as_ref().unwrap();
let codec_name = format!("{name}Codec");
writeln!(w)?;
writeln!(w, "class {codec_name}:")?;
if let ir::TyKind::Enum(variants) = &ty.kind {
let enum_format = lutra_bin::layout::enum_format(variants, &ty.variants_recursive);
let buf = enum_format.encode();
let format_base85 = base85::encode(&buf);
writeln!(w, " helper = lutra_bin.EnumCodecHelper(")?;
writeln!(w, " base64.b85decode(b'{format_base85}'),")?;
writeln!(w, " )")?;
}
let head_bytes = ty.layout.as_ref().unwrap().head_size.div_ceil(8);
writeln!(w, " def head_bytes(self) -> int:")?;
writeln!(w, " return {head_bytes}")?;
match &ty.kind {
ir::TyKind::Primitive(_) | ir::TyKind::Ident(_) | ir::TyKind::Array(_) => {
writeln!(w, " def decode(self, buf: bytes) -> {name}:")?;
writeln!(w, " return {name}({}.decode(buf))", ty_codec(ty, ctx))?;
writeln!(w, " def encode_head(self, obj: {name}, buf: lutra_bin.BytesMut) -> typing.Any:")?;
writeln!(w, " return {}.encode_head(obj.value, buf)", ty_codec(ty, ctx))?;
writeln!(w, " def encode_body(self, obj: {name}, residual: typing.Any, buf: lutra_bin.BytesMut) -> typing.Any:")?;
writeln!(w, " return {}.encode_body(obj.value, residual, buf)", ty_codec(ty, ctx))?;
}
ir::TyKind::Tuple(fields) => {
writeln!(w)?;
writeln!(w, " def decode(self, buf: bytes) -> {name}:")?;
writeln!(w, " buf = memoryview(buf)")?;
let offsets = layout::tuple_field_offsets(ty);
for (index, field) in fields.iter().enumerate() {
let buf = format!("buf[{}:]", offsets[index]);
writeln!(w, " f{index} = {}", ty_decode(&field.ty, &buf, ctx))?;
}
write!(w, " return {name}(")?;
for (index, _) in fields.iter().enumerate() {
write!(w, "f{index}, ")?;
}
writeln!(w, ")")?;
writeln!(w)?;
writeln!(w, " def encode_head(self, obj: {name}, buf: lutra_bin.BytesMut) -> typing.Any:")?;
writeln!(w, " return (")?;
for (index, field) in fields.iter().enumerate() {
let name = tuple_field_name(&field.name, index);
let val_ref = format!("obj.{name}");
write!(w, " ",)?;
writeln!(w, "{}.encode_head({val_ref}, buf),", ty_codec(&field.ty, ctx))?;
}
writeln!(w, " )")?;
writeln!(w)?;
writeln!(w, " def encode_body(self, obj: {name}, residuals: typing.Any, buf: lutra_bin.BytesMut) -> typing.Any:")?;
for (index, field) in fields.iter().enumerate() {
let name = tuple_field_name(&field.name, index);
let val_ref = format!("obj.{name}");
let res_ref = format!("residuals[{index}]");
let encode_body = ty_encode_body(&field.ty, &val_ref, &res_ref, ctx);
if !encode_body.is_empty() {
writeln!(w, " {encode_body}")?;
}
}
}
ir::TyKind::Enum(variants) => {
let head = lutra_bin::layout::enum_head_format(variants, &ty.variants_recursive);
writeln!(w)?;
writeln!(w, " def decode(self, buf: bytes) -> {name}:")?;
writeln!(w, " tag, inner_offset = self.helper.decode_head(buf)")?;
for (tag, variant) in variants.iter().enumerate() {
let variant_name = camel_to_snake(&variant.name);
writeln!(w, " if tag == {tag}:")?;
if variant.ty.is_unit() {
writeln!(w, " return {name}({variant_name}=True)")?;
} else {
writeln!(w, " return {name}({variant_name}={}.decode(buf[inner_offset:]))", ty_codec(&variant.ty, ctx))?;
}
}
writeln!(w, " raise AssertionError()")?;
writeln!(w)?;
writeln!(w, " def encode_head(self, obj: {name}, buf: lutra_bin.BytesMut) -> typing.Any:")?;
writeln!(w, " tag: int")?;
writeln!(w, " res: typing.Any")?;
for (tag, variant) in variants.iter().enumerate() {
let variant_name = camel_to_snake(&variant.name);
writeln!(w, " if obj.{variant_name} is not None:")?;
writeln!(w, " tag = {tag}")?;
writeln!(w, " res = self.helper.encode_head_tag(tag, buf)")?;
if !head.has_ptr && !variant.ty.is_unit() {
writeln!(w, " res = {}.encode_head(obj.{variant_name}, buf)", ty_codec(&variant.ty, ctx))?;
}
}
writeln!(w, " self.helper.encode_head_padding(tag, buf)")?;
writeln!(w, " return res")?;
writeln!(w)?;
writeln!(w, " def encode_body(self, obj: {name}, residual: typing.Any, buf: lutra_bin.BytesMut) -> typing.Any:")?;
if head.has_ptr {
writeln!(w, " self.helper.encode_body_ptr(residual, buf)")?;
for variant in variants {
if variant.ty.is_unit() {
continue;
}
let variant_name = camel_to_snake(&variant.name);
writeln!(w, " if obj.{variant_name} is not None:")?;
writeln!(w, " r = {}.encode_head(obj.{variant_name}, buf)", ty_codec(&variant.ty, ctx))?;
writeln!(w, " {}.encode_body(obj.{variant_name}, r, buf)", ty_codec(&variant.ty, ctx))?;
}
} else {
writeln!(w, " if residual is None:")?;
writeln!(w, " return")?;
for variant in variants {
if variant.ty.is_unit() {
continue;
}
let variant_name = camel_to_snake(&variant.name);
writeln!(w, " if obj.{variant_name} is not None:")?;
writeln!(w, " {}.encode_body(obj.{variant_name}, residual, buf)", ty_codec(&variant.ty, ctx))?;
}
}
}
_ => unimplemented!(),
}
writeln!(w)?;
Ok(())
}
fn tuple_field_name(name: &Option<String>, index: usize) -> Cow<'_, str> {
(name.as_ref())
.map(|x| Cow::Borrowed(x.as_str()))
.unwrap_or_else(|| format!("field{index}").into())
}
fn ty_ref(ty: &ir::Ty, as_ty: bool, ctx: &mut Context) -> Cow<'static, str> {
if ty.is_unit() {
return "tuple[()]".into();
}
match &ty.kind {
ir::TyKind::Primitive(_) => "int".into(),
ir::TyKind::Array(items_ty) => format!("list[{}]", ty_ref(items_ty, as_ty, ctx)).into(),
ir::TyKind::Ident(ident) => {
if let Some(ty) = scalar_std_type_ref(ident) {
ty.into()
} else {
let name = ident.0.last().unwrap().clone();
if ctx.tys_written.contains(&name) {
name.into()
} else {
format!("'{name}'").into()
}
}
}
ir::TyKind::Tuple(_) | ir::TyKind::Enum(_) => {
let name = ty
.name
.clone()
.unwrap_or_else(|| panic!("no name for {ty:?}"));
if ctx.tys_written.contains(&name) {
name.into()
} else {
ctx.def_buffer.push_back(ty.clone());
if as_ty {
format!("'{name}'").into()
} else {
name.into()
}
}
}
_ => unimplemented!(),
}
}
fn ty_codec(ty: &ir::Ty, ctx: &mut Context) -> Cow<'static, str> {
if ty.is_unit() {
return "lutra_bin.UnitCodec()".into();
}
match &ty.kind {
ir::TyKind::Primitive(ir::TyPrimitive::Prim8) => "lutra_bin.Int8Codec()".into(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim16) => "lutra_bin.Int16Codec()".into(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim32) => "lutra_bin.Int32Codec()".into(),
ir::TyKind::Primitive(ir::TyPrimitive::Prim64) => "lutra_bin.Int64Codec()".into(),
ir::TyKind::Array(item_ty) => {
format!("lutra_bin.ArrayCodec({})", ty_codec(item_ty, ctx)).into()
}
ir::TyKind::Ident(ident) => {
if let Some(codec) = scalar_std_codec(ident) {
codec.into()
} else {
format!("{}.codec()", ty_ref(ty, false, ctx)).into()
}
}
ir::TyKind::Tuple(_) | ir::TyKind::Enum(_) => {
format!("{}.codec()", ty_ref(ty, false, ctx)).into()
}
_ => unimplemented!(),
}
}
fn ty_decode(ty: &ir::Ty, buf_ref: &str, ctx: &mut Context) -> String {
match &ty.kind {
ir::TyKind::Primitive(_) | ir::TyKind::Array(_) => {
format!("{}.decode({buf_ref})", ty_codec(ty, ctx))
}
ir::TyKind::Ident(_) | ir::TyKind::Tuple(_) | ir::TyKind::Enum(_) => {
format!("{}.decode({buf_ref})", ty_codec(ty, ctx))
}
_ => unimplemented!(),
}
}
fn ty_encode_body(ty: &ir::Ty, val_ref: &str, residual_ref: &str, ctx: &mut Context) -> String {
let codec = ty_codec(ty, ctx);
format!("{codec}.encode_body({val_ref}, {residual_ref}, buf)")
}
fn write_programs(
w: &mut impl Write,
functions: &[(&String, ir::TyFunction)],
repr: ProgramRepr,
ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
if functions.is_empty() {
return Ok(());
}
for (name, _func) in functions {
let mut fq_path = pr::Path::new(&ctx.current_rust_mod);
fq_path.push((*name).clone());
let fq_path = fq_path.to_string();
let (program, mut ty) = lutra_compiler::compile(
ctx.project,
&lutra_compiler::CompileParams::new(&fq_path, repr),
)
.unwrap();
let buf = program.encode();
let program_base85 = base85::encode(&buf);
super::infer_names_of_program_ty(&mut ty, name);
writeln!(w)?;
writeln!(w, "@functools.cache")?;
writeln!(
w,
"def {name}() -> lutra_bin.Program[{}, {}]:",
ty_ref(&ty.input, true, ctx),
ty_ref(&ty.output, true, ctx)
)?;
writeln!(w, " return lutra_bin.Program(")?;
writeln!(w, " base64.b85decode(b'{program_base85}'),")?;
writeln!(w, " {},", ty_codec(&ty.input, ctx))?;
writeln!(w, " {},", ty_codec(&ty.output, ctx))?;
writeln!(w, " )")?;
}
Ok(())
}