use lutra_bin::ir;
use postgres_types as pg_ty;
#[cfg(feature = "postgres")]
use postgres::Error;
#[cfg(not(feature = "postgres"))]
use tokio_postgres::Error;
#[cfg(feature = "tokio-postgres")]
use crate::RunnerAsync;
#[allow(dead_code)]
mod lutra {
include!(concat!(env!("OUT_DIR"), "/lutra.rs"));
}
#[cfg(feature = "tokio-postgres")]
pub async fn pull_interface<C>(runner: &RunnerAsync<C>) -> Result<String, tokio_postgres::Error>
where
C: tokio_postgres::GenericClient,
{
use lutra_runner::Run;
let program = lutra::pull_interface();
let mut schemas = runner.run(&program, &()).await.unwrap().unwrap();
let mut output = String::new();
const DEFAULT_SCHEMA_NAME: &str = "public";
let is_default = |x: &lutra::PullInterfaceOutputItems| x.schema_name != DEFAULT_SCHEMA_NAME;
schemas.sort_by(|a, b| {
(is_default(a).cmp(&is_default(b))).then(a.schema_name.cmp(&b.schema_name))
});
for schema in schemas {
if schema.schema_name.starts_with("pg_toast_") || schema.schema_name.starts_with("pg_temp_")
{
continue;
}
let mut indent = "";
if schema.schema_name != DEFAULT_SCHEMA_NAME {
output += "\n";
output += &format!("module {} {{\n", schema.schema_name);
indent = " ";
}
for table in schema.tables {
let t_name = &table.table_name;
let table_ty = tuple_from_pg_columns(&table.columns, "")?;
let ty_name = table_type_name(t_name);
let snake = crate::case::to_snake_case(t_name);
output += "\n";
output += &format!("{indent}## Row of table {t_name}\n");
output += &format!(
"{indent}type {ty_name}: {}\n",
lutra_bin::ir::print_ty(&table_ty)
);
output += &format!("{indent}## Read from table {t_name}\n");
output += &format!(
"{indent}func from_{snake}(): [{ty_name}] -> std::sql::from(\"{t_name}\")\n",
);
output += &format!("{indent}## Write into table {t_name}\n");
output += &format!(
"{indent}func insert_{snake}(values: [{ty_name}]) -> std::sql::insert(values, \"{t_name}\")\n"
);
generate_index_lookup(&mut output, indent, table, ty_name, snake);
}
if schema.schema_name != DEFAULT_SCHEMA_NAME {
output += "}\n";
}
}
Ok(output)
}
fn tuple_from_pg_columns(
columns: &[lutra::PullInterfaceOutputItemstablesItemscolumnsItems],
namespace: &str,
) -> Result<ir::Ty, Error> {
let mut fields = Vec::new();
let mut i = 0;
while i < columns.len() {
let c = &columns[i];
if let Some((name, len)) = group_tuple_columns(&columns[i..], namespace) {
let strip_prefix = format!("{namespace}{name}.");
let ty = tuple_from_pg_columns(&columns[i..i + len], &strip_prefix)?;
fields.push(ir::TyTupleField {
name: Some(name.to_string()),
ty,
});
i += len;
} else {
let name = c.name.strip_prefix(namespace).unwrap().to_string();
let ty = ty_from_pg_column(c);
fields.push(ir::TyTupleField {
name: Some(name),
ty,
});
i += 1;
}
}
Ok(ir::Ty::new(ir::TyKind::Tuple(fields)))
}
fn group_tuple_columns<'a>(
columns: &'a [lutra::PullInterfaceOutputItemstablesItemscolumnsItems],
namespace: &str,
) -> Option<(&'a str, usize)> {
let c_name = columns[0].name.strip_prefix(namespace).unwrap();
let (prefix, _) = c_name.split_once('.')?;
let mut len = 1;
while len < columns.len() {
let c_name = columns[len].name.strip_prefix(namespace).unwrap();
if !c_name.contains('.') {
break;
}
len += 1;
}
Some((prefix, len))
}
fn ty_from_pg_column(c: &lutra::PullInterfaceOutputItemstablesItemscolumnsItems) -> ir::Ty {
let pg_ty = match c.typ_id {
13226 => pg_ty::Type::INT4, 13229 => pg_ty::Type::TEXT, 13231 => pg_ty::Type::TEXT, 13237 => pg_ty::Type::TIMESTAMPTZ, 13239 => pg_ty::Type::TEXT, 10029 => pg_ty::Type::TEXT,
_ => pg_ty::Type::from_oid(c.typ_id as u32)
.unwrap_or_else(|| panic!("unknown type with oid: {}", c.typ_id)),
};
let tn = pg_ty.name();
match tn {
"boolean" | "bool" => ir::Ty::new(ir::TyPrimitive::bool),
"smallint" | "int2" => ir::Ty::new(ir::TyPrimitive::int16),
"integer" | "int4" | "oid" => ir::Ty::new(ir::TyPrimitive::int32),
"bigint" | "int8" => ir::Ty::new(ir::TyPrimitive::int64),
"real" | "float4" => ir::Ty::new(ir::TyPrimitive::float32),
"double precision" | "float8" => ir::Ty::new(ir::TyPrimitive::float64),
"date" => ir::Ty::new(ir::Path(vec!["std".into(), "Date".into()])),
"timestamp" | "timestamp without time zone" => {
ir::Ty::new(ir::Path(vec!["std".into(), "Timestamp".into()]))
}
"time" | "time without time zone" => {
ir::Ty::new(ir::Path(vec!["std".into(), "Time".into()]))
}
"text" | "varchar" | "char" | "bpchar" | "name" => ir::Ty::new(ir::TyPrimitive::text),
_ if tn.starts_with("varchar") | tn.starts_with("char") | tn.starts_with("bpchar") => {
ir::Ty::new(ir::TyPrimitive::text)
}
_ if tn.starts_with("decimal") | tn.starts_with("numeric") => {
ir::Ty::new(ir::Path(vec!["std".into(), "Decimal".into()]))
}
_ => ir::Ty::new(ir::TyPrimitive::text),
}
}
fn table_type_name(table_name: &str) -> String {
if let Some(n) = table_name.strip_suffix("s") {
crate::case::to_pascal_case(n)
} else {
format!("{}Row", crate::case::to_pascal_case(table_name))
}
}
fn generate_index_lookup(
output: &mut String,
indent: &'static str,
table: lutra::PullInterfaceOutputItemstablesItems,
ty_name: String,
snake: String,
) {
for index in &table.indexes {
let by = index.columns.join("_and_");
let mut params = String::new();
let mut cond = String::new();
for (i, c) in index.columns.iter().enumerate() {
if i != 0 {
params += ", ";
}
params += &format!("{c}: int32");
if i != 0 {
cond += " && ";
}
cond += &format!("x.{c} == {c}");
}
*output += &format!(
"{indent}## Lookup in {} by index {}\n",
table.table_name, index.index_name
);
if index.is_unique {
*output += &format!(
"{indent}func from_{snake}_by_{by}({params}): enum {{none, some: {ty_name}}} -> (\n"
);
*output += &format!("{indent} from_{snake}() | std::find(x -> {cond})\n");
*output += &format!("{indent})\n");
} else {
*output += &format!("{indent}func from_{snake}_by_{by}({params}): [{ty_name}] -> (\n");
*output += &format!("{indent} from_{snake}() | std::filter(x -> {cond})\n");
*output += &format!("{indent})\n");
};
}
}