use crate::sqlgen::SqlGenContext;
use crate::{QError, QSource};
use codegen::{Block, Function, Scope, Struct};
use dibs_db_schema::Schema;
use dibs_query_schema::{
Decl, Delete, FieldDef, Insert, InsertMany, Meta, Params, QueryFile, Returning, Returns,
Select, SelectFields, Update, Upsert, UpsertMany,
};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct GeneratedCode {
pub code: String,
}
fn wrap_with_trace_err(body: &str, fn_name: &str) -> String {
format!(
"let __dibs_result = async {{\n{body}\n}}.await;\n\
<_ as TraceErr>::trace_err(__dibs_result, \"{fn_name}\")"
)
}
fn schema_column_type(schema: &Schema, table: &str, column: &str) -> Option<String> {
let table_info = schema.get_table(table)?;
let col = table_info.columns.iter().find(|c| c.name == column)?;
let rust_type = col
.rust_type
.clone()
.unwrap_or_else(|| col.pg_type.to_rust_type().to_string());
if col.nullable {
Some(format!("Option<{}>", rust_type))
} else {
Some(rust_type)
}
}
struct CodegenContext<'a> {
schema: &'a Schema,
source: Arc<QSource>,
#[allow(dead_code)]
scope: Scope,
}
impl CodegenContext<'_> {
fn column_type(&self, table: &str, column: &str) -> Option<String> {
schema_column_type(self.schema, table, column)
}
fn sqlgen_ctx(&self) -> SqlGenContext<'_> {
SqlGenContext::new(self.schema, self.source.clone())
}
}
pub fn generate_rust_code(
file: &QueryFile,
schema: &Schema,
source: Arc<QSource>,
) -> Result<GeneratedCode, QError> {
let mut scope = Scope::new();
scope.raw("// Generated by dibs-qgen. Do not edit.");
scope.raw("");
scope.import("dibs_runtime::prelude", "*");
scope.import("dibs_runtime", "tokio_postgres");
let ctx = CodegenContext {
schema,
source,
scope: Scope::new(),
};
for (name_meta, decl) in &file.0 {
match decl {
Decl::Select(select) => {
generate_select_code(&ctx, name_meta, select, &mut scope);
}
Decl::Insert(insert) => {
generate_insert_code(&ctx, name_meta, insert, &mut scope);
}
Decl::InsertMany(insert_many) => {
generate_insert_many_code(&ctx, name_meta, insert_many, &mut scope);
}
Decl::Upsert(upsert) => {
generate_upsert_code(&ctx, name_meta, upsert, &mut scope);
}
Decl::UpsertMany(upsert_many) => {
generate_upsert_many_code(&ctx, name_meta, upsert_many, &mut scope);
}
Decl::Update(update) => {
generate_update_code(&ctx, name_meta, update, &mut scope)?;
}
Decl::Delete(delete) => {
generate_delete_code(&ctx, name_meta, delete, &mut scope)?;
}
}
}
Ok(GeneratedCode {
code: scope.to_string(),
})
}
fn generate_select_code(
ctx: &CodegenContext,
name_meta: &Meta<String>,
select: &Select,
scope: &mut Scope,
) {
let name = &name_meta.value;
let struct_name = format!("{}Result", name);
if let Some(from) = &select.from {
if select.fields.is_some() {
generate_result_struct(ctx, select, name_meta, &struct_name, from, scope);
if select.has_relations() {
let flat_struct_name = format!("{}Row", name);
generate_flat_row_struct(ctx, select, &flat_struct_name, from, scope);
}
}
} else if let Some(returns) = &select.returns {
generate_raw_sql_result_struct(&struct_name, returns, scope);
}
generate_select_function(ctx, name_meta, select, &struct_name, scope);
}
fn generate_flat_row_struct(
ctx: &CodegenContext,
select: &Select,
struct_name: &str,
table: &Meta<dibs_sql::TableName>,
scope: &mut Scope,
) {
let mut st = Struct::new(struct_name);
st.derive("Debug");
st.derive("Clone");
st.derive("Facet");
st.attr("facet(crate = dibs_runtime::facet)");
let table_name = table.value.as_str();
if let Some(select_fields) = &select.fields {
add_flat_fields_for_select(ctx, &mut st, table_name, "", select_fields);
}
scope.push_struct(st);
}
fn add_flat_fields_for_select(
ctx: &CodegenContext,
st: &mut Struct,
table_name: &str,
prefix: &str,
select_fields: &SelectFields,
) {
for (field_name_meta, field_def) in &select_fields.fields {
let field_name = field_name_meta.value.as_str();
match field_def {
None => {
let rust_ty = ctx
.column_type(table_name, field_name)
.unwrap_or_else(|| "String".to_string());
let flat_field_name = if prefix.is_empty() {
field_name.to_string()
} else {
format!("{}_{}", prefix, field_name)
};
let final_ty = if prefix.is_empty() {
rust_ty
} else if rust_ty.starts_with("Option<") {
rust_ty
} else {
format!("Option<{}>", rust_ty)
};
st.field(&flat_field_name, &final_ty);
}
Some(FieldDef::Rel(rel)) => {
let rel_table = rel.table_name().unwrap_or(field_name);
let new_prefix = if prefix.is_empty() {
field_name.to_string()
} else {
format!("{}_{}", prefix, field_name)
};
if let Some(rel_fields) = &rel.fields {
add_flat_fields_for_select(ctx, st, rel_table, &new_prefix, rel_fields);
}
}
Some(FieldDef::Count(_)) => {
let flat_field_name = if prefix.is_empty() {
field_name.to_string()
} else {
format!("{}_{}", prefix, field_name)
};
st.field(&flat_field_name, "i64");
}
}
}
}
fn generate_raw_sql_result_struct(struct_name: &str, returns: &Returns, scope: &mut Scope) {
let mut st = Struct::new(struct_name);
st.vis("pub");
st.derive("Debug");
st.derive("Clone");
st.derive("Facet");
st.attr("facet(crate = dibs_runtime::facet)");
for (field_name_meta, param_type) in &returns.fields {
let field_name = field_name_meta.value.as_str();
let rust_ty = param_type_to_rust(param_type);
st.field(format!("pub {}", field_name), &rust_ty);
}
scope.push_struct(st);
}
fn generate_result_struct(
ctx: &CodegenContext,
select: &Select,
name_meta: &Meta<String>,
struct_name: &str,
table: &Meta<dibs_sql::TableName>,
scope: &mut Scope,
) {
let mut st = Struct::new(struct_name);
st.vis("pub");
st.derive("Debug");
st.derive("Clone");
st.derive("Facet");
st.attr("facet(crate = dibs_runtime::facet)");
let parent_prefix = &name_meta.value;
let table_name = table.value.as_str();
if let Some(select_fields) = &select.fields {
for (field_name_meta, field_def) in &select_fields.fields {
let field_name = field_name_meta.value.as_str();
match field_def {
None => {
let rust_ty = ctx
.column_type(table_name, field_name)
.unwrap_or_else(|| "String".to_string());
st.field(format!("pub {}", field_name), &rust_ty);
}
Some(FieldDef::Rel(rel)) => {
let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
let ty = if rel.first.is_some() {
format!("Option<{}>", nested_name)
} else {
format!("Vec<{}>", nested_name)
};
st.field(format!("pub {}", field_name), &ty);
}
Some(FieldDef::Count(_)) => {
st.field(format!("pub {}", field_name), "i64");
}
}
}
}
scope.push_struct(st);
if let Some(select_fields) = &select.fields {
generate_nested_structs(ctx, parent_prefix, select_fields, scope);
}
}
fn generate_nested_structs(
ctx: &CodegenContext,
parent_prefix: &str,
select_fields: &SelectFields,
scope: &mut Scope,
) {
for (field_name_meta, field_def) in &select_fields.fields {
if let Some(FieldDef::Rel(rel)) = field_def {
let field_name = field_name_meta.value.as_str();
let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
let rel_table = rel.table_name().unwrap_or(field_name);
let mut nested_st = Struct::new(&nested_name);
nested_st.vis("pub");
nested_st.derive("Debug");
nested_st.derive("Clone");
nested_st.derive("Facet");
nested_st.attr("facet(crate = dibs_runtime::facet)");
if let Some(rel_fields) = &rel.fields {
for (rel_field_name_meta, rel_field_def) in &rel_fields.fields {
let rel_field_name = rel_field_name_meta.value.as_str();
match rel_field_def {
None => {
let rust_ty = ctx
.column_type(rel_table, rel_field_name)
.unwrap_or_else(|| "String".to_string());
nested_st.field(format!("pub {}", rel_field_name), &rust_ty);
}
Some(FieldDef::Rel(nested_rel)) => {
let nested_rel_name =
format!("{}{}", nested_name, to_pascal_case(rel_field_name));
let ty = if nested_rel.first.is_some() {
format!("Option<{}>", nested_rel_name)
} else {
format!("Vec<{}>", nested_rel_name)
};
nested_st.field(format!("pub {}", rel_field_name), &ty);
}
Some(FieldDef::Count(_)) => {
nested_st.field(format!("pub {}", rel_field_name), "i64");
}
}
}
}
scope.push_struct(nested_st);
if let Some(rel_fields) = &rel.fields {
generate_nested_structs(ctx, &nested_name, rel_fields, scope);
}
}
}
}
fn generate_select_function(
ctx: &CodegenContext,
name_meta: &Meta<String>,
query: &Select,
struct_name: &str,
scope: &mut Scope,
) {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let return_ty = if query.first.is_some() {
format!("Result<Option<{}>, QueryError>", struct_name)
} else {
format!("Result<Vec<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
func.attr("allow(clippy::clone_on_copy)");
if let Some(params) = &query.params {
for (param_name_meta, param_type) in ¶ms.params {
let param_name = ¶m_name_meta.value;
let rust_ty = param_type_to_rust(param_type);
func.arg(param_name, format!("&{}", rust_ty));
}
}
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = if let Some(raw_sql_meta) = &query.sql {
block_to_string(&generate_raw_query_body(query, &raw_sql_meta.value))
} else {
generate_query_body(ctx, query, struct_name)
};
func.line(wrap_with_trace_err(&body, &fn_name));
scope.push_fn(func);
}
fn generate_query_body(ctx: &CodegenContext, query: &Select, struct_name: &str) -> String {
let sqlgen_ctx = ctx.sqlgen_ctx();
let generated = match crate::sqlgen::generate_select_sql(&sqlgen_ctx, query) {
Ok(g) => g,
Err(e) => {
panic!("SELECT SQL generation failed: {}", e);
}
};
let mut block = Block::new("");
block.line(format!("const SQL: &str = r#\"{}\"#;", generated.sql));
block.line("");
let params: Vec<_> = generated
.param_order
.iter()
.filter(|p| !p.as_str().starts_with("__literal_"))
.collect();
if params.is_empty() {
block.line("let rows = client.query(SQL, &[]).await?;");
} else {
let params_str = params
.iter()
.map(|p| p.as_str())
.collect::<Vec<_>>()
.join(", ");
block.line(format!(
"let rows = client.query(SQL, &[{}]).await?;",
params_str
));
}
if !query.has_relations() {
if query.first.is_some() {
let mut match_block = Block::new("match rows.into_iter().next()");
match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
match_block.line("None => Ok(None),");
block.push_block(match_block);
} else {
block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
}
return block_to_string(&block);
}
let query_name = struct_name.strip_suffix("Result").unwrap_or(struct_name);
let flat_struct_name = format!("{}Row", query_name);
block.line("");
block.line("// Deserialize all rows into flat structs using facet reflection");
block.line(format!(
"let flat_rows: Vec<{flat_struct_name}> = rows.iter().map(from_row).collect::<Result<Vec<_>, _>>()?;"
));
block.line("");
let Some(select_fields) = &query.fields else {
block.line("Ok(vec![])".to_string());
return block_to_string(&block);
};
let root_table = query
.from
.as_ref()
.map(|m| m.value.as_str())
.unwrap_or("unknown");
let is_first = query.is_first();
block.line(generate_flat_to_nested_transform(
ctx,
select_fields,
struct_name,
root_table,
is_first,
));
block_to_string(&block)
}
fn generate_flat_to_nested_transform(
ctx: &CodegenContext,
select_fields: &SelectFields,
struct_name: &str,
root_table: &str,
is_first: bool,
) -> String {
let mut block = Block::new("");
let id_column = select_fields
.id_column()
.map(|c| c.to_string())
.unwrap_or_else(|| "id".to_string());
let id_type = ctx
.column_type(root_table, &id_column)
.unwrap_or_else(|| "i64".to_string());
if select_fields.has_vec_relations() {
block.line("// Group flat rows by parent ID and assemble nested structs");
block.line(format!(
"let mut grouped: std::collections::HashMap<{id_type}, {struct_name}> = std::collections::HashMap::new();"
));
generate_seen_id_declarations(&mut block, ctx, select_fields, &id_type, "");
block.line("");
let mut for_block = Block::new("for flat_row in flat_rows");
for_block.line(format!("let parent_id = flat_row.{id_column}.clone();"));
for_block.line("");
let mut entry_block = Block::new(format!(
"let entry = grouped.entry(parent_id.clone()).or_insert_with(|| {struct_name}"
));
for (field_name_meta, field_def) in &select_fields.fields {
let field_name = field_name_meta.value.as_str();
match field_def {
None => {
entry_block.line(format!("{field_name}: flat_row.{field_name}.clone(),"));
}
Some(FieldDef::Rel(rel)) => {
if rel.is_first() {
entry_block.line(format!("{field_name}: None,"));
} else {
entry_block.line(format!("{field_name}: Vec::new(),"));
}
}
Some(FieldDef::Count(_)) => {
entry_block.line(format!("{field_name}: flat_row.{field_name},"));
}
}
}
entry_block.after(");");
for_block.push_block(entry_block);
for_block.line("");
let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
generate_relation_assembly(
&mut for_block,
ctx,
select_fields,
parent_prefix,
"",
&id_type,
);
block.push_block(for_block);
block.line("");
if is_first {
block.line("Ok(grouped.into_values().next())");
} else {
block.line("Ok(grouped.into_values().collect())");
}
} else {
block.line("// Transform flat rows into nested structs (Option relations only)");
let mut map_block = Block::new(
"let results: Result<Vec<_>, QueryError> = flat_rows.into_iter().map(|flat_row| {",
);
let mut result_block = Block::new(format!("Ok({struct_name}"));
let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
for (field_name_meta, field_def) in &select_fields.fields {
let field_name = field_name_meta.value.as_str();
match field_def {
None => {
result_block.line(format!("{field_name}: flat_row.{field_name},"));
}
Some(FieldDef::Rel(rel)) => {
if rel.is_first() {
if let Some(rel_fields) = &rel.fields {
let rel_table = rel.table_name().unwrap_or(field_name);
let first_col = rel_fields
.first_column()
.map(|c| c.as_str())
.unwrap_or("id");
let first_alias = format!("{field_name}_{first_col}");
let nested_struct =
format!("{}{}", parent_prefix, to_pascal_case(field_name));
let mut map_inner = Block::new(format!(
"{field_name}: flat_row.{first_alias}.as_ref().map(|_| {nested_struct}"
));
for (inner_field_meta, inner_def) in &rel_fields.fields {
let inner_name = inner_field_meta.value.as_str();
if inner_def.is_none() {
let alias = format!("{field_name}_{inner_name}");
let rust_ty = ctx
.column_type(rel_table, inner_name)
.unwrap_or_else(|| "String".to_string());
if rust_ty.starts_with("Option<") {
map_inner.line(format!(
"{inner_name}: flat_row.{alias}.clone(),"
));
} else {
map_inner.line(format!(
"{inner_name}: flat_row.{alias}.clone().expect(\"non-null column from LEFT JOIN\"),"
));
}
}
}
map_inner.after("),");
result_block.push_block(map_inner);
}
} else {
result_block.line(format!("{field_name}: Vec::new(),"));
}
}
Some(FieldDef::Count(_)) => {
result_block.line(format!("{field_name}: flat_row.{field_name},"));
}
}
}
result_block.after(")");
map_block.push_block(result_block);
map_block.after("}).collect();");
block.push_block(map_block);
block.line("");
if is_first {
block.line("results.map(|mut v| v.pop())");
} else {
block.line("results");
}
}
block_to_string(&block)
}
fn generate_seen_id_declarations(
block: &mut Block,
ctx: &CodegenContext,
select_fields: &SelectFields,
parent_id_type: &str,
prefix: &str,
) {
for (field_name_meta, field_def) in &select_fields.fields {
if let Some(FieldDef::Rel(rel)) = field_def {
let field_name = field_name_meta.value.as_str();
if !rel.is_first() {
if let Some(rel_fields) = &rel.fields {
let rel_table = rel.table_name().unwrap_or(field_name);
let id_col = rel_fields.id_column().map(|c| c.as_str()).unwrap_or("id");
let id_type = ctx
.column_type(rel_table, id_col)
.unwrap_or_else(|| "i64".to_string());
let set_name = if prefix.is_empty() {
format!("seen_{field_name}")
} else {
format!("seen_{prefix}_{field_name}")
};
block.line(format!(
"let mut {set_name}: std::collections::HashSet<({parent_id_type}, {id_type})> = std::collections::HashSet::new();"
));
let new_prefix = if prefix.is_empty() {
field_name.to_string()
} else {
format!("{prefix}_{field_name}")
};
generate_seen_id_declarations(block, ctx, rel_fields, &id_type, &new_prefix);
}
}
}
}
}
fn generate_relation_assembly(
for_block: &mut Block,
ctx: &CodegenContext,
select_fields: &SelectFields,
parent_prefix: &str,
flat_prefix: &str,
_parent_id_type: &str,
) {
for (field_name_meta, field_def) in &select_fields.fields {
if let Some(FieldDef::Rel(rel)) = field_def {
let field_name = field_name_meta.value.as_str();
let rel_table = rel.table_name().unwrap_or(field_name);
let nested_struct = format!("{}{}", parent_prefix, to_pascal_case(field_name));
let flat_field_prefix = if flat_prefix.is_empty() {
field_name.to_string()
} else {
format!("{flat_prefix}_{field_name}")
};
if let Some(rel_fields) = &rel.fields {
let first_col = rel_fields
.first_column()
.map(|c| c.as_str())
.unwrap_or("id");
let id_col = rel_fields
.id_column()
.map(|c| c.as_str())
.unwrap_or(first_col);
let id_alias = format!("{flat_field_prefix}_{id_col}");
if rel.is_first() {
for_block.line(format!("// Populate {field_name} (Option relation)"));
let mut if_block = Block::new(format!(
"if entry.{field_name}.is_none() && flat_row.{id_alias}.is_some()"
));
let mut some_block =
Block::new(format!("entry.{field_name} = Some({nested_struct}"));
generate_relation_fields(
&mut some_block,
ctx,
rel_fields,
rel_table,
&flat_field_prefix,
);
some_block.after(");");
if_block.push_block(some_block);
for_block.push_block(if_block);
for_block.line("");
} else {
let set_name = if flat_prefix.is_empty() {
format!("seen_{field_name}")
} else {
format!("seen_{flat_prefix}_{field_name}")
};
for_block.line(format!("// Append to {field_name} (Vec relation)"));
let mut if_block =
Block::new(format!("if let Some(ref rel_id) = flat_row.{id_alias}"));
if_block.line("let key = (parent_id.clone(), rel_id.clone());".to_string());
let mut if_insert = Block::new(format!("if {set_name}.insert(key)"));
let mut push_block =
Block::new(format!("entry.{field_name}.push({nested_struct}"));
generate_relation_fields(
&mut push_block,
ctx,
rel_fields,
rel_table,
&flat_field_prefix,
);
push_block.after(");");
if_insert.push_block(push_block);
if_block.push_block(if_insert);
for_block.push_block(if_block);
for_block.line("");
}
}
}
}
}
fn generate_relation_fields(
block: &mut Block,
ctx: &CodegenContext,
select_fields: &SelectFields,
table_name: &str,
flat_prefix: &str,
) {
for (field_name_meta, field_def) in &select_fields.fields {
let field_name = field_name_meta.value.as_str();
let alias = format!("{flat_prefix}_{field_name}");
match field_def {
None => {
let rust_ty = ctx
.column_type(table_name, field_name)
.unwrap_or_else(|| "String".to_string());
if rust_ty.starts_with("Option<") {
block.line(format!("{field_name}: flat_row.{alias}.clone(),"));
} else {
block.line(format!(
"{field_name}: flat_row.{alias}.clone().expect(\"non-null from LEFT JOIN\"),"
));
}
}
Some(FieldDef::Rel(rel)) => {
if rel.is_first() {
block.line(format!(
"{field_name}: None, // TODO: nested Option relation"
));
} else {
block.line(format!(
"{field_name}: Vec::new(), // TODO: nested Vec relation"
));
}
}
Some(FieldDef::Count(_)) => {
block.line(format!("{field_name}: flat_row.{alias},"));
}
}
}
}
fn generate_raw_query_body(query: &Select, raw_sql: &str) -> Block {
let cleaned: String = raw_sql
.lines()
.map(|l| l.trim())
.collect::<Vec<_>>()
.join("\n");
let mut block = Block::new("");
block.line(format!("const SQL: &str = r#\"{}\"#;", cleaned.trim()));
block.line("");
if let Some(params) = &query.params {
let param_names: Vec<&str> = params.iter().map(|(meta, _)| meta.value.as_str()).collect();
if !param_names.is_empty() {
let params_str = param_names.join(", ");
block.line(format!(
"let rows = client.query(SQL, &[{}]).await?;",
params_str
));
} else {
block.line("let rows = client.query(SQL, &[]).await?;");
}
} else {
block.line("let rows = client.query(SQL, &[]).await?;");
}
if query.first.is_some() {
let mut match_block = Block::new("match rows.into_iter().next()");
match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
match_block.line("None => Ok(None),");
block.push_block(match_block);
} else {
block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
}
block
}
fn param_type_to_rust(ty: &dibs_query_schema::ParamType) -> String {
use dibs_query_schema::ParamType;
match ty {
ParamType::String => "String".to_string(),
ParamType::Int => "i64".to_string(),
ParamType::Float => "f64".to_string(),
ParamType::Bool => "bool".to_string(),
ParamType::Uuid => "Uuid".to_string(),
ParamType::Decimal => "Decimal".to_string(),
ParamType::Timestamp => "Timestamp".to_string(),
ParamType::Bytes => "Vec<u8>".to_string(),
ParamType::Jsonb => "String".to_string(),
ParamType::Optional(inner_vec) => {
if let Some(inner) = inner_vec.first() {
format!("Option<{}>", param_type_to_rust(inner))
} else {
"Option<String>".to_string()
}
}
}
}
fn block_to_string(block: &Block) -> String {
let mut output = String::new();
let mut formatter = codegen::Formatter::new(&mut output);
block.fmt(&mut formatter).expect("formatting failed");
output
}
fn to_pascal_case(s: &str) -> String {
let mut result = String::new();
let mut capitalize_next = true;
for c in s.chars() {
if c == '_' {
capitalize_next = true;
} else if capitalize_next {
result.push(c.to_ascii_uppercase());
capitalize_next = false;
} else {
result.push(c);
}
}
result
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(c.to_ascii_lowercase());
} else {
result.push(c);
}
}
result
}
fn generate_insert_code(
_ctx: &CodegenContext,
name_meta: &Meta<String>,
insert: &Insert,
scope: &mut Scope,
) {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let generated = crate::sqlgen::generate_insert_sql(insert);
let has_returning = insert.returning.is_some();
let return_ty = if !has_returning {
"Result<u64, QueryError>".to_string()
} else {
let struct_name = format!("{}Result", name);
if let Some(returning) = &insert.returning {
generate_mutation_result_struct(
_ctx,
&struct_name,
insert.into.value.as_str(),
returning,
scope,
);
}
format!("Result<Option<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
if let Some(params) = &insert.params {
for (param_name_meta, param_type) in ¶ms.params {
let param_name = param_name_meta.value.as_str();
let rust_ty = param_type_to_rust(param_type);
func.arg(param_name, format!("&{}", rust_ty));
}
}
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
scope.push_fn(func);
}
fn generate_upsert_code(
_ctx: &CodegenContext,
name_meta: &Meta<String>,
upsert: &Upsert,
scope: &mut Scope,
) {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let generated = crate::sqlgen::generate_upsert_sql(upsert);
let has_returning = upsert.returning.is_some();
let return_ty = if !has_returning {
"Result<u64, QueryError>".to_string()
} else {
let struct_name = format!("{}Result", name);
if let Some(returning) = &upsert.returning {
generate_mutation_result_struct(
_ctx,
&struct_name,
upsert.into.value.as_str(),
returning,
scope,
);
}
format!("Result<Option<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
if let Some(params) = &upsert.params {
for (param_name_meta, param_type) in ¶ms.params {
let param_name = param_name_meta.value.as_str();
let rust_ty = param_type_to_rust(param_type);
func.arg(param_name, format!("&{}", rust_ty));
}
}
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
scope.push_fn(func);
}
fn generate_insert_many_code(
ctx: &CodegenContext,
name_meta: &Meta<String>,
insert: &InsertMany,
scope: &mut Scope,
) {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let generated = crate::sqlgen::generate_insert_many_sql(insert);
let params_struct_name = format!("{}Params", name);
if let Some(params) = &insert.params {
generate_bulk_params_struct(
ctx,
¶ms_struct_name,
insert.into.value.as_str(),
params,
scope,
);
}
let has_returning = insert.returning.is_some();
let return_ty = if !has_returning {
"Result<u64, QueryError>".to_string()
} else {
let struct_name = format!("{}Result", name);
if let Some(returning) = &insert.returning {
generate_mutation_result_struct(
ctx,
&struct_name,
insert.into.value.as_str(),
returning,
scope,
);
}
format!("Result<Vec<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
func.arg("items", format!("&[{}]", params_struct_name));
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = generate_bulk_mutation_body(&generated.sql, insert.params.as_ref(), !has_returning);
func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
scope.push_fn(func);
}
fn generate_upsert_many_code(
ctx: &CodegenContext,
name_meta: &Meta<String>,
upsert: &UpsertMany,
scope: &mut Scope,
) {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let generated = crate::sqlgen::generate_upsert_many_sql(upsert);
let params_struct_name = format!("{}Params", name);
if let Some(params) = &upsert.params {
generate_bulk_params_struct(
ctx,
¶ms_struct_name,
upsert.into.value.as_str(),
params,
scope,
);
}
let has_returning = upsert.returning.is_some();
let return_ty = if !has_returning {
"Result<u64, QueryError>".to_string()
} else {
let struct_name = format!("{}Result", name);
if let Some(returning) = &upsert.returning {
generate_mutation_result_struct(
ctx,
&struct_name,
upsert.into.value.as_str(),
returning,
scope,
);
}
format!("Result<Vec<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
func.arg("items", format!("&[{}]", params_struct_name));
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = generate_bulk_mutation_body(&generated.sql, upsert.params.as_ref(), !has_returning);
func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
scope.push_fn(func);
}
fn generate_bulk_params_struct(
ctx: &CodegenContext,
struct_name: &str,
table: &str,
params: &Params,
scope: &mut Scope,
) {
let mut st = Struct::new(struct_name);
st.vis("pub");
st.derive("Debug");
st.derive("Clone");
for (param_name_meta, param_type) in ¶ms.params {
let param_name = param_name_meta.value.as_str();
let rust_ty = ctx
.column_type(table, param_name)
.unwrap_or_else(|| param_type_to_rust(param_type));
st.field(format!("pub {}", param_name), &rust_ty);
}
scope.push_struct(st);
}
fn generate_bulk_mutation_body(sql: &str, params: Option<&Params>, execute_only: bool) -> Block {
let mut block = Block::new("");
block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
block.line("");
if let Some(params) = params {
block.line("// Convert items to parallel arrays for UNNEST");
for (param_name_meta, param_type) in ¶ms.params {
let param_name = param_name_meta.value.as_str();
let rust_ty = param_type_to_rust(param_type);
block.line(format!(
"let {}_arr: Vec<{}> = items.iter().map(|i| i.{}.clone()).collect();",
param_name, rust_ty, param_name
));
}
block.line("");
let param_refs: Vec<String> = params
.params
.keys()
.map(|p| format!("&{}_arr", p.value))
.collect();
if execute_only {
block.line(format!(
"let affected = client.execute(SQL, &[{}]).await?;",
param_refs.join(", ")
));
block.line("Ok(affected)");
} else {
block.line(format!(
"let rows = client.query(SQL, &[{}]).await?;",
param_refs.join(", ")
));
block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
}
}
block
}
fn generate_update_code(
ctx: &CodegenContext,
name_meta: &Meta<String>,
update: &Update,
scope: &mut Scope,
) -> Result<(), QError> {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let sqlgen_ctx = ctx.sqlgen_ctx();
let generated = crate::sqlgen::generate_update_sql(&sqlgen_ctx, update)?;
let has_returning = update.returning.is_some();
let return_ty = if !has_returning {
"Result<u64, QueryError>".to_string()
} else {
let struct_name = format!("{}Result", name);
if let Some(returning) = &update.returning {
generate_mutation_result_struct(
ctx,
&struct_name,
update.table.value.as_str(),
returning,
scope,
);
}
format!("Result<Option<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
if let Some(params) = &update.params {
for (param_name_meta, param_type) in ¶ms.params {
let param_name = param_name_meta.value.as_str();
let rust_ty = param_type_to_rust(param_type);
func.arg(param_name, format!("&{}", rust_ty));
}
}
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
scope.push_fn(func);
Ok(())
}
fn generate_delete_code(
ctx: &CodegenContext,
name_meta: &Meta<String>,
delete: &Delete,
scope: &mut Scope,
) -> Result<(), QError> {
let name = &name_meta.value;
let fn_name = to_snake_case(name);
let sqlgen_ctx = ctx.sqlgen_ctx();
let generated = crate::sqlgen::generate_delete_sql(&sqlgen_ctx, delete)?;
let has_returning = delete.returning.is_some();
let return_ty = if !has_returning {
"Result<u64, QueryError>".to_string()
} else {
let struct_name = format!("{}Result", name);
if let Some(returning) = &delete.returning {
generate_mutation_result_struct(
ctx,
&struct_name,
delete.from.value.as_str(),
returning,
scope,
);
}
format!("Result<Option<{}>, QueryError>", struct_name)
};
let mut func = Function::new(&fn_name);
if let Some(doc) = &name_meta.doc {
let doc_str = doc.join("\n");
func.doc(&doc_str);
}
func.vis("pub");
func.set_async(true);
func.generic("C");
func.arg("client", "&C");
if let Some(params) = &delete.params {
for (param_name_meta, param_type) in ¶ms.params {
let param_name = param_name_meta.value.as_str();
let rust_ty = param_type_to_rust(param_type);
func.arg(param_name, format!("&{}", rust_ty));
}
}
func.ret(&return_ty);
func.bound("C", "tokio_postgres::GenericClient");
let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
scope.push_fn(func);
Ok(())
}
fn generate_mutation_result_struct(
ctx: &CodegenContext,
struct_name: &str,
table: &str,
returning: &Returning,
scope: &mut Scope,
) {
let mut st = Struct::new(struct_name);
st.vis("pub");
st.derive("Debug");
st.derive("Clone");
st.derive("Facet");
st.attr("facet(crate = dibs_runtime::facet)");
for (col_name_meta, _) in &returning.columns {
let col_name = col_name_meta.value.as_str();
let rust_ty = ctx
.column_type(table, col_name)
.unwrap_or_else(|| "String".to_string());
st.field(format!("pub {col_name}"), &rust_ty);
}
scope.push_struct(st);
}
fn generate_mutation_body(
sql: &str,
param_order: &[dibs_sql::ParamName],
execute_only: bool,
) -> Block {
let mut block = Block::new("");
block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
block.line("");
let params: Vec<_> = param_order
.iter()
.filter(|p| !p.as_str().starts_with("__literal_"))
.collect();
if execute_only {
if params.is_empty() {
block.line("let affected = client.execute(SQL, &[]).await?;");
} else {
let params_str = params
.iter()
.map(|p| p.as_str())
.collect::<Vec<_>>()
.join(", ");
block.line(format!(
"let affected = client.execute(SQL, &[{}]).await?;",
params_str
));
}
block.line("Ok(affected)");
} else {
if params.is_empty() {
block.line("let rows = client.query(SQL, &[]).await?;");
} else {
let params_str = params
.iter()
.map(|p| p.as_str())
.collect::<Vec<_>>()
.join(", ");
block.line(format!(
"let rows = client.query(SQL, &[{}]).await?;",
params_str
));
}
let mut match_block = Block::new("match rows.into_iter().next()");
match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
match_block.line("None => Ok(None),");
block.push_block(match_block);
}
block
}
#[cfg(test)]
mod tests;