lutra-codegen 0.5.1

Code generation for Lutra
Documentation
mod encode;
mod functions;
mod program;
mod types;

use std::collections::{HashMap, VecDeque};
use std::path;

use lutra_bin::ir;
use lutra_compiler::Project;

use crate::{GenerateOptions, infer_names};

#[derive(Debug)]
pub struct Context<'a> {
    current_rust_mod: Vec<String>,

    /// Buffer for types that don't have their own Lutra decl, but need their own Rust decl.
    /// When such type ref is encountered, it is pushed into here and generated later.
    def_buffer: VecDeque<ir::Ty>,

    // static env
    options: &'a GenerateOptions,
    ty_defs: &'a HashMap<ir::Path, &'a ir::Ty>,
    project: &'a Project,
    out_dir: path::PathBuf,
}

impl<'a> Context<'a> {
    fn is_done(&self) -> bool {
        self.def_buffer.is_empty()
    }

    #[allow(dead_code)]
    fn get_ty_mat<'t: 'a>(&'t self, ty: &'t ir::Ty) -> &'t ir::Ty {
        if let ir::TyKind::Ident(path) = &ty.kind {
            self.ty_defs.get(path).unwrap()
        } else {
            ty
        }
    }
}

pub(crate) fn run(
    project: &Project,
    options: &GenerateOptions,
    out_dir: path::PathBuf,
) -> Result<String, std::fmt::Error> {
    use std::fmt::Write;

    let module = lutra_compiler::project_to_types(project);

    let ty_defs = module.iter_types_re().collect();

    let mut w = String::new();
    writeln!(w, "//# Generated by lutra-codegen\n")?;

    let mut ctx = Context {
        current_rust_mod: vec![],
        def_buffer: VecDeque::new(),

        options,
        ty_defs: &ty_defs,
        project,
        out_dir,
    };

    let module_path = vec![];
    codegen_module(&mut w, &module, module_path, &mut ctx)?;
    Ok(w)
}

fn codegen_module(
    w: &mut impl std::fmt::Write,
    module: &ir::Module,
    module_path: Vec<String>,
    ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
    // collect defs
    let mut tys = Vec::new();
    let mut functions = Vec::new();
    let mut sub_modules = Vec::new();

    // iterate pr defs (which keep the order in the source)
    let root_mod = &ctx.project.root_module;
    let pr_mod = root_mod.get_submodule(&module_path).unwrap();
    for (name, pr_def) in &pr_mod.defs {
        let Some(decl) = module.decls.iter().find(|d| &d.name == name) else {
            continue;
        };

        match &decl.decl {
            ir::Decl::Module(module) => {
                sub_modules.push((name, module));
            }
            ir::Decl::Type(ty) => {
                let mut ty = ty.clone();
                infer_names(name, &mut ty);

                tys.push((ty, pr_def.annotations.as_slice()));
            }
            ir::Decl::Var(ty) => {
                let mut ty = ty.clone();
                infer_names(name, &mut ty);

                if let ir::TyKind::Function(func) = ty.kind {
                    functions.push((name, *func));
                }
            }
        }
    }

    ctx.current_rust_mod = module_path.clone();

    // write types
    let mut all_tys = if ctx.options.generate_types {
        types::write_tys(w, tys, ctx)?
    } else {
        vec![]
    };

    // write traits for functions
    if ctx.options.generate_function_traits {
        functions::write_functions(w, &functions, ctx)?;

        all_tys.extend(types::write_tys_in_buffer(w, ctx)?);
    }

    // write programs
    let module_path_str = module_path.as_slice().join("::");
    if let Some((_, format)) = ctx
        .options
        .include_programs
        .iter()
        .find(|(p, _)| p == &module_path_str)
    {
        program::write_rr_programs(w, &functions, *format, ctx)?;

        all_tys.extend(types::write_tys_in_buffer(w, ctx)?);
    }

    // recurse into sub modules
    for (name, sub_mod) in sub_modules {
        writeln!(w, "pub mod {name} {{")?;

        let mut path = module_path.clone();
        path.push(name.clone());
        codegen_module(w, sub_mod, path, ctx)?;
        writeln!(w, "}}\n")?;
    }

    // write encode/decode impls
    ctx.current_rust_mod = module_path.clone();
    if ctx.options.generate_encode_decode {
        encode::write_encode_impls(w, &all_tys, ctx)?;
    }

    assert!(ctx.is_done(), "{ctx:?}");

    Ok(())
}