1use crate::error::QErrorKind;
14use crate::sqlgen::SqlGenContext;
15use crate::{QError, QSource};
16use codegen::{Block, Function, Scope, Struct};
17use dibs_db_schema::{Schema, Table};
18use dibs_query_schema::{
19 Decl, Delete, FieldDef, Insert, InsertMany, Meta, Params, QueryFile, Returning, Returns,
20 Select, SelectFields, Span, Update, Upsert, UpsertMany,
21};
22use std::sync::Arc;
23
24#[derive(Debug, Clone)]
26pub struct GeneratedCode {
27 pub code: String,
29}
30
31fn wrap_with_trace_err(body: &str, fn_name: &str) -> String {
44 format!(
45 "let __dibs_result = async {{\n{body}\n}}.await;\n\
46 <_ as TraceErr>::trace_err(__dibs_result, \"{fn_name}\")"
47 )
48}
49
50fn schema_column_type(schema: &Schema, table: &str, column: &str) -> Option<String> {
52 let table_info = schema.get_table(table)?;
53 let col = table_info.columns.iter().find(|c| c.name == column)?;
54 let rust_type = col
55 .rust_type
56 .clone()
57 .unwrap_or_else(|| col.pg_type.to_rust_type().to_string());
58 if col.nullable {
59 Some(format!("Option<{}>", rust_type))
60 } else {
61 Some(rust_type)
62 }
63}
64
65struct CodegenContext<'a> {
67 schema: &'a Schema,
68 source: Arc<QSource>,
69 #[allow(dead_code)]
70 scope: Scope,
71}
72
73impl CodegenContext<'_> {
74 fn column_type(&self, table: &str, column: &str) -> Option<String> {
76 schema_column_type(self.schema, table, column)
77 }
78
79 fn table_not_found(&self, table: &str, span: Span) -> QError {
83 let mut available: Vec<String> = self.schema.tables.keys().cloned().collect();
84 available.sort();
85 QError {
86 source: self.source.clone(),
87 span,
88 kind: QErrorKind::TableNotFound {
89 table: table.to_string(),
90 available,
91 },
92 }
93 }
94
95 fn require_table(&self, table: &str, span: Span) -> Result<&Table, QError> {
100 self.schema
101 .get_table(table)
102 .ok_or_else(|| self.table_not_found(table, span))
103 }
104
105 fn column_type_at(&self, table: &str, column: &str, span: Span) -> Result<String, QError> {
114 let table_info = self.require_table(table, span)?;
115 if let Some(ty) = schema_column_type(self.schema, table, column) {
116 return Ok(ty);
117 }
118 Err(QError {
119 source: self.source.clone(),
120 span,
121 kind: QErrorKind::ColumnNotFound {
122 table: table.to_string(),
123 column: column.to_string(),
124 available: table_info.columns.iter().map(|c| c.name.clone()).collect(),
125 },
126 })
127 }
128
129 fn sqlgen_ctx(&self) -> SqlGenContext<'_> {
131 SqlGenContext::new(self.schema, self.source.clone())
132 }
133}
134
135pub fn generate_rust_code(
137 file: &QueryFile,
138 schema: &Schema,
139 source: Arc<QSource>,
140) -> Result<GeneratedCode, QError> {
141 let mut scope = Scope::new();
142
143 scope.raw("// Generated by dibs-qgen. Do not edit.");
145 scope.raw("");
146
147 scope.import("dibs_runtime::prelude", "*");
149 scope.import("dibs_runtime", "tokio_postgres");
150
151 let ctx = CodegenContext {
152 schema,
153 source,
154 scope: Scope::new(),
155 };
156
157 for (name_meta, decl) in &file.0 {
159 match decl {
160 Decl::Select(select) => {
161 generate_select_code(&ctx, name_meta, select, &mut scope)?;
162 }
163 Decl::Insert(insert) => {
164 generate_insert_code(&ctx, name_meta, insert, &mut scope)?;
165 }
166 Decl::InsertMany(insert_many) => {
167 generate_insert_many_code(&ctx, name_meta, insert_many, &mut scope)?;
168 }
169 Decl::Upsert(upsert) => {
170 generate_upsert_code(&ctx, name_meta, upsert, &mut scope)?;
171 }
172 Decl::UpsertMany(upsert_many) => {
173 generate_upsert_many_code(&ctx, name_meta, upsert_many, &mut scope)?;
174 }
175 Decl::Update(update) => {
176 generate_update_code(&ctx, name_meta, update, &mut scope)?;
177 }
178 Decl::Delete(delete) => {
179 generate_delete_code(&ctx, name_meta, delete, &mut scope)?;
180 }
181 }
182 }
183
184 Ok(GeneratedCode {
185 code: scope.to_string(),
186 })
187}
188
189fn generate_select_code(
190 ctx: &CodegenContext,
191 name_meta: &Meta<String>,
192 select: &Select,
193 scope: &mut Scope,
194) -> Result<(), QError> {
195 let name = &name_meta.value;
196 let struct_name = format!("{}Result", name);
197
198 if let Some(from) = &select.from {
200 if select.fields.is_some() {
201 generate_result_struct(ctx, select, name_meta, &struct_name, from, scope)?;
202
203 if select.has_relations() {
205 let flat_struct_name = format!("{}Row", name);
206 generate_flat_row_struct(ctx, select, &flat_struct_name, from, scope)?;
207 }
208 }
209 } else if let Some(returns) = &select.returns {
210 generate_raw_sql_result_struct(&struct_name, returns, scope);
212 }
213
214 generate_select_function(ctx, name_meta, select, &struct_name, scope)?;
216 Ok(())
217}
218
219fn generate_flat_row_struct(
242 ctx: &CodegenContext,
243 select: &Select,
244 struct_name: &str,
245 table: &Meta<dibs_sql::TableName>,
246 scope: &mut Scope,
247) -> Result<(), QError> {
248 let mut st = Struct::new(struct_name);
249 st.derive("Debug");
251 st.derive("Clone");
252 st.derive("Facet");
253 st.attr("facet(crate = dibs_runtime::facet)");
254
255 let table_name = table.value.as_str();
256 ctx.require_table(table_name, table.span)?;
257
258 if let Some(select_fields) = &select.fields {
259 add_flat_fields_for_select(ctx, &mut st, table_name, "", select_fields)?;
261 }
262
263 scope.push_struct(st);
264 Ok(())
265}
266
267fn add_flat_fields_for_select(
269 ctx: &CodegenContext,
270 st: &mut Struct,
271 table_name: &str,
272 prefix: &str,
273 select_fields: &SelectFields,
274) -> Result<(), QError> {
275 for (field_name_meta, field_def) in &select_fields.fields {
276 let field_name = field_name_meta.value.as_str();
277
278 match field_def {
279 None => {
280 let rust_ty = ctx.column_type_at(table_name, field_name, field_name_meta.span)?;
282
283 let flat_field_name = if prefix.is_empty() {
284 field_name.to_string()
285 } else {
286 format!("{}_{}", prefix, field_name)
287 };
288
289 let final_ty = if prefix.is_empty() {
291 rust_ty
292 } else if rust_ty.starts_with("Option<") {
293 rust_ty
295 } else {
296 format!("Option<{}>", rust_ty)
297 };
298
299 st.field(&flat_field_name, &final_ty);
301 }
302 Some(FieldDef::Rel(rel)) => {
303 let rel_table = rel.table_name().unwrap_or(field_name);
305 let rel_span = rel
306 .from
307 .as_ref()
308 .map(|m| m.span)
309 .unwrap_or(field_name_meta.span);
310 ctx.require_table(rel_table, rel_span)?;
311 let new_prefix = if prefix.is_empty() {
312 field_name.to_string()
313 } else {
314 format!("{}_{}", prefix, field_name)
315 };
316
317 if let Some(rel_fields) = &rel.fields {
318 add_flat_fields_for_select(ctx, st, rel_table, &new_prefix, rel_fields)?;
319 }
320 }
321 Some(FieldDef::Count(_)) => {
322 let flat_field_name = if prefix.is_empty() {
324 field_name.to_string()
325 } else {
326 format!("{}_{}", prefix, field_name)
327 };
328 st.field(&flat_field_name, "i64");
329 }
330 }
331 }
332 Ok(())
333}
334
335fn generate_raw_sql_result_struct(struct_name: &str, returns: &Returns, scope: &mut Scope) {
336 let mut st = Struct::new(struct_name);
337 st.vis("pub");
338 st.derive("Debug");
339 st.derive("Clone");
340 st.derive("Facet");
341 st.attr("facet(crate = dibs_runtime::facet)");
342
343 for (field_name_meta, param_type) in &returns.fields {
344 let field_name = field_name_meta.value.as_str();
345 let rust_ty = param_type_to_rust(param_type);
346 st.field(format!("pub {}", field_name), &rust_ty);
347 }
348
349 scope.push_struct(st);
350}
351
352fn generate_result_struct(
353 ctx: &CodegenContext,
354 select: &Select,
355 name_meta: &Meta<String>,
356 struct_name: &str,
357 table: &Meta<dibs_sql::TableName>,
358 scope: &mut Scope,
359) -> Result<(), QError> {
360 let mut st = Struct::new(struct_name);
361 st.vis("pub");
362 st.derive("Debug");
363 st.derive("Clone");
364 st.derive("Facet");
365 st.attr("facet(crate = dibs_runtime::facet)");
366
367 let parent_prefix = &name_meta.value;
369 let table_name = table.value.as_str();
370 ctx.require_table(table_name, table.span)?;
373
374 if let Some(select_fields) = &select.fields {
375 for (field_name_meta, field_def) in &select_fields.fields {
376 let field_name = field_name_meta.value.as_str();
377 match field_def {
378 None => {
379 let rust_ty =
381 ctx.column_type_at(table_name, field_name, field_name_meta.span)?;
382 st.field(format!("pub {}", field_name), &rust_ty);
383 }
384 Some(FieldDef::Rel(rel)) => {
385 let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
386 let ty = if rel.first.is_some() {
387 format!("Option<{}>", nested_name)
388 } else {
389 format!("Vec<{}>", nested_name)
390 };
391 st.field(format!("pub {}", field_name), &ty);
392 }
393 Some(FieldDef::Count(_)) => {
394 st.field(format!("pub {}", field_name), "i64");
395 }
396 }
397 }
398 }
399
400 scope.push_struct(st);
401
402 if let Some(select_fields) = &select.fields {
404 generate_nested_structs(ctx, parent_prefix, select_fields, scope)?;
405 }
406 Ok(())
407}
408
409fn generate_nested_structs(
414 ctx: &CodegenContext,
415 parent_prefix: &str,
416 select_fields: &SelectFields,
417 scope: &mut Scope,
418) -> Result<(), QError> {
419 for (field_name_meta, field_def) in &select_fields.fields {
420 if let Some(FieldDef::Rel(rel)) = field_def {
421 let field_name = field_name_meta.value.as_str();
422 let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
423 let rel_table = rel.table_name().unwrap_or(field_name);
424 let rel_span = rel
425 .from
426 .as_ref()
427 .map(|m| m.span)
428 .unwrap_or(field_name_meta.span);
429 ctx.require_table(rel_table, rel_span)?;
430
431 let mut nested_st = Struct::new(&nested_name);
432 nested_st.vis("pub");
433 nested_st.derive("Debug");
434 nested_st.derive("Clone");
435 nested_st.derive("Facet");
436 nested_st.attr("facet(crate = dibs_runtime::facet)");
437
438 if let Some(rel_fields) = &rel.fields {
439 for (rel_field_name_meta, rel_field_def) in &rel_fields.fields {
440 let rel_field_name = rel_field_name_meta.value.as_str();
441 match rel_field_def {
442 None => {
443 let rust_ty = ctx.column_type_at(
445 rel_table,
446 rel_field_name,
447 rel_field_name_meta.span,
448 )?;
449 nested_st.field(format!("pub {}", rel_field_name), &rust_ty);
450 }
451 Some(FieldDef::Rel(nested_rel)) => {
452 let nested_rel_name =
454 format!("{}{}", nested_name, to_pascal_case(rel_field_name));
455 let ty = if nested_rel.first.is_some() {
456 format!("Option<{}>", nested_rel_name)
457 } else {
458 format!("Vec<{}>", nested_rel_name)
459 };
460 nested_st.field(format!("pub {}", rel_field_name), &ty);
461 }
462 Some(FieldDef::Count(_)) => {
463 nested_st.field(format!("pub {}", rel_field_name), "i64");
464 }
465 }
466 }
467 }
468
469 scope.push_struct(nested_st);
470
471 if let Some(rel_fields) = &rel.fields {
473 generate_nested_structs(ctx, &nested_name, rel_fields, scope)?;
474 }
475 }
476 }
477 Ok(())
478}
479
480fn generate_select_function(
481 ctx: &CodegenContext,
482 name_meta: &Meta<String>,
483 query: &Select,
484 struct_name: &str,
485 scope: &mut Scope,
486) -> Result<(), QError> {
487 let name = &name_meta.value;
488 let fn_name = to_snake_case(name);
489
490 let return_ty = if query.first.is_some() {
491 format!("Result<Option<{}>, QueryError>", struct_name)
492 } else {
493 format!("Result<Vec<{}>, QueryError>", struct_name)
494 };
495
496 let mut func = Function::new(&fn_name);
497 if let Some(doc) = &name_meta.doc {
498 let doc_str = doc.join("\n");
499 func.doc(&doc_str);
500 }
501 func.vis("pub");
502 func.set_async(true);
503 func.attr("allow(clippy::too_many_arguments)");
506 func.generic("C");
507 func.arg("client", "&C");
508 func.attr("allow(clippy::clone_on_copy)");
510
511 if let Some(params) = &query.params {
512 for (param_name_meta, param_type) in ¶ms.params {
513 let param_name = ¶m_name_meta.value;
514 let rust_ty = param_type_to_rust(param_type);
515 func.arg(param_name, format!("&{}", rust_ty));
516 }
517 }
518
519 func.ret(&return_ty);
520 func.bound("C", "tokio_postgres::GenericClient");
521
522 let body = if let Some(raw_sql_meta) = &query.sql {
524 block_to_string(&generate_raw_query_body(query, &raw_sql_meta.value))
525 } else {
526 generate_query_body(ctx, query, struct_name)?
527 };
528 func.line(wrap_with_trace_err(&body, &fn_name));
529
530 scope.push_fn(func);
531 Ok(())
532}
533
534fn generate_query_body(
539 ctx: &CodegenContext,
540 query: &Select,
541 struct_name: &str,
542) -> Result<String, QError> {
543 let sqlgen_ctx = ctx.sqlgen_ctx();
544 let generated = match crate::sqlgen::generate_select_sql(&sqlgen_ctx, query) {
545 Ok(g) => g,
546 Err(e) => {
547 panic!("SELECT SQL generation failed: {}", e);
548 }
549 };
550
551 let mut block = Block::new("");
552
553 block.line(format!("const SQL: &str = r#\"{}\"#;", generated.sql));
555 block.line("");
556
557 let params: Vec<_> = generated
559 .param_order
560 .iter()
561 .filter(|p| !p.as_str().starts_with("__literal_"))
562 .collect();
563
564 if params.is_empty() {
565 block.line("let rows = client.query(SQL, &[]).await?;");
566 } else {
567 let params_str = params
568 .iter()
569 .map(|p| p.as_str())
570 .collect::<Vec<_>>()
571 .join(", ");
572 block.line(format!(
573 "let rows = client.query(SQL, &[{}]).await?;",
574 params_str
575 ));
576 }
577
578 if !query.has_relations() {
580 if query.first.is_some() {
581 let mut match_block = Block::new("match rows.into_iter().next()");
582 match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
583 match_block.line("None => Ok(None),");
584 block.push_block(match_block);
585 } else {
586 block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
587 }
588 return Ok(block_to_string(&block));
589 }
590
591 let query_name = struct_name.strip_suffix("Result").unwrap_or(struct_name);
593 let flat_struct_name = format!("{}Row", query_name);
594
595 block.line("");
596 block.line("// Deserialize all rows into flat structs using facet reflection");
597 block.line(format!(
598 "let flat_rows: Vec<{flat_struct_name}> = rows.iter().map(from_row).collect::<Result<Vec<_>, _>>()?;"
599 ));
600 block.line("");
601
602 let Some(select_fields) = &query.fields else {
604 block.line("Ok(vec![])".to_string());
606 return Ok(block_to_string(&block));
607 };
608
609 let root_table = query
610 .from
611 .as_ref()
612 .map(|m| m.value.as_str())
613 .unwrap_or("unknown");
614 let is_first = query.is_first();
615
616 block.line(generate_flat_to_nested_transform(
617 ctx,
618 select_fields,
619 struct_name,
620 root_table,
621 is_first,
622 )?);
623
624 Ok(block_to_string(&block))
625}
626
627fn generate_flat_to_nested_transform(
629 ctx: &CodegenContext,
630 select_fields: &SelectFields,
631 struct_name: &str,
632 root_table: &str,
633 is_first: bool,
634) -> Result<String, QError> {
635 let mut block = Block::new("");
636
637 let id_column = select_fields
639 .id_column()
640 .map(|c| c.to_string())
641 .unwrap_or_else(|| "id".to_string());
642
643 let id_type = ctx
644 .column_type(root_table, &id_column)
645 .unwrap_or_else(|| "i64".to_string());
646
647 if select_fields.has_vec_relations() {
649 block.line("// Group flat rows by parent ID and assemble nested structs");
651 block.line(format!(
652 "let mut grouped: std::collections::HashMap<{id_type}, {struct_name}> = std::collections::HashMap::new();"
653 ));
654
655 generate_seen_id_declarations(&mut block, ctx, select_fields, &id_type, "")?;
657
658 block.line("");
659
660 let mut for_block = Block::new("for flat_row in flat_rows");
661 for_block.line(format!("let parent_id = flat_row.{id_column}.clone();"));
662 for_block.line("");
663
664 let mut entry_block = Block::new(format!(
666 "let entry = grouped.entry(parent_id.clone()).or_insert_with(|| {struct_name}"
667 ));
668
669 for (field_name_meta, field_def) in &select_fields.fields {
671 let field_name = field_name_meta.value.as_str();
672 match field_def {
673 None => {
674 entry_block.line(format!("{field_name}: flat_row.{field_name}.clone(),"));
675 }
676 Some(FieldDef::Rel(rel)) => {
677 if rel.is_first() {
678 entry_block.line(format!("{field_name}: None,"));
679 } else {
680 entry_block.line(format!("{field_name}: Vec::new(),"));
681 }
682 }
683 Some(FieldDef::Count(_)) => {
684 entry_block.line(format!("{field_name}: flat_row.{field_name},"));
685 }
686 }
687 }
688 entry_block.after(");");
689 for_block.push_block(entry_block);
690 for_block.line("");
691
692 let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
694 generate_relation_assembly(
695 &mut for_block,
696 ctx,
697 select_fields,
698 parent_prefix,
699 "",
700 &id_type,
701 )?;
702
703 block.push_block(for_block);
704 block.line("");
705
706 if is_first {
707 block.line("Ok(grouped.into_values().next())");
708 } else {
709 block.line("Ok(grouped.into_values().collect())");
710 }
711 } else {
712 block.line("// Transform flat rows into nested structs (Option relations only)");
714
715 let mut map_block = Block::new(
716 "let results: Result<Vec<_>, QueryError> = flat_rows.into_iter().map(|flat_row| {",
717 );
718
719 let mut result_block = Block::new(format!("Ok({struct_name}"));
720 let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
721
722 for (field_name_meta, field_def) in &select_fields.fields {
723 let field_name = field_name_meta.value.as_str();
724 match field_def {
725 None => {
726 result_block.line(format!("{field_name}: flat_row.{field_name},"));
727 }
728 Some(FieldDef::Rel(rel)) => {
729 if rel.is_first() {
730 if let Some(rel_fields) = &rel.fields {
732 let rel_table = rel.table_name().unwrap_or(field_name);
733 let first_col = rel_fields
734 .first_column()
735 .map(|c| c.as_str())
736 .unwrap_or("id");
737 let first_alias = format!("{field_name}_{first_col}");
738 let nested_struct =
739 format!("{}{}", parent_prefix, to_pascal_case(field_name));
740
741 let mut map_inner = Block::new(format!(
742 "{field_name}: flat_row.{first_alias}.as_ref().map(|_| {nested_struct}"
743 ));
744
745 for (inner_field_meta, inner_def) in &rel_fields.fields {
746 let inner_name = inner_field_meta.value.as_str();
747 if inner_def.is_none() {
748 let alias = format!("{field_name}_{inner_name}");
749 let rust_ty = ctx.column_type_at(
750 rel_table,
751 inner_name,
752 inner_field_meta.span,
753 )?;
754
755 if rust_ty.starts_with("Option<") {
757 map_inner.line(format!(
758 "{inner_name}: flat_row.{alias}.clone(),"
759 ));
760 } else {
761 map_inner.line(format!(
762 "{inner_name}: flat_row.{alias}.clone().expect(\"non-null column from LEFT JOIN\"),"
763 ));
764 }
765 }
766 }
767
768 map_inner.after("),");
769 result_block.push_block(map_inner);
770 }
771 } else {
772 result_block.line(format!("{field_name}: Vec::new(),"));
774 }
775 }
776 Some(FieldDef::Count(_)) => {
777 result_block.line(format!("{field_name}: flat_row.{field_name},"));
778 }
779 }
780 }
781
782 result_block.after(")");
783 map_block.push_block(result_block);
784 map_block.after("}).collect();");
785 block.push_block(map_block);
786 block.line("");
787
788 if is_first {
789 block.line("results.map(|mut v| v.pop())");
790 } else {
791 block.line("results");
792 }
793 }
794
795 Ok(block_to_string(&block))
796}
797
798fn generate_seen_id_declarations(
800 block: &mut Block,
801 ctx: &CodegenContext,
802 select_fields: &SelectFields,
803 parent_id_type: &str,
804 prefix: &str,
805) -> Result<(), QError> {
806 for (field_name_meta, field_def) in &select_fields.fields {
807 if let Some(FieldDef::Rel(rel)) = field_def {
808 let field_name = field_name_meta.value.as_str();
809 if !rel.is_first() {
810 if let Some(rel_fields) = &rel.fields {
812 let rel_table = rel.table_name().unwrap_or(field_name);
813 let id_col = rel_fields.id_column().map(|c| c.as_str()).unwrap_or("id");
814 let id_type = ctx
815 .column_type(rel_table, id_col)
816 .unwrap_or_else(|| "i64".to_string());
817
818 let set_name = if prefix.is_empty() {
819 format!("seen_{field_name}")
820 } else {
821 format!("seen_{prefix}_{field_name}")
822 };
823
824 block.line(format!(
825 "let mut {set_name}: std::collections::HashSet<({parent_id_type}, {id_type})> = std::collections::HashSet::new();"
826 ));
827
828 let new_prefix = if prefix.is_empty() {
830 field_name.to_string()
831 } else {
832 format!("{prefix}_{field_name}")
833 };
834
835 generate_seen_id_declarations(block, ctx, rel_fields, &id_type, &new_prefix)?;
837 }
838 }
839 }
840 }
841 Ok(())
842}
843
844fn generate_relation_assembly(
846 for_block: &mut Block,
847 ctx: &CodegenContext,
848 select_fields: &SelectFields,
849 parent_prefix: &str,
850 flat_prefix: &str,
851 _parent_id_type: &str,
852) -> Result<(), QError> {
853 for (field_name_meta, field_def) in &select_fields.fields {
854 if let Some(FieldDef::Rel(rel)) = field_def {
855 let field_name = field_name_meta.value.as_str();
856 let rel_table = rel.table_name().unwrap_or(field_name);
857 let nested_struct = format!("{}{}", parent_prefix, to_pascal_case(field_name));
858
859 let flat_field_prefix = if flat_prefix.is_empty() {
860 field_name.to_string()
861 } else {
862 format!("{flat_prefix}_{field_name}")
863 };
864
865 if let Some(rel_fields) = &rel.fields {
866 let first_col = rel_fields
867 .first_column()
868 .map(|c| c.as_str())
869 .unwrap_or("id");
870 let id_col = rel_fields
871 .id_column()
872 .map(|c| c.as_str())
873 .unwrap_or(first_col);
874 let id_alias = format!("{flat_field_prefix}_{id_col}");
875
876 if rel.is_first() {
877 for_block.line(format!("// Populate {field_name} (Option relation)"));
879
880 let mut if_block = Block::new(format!(
881 "if entry.{field_name}.is_none() && flat_row.{id_alias}.is_some()"
882 ));
883
884 let mut some_block =
885 Block::new(format!("entry.{field_name} = Some({nested_struct}"));
886 generate_relation_fields(
887 &mut some_block,
888 ctx,
889 rel_fields,
890 rel_table,
891 &flat_field_prefix,
892 )?;
893 some_block.after(");");
894 if_block.push_block(some_block);
895
896 for_block.push_block(if_block);
897 for_block.line("");
898 } else {
899 let set_name = if flat_prefix.is_empty() {
901 format!("seen_{field_name}")
902 } else {
903 format!("seen_{flat_prefix}_{field_name}")
904 };
905
906 for_block.line(format!("// Append to {field_name} (Vec relation)"));
907
908 let mut if_block =
909 Block::new(format!("if let Some(ref rel_id) = flat_row.{id_alias}"));
910 if_block.line("let key = (parent_id.clone(), rel_id.clone());".to_string());
911
912 let mut if_insert = Block::new(format!("if {set_name}.insert(key)"));
913 let mut push_block =
914 Block::new(format!("entry.{field_name}.push({nested_struct}"));
915 generate_relation_fields(
916 &mut push_block,
917 ctx,
918 rel_fields,
919 rel_table,
920 &flat_field_prefix,
921 )?;
922 push_block.after(");");
923 if_insert.push_block(push_block);
924
925 if_block.push_block(if_insert);
926 for_block.push_block(if_block);
927 for_block.line("");
928 }
929 }
930 }
931 }
932 Ok(())
933}
934
935fn generate_relation_fields(
937 block: &mut Block,
938 ctx: &CodegenContext,
939 select_fields: &SelectFields,
940 table_name: &str,
941 flat_prefix: &str,
942) -> Result<(), QError> {
943 for (field_name_meta, field_def) in &select_fields.fields {
944 let field_name = field_name_meta.value.as_str();
945 let alias = format!("{flat_prefix}_{field_name}");
946
947 match field_def {
948 None => {
949 let rust_ty = ctx.column_type_at(table_name, field_name, field_name_meta.span)?;
950
951 if rust_ty.starts_with("Option<") {
954 block.line(format!("{field_name}: flat_row.{alias}.clone(),"));
955 } else {
956 block.line(format!(
957 "{field_name}: flat_row.{alias}.clone().expect(\"non-null from LEFT JOIN\"),"
958 ));
959 }
960 }
961 Some(FieldDef::Rel(rel)) => {
962 if rel.is_first() {
963 block.line(format!(
964 "{field_name}: None, // TODO: nested Option relation"
965 ));
966 } else {
967 block.line(format!(
968 "{field_name}: Vec::new(), // TODO: nested Vec relation"
969 ));
970 }
971 }
972 Some(FieldDef::Count(_)) => {
973 block.line(format!("{field_name}: flat_row.{alias},"));
974 }
975 }
976 }
977 Ok(())
978}
979
980fn generate_raw_query_body(query: &Select, raw_sql: &str) -> Block {
981 let cleaned: String = raw_sql
982 .lines()
983 .map(|l| l.trim())
984 .collect::<Vec<_>>()
985 .join("\n");
986
987 let mut block = Block::new("");
988
989 block.line(format!("const SQL: &str = r#\"{}\"#;", cleaned.trim()));
991 block.line("");
992
993 if let Some(params) = &query.params {
995 let param_names: Vec<&str> = params.iter().map(|(meta, _)| meta.value.as_str()).collect();
996 if !param_names.is_empty() {
997 let params_str = param_names.join(", ");
998 block.line(format!(
999 "let rows = client.query(SQL, &[{}]).await?;",
1000 params_str
1001 ));
1002 } else {
1003 block.line("let rows = client.query(SQL, &[]).await?;");
1004 }
1005 } else {
1006 block.line("let rows = client.query(SQL, &[]).await?;");
1007 }
1008
1009 if query.first.is_some() {
1011 let mut match_block = Block::new("match rows.into_iter().next()");
1012 match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
1013 match_block.line("None => Ok(None),");
1014 block.push_block(match_block);
1015 } else {
1016 block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
1017 }
1018
1019 block
1020}
1021
1022fn param_type_to_rust(ty: &dibs_query_schema::ParamType) -> String {
1023 use dibs_query_schema::ParamType;
1024 match ty {
1025 ParamType::String => "String".to_string(),
1026 ParamType::Int => "i64".to_string(),
1027 ParamType::Float => "f64".to_string(),
1028 ParamType::Bool => "bool".to_string(),
1029 ParamType::Uuid => "Uuid".to_string(),
1030 ParamType::Decimal => "Decimal".to_string(),
1031 ParamType::Timestamp => "Timestamp".to_string(),
1032 ParamType::Bytes => "Vec<u8>".to_string(),
1033 ParamType::Jsonb => "String".to_string(),
1039 ParamType::Optional(inner_vec) => {
1040 if let Some(inner) = inner_vec.first() {
1041 format!("Option<{}>", param_type_to_rust(inner))
1042 } else {
1043 "Option<String>".to_string()
1044 }
1045 }
1046 }
1047}
1048
1049fn block_to_string(block: &Block) -> String {
1051 let mut output = String::new();
1052 let mut formatter = codegen::Formatter::new(&mut output);
1053 block.fmt(&mut formatter).expect("formatting failed");
1054 output
1055}
1056
1057fn to_pascal_case(s: &str) -> String {
1058 let mut result = String::new();
1059 let mut capitalize_next = true;
1060
1061 for c in s.chars() {
1062 if c == '_' {
1063 capitalize_next = true;
1064 } else if capitalize_next {
1065 result.push(c.to_ascii_uppercase());
1066 capitalize_next = false;
1067 } else {
1068 result.push(c);
1069 }
1070 }
1071
1072 result
1073}
1074
1075fn to_snake_case(s: &str) -> String {
1076 let mut result = String::new();
1077
1078 for (i, c) in s.chars().enumerate() {
1079 if c.is_uppercase() {
1080 if i > 0 {
1081 result.push('_');
1082 }
1083 result.push(c.to_ascii_lowercase());
1084 } else {
1085 result.push(c);
1086 }
1087 }
1088
1089 result
1090}
1091
1092fn generate_insert_code(
1097 _ctx: &CodegenContext,
1098 name_meta: &Meta<String>,
1099 insert: &Insert,
1100 scope: &mut Scope,
1101) -> Result<(), QError> {
1102 let name = &name_meta.value;
1103 let fn_name = to_snake_case(name);
1104 let generated = crate::sqlgen::generate_insert_sql(insert);
1105
1106 let has_returning = insert.returning.is_some();
1108 let return_ty = if !has_returning {
1109 "Result<u64, QueryError>".to_string()
1110 } else {
1111 let struct_name = format!("{}Result", name);
1112 if let Some(returning) = &insert.returning {
1113 generate_mutation_result_struct(
1114 _ctx,
1115 &struct_name,
1116 insert.into.value.as_str(),
1117 insert.into.span,
1118 returning,
1119 scope,
1120 )?;
1121 }
1122 format!("Result<Option<{}>, QueryError>", struct_name)
1123 };
1124
1125 let mut func = Function::new(&fn_name);
1126 if let Some(doc) = &name_meta.doc {
1127 let doc_str = doc.join("\n");
1128 func.doc(&doc_str);
1129 }
1130 func.vis("pub");
1131 func.set_async(true);
1132 func.attr("allow(clippy::too_many_arguments)");
1135 func.generic("C");
1136 func.arg("client", "&C");
1137
1138 if let Some(params) = &insert.params {
1139 for (param_name_meta, param_type) in ¶ms.params {
1140 let param_name = param_name_meta.value.as_str();
1141 let rust_ty = param_type_to_rust(param_type);
1142 func.arg(param_name, format!("&{}", rust_ty));
1143 }
1144 }
1145
1146 func.ret(&return_ty);
1147 func.bound("C", "tokio_postgres::GenericClient");
1148
1149 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1150 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1151
1152 scope.push_fn(func);
1153 Ok(())
1154}
1155
1156fn generate_upsert_code(
1157 _ctx: &CodegenContext,
1158 name_meta: &Meta<String>,
1159 upsert: &Upsert,
1160 scope: &mut Scope,
1161) -> Result<(), QError> {
1162 let name = &name_meta.value;
1163 let fn_name = to_snake_case(name);
1164 let generated = crate::sqlgen::generate_upsert_sql(upsert);
1165
1166 let has_returning = upsert.returning.is_some();
1167 let return_ty = if !has_returning {
1168 "Result<u64, QueryError>".to_string()
1169 } else {
1170 let struct_name = format!("{}Result", name);
1171 if let Some(returning) = &upsert.returning {
1172 generate_mutation_result_struct(
1173 _ctx,
1174 &struct_name,
1175 upsert.into.value.as_str(),
1176 upsert.into.span,
1177 returning,
1178 scope,
1179 )?;
1180 }
1181 format!("Result<Option<{}>, QueryError>", struct_name)
1182 };
1183
1184 let mut func = Function::new(&fn_name);
1185 if let Some(doc) = &name_meta.doc {
1186 let doc_str = doc.join("\n");
1187 func.doc(&doc_str);
1188 }
1189 func.vis("pub");
1190 func.set_async(true);
1191 func.attr("allow(clippy::too_many_arguments)");
1194 func.generic("C");
1195 func.arg("client", "&C");
1196
1197 if let Some(params) = &upsert.params {
1198 for (param_name_meta, param_type) in ¶ms.params {
1199 let param_name = param_name_meta.value.as_str();
1200 let rust_ty = param_type_to_rust(param_type);
1201 func.arg(param_name, format!("&{}", rust_ty));
1202 }
1203 }
1204
1205 func.ret(&return_ty);
1206 func.bound("C", "tokio_postgres::GenericClient");
1207
1208 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1209 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1210
1211 scope.push_fn(func);
1212 Ok(())
1213}
1214
1215fn generate_insert_many_code(
1216 ctx: &CodegenContext,
1217 name_meta: &Meta<String>,
1218 insert: &InsertMany,
1219 scope: &mut Scope,
1220) -> Result<(), QError> {
1221 let name = &name_meta.value;
1222 let fn_name = to_snake_case(name);
1223 let generated = crate::sqlgen::generate_insert_many_sql(insert);
1224
1225 let params_struct_name = format!("{}Params", name);
1227 if let Some(params) = &insert.params {
1228 generate_bulk_params_struct(
1229 ctx,
1230 ¶ms_struct_name,
1231 insert.into.value.as_str(),
1232 params,
1233 scope,
1234 );
1235 }
1236
1237 let has_returning = insert.returning.is_some();
1239 let return_ty = if !has_returning {
1240 "Result<u64, QueryError>".to_string()
1241 } else {
1242 let struct_name = format!("{}Result", name);
1243 if let Some(returning) = &insert.returning {
1244 generate_mutation_result_struct(
1245 ctx,
1246 &struct_name,
1247 insert.into.value.as_str(),
1248 insert.into.span,
1249 returning,
1250 scope,
1251 )?;
1252 }
1253 format!("Result<Vec<{}>, QueryError>", struct_name)
1254 };
1255
1256 let mut func = Function::new(&fn_name);
1257 if let Some(doc) = &name_meta.doc {
1258 let doc_str = doc.join("\n");
1259 func.doc(&doc_str);
1260 }
1261 func.vis("pub");
1262 func.set_async(true);
1263 func.attr("allow(clippy::too_many_arguments)");
1266 func.generic("C");
1267 func.arg("client", "&C");
1268 func.arg("items", format!("&[{}]", params_struct_name));
1269
1270 func.ret(&return_ty);
1271 func.bound("C", "tokio_postgres::GenericClient");
1272
1273 let body = generate_bulk_mutation_body(&generated.sql, insert.params.as_ref(), !has_returning);
1274 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1275
1276 scope.push_fn(func);
1277 Ok(())
1278}
1279
1280fn generate_upsert_many_code(
1281 ctx: &CodegenContext,
1282 name_meta: &Meta<String>,
1283 upsert: &UpsertMany,
1284 scope: &mut Scope,
1285) -> Result<(), QError> {
1286 let name = &name_meta.value;
1287 let fn_name = to_snake_case(name);
1288 let generated = crate::sqlgen::generate_upsert_many_sql(upsert);
1289
1290 let params_struct_name = format!("{}Params", name);
1292 if let Some(params) = &upsert.params {
1293 generate_bulk_params_struct(
1294 ctx,
1295 ¶ms_struct_name,
1296 upsert.into.value.as_str(),
1297 params,
1298 scope,
1299 );
1300 }
1301
1302 let has_returning = upsert.returning.is_some();
1304 let return_ty = if !has_returning {
1305 "Result<u64, QueryError>".to_string()
1306 } else {
1307 let struct_name = format!("{}Result", name);
1308 if let Some(returning) = &upsert.returning {
1309 generate_mutation_result_struct(
1310 ctx,
1311 &struct_name,
1312 upsert.into.value.as_str(),
1313 upsert.into.span,
1314 returning,
1315 scope,
1316 )?;
1317 }
1318 format!("Result<Vec<{}>, QueryError>", struct_name)
1319 };
1320
1321 let mut func = Function::new(&fn_name);
1322 if let Some(doc) = &name_meta.doc {
1323 let doc_str = doc.join("\n");
1324 func.doc(&doc_str);
1325 }
1326 func.vis("pub");
1327 func.set_async(true);
1328 func.attr("allow(clippy::too_many_arguments)");
1331 func.generic("C");
1332 func.arg("client", "&C");
1333 func.arg("items", format!("&[{}]", params_struct_name));
1334
1335 func.ret(&return_ty);
1336 func.bound("C", "tokio_postgres::GenericClient");
1337
1338 let body = generate_bulk_mutation_body(&generated.sql, upsert.params.as_ref(), !has_returning);
1339 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1340
1341 scope.push_fn(func);
1342 Ok(())
1343}
1344
1345fn generate_bulk_params_struct(
1347 ctx: &CodegenContext,
1348 struct_name: &str,
1349 table: &str,
1350 params: &Params,
1351 scope: &mut Scope,
1352) {
1353 let mut st = Struct::new(struct_name);
1354 st.vis("pub");
1355 st.derive("Debug");
1356 st.derive("Clone");
1357
1358 for (param_name_meta, param_type) in ¶ms.params {
1359 let param_name = param_name_meta.value.as_str();
1360 let rust_ty = ctx
1361 .column_type(table, param_name)
1362 .unwrap_or_else(|| param_type_to_rust(param_type));
1363 st.field(format!("pub {}", param_name), &rust_ty);
1364 }
1365
1366 scope.push_struct(st);
1367}
1368
1369fn generate_bulk_mutation_body(sql: &str, params: Option<&Params>, execute_only: bool) -> Block {
1371 let mut block = Block::new("");
1372
1373 block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
1375 block.line("");
1376
1377 if let Some(params) = params {
1379 block.line("// Convert items to parallel arrays for UNNEST");
1380 for (param_name_meta, param_type) in ¶ms.params {
1381 let param_name = param_name_meta.value.as_str();
1382 let rust_ty = param_type_to_rust(param_type);
1383 block.line(format!(
1384 "let {}_arr: Vec<{}> = items.iter().map(|i| i.{}.clone()).collect();",
1385 param_name, rust_ty, param_name
1386 ));
1387 }
1388 block.line("");
1389
1390 let param_refs: Vec<String> = params
1392 .params
1393 .keys()
1394 .map(|p| format!("&{}_arr", p.value))
1395 .collect();
1396
1397 if execute_only {
1398 block.line(format!(
1400 "let affected = client.execute(SQL, &[{}]).await?;",
1401 param_refs.join(", ")
1402 ));
1403 block.line("Ok(affected)");
1404 } else {
1405 block.line(format!(
1407 "let rows = client.query(SQL, &[{}]).await?;",
1408 param_refs.join(", ")
1409 ));
1410 block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
1411 }
1412 }
1413
1414 block
1415}
1416
1417fn generate_update_code(
1418 ctx: &CodegenContext,
1419 name_meta: &Meta<String>,
1420 update: &Update,
1421 scope: &mut Scope,
1422) -> Result<(), QError> {
1423 let name = &name_meta.value;
1424 let fn_name = to_snake_case(name);
1425 let sqlgen_ctx = ctx.sqlgen_ctx();
1426 let generated = crate::sqlgen::generate_update_sql(&sqlgen_ctx, update)?;
1427
1428 let has_returning = update.returning.is_some();
1429 let return_ty = if !has_returning {
1430 "Result<u64, QueryError>".to_string()
1431 } else {
1432 let struct_name = format!("{}Result", name);
1433 if let Some(returning) = &update.returning {
1434 generate_mutation_result_struct(
1435 ctx,
1436 &struct_name,
1437 update.table.value.as_str(),
1438 update.table.span,
1439 returning,
1440 scope,
1441 )?;
1442 }
1443 format!("Result<Option<{}>, QueryError>", struct_name)
1444 };
1445
1446 let mut func = Function::new(&fn_name);
1447 if let Some(doc) = &name_meta.doc {
1448 let doc_str = doc.join("\n");
1449 func.doc(&doc_str);
1450 }
1451 func.vis("pub");
1452 func.set_async(true);
1453 func.attr("allow(clippy::too_many_arguments)");
1456 func.generic("C");
1457 func.arg("client", "&C");
1458
1459 if let Some(params) = &update.params {
1460 for (param_name_meta, param_type) in ¶ms.params {
1461 let param_name = param_name_meta.value.as_str();
1462 let rust_ty = param_type_to_rust(param_type);
1463 func.arg(param_name, format!("&{}", rust_ty));
1464 }
1465 }
1466
1467 func.ret(&return_ty);
1468 func.bound("C", "tokio_postgres::GenericClient");
1469
1470 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1471 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1472
1473 scope.push_fn(func);
1474 Ok(())
1475}
1476
1477fn generate_delete_code(
1478 ctx: &CodegenContext,
1479 name_meta: &Meta<String>,
1480 delete: &Delete,
1481 scope: &mut Scope,
1482) -> Result<(), QError> {
1483 let name = &name_meta.value;
1484 let fn_name = to_snake_case(name);
1485 let sqlgen_ctx = ctx.sqlgen_ctx();
1486 let generated = crate::sqlgen::generate_delete_sql(&sqlgen_ctx, delete)?;
1487
1488 let has_returning = delete.returning.is_some();
1489 let return_ty = if !has_returning {
1490 "Result<u64, QueryError>".to_string()
1491 } else {
1492 let struct_name = format!("{}Result", name);
1493 if let Some(returning) = &delete.returning {
1494 generate_mutation_result_struct(
1495 ctx,
1496 &struct_name,
1497 delete.from.value.as_str(),
1498 delete.from.span,
1499 returning,
1500 scope,
1501 )?;
1502 }
1503 format!("Result<Option<{}>, QueryError>", struct_name)
1504 };
1505
1506 let mut func = Function::new(&fn_name);
1507 if let Some(doc) = &name_meta.doc {
1508 let doc_str = doc.join("\n");
1509 func.doc(&doc_str);
1510 }
1511 func.vis("pub");
1512 func.set_async(true);
1513 func.attr("allow(clippy::too_many_arguments)");
1516 func.generic("C");
1517 func.arg("client", "&C");
1518
1519 if let Some(params) = &delete.params {
1520 for (param_name_meta, param_type) in ¶ms.params {
1521 let param_name = param_name_meta.value.as_str();
1522 let rust_ty = param_type_to_rust(param_type);
1523 func.arg(param_name, format!("&{}", rust_ty));
1524 }
1525 }
1526
1527 func.ret(&return_ty);
1528 func.bound("C", "tokio_postgres::GenericClient");
1529
1530 let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1531 func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1532
1533 scope.push_fn(func);
1534 Ok(())
1535}
1536
1537fn generate_mutation_result_struct(
1538 ctx: &CodegenContext,
1539 struct_name: &str,
1540 table: &str,
1541 table_span: Span,
1542 returning: &Returning,
1543 scope: &mut Scope,
1544) -> Result<(), QError> {
1545 ctx.require_table(table, table_span)?;
1548
1549 let mut st = Struct::new(struct_name);
1550 st.vis("pub");
1551 st.derive("Debug");
1552 st.derive("Clone");
1553 st.derive("Facet");
1554 st.attr("facet(crate = dibs_runtime::facet)");
1555
1556 for (col_name_meta, _) in &returning.columns {
1557 let col_name = col_name_meta.value.as_str();
1558 let rust_ty = ctx.column_type_at(table, col_name, col_name_meta.span)?;
1559 st.field(format!("pub {col_name}"), &rust_ty);
1560 }
1561
1562 scope.push_struct(st);
1563 Ok(())
1564}
1565
1566fn generate_mutation_body(
1567 sql: &str,
1568 param_order: &[dibs_sql::ParamName],
1569 execute_only: bool,
1570) -> Block {
1571 let mut block = Block::new("");
1572
1573 block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
1575 block.line("");
1576
1577 let params: Vec<_> = param_order
1578 .iter()
1579 .filter(|p| !p.as_str().starts_with("__literal_"))
1580 .collect();
1581
1582 if execute_only {
1583 if params.is_empty() {
1585 block.line("let affected = client.execute(SQL, &[]).await?;");
1586 } else {
1587 let params_str = params
1588 .iter()
1589 .map(|p| p.as_str())
1590 .collect::<Vec<_>>()
1591 .join(", ");
1592 block.line(format!(
1593 "let affected = client.execute(SQL, &[{}]).await?;",
1594 params_str
1595 ));
1596 }
1597 block.line("Ok(affected)");
1598 } else {
1599 if params.is_empty() {
1601 block.line("let rows = client.query(SQL, &[]).await?;");
1602 } else {
1603 let params_str = params
1604 .iter()
1605 .map(|p| p.as_str())
1606 .collect::<Vec<_>>()
1607 .join(", ");
1608 block.line(format!(
1609 "let rows = client.query(SQL, &[{}]).await?;",
1610 params_str
1611 ));
1612 }
1613 let mut match_block = Block::new("match rows.into_iter().next()");
1614 match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
1615 match_block.line("None => Ok(None),");
1616 block.push_block(match_block);
1617 }
1618
1619 block
1620}
1621
1622#[cfg(test)]
1623mod tests;