1use crate::sqlgen::SqlGenContext;
14use crate::{QError, QSource};
15use codegen::{Block, Function, Scope, Struct};
16use dibs_db_schema::Schema;
17use dibs_query_schema::{
18 Decl, Delete, FieldDef, Insert, InsertMany, Meta, Params, QueryFile, Returning, Returns,
19 Select, SelectFields, Update, Upsert, UpsertMany,
20};
21use std::sync::Arc;
22
23#[derive(Debug, Clone)]
25pub struct GeneratedCode {
26 pub code: String,
28}
29
30fn wrap_with_trace_err(body: &str, fn_name: &str) -> String {
43 format!(
44 "let __dibs_result = async {{\n{body}\n}}.await;\n\
45 <_ as TraceErr>::trace_err(__dibs_result, \"{fn_name}\")"
46 )
47}
48
49fn schema_column_type(schema: &Schema, table: &str, column: &str) -> Option<String> {
51 let table_info = schema.get_table(table)?;
52 let col = table_info.columns.iter().find(|c| c.name == column)?;
53 let rust_type = col
54 .rust_type
55 .clone()
56 .unwrap_or_else(|| col.pg_type.to_rust_type().to_string());
57 if col.nullable {
58 Some(format!("Option<{}>", rust_type))
59 } else {
60 Some(rust_type)
61 }
62}
63
64struct CodegenContext<'a> {
66 schema: &'a Schema,
67 source: Arc<QSource>,
68 #[allow(dead_code)]
69 scope: Scope,
70}
71
72impl CodegenContext<'_> {
73 fn column_type(&self, table: &str, column: &str) -> Option<String> {
75 schema_column_type(self.schema, table, column)
76 }
77
78 fn sqlgen_ctx(&self) -> SqlGenContext<'_> {
80 SqlGenContext::new(self.schema, self.source.clone())
81 }
82}
83
84pub fn generate_rust_code(
86 file: &QueryFile,
87 schema: &Schema,
88 source: Arc<QSource>,
89) -> Result<GeneratedCode, QError> {
90 let mut scope = Scope::new();
91
92 scope.raw("// Generated by dibs-qgen. Do not edit.");
94 scope.raw("");
95
96 scope.import("dibs_runtime::prelude", "*");
98 scope.import("dibs_runtime", "tokio_postgres");
99
100 let ctx = CodegenContext {
101 schema,
102 source,
103 scope: Scope::new(),
104 };
105
106 for (name_meta, decl) in &file.0 {
108 match decl {
109 Decl::Select(select) => {
110 generate_select_code(&ctx, name_meta, select, &mut scope);
111 }
112 Decl::Insert(insert) => {
113 generate_insert_code(&ctx, name_meta, insert, &mut scope);
114 }
115 Decl::InsertMany(insert_many) => {
116 generate_insert_many_code(&ctx, name_meta, insert_many, &mut scope);
117 }
118 Decl::Upsert(upsert) => {
119 generate_upsert_code(&ctx, name_meta, upsert, &mut scope);
120 }
121 Decl::UpsertMany(upsert_many) => {
122 generate_upsert_many_code(&ctx, name_meta, upsert_many, &mut scope);
123 }
124 Decl::Update(update) => {
125 generate_update_code(&ctx, name_meta, update, &mut scope)?;
126 }
127 Decl::Delete(delete) => {
128 generate_delete_code(&ctx, name_meta, delete, &mut scope)?;
129 }
130 }
131 }
132
133 Ok(GeneratedCode {
134 code: scope.to_string(),
135 })
136}
137
138fn generate_select_code(
139 ctx: &CodegenContext,
140 name_meta: &Meta<String>,
141 select: &Select,
142 scope: &mut Scope,
143) {
144 let name = &name_meta.value;
145 let struct_name = format!("{}Result", name);
146
147 if let Some(from) = &select.from {
149 if select.fields.is_some() {
150 generate_result_struct(ctx, select, name_meta, &struct_name, from, scope);
151
152 if select.has_relations() {
154 let flat_struct_name = format!("{}Row", name);
155 generate_flat_row_struct(ctx, select, &flat_struct_name, from, scope);
156 }
157 }
158 } else if let Some(returns) = &select.returns {
159 generate_raw_sql_result_struct(&struct_name, returns, scope);
161 }
162
163 generate_select_function(ctx, name_meta, select, &struct_name, scope);
165}
166
167fn generate_flat_row_struct(
190 ctx: &CodegenContext,
191 select: &Select,
192 struct_name: &str,
193 table: &Meta<dibs_sql::TableName>,
194 scope: &mut Scope,
195) {
196 let mut st = Struct::new(struct_name);
197 st.derive("Debug");
199 st.derive("Clone");
200 st.derive("Facet");
201 st.attr("facet(crate = dibs_runtime::facet)");
202
203 let table_name = table.value.as_str();
204
205 if let Some(select_fields) = &select.fields {
206 add_flat_fields_for_select(ctx, &mut st, table_name, "", select_fields);
208 }
209
210 scope.push_struct(st);
211}
212
213fn add_flat_fields_for_select(
215 ctx: &CodegenContext,
216 st: &mut Struct,
217 table_name: &str,
218 prefix: &str,
219 select_fields: &SelectFields,
220) {
221 for (field_name_meta, field_def) in &select_fields.fields {
222 let field_name = field_name_meta.value.as_str();
223
224 match field_def {
225 None => {
226 let rust_ty = ctx
228 .column_type(table_name, field_name)
229 .unwrap_or_else(|| "String".to_string());
230
231 let flat_field_name = if prefix.is_empty() {
232 field_name.to_string()
233 } else {
234 format!("{}_{}", prefix, field_name)
235 };
236
237 let final_ty = if prefix.is_empty() {
239 rust_ty
240 } else if rust_ty.starts_with("Option<") {
241 rust_ty
243 } else {
244 format!("Option<{}>", rust_ty)
245 };
246
247 st.field(&flat_field_name, &final_ty);
249 }
250 Some(FieldDef::Rel(rel)) => {
251 let rel_table = rel.table_name().unwrap_or(field_name);
253 let new_prefix = if prefix.is_empty() {
254 field_name.to_string()
255 } else {
256 format!("{}_{}", prefix, field_name)
257 };
258
259 if let Some(rel_fields) = &rel.fields {
260 add_flat_fields_for_select(ctx, st, rel_table, &new_prefix, rel_fields);
261 }
262 }
263 Some(FieldDef::Count(_)) => {
264 let flat_field_name = if prefix.is_empty() {
266 field_name.to_string()
267 } else {
268 format!("{}_{}", prefix, field_name)
269 };
270 st.field(&flat_field_name, "i64");
271 }
272 }
273 }
274}
275
276fn generate_raw_sql_result_struct(struct_name: &str, returns: &Returns, scope: &mut Scope) {
277 let mut st = Struct::new(struct_name);
278 st.vis("pub");
279 st.derive("Debug");
280 st.derive("Clone");
281 st.derive("Facet");
282 st.attr("facet(crate = dibs_runtime::facet)");
283
284 for (field_name_meta, param_type) in &returns.fields {
285 let field_name = field_name_meta.value.as_str();
286 let rust_ty = param_type_to_rust(param_type);
287 st.field(format!("pub {}", field_name), &rust_ty);
288 }
289
290 scope.push_struct(st);
291}
292
293fn generate_result_struct(
294 ctx: &CodegenContext,
295 select: &Select,
296 name_meta: &Meta<String>,
297 struct_name: &str,
298 table: &Meta<dibs_sql::TableName>,
299 scope: &mut Scope,
300) {
301 let mut st = Struct::new(struct_name);
302 st.vis("pub");
303 st.derive("Debug");
304 st.derive("Clone");
305 st.derive("Facet");
306 st.attr("facet(crate = dibs_runtime::facet)");
307
308 let parent_prefix = &name_meta.value;
310 let table_name = table.value.as_str();
311
312 if let Some(select_fields) = &select.fields {
313 for (field_name_meta, field_def) in &select_fields.fields {
314 let field_name = field_name_meta.value.as_str();
315 match field_def {
316 None => {
317 let rust_ty = ctx
319 .column_type(table_name, field_name)
320 .unwrap_or_else(|| "String".to_string());
321 st.field(format!("pub {}", field_name), &rust_ty);
322 }
323 Some(FieldDef::Rel(rel)) => {
324 let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
325 let ty = if rel.first.is_some() {
326 format!("Option<{}>", nested_name)
327 } else {
328 format!("Vec<{}>", nested_name)
329 };
330 st.field(format!("pub {}", field_name), &ty);
331 }
332 Some(FieldDef::Count(_)) => {
333 st.field(format!("pub {}", field_name), "i64");
334 }
335 }
336 }
337 }
338
339 scope.push_struct(st);
340
341 if let Some(select_fields) = &select.fields {
343 generate_nested_structs(ctx, parent_prefix, select_fields, scope);
344 }
345}
346
347fn generate_nested_structs(
352 ctx: &CodegenContext,
353 parent_prefix: &str,
354 select_fields: &SelectFields,
355 scope: &mut Scope,
356) {
357 for (field_name_meta, field_def) in &select_fields.fields {
358 if let Some(FieldDef::Rel(rel)) = field_def {
359 let field_name = field_name_meta.value.as_str();
360 let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
361 let rel_table = rel.table_name().unwrap_or(field_name);
362
363 let mut nested_st = Struct::new(&nested_name);
364 nested_st.vis("pub");
365 nested_st.derive("Debug");
366 nested_st.derive("Clone");
367 nested_st.derive("Facet");
368 nested_st.attr("facet(crate = dibs_runtime::facet)");
369
370 if let Some(rel_fields) = &rel.fields {
371 for (rel_field_name_meta, rel_field_def) in &rel_fields.fields {
372 let rel_field_name = rel_field_name_meta.value.as_str();
373 match rel_field_def {
374 None => {
375 let rust_ty = ctx
377 .column_type(rel_table, rel_field_name)
378 .unwrap_or_else(|| "String".to_string());
379 nested_st.field(format!("pub {}", rel_field_name), &rust_ty);
380 }
381 Some(FieldDef::Rel(nested_rel)) => {
382 let nested_rel_name =
384 format!("{}{}", nested_name, to_pascal_case(rel_field_name));
385 let ty = if nested_rel.first.is_some() {
386 format!("Option<{}>", nested_rel_name)
387 } else {
388 format!("Vec<{}>", nested_rel_name)
389 };
390 nested_st.field(format!("pub {}", rel_field_name), &ty);
391 }
392 Some(FieldDef::Count(_)) => {
393 nested_st.field(format!("pub {}", rel_field_name), "i64");
394 }
395 }
396 }
397 }
398
399 scope.push_struct(nested_st);
400
401 if let Some(rel_fields) = &rel.fields {
403 generate_nested_structs(ctx, &nested_name, rel_fields, scope);
404 }
405 }
406 }
407}
408
409fn generate_select_function(
410 ctx: &CodegenContext,
411 name_meta: &Meta<String>,
412 query: &Select,
413 struct_name: &str,
414 scope: &mut Scope,
415) {
416 let name = &name_meta.value;
417 let fn_name = to_snake_case(name);
418
419 let return_ty = if query.first.is_some() {
420 format!("Result<Option<{}>, QueryError>", struct_name)
421 } else {
422 format!("Result<Vec<{}>, QueryError>", struct_name)
423 };
424
425 let mut func = Function::new(&fn_name);
426 if let Some(doc) = &name_meta.doc {
427 let doc_str = doc.join("\n");
428 func.doc(&doc_str);
429 }
430 func.vis("pub");
431 func.set_async(true);
432 func.generic("C");
433 func.arg("client", "&C");
434 func.attr("allow(clippy::clone_on_copy)");
436
437 if let Some(params) = &query.params {
438 for (param_name_meta, param_type) in ¶ms.params {
439 let param_name = ¶m_name_meta.value;
440 let rust_ty = param_type_to_rust(param_type);
441 func.arg(param_name, format!("&{}", rust_ty));
442 }
443 }
444
445 func.ret(&return_ty);
446 func.bound("C", "tokio_postgres::GenericClient");
447
448 let body = if let Some(raw_sql_meta) = &query.sql {
450 block_to_string(&generate_raw_query_body(query, &raw_sql_meta.value))
451 } else {
452 generate_query_body(ctx, query, struct_name)
453 };
454 func.line(wrap_with_trace_err(&body, &fn_name));
455
456 scope.push_fn(func);
457}
458
459fn generate_query_body(ctx: &CodegenContext, query: &Select, struct_name: &str) -> String {
464 let sqlgen_ctx = ctx.sqlgen_ctx();
465 let generated = match crate::sqlgen::generate_select_sql(&sqlgen_ctx, query) {
466 Ok(g) => g,
467 Err(e) => {
468 panic!("SELECT SQL generation failed: {}", e);
469 }
470 };
471
472 let mut block = Block::new("");
473
474 block.line(format!("const SQL: &str = r#\"{}\"#;", generated.sql));
476 block.line("");
477
478 let params: Vec<_> = generated
480 .param_order
481 .iter()
482 .filter(|p| !p.as_str().starts_with("__literal_"))
483 .collect();
484
485 if params.is_empty() {
486 block.line("let rows = client.query(SQL, &[]).await?;");
487 } else {
488 let params_str = params
489 .iter()
490 .map(|p| p.as_str())
491 .collect::<Vec<_>>()
492 .join(", ");
493 block.line(format!(
494 "let rows = client.query(SQL, &[{}]).await?;",
495 params_str
496 ));
497 }
498
499 if !query.has_relations() {
501 if query.first.is_some() {
502 let mut match_block = Block::new("match rows.into_iter().next()");
503 match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
504 match_block.line("None => Ok(None),");
505 block.push_block(match_block);
506 } else {
507 block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
508 }
509 return block_to_string(&block);
510 }
511
512 let query_name = struct_name.strip_suffix("Result").unwrap_or(struct_name);
514 let flat_struct_name = format!("{}Row", query_name);
515
516 block.line("");
517 block.line("// Deserialize all rows into flat structs using facet reflection");
518 block.line(format!(
519 "let flat_rows: Vec<{flat_struct_name}> = rows.iter().map(from_row).collect::<Result<Vec<_>, _>>()?;"
520 ));
521 block.line("");
522
523 let Some(select_fields) = &query.fields else {
525 block.line("Ok(vec![])".to_string());
527 return block_to_string(&block);
528 };
529
530 let root_table = query
531 .from
532 .as_ref()
533 .map(|m| m.value.as_str())
534 .unwrap_or("unknown");
535 let is_first = query.is_first();
536
537 block.line(generate_flat_to_nested_transform(
538 ctx,
539 select_fields,
540 struct_name,
541 root_table,
542 is_first,
543 ));
544
545 block_to_string(&block)
546}
547
548fn generate_flat_to_nested_transform(
550 ctx: &CodegenContext,
551 select_fields: &SelectFields,
552 struct_name: &str,
553 root_table: &str,
554 is_first: bool,
555) -> String {
556 let mut block = Block::new("");
557
558 let id_column = select_fields
560 .id_column()
561 .map(|c| c.to_string())
562 .unwrap_or_else(|| "id".to_string());
563
564 let id_type = ctx
565 .column_type(root_table, &id_column)
566 .unwrap_or_else(|| "i64".to_string());
567
568 if select_fields.has_vec_relations() {
570 block.line("// Group flat rows by parent ID and assemble nested structs");
572 block.line(format!(
573 "let mut grouped: std::collections::HashMap<{id_type}, {struct_name}> = std::collections::HashMap::new();"
574 ));
575
576 generate_seen_id_declarations(&mut block, ctx, select_fields, &id_type, "");
578
579 block.line("");
580
581 let mut for_block = Block::new("for flat_row in flat_rows");
582 for_block.line(format!("let parent_id = flat_row.{id_column}.clone();"));
583 for_block.line("");
584
585 let mut entry_block = Block::new(format!(
587 "let entry = grouped.entry(parent_id.clone()).or_insert_with(|| {struct_name}"
588 ));
589
590 for (field_name_meta, field_def) in &select_fields.fields {
592 let field_name = field_name_meta.value.as_str();
593 match field_def {
594 None => {
595 entry_block.line(format!("{field_name}: flat_row.{field_name}.clone(),"));
596 }
597 Some(FieldDef::Rel(rel)) => {
598 if rel.is_first() {
599 entry_block.line(format!("{field_name}: None,"));
600 } else {
601 entry_block.line(format!("{field_name}: Vec::new(),"));
602 }
603 }
604 Some(FieldDef::Count(_)) => {
605 entry_block.line(format!("{field_name}: flat_row.{field_name},"));
606 }
607 }
608 }
609 entry_block.after(");");
610 for_block.push_block(entry_block);
611 for_block.line("");
612
613 let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
615 generate_relation_assembly(
616 &mut for_block,
617 ctx,
618 select_fields,
619 parent_prefix,
620 "",
621 &id_type,
622 );
623
624 block.push_block(for_block);
625 block.line("");
626
627 if is_first {
628 block.line("Ok(grouped.into_values().next())");
629 } else {
630 block.line("Ok(grouped.into_values().collect())");
631 }
632 } else {
633 block.line("// Transform flat rows into nested structs (Option relations only)");
635
636 let mut map_block = Block::new(
637 "let results: Result<Vec<_>, QueryError> = flat_rows.into_iter().map(|flat_row| {",
638 );
639
640 let mut result_block = Block::new(format!("Ok({struct_name}"));
641 let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
642
643 for (field_name_meta, field_def) in &select_fields.fields {
644 let field_name = field_name_meta.value.as_str();
645 match field_def {
646 None => {
647 result_block.line(format!("{field_name}: flat_row.{field_name},"));
648 }
649 Some(FieldDef::Rel(rel)) => {
650 if rel.is_first() {
651 if let Some(rel_fields) = &rel.fields {
653 let rel_table = rel.table_name().unwrap_or(field_name);
654 let first_col = rel_fields
655 .first_column()
656 .map(|c| c.as_str())
657 .unwrap_or("id");
658 let first_alias = format!("{field_name}_{first_col}");
659 let nested_struct =
660 format!("{}{}", parent_prefix, to_pascal_case(field_name));
661
662 let mut map_inner = Block::new(format!(
663 "{field_name}: flat_row.{first_alias}.as_ref().map(|_| {nested_struct}"
664 ));
665
666 for (inner_field_meta, inner_def) in &rel_fields.fields {
667 let inner_name = inner_field_meta.value.as_str();
668 if inner_def.is_none() {
669 let alias = format!("{field_name}_{inner_name}");
670 let rust_ty = ctx
671 .column_type(rel_table, inner_name)
672 .unwrap_or_else(|| "String".to_string());
673
674 if rust_ty.starts_with("Option<") {
676 map_inner.line(format!(
677 "{inner_name}: flat_row.{alias}.clone(),"
678 ));
679 } else {
680 map_inner.line(format!(
681 "{inner_name}: flat_row.{alias}.clone().expect(\"non-null column from LEFT JOIN\"),"
682 ));
683 }
684 }
685 }
686
687 map_inner.after("),");
688 result_block.push_block(map_inner);
689 }
690 } else {
691 result_block.line(format!("{field_name}: Vec::new(),"));
693 }
694 }
695 Some(FieldDef::Count(_)) => {
696 result_block.line(format!("{field_name}: flat_row.{field_name},"));
697 }
698 }
699 }
700
701 result_block.after(")");
702 map_block.push_block(result_block);
703 map_block.after("}).collect();");
704 block.push_block(map_block);
705 block.line("");
706
707 if is_first {
708 block.line("results.map(|mut v| v.pop())");
709 } else {
710 block.line("results");
711 }
712 }
713
714 block_to_string(&block)
715}
716
717fn generate_seen_id_declarations(
719 block: &mut Block,
720 ctx: &CodegenContext,
721 select_fields: &SelectFields,
722 parent_id_type: &str,
723 prefix: &str,
724) {
725 for (field_name_meta, field_def) in &select_fields.fields {
726 if let Some(FieldDef::Rel(rel)) = field_def {
727 let field_name = field_name_meta.value.as_str();
728 if !rel.is_first() {
729 if let Some(rel_fields) = &rel.fields {
731 let rel_table = rel.table_name().unwrap_or(field_name);
732 let id_col = rel_fields.id_column().map(|c| c.as_str()).unwrap_or("id");
733 let id_type = ctx
734 .column_type(rel_table, id_col)
735 .unwrap_or_else(|| "i64".to_string());
736
737 let set_name = if prefix.is_empty() {
738 format!("seen_{field_name}")
739 } else {
740 format!("seen_{prefix}_{field_name}")
741 };
742
743 block.line(format!(
744 "let mut {set_name}: std::collections::HashSet<({parent_id_type}, {id_type})> = std::collections::HashSet::new();"
745 ));
746
747 let new_prefix = if prefix.is_empty() {
749 field_name.to_string()
750 } else {
751 format!("{prefix}_{field_name}")
752 };
753
754 generate_seen_id_declarations(block, ctx, rel_fields, &id_type, &new_prefix);
756 }
757 }
758 }
759 }
760}
761
762fn generate_relation_assembly(
764 for_block: &mut Block,
765 ctx: &CodegenContext,
766 select_fields: &SelectFields,
767 parent_prefix: &str,
768 flat_prefix: &str,
769 _parent_id_type: &str,
770) {
771 for (field_name_meta, field_def) in &select_fields.fields {
772 if let Some(FieldDef::Rel(rel)) = field_def {
773 let field_name = field_name_meta.value.as_str();
774 let rel_table = rel.table_name().unwrap_or(field_name);
775 let nested_struct = format!("{}{}", parent_prefix, to_pascal_case(field_name));
776
777 let flat_field_prefix = if flat_prefix.is_empty() {
778 field_name.to_string()
779 } else {
780 format!("{flat_prefix}_{field_name}")
781 };
782
783 if let Some(rel_fields) = &rel.fields {
784 let first_col = rel_fields
785 .first_column()
786 .map(|c| c.as_str())
787 .unwrap_or("id");
788 let id_col = rel_fields
789 .id_column()
790 .map(|c| c.as_str())
791 .unwrap_or(first_col);
792 let id_alias = format!("{flat_field_prefix}_{id_col}");
793
794 if rel.is_first() {
795 for_block.line(format!("// Populate {field_name} (Option relation)"));
797
798 let mut if_block = Block::new(format!(
799 "if entry.{field_name}.is_none() && flat_row.{id_alias}.is_some()"
800 ));
801
802 let mut some_block =
803 Block::new(format!("entry.{field_name} = Some({nested_struct}"));
804 generate_relation_fields(
805 &mut some_block,
806 ctx,
807 rel_fields,
808 rel_table,
809 &flat_field_prefix,
810 );
811 some_block.after(");");
812 if_block.push_block(some_block);
813
814 for_block.push_block(if_block);
815 for_block.line("");
816 } else {
817 let set_name = if flat_prefix.is_empty() {
819 format!("seen_{field_name}")
820 } else {
821 format!("seen_{flat_prefix}_{field_name}")
822 };
823
824 for_block.line(format!("// Append to {field_name} (Vec relation)"));
825
826 let mut if_block =
827 Block::new(format!("if let Some(ref rel_id) = flat_row.{id_alias}"));
828 if_block.line("let key = (parent_id.clone(), rel_id.clone());".to_string());
829
830 let mut if_insert = Block::new(format!("if {set_name}.insert(key)"));
831 let mut push_block =
832 Block::new(format!("entry.{field_name}.push({nested_struct}"));
833 generate_relation_fields(
834 &mut push_block,
835 ctx,
836 rel_fields,
837 rel_table,
838 &flat_field_prefix,
839 );
840 push_block.after(");");
841 if_insert.push_block(push_block);
842
843 if_block.push_block(if_insert);
844 for_block.push_block(if_block);
845 for_block.line("");
846 }
847 }
848 }
849 }
850}
851
852fn generate_relation_fields(
854 block: &mut Block,
855 ctx: &CodegenContext,
856 select_fields: &SelectFields,
857 table_name: &str,
858 flat_prefix: &str,
859) {
860 for (field_name_meta, field_def) in &select_fields.fields {
861 let field_name = field_name_meta.value.as_str();
862 let alias = format!("{flat_prefix}_{field_name}");
863
864 match field_def {
865 None => {
866 let rust_ty = ctx
867 .column_type(table_name, field_name)
868 .unwrap_or_else(|| "String".to_string());
869
870 if rust_ty.starts_with("Option<") {
873 block.line(format!("{field_name}: flat_row.{alias}.clone(),"));
874 } else {
875 block.line(format!(
876 "{field_name}: flat_row.{alias}.clone().expect(\"non-null from LEFT JOIN\"),"
877 ));
878 }
879 }
880 Some(FieldDef::Rel(rel)) => {
881 if rel.is_first() {
882 block.line(format!(
883 "{field_name}: None, // TODO: nested Option relation"
884 ));
885 } else {
886 block.line(format!(
887 "{field_name}: Vec::new(), // TODO: nested Vec relation"
888 ));
889 }
890 }
891 Some(FieldDef::Count(_)) => {
892 block.line(format!("{field_name}: flat_row.{alias},"));
893 }
894 }
895 }
896}
897
898fn generate_raw_query_body(query: &Select, raw_sql: &str) -> Block {
899 let cleaned: String = raw_sql
900 .lines()
901 .map(|l| l.trim())
902 .collect::<Vec<_>>()
903 .join("\n");
904
905 let mut block = Block::new("");
906
907 block.line(format!("const SQL: &str = r#\"{}\"#;", cleaned.trim()));
909 block.line("");
910
911 if let Some(params) = &query.params {
913 let param_names: Vec<&str> = params.iter().map(|(meta, _)| meta.value.as_str()).collect();
914 if !param_names.is_empty() {
915 let params_str = param_names.join(", ");
916 block.line(format!(
917 "let rows = client.query(SQL, &[{}]).await?;",
918 params_str
919 ));
920 } else {
921 block.line("let rows = client.query(SQL, &[]).await?;");
922 }
923 } else {
924 block.line("let rows = client.query(SQL, &[]).await?;");
925 }
926
927 if query.first.is_some() {
929 let mut match_block = Block::new("match rows.into_iter().next()");
930 match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
931 match_block.line("None => Ok(None),");
932 block.push_block(match_block);
933 } else {
934 block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
935 }
936
937 block
938}
939
940fn param_type_to_rust(ty: &dibs_query_schema::ParamType) -> String {
941 use dibs_query_schema::ParamType;
942 match ty {
943 ParamType::String => "String".to_string(),
944 ParamType::Int => "i64".to_string(),
945 ParamType::Float => "f64".to_string(),
946 ParamType::Bool => "bool".to_string(),
947 ParamType::Uuid => "Uuid".to_string(),
948 ParamType::Decimal => "Decimal".to_string(),
949 ParamType::Timestamp => "Timestamp".to_string(),
950 ParamType::Bytes => "Vec<u8>".to_string(),
951 ParamType::Jsonb => "String".to_string(),
957 ParamType::Optional(inner_vec) => {
958 if let Some(inner) = inner_vec.first() {
959 format!("Option<{}>", param_type_to_rust(inner))
960 } else {
961 "Option<String>".to_string()
962 }
963 }
964 }
965}
966
967fn block_to_string(block: &Block) -> String {
969 let mut output = String::new();
970 let mut formatter = codegen::Formatter::new(&mut output);
971 block.fmt(&mut formatter).expect("formatting failed");
972 output
973}
974
975fn to_pascal_case(s: &str) -> String {
976 let mut result = String::new();
977 let mut capitalize_next = true;
978
979 for c in s.chars() {
980 if c == '_' {
981 capitalize_next = true;
982 } else if capitalize_next {
983 result.push(c.to_ascii_uppercase());
984 capitalize_next = false;
985 } else {
986 result.push(c);
987 }
988 }
989
990 result
991}
992
993fn to_snake_case(s: &str) -> String {
994 let mut result = String::new();
995
996 for (i, c) in s.chars().enumerate() {
997 if c.is_uppercase() {
998 if i > 0 {
999 result.push('_');
1000 }
1001 result.push(c.to_ascii_lowercase());
1002 } else {
1003 result.push(c);
1004 }
1005 }
1006
1007 result
1008}
1009
1010fn generate_insert_code(
1015 _ctx: &CodegenContext,
1016 name_meta: &Meta<String>,
1017 insert: &Insert,
1018 scope: &mut Scope,
1019) {
1020 let name = &name_meta.value;
1021 let fn_name = to_snake_case(name);
1022 let generated = crate::sqlgen::generate_insert_sql(insert);
1023
1024 let has_returning = insert.returning.is_some();
1026 let return_ty = if !has_returning {
1027 "Result<u64, QueryError>".to_string()
1028 } else {
1029 let struct_name = format!("{}Result", name);
1030 if let Some(returning) = &insert.returning {
1031 generate_mutation_result_struct(
1032 _ctx,
1033 &struct_name,
1034 insert.into.value.as_str(),
1035 returning,
1036 scope,
1037 );
1038 }
1039 format!("Result<Option<{}>, QueryError>", struct_name)
1040 };
1041
1042 let mut func = Function::new(&fn_name);
1043 if let Some(doc) = &name_meta.doc {
1044 let doc_str = doc.join("\n");
1045 func.doc(&doc_str);
1046 }
1047 func.vis("pub");
1048 func.set_async(true);
1049 func.generic("C");
1050 func.arg("client", "&C");
1051
1052 if let Some(params) = &insert.params {
1053 for (param_name_meta, param_type) in ¶ms.params {
1054 let param_name = param_name_meta.value.as_str();
1055 let rust_ty = param_type_to_rust(param_type);
1056 func.arg(param_name, format!("&{}", rust_ty));
1057 }
1058 }
1059
1060 func.ret(&return_ty);
1061 func.bound("C", "tokio_postgres::GenericClient");
1062
1063 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1064 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1065
1066 scope.push_fn(func);
1067}
1068
1069fn generate_upsert_code(
1070 _ctx: &CodegenContext,
1071 name_meta: &Meta<String>,
1072 upsert: &Upsert,
1073 scope: &mut Scope,
1074) {
1075 let name = &name_meta.value;
1076 let fn_name = to_snake_case(name);
1077 let generated = crate::sqlgen::generate_upsert_sql(upsert);
1078
1079 let has_returning = upsert.returning.is_some();
1080 let return_ty = if !has_returning {
1081 "Result<u64, QueryError>".to_string()
1082 } else {
1083 let struct_name = format!("{}Result", name);
1084 if let Some(returning) = &upsert.returning {
1085 generate_mutation_result_struct(
1086 _ctx,
1087 &struct_name,
1088 upsert.into.value.as_str(),
1089 returning,
1090 scope,
1091 );
1092 }
1093 format!("Result<Option<{}>, QueryError>", struct_name)
1094 };
1095
1096 let mut func = Function::new(&fn_name);
1097 if let Some(doc) = &name_meta.doc {
1098 let doc_str = doc.join("\n");
1099 func.doc(&doc_str);
1100 }
1101 func.vis("pub");
1102 func.set_async(true);
1103 func.generic("C");
1104 func.arg("client", "&C");
1105
1106 if let Some(params) = &upsert.params {
1107 for (param_name_meta, param_type) in ¶ms.params {
1108 let param_name = param_name_meta.value.as_str();
1109 let rust_ty = param_type_to_rust(param_type);
1110 func.arg(param_name, format!("&{}", rust_ty));
1111 }
1112 }
1113
1114 func.ret(&return_ty);
1115 func.bound("C", "tokio_postgres::GenericClient");
1116
1117 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1118 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1119
1120 scope.push_fn(func);
1121}
1122
1123fn generate_insert_many_code(
1124 ctx: &CodegenContext,
1125 name_meta: &Meta<String>,
1126 insert: &InsertMany,
1127 scope: &mut Scope,
1128) {
1129 let name = &name_meta.value;
1130 let fn_name = to_snake_case(name);
1131 let generated = crate::sqlgen::generate_insert_many_sql(insert);
1132
1133 let params_struct_name = format!("{}Params", name);
1135 if let Some(params) = &insert.params {
1136 generate_bulk_params_struct(
1137 ctx,
1138 ¶ms_struct_name,
1139 insert.into.value.as_str(),
1140 params,
1141 scope,
1142 );
1143 }
1144
1145 let has_returning = insert.returning.is_some();
1147 let return_ty = if !has_returning {
1148 "Result<u64, QueryError>".to_string()
1149 } else {
1150 let struct_name = format!("{}Result", name);
1151 if let Some(returning) = &insert.returning {
1152 generate_mutation_result_struct(
1153 ctx,
1154 &struct_name,
1155 insert.into.value.as_str(),
1156 returning,
1157 scope,
1158 );
1159 }
1160 format!("Result<Vec<{}>, QueryError>", struct_name)
1161 };
1162
1163 let mut func = Function::new(&fn_name);
1164 if let Some(doc) = &name_meta.doc {
1165 let doc_str = doc.join("\n");
1166 func.doc(&doc_str);
1167 }
1168 func.vis("pub");
1169 func.set_async(true);
1170 func.generic("C");
1171 func.arg("client", "&C");
1172 func.arg("items", format!("&[{}]", params_struct_name));
1173
1174 func.ret(&return_ty);
1175 func.bound("C", "tokio_postgres::GenericClient");
1176
1177 let body = generate_bulk_mutation_body(&generated.sql, insert.params.as_ref(), !has_returning);
1178 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1179
1180 scope.push_fn(func);
1181}
1182
1183fn generate_upsert_many_code(
1184 ctx: &CodegenContext,
1185 name_meta: &Meta<String>,
1186 upsert: &UpsertMany,
1187 scope: &mut Scope,
1188) {
1189 let name = &name_meta.value;
1190 let fn_name = to_snake_case(name);
1191 let generated = crate::sqlgen::generate_upsert_many_sql(upsert);
1192
1193 let params_struct_name = format!("{}Params", name);
1195 if let Some(params) = &upsert.params {
1196 generate_bulk_params_struct(
1197 ctx,
1198 ¶ms_struct_name,
1199 upsert.into.value.as_str(),
1200 params,
1201 scope,
1202 );
1203 }
1204
1205 let has_returning = upsert.returning.is_some();
1207 let return_ty = if !has_returning {
1208 "Result<u64, QueryError>".to_string()
1209 } else {
1210 let struct_name = format!("{}Result", name);
1211 if let Some(returning) = &upsert.returning {
1212 generate_mutation_result_struct(
1213 ctx,
1214 &struct_name,
1215 upsert.into.value.as_str(),
1216 returning,
1217 scope,
1218 );
1219 }
1220 format!("Result<Vec<{}>, QueryError>", struct_name)
1221 };
1222
1223 let mut func = Function::new(&fn_name);
1224 if let Some(doc) = &name_meta.doc {
1225 let doc_str = doc.join("\n");
1226 func.doc(&doc_str);
1227 }
1228 func.vis("pub");
1229 func.set_async(true);
1230 func.generic("C");
1231 func.arg("client", "&C");
1232 func.arg("items", format!("&[{}]", params_struct_name));
1233
1234 func.ret(&return_ty);
1235 func.bound("C", "tokio_postgres::GenericClient");
1236
1237 let body = generate_bulk_mutation_body(&generated.sql, upsert.params.as_ref(), !has_returning);
1238 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1239
1240 scope.push_fn(func);
1241}
1242
1243fn generate_bulk_params_struct(
1245 ctx: &CodegenContext,
1246 struct_name: &str,
1247 table: &str,
1248 params: &Params,
1249 scope: &mut Scope,
1250) {
1251 let mut st = Struct::new(struct_name);
1252 st.vis("pub");
1253 st.derive("Debug");
1254 st.derive("Clone");
1255
1256 for (param_name_meta, param_type) in ¶ms.params {
1257 let param_name = param_name_meta.value.as_str();
1258 let rust_ty = ctx
1259 .column_type(table, param_name)
1260 .unwrap_or_else(|| param_type_to_rust(param_type));
1261 st.field(format!("pub {}", param_name), &rust_ty);
1262 }
1263
1264 scope.push_struct(st);
1265}
1266
1267fn generate_bulk_mutation_body(sql: &str, params: Option<&Params>, execute_only: bool) -> Block {
1269 let mut block = Block::new("");
1270
1271 block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
1273 block.line("");
1274
1275 if let Some(params) = params {
1277 block.line("// Convert items to parallel arrays for UNNEST");
1278 for (param_name_meta, param_type) in ¶ms.params {
1279 let param_name = param_name_meta.value.as_str();
1280 let rust_ty = param_type_to_rust(param_type);
1281 block.line(format!(
1282 "let {}_arr: Vec<{}> = items.iter().map(|i| i.{}.clone()).collect();",
1283 param_name, rust_ty, param_name
1284 ));
1285 }
1286 block.line("");
1287
1288 let param_refs: Vec<String> = params
1290 .params
1291 .keys()
1292 .map(|p| format!("&{}_arr", p.value))
1293 .collect();
1294
1295 if execute_only {
1296 block.line(format!(
1298 "let affected = client.execute(SQL, &[{}]).await?;",
1299 param_refs.join(", ")
1300 ));
1301 block.line("Ok(affected)");
1302 } else {
1303 block.line(format!(
1305 "let rows = client.query(SQL, &[{}]).await?;",
1306 param_refs.join(", ")
1307 ));
1308 block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
1309 }
1310 }
1311
1312 block
1313}
1314
1315fn generate_update_code(
1316 ctx: &CodegenContext,
1317 name_meta: &Meta<String>,
1318 update: &Update,
1319 scope: &mut Scope,
1320) -> Result<(), QError> {
1321 let name = &name_meta.value;
1322 let fn_name = to_snake_case(name);
1323 let sqlgen_ctx = ctx.sqlgen_ctx();
1324 let generated = crate::sqlgen::generate_update_sql(&sqlgen_ctx, update)?;
1325
1326 let has_returning = update.returning.is_some();
1327 let return_ty = if !has_returning {
1328 "Result<u64, QueryError>".to_string()
1329 } else {
1330 let struct_name = format!("{}Result", name);
1331 if let Some(returning) = &update.returning {
1332 generate_mutation_result_struct(
1333 ctx,
1334 &struct_name,
1335 update.table.value.as_str(),
1336 returning,
1337 scope,
1338 );
1339 }
1340 format!("Result<Option<{}>, QueryError>", struct_name)
1341 };
1342
1343 let mut func = Function::new(&fn_name);
1344 if let Some(doc) = &name_meta.doc {
1345 let doc_str = doc.join("\n");
1346 func.doc(&doc_str);
1347 }
1348 func.vis("pub");
1349 func.set_async(true);
1350 func.generic("C");
1351 func.arg("client", "&C");
1352
1353 if let Some(params) = &update.params {
1354 for (param_name_meta, param_type) in ¶ms.params {
1355 let param_name = param_name_meta.value.as_str();
1356 let rust_ty = param_type_to_rust(param_type);
1357 func.arg(param_name, format!("&{}", rust_ty));
1358 }
1359 }
1360
1361 func.ret(&return_ty);
1362 func.bound("C", "tokio_postgres::GenericClient");
1363
1364 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1365 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1366
1367 scope.push_fn(func);
1368 Ok(())
1369}
1370
1371fn generate_delete_code(
1372 ctx: &CodegenContext,
1373 name_meta: &Meta<String>,
1374 delete: &Delete,
1375 scope: &mut Scope,
1376) -> Result<(), QError> {
1377 let name = &name_meta.value;
1378 let fn_name = to_snake_case(name);
1379 let sqlgen_ctx = ctx.sqlgen_ctx();
1380 let generated = crate::sqlgen::generate_delete_sql(&sqlgen_ctx, delete)?;
1381
1382 let has_returning = delete.returning.is_some();
1383 let return_ty = if !has_returning {
1384 "Result<u64, QueryError>".to_string()
1385 } else {
1386 let struct_name = format!("{}Result", name);
1387 if let Some(returning) = &delete.returning {
1388 generate_mutation_result_struct(
1389 ctx,
1390 &struct_name,
1391 delete.from.value.as_str(),
1392 returning,
1393 scope,
1394 );
1395 }
1396 format!("Result<Option<{}>, QueryError>", struct_name)
1397 };
1398
1399 let mut func = Function::new(&fn_name);
1400 if let Some(doc) = &name_meta.doc {
1401 let doc_str = doc.join("\n");
1402 func.doc(&doc_str);
1403 }
1404 func.vis("pub");
1405 func.set_async(true);
1406 func.generic("C");
1407 func.arg("client", "&C");
1408
1409 if let Some(params) = &delete.params {
1410 for (param_name_meta, param_type) in ¶ms.params {
1411 let param_name = param_name_meta.value.as_str();
1412 let rust_ty = param_type_to_rust(param_type);
1413 func.arg(param_name, format!("&{}", rust_ty));
1414 }
1415 }
1416
1417 func.ret(&return_ty);
1418 func.bound("C", "tokio_postgres::GenericClient");
1419
1420 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1421 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1422
1423 scope.push_fn(func);
1424 Ok(())
1425}
1426
1427fn generate_mutation_result_struct(
1428 ctx: &CodegenContext,
1429 struct_name: &str,
1430 table: &str,
1431 returning: &Returning,
1432 scope: &mut Scope,
1433) {
1434 let mut st = Struct::new(struct_name);
1435 st.vis("pub");
1436 st.derive("Debug");
1437 st.derive("Clone");
1438 st.derive("Facet");
1439 st.attr("facet(crate = dibs_runtime::facet)");
1440
1441 for (col_name_meta, _) in &returning.columns {
1442 let col_name = col_name_meta.value.as_str();
1443 let rust_ty = ctx
1444 .column_type(table, col_name)
1445 .unwrap_or_else(|| "String".to_string());
1446 st.field(format!("pub {col_name}"), &rust_ty);
1447 }
1448
1449 scope.push_struct(st);
1450}
1451
1452fn generate_mutation_body(
1453 sql: &str,
1454 param_order: &[dibs_sql::ParamName],
1455 execute_only: bool,
1456) -> Block {
1457 let mut block = Block::new("");
1458
1459 block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
1461 block.line("");
1462
1463 let params: Vec<_> = param_order
1464 .iter()
1465 .filter(|p| !p.as_str().starts_with("__literal_"))
1466 .collect();
1467
1468 if execute_only {
1469 if params.is_empty() {
1471 block.line("let affected = client.execute(SQL, &[]).await?;");
1472 } else {
1473 let params_str = params
1474 .iter()
1475 .map(|p| p.as_str())
1476 .collect::<Vec<_>>()
1477 .join(", ");
1478 block.line(format!(
1479 "let affected = client.execute(SQL, &[{}]).await?;",
1480 params_str
1481 ));
1482 }
1483 block.line("Ok(affected)");
1484 } else {
1485 if params.is_empty() {
1487 block.line("let rows = client.query(SQL, &[]).await?;");
1488 } else {
1489 let params_str = params
1490 .iter()
1491 .map(|p| p.as_str())
1492 .collect::<Vec<_>>()
1493 .join(", ");
1494 block.line(format!(
1495 "let rows = client.query(SQL, &[{}]).await?;",
1496 params_str
1497 ));
1498 }
1499 let mut match_block = Block::new("match rows.into_iter().next()");
1500 match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
1501 match_block.line("None => Ok(None),");
1502 block.push_block(match_block);
1503 }
1504
1505 block
1506}
1507
1508#[cfg(test)]
1509mod tests;