use std::{collections::HashMap, fmt::Write, path};
use lutra_bin::ir;
use lutra_compiler::{Project, pr};
use crate::{GenerateOptions, infer_names};
pub(crate) fn run(
project: &Project,
options: &GenerateOptions,
_out_dir: path::PathBuf,
) -> Result<String, std::fmt::Error> {
let module = lutra_compiler::project_to_types(project);
let ty_defs = module.iter_types_re().collect();
let mut w = String::new();
w += "-- Generated by lutra-codegen";
let mut ctx = Context {
options,
ty_defs: &ty_defs,
project,
};
let module_path = vec![];
codegen_module(&mut w, &module, module_path, &mut ctx)?;
Ok(w)
}
#[derive(Debug)]
#[allow(dead_code)]
struct Printer<'a> {
buf: String,
ctx: &'a Context<'a>,
}
#[derive(Debug)]
#[allow(dead_code)]
struct Context<'a> {
options: &'a GenerateOptions,
ty_defs: &'a HashMap<ir::Path, &'a ir::Ty>,
project: &'a Project,
}
#[derive(Clone, Copy)]
enum TyStd {
Bool,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
Text,
Date,
Time,
Timestamp,
Decimal,
}
impl TyStd {
fn from_ident(ident: &ir::Path) -> Option<Self> {
if ident.is(&["std", "Bool"]) {
Some(Self::Bool)
} else if ident.is(&["std", "Int8"]) || ident.is(&["std", "Uint8"]) {
Some(Self::Int8)
} else if ident.is(&["std", "Int16"]) || ident.is(&["std", "Uint16"]) {
Some(Self::Int16)
} else if ident.is(&["std", "Int32"]) || ident.is(&["std", "Uint32"]) {
Some(Self::Int32)
} else if ident.is(&["std", "Int64"]) || ident.is(&["std", "Uint64"]) {
Some(Self::Int64)
} else if ident.is(&["std", "Float32"]) {
Some(Self::Float32)
} else if ident.is(&["std", "Float64"]) {
Some(Self::Float64)
} else if ident.is(&["std", "Text"]) {
Some(Self::Text)
} else if ident.is(&["std", "Date"]) {
Some(Self::Date)
} else if ident.is(&["std", "Time"]) {
Some(Self::Time)
} else if ident.is(&["std", "Timestamp"]) {
Some(Self::Timestamp)
} else if ident.is(&["std", "Decimal"]) {
Some(Self::Decimal)
} else {
None
}
}
fn from_primitive(prim: ir::TyPrimitive) -> Self {
match prim {
ir::TyPrimitive::Prim8 => Self::Int8,
ir::TyPrimitive::Prim16 => Self::Int16,
ir::TyPrimitive::Prim32 => Self::Int32,
ir::TyPrimitive::Prim64 => Self::Int64,
}
}
fn sql_name(self) -> &'static str {
match self {
Self::Bool => "BOOLEAN",
Self::Int8 => "SMALLINT",
Self::Int16 => "SMALLINT",
Self::Int32 => "INTEGER",
Self::Int64 => "BIGINT",
Self::Float32 => "REAL",
Self::Float64 => "FLOAT",
Self::Text => "TEXT",
Self::Date => "DATE",
Self::Time => "TIME",
Self::Timestamp => "TIMESTAMP",
Self::Decimal => "DECIMAL",
}
}
}
fn codegen_module(
w: &mut impl std::fmt::Write,
module: &ir::Module,
module_path: Vec<String>,
ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
let mut tys = Vec::new();
let mut sub_modules = Vec::new();
let root_mod = &ctx.project.root_module;
let pr_mod = root_mod.get_module(&module_path).unwrap();
for (name, pr_def) in &pr_mod.defs {
let Some(decl) = module.decls.iter().find(|d| &d.name == name) else {
continue;
};
match &decl.decl {
ir::Decl::Mod(module) => {
sub_modules.push((name, module));
}
ir::Decl::Ty(ty) => {
let mut ty = ty.clone();
infer_names(name, &mut ty);
tys.push((ty, pr_def.annotations.as_slice()));
}
_ => {}
}
}
write_tys(w, tys, ctx)?;
for (name, sub_mod) in sub_modules {
let mut path = module_path.clone();
path.push(name.clone());
codegen_module(w, sub_mod, path, ctx)?;
}
Ok(())
}
fn write_tys(
w: &mut impl Write,
tys: Vec<(ir::Ty, &[pr::Anno])>,
ctx: &mut Context,
) -> Result<(), std::fmt::Error> {
let mut p = Printer {
buf: String::new(),
ctx,
};
for (ty, annotations) in tys {
p.write_ty_def(&ty, annotations);
}
w.write_str(&p.buf)
}
impl<'a> Printer<'a> {
fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
if let ir::TyKind::Ident(ident) = &ty.kind {
let ty_mat = self.ctx.ty_defs.get(ident).unwrap();
self.get_ty_mat(ty_mat)
} else {
ty
}
}
fn get_ty_std(&self, ty: &'a ir::Ty) -> Option<TyStd> {
let mut ty = ty;
while let ir::TyKind::Ident(ident) = &ty.kind {
if let Some(ty_std) = TyStd::from_ident(ident) {
return Some(ty_std);
}
ty = self.ctx.ty_defs.get(ident).unwrap();
}
match ty.kind {
ir::TyKind::Primitive(prim) => Some(TyStd::from_primitive(prim)),
_ => None,
}
}
pub fn write_ty_def(&mut self, ty: &ir::Ty, _annotations: &[pr::Anno]) {
let name = ty.name.as_ref().unwrap();
let table_name = format!("{}s", crate::camel_to_snake(name));
self.buf += &format!("\nCREATE TABLE {table_name} (\n");
let cols = match &self.get_ty_mat(ty).kind {
ir::TyKind::Array(inner) => self.get_table_columns(inner, ""),
_ => self.get_table_columns(ty, ""),
};
let cols: Vec<_> = cols.iter().map(Col::to_string).collect();
self.buf += &cols.join(",\n");
self.buf += "\n);\n";
}
pub fn get_table_columns(&mut self, ty: &ir::Ty, prefix: &str) -> Vec<Col> {
let mut r = Vec::new();
if let Some(ty) = self.get_ty_std(ty) {
r.push(Col {
name: get_name_terminal(prefix).into(),
ty: ty.sql_name().into(),
nullable: false,
});
return r;
}
match &self.get_ty_mat(ty).kind {
ir::TyKind::Tuple(fields) => {
for (i, f) in fields.iter().enumerate() {
let prefix = get_name_prefix(prefix, f.name.as_ref(), i);
r.extend(self.get_table_columns(&f.ty, &prefix));
}
}
ir::TyKind::Enum(variants) => {
r.push(Col {
name: get_name_terminal(prefix).into(),
ty: "TEXT".into(),
nullable: false,
});
for v in variants {
let prefix = get_name_prefix(prefix, Some(&v.name), 0);
for mut c in self.get_table_columns(&v.ty, &prefix) {
c.nullable = true;
r.push(c);
}
}
}
_ => {}
}
r
}
}
struct Col {
name: String,
ty: String,
nullable: bool,
}
impl std::fmt::Display for Col {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let not_null = if self.nullable { "" } else { " NOT NULL" };
let name = lutra_sql::Ident::with_quote_if_needed('"', &self.name);
std::write!(f, " {name} {}{not_null}", self.ty)
}
}
fn get_name_prefix(prefix: &str, name: Option<&String>, i: usize) -> String {
let name = name.cloned().unwrap_or_else(|| format!("c{i}"));
if prefix.is_empty() {
name
} else {
format!("{prefix}.{name}")
}
}
fn get_name_terminal(prefix: &str) -> &str {
if prefix.is_empty() { "value" } else { prefix }
}