lutra-codegen 0.6.0

Code generation for Lutra
Documentation
use std::{collections::HashMap, fmt::Write, path};

use lutra_bin::ir;
use lutra_compiler::{Project, pr};

use crate::{GenerateOptions, infer_names};

pub(crate) fn run(
    project: &Project,
    options: &GenerateOptions,
    _out_dir: path::PathBuf,
) -> Result<String, std::fmt::Error> {
    let module = lutra_compiler::project_to_types(project);

    let ty_defs = module.iter_types_re().collect();

    let mut w = String::new();
    w += "-- Generated by lutra-codegen";

    let mut ctx = Context {
        options,
        ty_defs: &ty_defs,
        project,
    };

    let module_path = vec![];
    codegen_module(&mut w, &module, module_path, &mut ctx)?;
    Ok(w)
}

#[derive(Debug)]
#[allow(dead_code)]
struct Printer<'a> {
    buf: String,
    ctx: &'a Context<'a>,
}

#[derive(Debug)]
#[allow(dead_code)]
struct Context<'a> {
    options: &'a GenerateOptions,
    ty_defs: &'a HashMap<ir::Path, &'a ir::Ty>,
    project: &'a Project,
}

#[derive(Clone, Copy)]
enum TyStd {
    Bool,
    Int8,
    Int16,
    Int32,
    Int64,
    Float32,
    Float64,
    Text,
    Date,
    Time,
    Timestamp,
    Decimal,
}

impl TyStd {
    fn from_ident(ident: &ir::Path) -> Option<Self> {
        if ident.is(&["std", "Bool"]) {
            Some(Self::Bool)
        } else if ident.is(&["std", "Int8"]) || ident.is(&["std", "Uint8"]) {
            Some(Self::Int8)
        } else if ident.is(&["std", "Int16"]) || ident.is(&["std", "Uint16"]) {
            Some(Self::Int16)
        } else if ident.is(&["std", "Int32"]) || ident.is(&["std", "Uint32"]) {
            Some(Self::Int32)
        } else if ident.is(&["std", "Int64"]) || ident.is(&["std", "Uint64"]) {
            Some(Self::Int64)
        } else if ident.is(&["std", "Float32"]) {
            Some(Self::Float32)
        } else if ident.is(&["std", "Float64"]) {
            Some(Self::Float64)
        } else if ident.is(&["std", "Text"]) {
            Some(Self::Text)
        } else if ident.is(&["std", "Date"]) {
            Some(Self::Date)
        } else if ident.is(&["std", "Time"]) {
            Some(Self::Time)
        } else if ident.is(&["std", "Timestamp"]) {
            Some(Self::Timestamp)
        } else if ident.is(&["std", "Decimal"]) {
            Some(Self::Decimal)
        } else {
            None
        }
    }

    fn from_primitive(prim: ir::TyPrimitive) -> Self {
        match prim {
            ir::TyPrimitive::Prim8 => Self::Int8,
            ir::TyPrimitive::Prim16 => Self::Int16,
            ir::TyPrimitive::Prim32 => Self::Int32,
            ir::TyPrimitive::Prim64 => Self::Int64,
        }
    }

    fn sql_name(self) -> &'static str {
        match self {
            Self::Bool => "BOOLEAN",
            Self::Int8 => "SMALLINT",
            Self::Int16 => "SMALLINT",
            Self::Int32 => "INTEGER",
            Self::Int64 => "BIGINT",
            Self::Float32 => "REAL",
            Self::Float64 => "FLOAT",
            Self::Text => "TEXT",
            Self::Date => "DATE",
            Self::Time => "TIME",
            Self::Timestamp => "TIMESTAMP",
            Self::Decimal => "DECIMAL",
        }
    }
}

fn codegen_module(
    w: &mut impl std::fmt::Write,
    module: &ir::Module,
    module_path: Vec<String>,
    ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
    // collect defs
    let mut tys = Vec::new();
    let mut sub_modules = Vec::new();

    // iterate pr defs (which keep the order in the source)
    let root_mod = &ctx.project.root_module;
    let pr_mod = root_mod.get_module(&module_path).unwrap();
    for (name, pr_def) in &pr_mod.defs {
        let Some(decl) = module.decls.iter().find(|d| &d.name == name) else {
            continue;
        };

        match &decl.decl {
            ir::Decl::Mod(module) => {
                sub_modules.push((name, module));
            }
            ir::Decl::Ty(ty) => {
                let mut ty = ty.clone();
                infer_names(name, &mut ty);

                tys.push((ty, pr_def.annotations.as_slice()));
            }
            _ => {}
        }
    }

    // write types
    write_tys(w, tys, ctx)?;

    // recurse into sub modules
    for (name, sub_mod) in sub_modules {
        let mut path = module_path.clone();
        path.push(name.clone());

        codegen_module(w, sub_mod, path, ctx)?;
    }

    Ok(())
}

