lutra-codegen 0.5.1

Code generation for Lutra
Documentation
use std::{collections::HashMap, fmt::Write, path};

use lutra_bin::ir;
use lutra_compiler::{Project, pr};

use crate::{GenerateOptions, infer_names};

pub(crate) fn run(
    project: &Project,
    options: &GenerateOptions,
    _out_dir: path::PathBuf,
) -> Result<String, std::fmt::Error> {
    let module = lutra_compiler::project_to_types(project);

    let ty_defs = module.iter_types_re().collect();

    let mut w = String::new();
    w += "-- Generated by lutra-codegen";

    let mut ctx = Context {
        options,
        ty_defs: &ty_defs,
        project,
    };

    let module_path = vec![];
    codegen_module(&mut w, &module, module_path, &mut ctx)?;
    Ok(w)
}

#[derive(Debug)]
#[allow(dead_code)]
struct Printer<'a> {
    buf: String,
    ctx: &'a Context<'a>,
}

#[derive(Debug)]
#[allow(dead_code)]
struct Context<'a> {
    options: &'a GenerateOptions,
    ty_defs: &'a HashMap<ir::Path, &'a ir::Ty>,
    project: &'a Project,
}

fn codegen_module(
    w: &mut impl std::fmt::Write,
    module: &ir::Module,
    module_path: Vec<String>,
    ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
    // collect defs
    let mut tys = Vec::new();
    let mut sub_modules = Vec::new();

    // iterate pr defs (which keep the order in the source)
    let root_mod = &ctx.project.root_module;
    let pr_mod = root_mod.get_submodule(&module_path).unwrap();
    for (name, pr_def) in &pr_mod.defs {
        let Some(decl) = module.decls.iter().find(|d| &d.name == name) else {
            continue;
        };

        match &decl.decl {
            ir::Decl::Module(module) => {
                sub_modules.push((name, module));
            }
            ir::Decl::Type(ty) => {
                let mut ty = ty.clone();
                infer_names(name, &mut ty);

                tys.push((ty, pr_def.annotations.as_slice()));
            }
            _ => {}
        }
    }

    // write types
    write_tys(w, tys, ctx)?;

    // recurse into sub modules
    for (name, sub_mod) in sub_modules {
        let mut path = module_path.clone();
        path.push(name.clone());

        codegen_module(w, sub_mod, path, ctx)?;
    }

    Ok(())
}

fn write_tys(
    w: &mut impl Write,
    tys: Vec<(ir::Ty, &[pr::Annotation])>,
    ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
    let mut p = Printer {
        buf: String::new(),
        ctx,
    };

    for (ty, annotations) in tys {
        p.write_ty_def(&ty, annotations);
    }
    w.write_str(&p.buf)
}

impl<'a> Printer<'a> {
    fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
        if let ir::TyKind::Ident(ident) = &ty.kind {
            let ty_mat = self.ctx.ty_defs.get(ident).unwrap();
            self.get_ty_mat(ty_mat)
        } else {
            ty
        }
    }

    /// Generates a type definition.
    pub fn write_ty_def(&mut self, ty: &ir::Ty, _annotations: &[pr::Annotation]) {
        let name = ty.name.as_ref().unwrap();
        let table_name = format!("{}s", crate::camel_to_snake(name));
        self.buf += &format!("\nCREATE TABLE {table_name} (\n");

        let cols = match &self.get_ty_mat(ty).kind {
            ir::TyKind::Array(inner) => self.get_table_columns(inner, ""),
            _ => self.get_table_columns(ty, ""),
        };

        let cols: Vec<_> = cols.iter().map(Col::to_string).collect();
        self.buf += &cols.join(",\n");
        self.buf += "\n);\n";
    }

    pub fn get_table_columns(&mut self, ty: &ir::Ty, prefix: &str) -> Vec<Col> {
        let mut r = Vec::new();

        const FRAMED: &[(&[&str], &str)] = &[
            (&["std", "Date"], "DATE"),
            (&["std", "Time"], "TIME"),
            (&["std", "Timestamp"], "TIMESTAMP"),
            (&["std", "Decimal"], "DECIMAL"),
        ];

        if let ir::TyKind::Ident(name) = &ty.kind {
            for (lt_name, pg_ty) in FRAMED {
                if &name.0 == lt_name {
                    r.push(Col {
                        name: get_name_terminal(prefix).into(),
                        ty: pg_ty.to_string(),
                        nullable: false,
                    });
                    return r;
                }
            }
        }

        match &self.get_ty_mat(ty).kind {
            ir::TyKind::Primitive(prim) => {
                let ty = match prim {
                    ir::TyPrimitive::bool => "BOOLEAN",

                    ir::TyPrimitive::int8 => "SMALLINT",
                    ir::TyPrimitive::uint8 => "SMALLINT",
                    ir::TyPrimitive::int16 => "SMALLINT",
                    ir::TyPrimitive::uint16 => "SMALLINT",

                    ir::TyPrimitive::int32 => "INT",
                    ir::TyPrimitive::uint32 => "INT",

                    ir::TyPrimitive::int64 => "BIGINT",
                    ir::TyPrimitive::uint64 => "BIGINT",

                    ir::TyPrimitive::float32 => "REAL",
                    ir::TyPrimitive::float64 => "FLOAT",
                    ir::TyPrimitive::text => "TEXT",
                };

                r.push(Col {
                    name: get_name_terminal(prefix).into(),
                    ty: ty.into(),
                    nullable: false,
                });
            }

            ir::TyKind::Tuple(fields) => {
                for (i, f) in fields.iter().enumerate() {
                    let prefix = get_name_prefix(prefix, f.name.as_ref(), i);

                    r.extend(self.get_table_columns(&f.ty, &prefix));
                }
            }

            ir::TyKind::Enum(variants) => {
                r.push(Col {
                    name: get_name_terminal(prefix).into(),
                    ty: "TEXT".into(),
                    nullable: false,
                });

                for v in variants {
                    let prefix = get_name_prefix(prefix, Some(&v.name), 0);

                    for mut c in self.get_table_columns(&v.ty, &prefix) {
                        c.nullable = true;
                        r.push(c);
                    }
                }
            }

            _ => {}
        }
        r
    }
}

struct Col {
    name: String,
    ty: String,
    nullable: bool,
}

impl std::fmt::Display for Col {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let not_null = if self.nullable { "" } else { " NOT NULL" };
        let name = lutra_sql::Ident::with_quote_if_needed('"', &self.name);
        std::write!(f, "  {name} {}{not_null}", self.ty)
    }
}

fn get_name_prefix(prefix: &str, name: Option<&String>, i: usize) -> String {
    let name = name.cloned().unwrap_or_else(|| format!("c{i}"));

    if prefix.is_empty() {
        name
    } else {
        format!("{prefix}.{name}")
    }
}

fn get_name_terminal(prefix: &str) -> &str {
    if prefix.is_empty() { "value" } else { prefix }
}