fn write_tys(
    w: &mut impl Write,
    tys: Vec<(ir::Ty, &[pr::Anno])>,
    ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
    let mut p = Printer {
        buf: String::new(),
        ctx,
    };

    for (ty, annotations) in tys {
        p.write_ty_def(&ty, annotations);
    }
    w.write_str(&p.buf)
}

impl<'a> Printer<'a> {
    fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
        if let ir::TyKind::Ident(ident) = &ty.kind {
            let ty_mat = self.ctx.ty_defs.get(ident).unwrap();
            self.get_ty_mat(ty_mat)
        } else {
            ty
        }
    }

    fn get_ty_std(&self, ty: &'a ir::Ty) -> Option<TyStd> {
        let mut ty = ty;
        while let ir::TyKind::Ident(ident) = &ty.kind {
            if let Some(ty_std) = TyStd::from_ident(ident) {
                return Some(ty_std);
            }
            ty = self.ctx.ty_defs.get(ident).unwrap();
        }

        match ty.kind {
            ir::TyKind::Primitive(prim) => Some(TyStd::from_primitive(prim)),
            _ => None,
        }
    }

    /// Generates a type definition.
    pub fn write_ty_def(&mut self, ty: &ir::Ty, _annotations: &[pr::Anno]) {
        let name = ty.name.as_ref().unwrap();
        let table_name = format!("{}s", crate::camel_to_snake(name));
        self.buf += &format!("\nCREATE TABLE {table_name} (\n");

        let cols = match &self.get_ty_mat(ty).kind {
            ir::TyKind::Array(inner) => self.get_table_columns(inner, ""),
            _ => self.get_table_columns(ty, ""),
        };

        let cols: Vec<_> = cols.iter().map(Col::to_string).collect();
        self.buf += &cols.join(",\n");
        self.buf += "\n);\n";
    }

    pub fn get_table_columns(&mut self, ty: &ir::Ty, prefix: &str) -> Vec<Col> {
        let mut r = Vec::new();

        if let Some(ty) = self.get_ty_std(ty) {
            r.push(Col {
                name: get_name_terminal(prefix).into(),
                ty: ty.sql_name().into(),
                nullable: false,
            });
            return r;
        }

        match &self.get_ty_mat(ty).kind {
            ir::TyKind::Tuple(fields) => {
                for (i, f) in fields.iter().enumerate() {
                    let prefix = get_name_prefix(prefix, f.name.as_ref(), i);

                    r.extend(self.get_table_columns(&f.ty, &prefix));
                }
            }

            ir::TyKind::Enum(variants) => {
                r.push(Col {
                    name: get_name_terminal(prefix).into(),
                    ty: "TEXT".into(),
                    nullable: false,
                });

                for v in variants {
                    let prefix = get_name_prefix(prefix, Some(&v.name), 0);

                    for mut c in self.get_table_columns(&v.ty, &prefix) {
                        c.nullable = true;
                        r.push(c);
                    }
                }
            }

            _ => {}
        }
        r
    }
}

struct Col {
    name: String,
    ty: String,
    nullable: bool,
}

impl std::fmt::Display for Col {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let not_null = if self.nullable { "" } else { " NOT NULL" };
        let name = lutra_sql::Ident::with_quote_if_needed('"', &self.name);
        std::write!(f, "  {name} {}{not_null}", self.ty)
    }
}

fn get_name_prefix(prefix: &str, name: Option<&String>, i: usize) -> String {
    let name = name.cloned().unwrap_or_else(|| format!("c{i}"));

    if prefix.is_empty() {
        name
    } else {
        format!("{prefix}.{name}")
    }
}

fn get_name_terminal(prefix: &str) -> &str {
    if prefix.is_empty() { "value" } else { prefix }
}