Skip to main content

dibs_qgen/rustgen/
mod.rs

1//! Rust code generation from query schema types using the `codegen` crate.
2//!
3//! # Design
4//!
5//! For queries with JOINs (relations), we generate:
6//! 1. A **flat row struct** with all columns from the SELECT (using aliased names)
7//! 2. Use `from_row()` to deserialize each row into the flat struct
8//! 3. Grouping/deduplication logic works on the deserialized values
9//!
10//! This approach uses facet-tokio-postgres for all deserialization, which
11//! properly handles complex types like `Jsonb<T>` via reflection.
12
13use crate::error::QErrorKind;
14use crate::sqlgen::SqlGenContext;
15use crate::{QError, QSource};
16use codegen::{Block, Function, Scope, Struct};
17use dibs_db_schema::{Schema, Table};
18use dibs_query_schema::{
19    Decl, Delete, FieldDef, Insert, InsertMany, Meta, Params, QueryFile, Returning, Returns,
20    Select, SelectFields, Span, Update, Upsert, UpsertMany,
21};
22use std::sync::Arc;
23
24/// Generated Rust code for a query file.
25#[derive(Debug, Clone)]
26pub struct GeneratedCode {
27    /// Full Rust source code.
28    pub code: String,
29}
30
31/// Wrap a generated function body so its `Result` is fed through
32/// `TraceErr::trace_err(query_name)` before being returned. This is
33/// the *single point* where postgres-side error detail gets pushed
34/// into `tracing`, regardless of what each call site does with the
35/// `QueryError` afterward.
36///
37/// The body string is expected to be an expression evaluating to
38/// `Result<_, QueryError>` (which is what every generator already
39/// emits). We wrap it in `async { … }.await` so any `?` inside
40/// propagates to the wrapped Result instead of to the outer function,
41/// then let the inserted helper observe the Err and re-yield the
42/// Result.
43fn wrap_with_trace_err(body: &str, fn_name: &str) -> String {
44    format!(
45        "let __dibs_result = async {{\n{body}\n}}.await;\n\
46         <_ as TraceErr>::trace_err(__dibs_result, \"{fn_name}\")"
47    )
48}
49
50/// Look up the Rust type for a column in a schema.
51fn schema_column_type(schema: &Schema, table: &str, column: &str) -> Option<String> {
52    let table_info = schema.get_table(table)?;
53    let col = table_info.columns.iter().find(|c| c.name == column)?;
54    let rust_type = col
55        .rust_type
56        .clone()
57        .unwrap_or_else(|| col.pg_type.to_rust_type().to_string());
58    if col.nullable {
59        Some(format!("Option<{}>", rust_type))
60    } else {
61        Some(rust_type)
62    }
63}
64
65/// Context for code generation.
66struct CodegenContext<'a> {
67    schema: &'a Schema,
68    source: Arc<QSource>,
69    #[allow(dead_code)]
70    scope: Scope,
71}
72
73impl CodegenContext<'_> {
74    /// Look up the Rust type for a column.
75    fn column_type(&self, table: &str, column: &str) -> Option<String> {
76        schema_column_type(self.schema, table, column)
77    }
78
79    /// Build a `TableNotFound` error pointing at `span`, listing the tables the
80    /// schema *does* contain (empty list ⇒ the schema itself is empty, which the
81    /// error renders as a "you forgot ensure_linked()" hint).
82    fn table_not_found(&self, table: &str, span: Span) -> QError {
83        let mut available: Vec<String> = self.schema.tables.keys().cloned().collect();
84        available.sort();
85        QError {
86            source: self.source.clone(),
87            span,
88            kind: QErrorKind::TableNotFound {
89                table: table.to_string(),
90                available,
91            },
92        }
93    }
94
95    /// Resolve a table that is referenced in the query at `span`, erroring (with
96    /// the span pointing at the table reference, not some column) if it isn't in
97    /// the schema. Call this once per table reference before looking up its
98    /// columns, so a missing/empty schema is reported against the table.
99    fn require_table(&self, table: &str, span: Span) -> Result<&Table, QError> {
100        self.schema
101            .get_table(table)
102            .ok_or_else(|| self.table_not_found(table, span))
103    }
104
105    /// Look up the Rust type for a column referenced in the query at `span`,
106    /// erroring if the table or column can't be resolved from the schema.
107    ///
108    /// Use this for every column whose type ends up in a generated result/param
109    /// struct. A missing column means the schema handed to codegen is wrong or
110    /// empty (e.g. a build script that forgot to link its table definitions);
111    /// silently falling back to `String` there generates wrong-typed structs
112    /// that compile fine and corrupt data at runtime.
113    fn column_type_at(&self, table: &str, column: &str, span: Span) -> Result<String, QError> {
114        let table_info = self.require_table(table, span)?;
115        if let Some(ty) = schema_column_type(self.schema, table, column) {
116            return Ok(ty);
117        }
118        Err(QError {
119            source: self.source.clone(),
120            span,
121            kind: QErrorKind::ColumnNotFound {
122                table: table.to_string(),
123                column: column.to_string(),
124                available: table_info.columns.iter().map(|c| c.name.clone()).collect(),
125            },
126        })
127    }
128
129    /// Create an SqlGenContext for this codegen context.
130    fn sqlgen_ctx(&self) -> SqlGenContext<'_> {
131        SqlGenContext::new(self.schema, self.source.clone())
132    }
133}
134
135/// Generate Rust code for a query file.
136pub fn generate_rust_code(
137    file: &QueryFile,
138    schema: &Schema,
139    source: Arc<QSource>,
140) -> Result<GeneratedCode, QError> {
141    let mut scope = Scope::new();
142
143    // Add file header as raw code
144    scope.raw("// Generated by dibs-qgen. Do not edit.");
145    scope.raw("");
146
147    // Imports
148    scope.import("dibs_runtime::prelude", "*");
149    scope.import("dibs_runtime", "tokio_postgres");
150
151    let ctx = CodegenContext {
152        schema,
153        source,
154        scope: Scope::new(),
155    };
156
157    // Iterate through declarations and generate code for each type
158    for (name_meta, decl) in &file.0 {
159        match decl {
160            Decl::Select(select) => {
161                generate_select_code(&ctx, name_meta, select, &mut scope)?;
162            }
163            Decl::Insert(insert) => {
164                generate_insert_code(&ctx, name_meta, insert, &mut scope)?;
165            }
166            Decl::InsertMany(insert_many) => {
167                generate_insert_many_code(&ctx, name_meta, insert_many, &mut scope)?;
168            }
169            Decl::Upsert(upsert) => {
170                generate_upsert_code(&ctx, name_meta, upsert, &mut scope)?;
171            }
172            Decl::UpsertMany(upsert_many) => {
173                generate_upsert_many_code(&ctx, name_meta, upsert_many, &mut scope)?;
174            }
175            Decl::Update(update) => {
176                generate_update_code(&ctx, name_meta, update, &mut scope)?;
177            }
178            Decl::Delete(delete) => {
179                generate_delete_code(&ctx, name_meta, delete, &mut scope)?;
180            }
181        }
182    }
183
184    Ok(GeneratedCode {
185        code: scope.to_string(),
186    })
187}
188
189fn generate_select_code(
190    ctx: &CodegenContext,
191    name_meta: &Meta<String>,
192    select: &Select,
193    scope: &mut Scope,
194) -> Result<(), QError> {
195    let name = &name_meta.value;
196    let struct_name = format!("{}Result", name);
197
198    // Generate result struct(s)
199    if let Some(from) = &select.from {
200        if select.fields.is_some() {
201            generate_result_struct(ctx, select, name_meta, &struct_name, from, scope)?;
202
203            // For queries with relations, also generate a flat row struct for deserialization
204            if select.has_relations() {
205                let flat_struct_name = format!("{}Row", name);
206                generate_flat_row_struct(ctx, select, &flat_struct_name, from, scope)?;
207            }
208        }
209    } else if let Some(returns) = &select.returns {
210        // Raw SQL query with explicit returns clause
211        generate_raw_sql_result_struct(&struct_name, returns, scope);
212    }
213
214    // Generate query function
215    generate_select_function(ctx, name_meta, select, &struct_name, scope)?;
216    Ok(())
217}
218
219/// Generate a flat row struct that matches the SQL result columns exactly.
220///
221/// This struct is used with `from_row()` to deserialize each database row,
222/// then transformed into the nested result struct.
223///
224/// For a query like:
225/// ```text
226/// ProductDetails @select{
227///     from product
228///     fields { id, handle, variants @rel{ from product_variant, fields { id, sku } } }
229/// }
230/// ```
231///
232/// Generates:
233/// ```text
234/// struct ProductDetailsRow {
235///     id: i64,
236///     handle: String,
237///     variants_id: Option<i64>,    // Option because LEFT JOIN
238///     variants_sku: Option<String>,
239/// }
240/// ```
241fn generate_flat_row_struct(
242    ctx: &CodegenContext,
243    select: &Select,
244    struct_name: &str,
245    table: &Meta<dibs_sql::TableName>,
246    scope: &mut Scope,
247) -> Result<(), QError> {
248    let mut st = Struct::new(struct_name);
249    // Internal struct - not pub
250    st.derive("Debug");
251    st.derive("Clone");
252    st.derive("Facet");
253    st.attr("facet(crate = dibs_runtime::facet)");
254
255    let table_name = table.value.as_str();
256    ctx.require_table(table_name, table.span)?;
257
258    if let Some(select_fields) = &select.fields {
259        // Add root table columns
260        add_flat_fields_for_select(ctx, &mut st, table_name, "", select_fields)?;
261    }
262
263    scope.push_struct(st);
264    Ok(())
265}
266
267/// Recursively add fields to the flat row struct for a SelectFields.
268fn add_flat_fields_for_select(
269    ctx: &CodegenContext,
270    st: &mut Struct,
271    table_name: &str,
272    prefix: &str,
273    select_fields: &SelectFields,
274) -> Result<(), QError> {
275    for (field_name_meta, field_def) in &select_fields.fields {
276        let field_name = field_name_meta.value.as_str();
277
278        match field_def {
279            None => {
280                // Simple column
281                let rust_ty = ctx.column_type_at(table_name, field_name, field_name_meta.span)?;
282
283                let flat_field_name = if prefix.is_empty() {
284                    field_name.to_string()
285                } else {
286                    format!("{}_{}", prefix, field_name)
287                };
288
289                // If we're in a relation (prefix is not empty), wrap in Option for LEFT JOIN
290                let final_ty = if prefix.is_empty() {
291                    rust_ty
292                } else if rust_ty.starts_with("Option<") {
293                    // Already optional
294                    rust_ty
295                } else {
296                    format!("Option<{}>", rust_ty)
297                };
298
299                // Use rename attribute since field names with underscores need to match SQL aliases
300                st.field(&flat_field_name, &final_ty);
301            }
302            Some(FieldDef::Rel(rel)) => {
303                // Recurse into relation
304                let rel_table = rel.table_name().unwrap_or(field_name);
305                let rel_span = rel
306                    .from
307                    .as_ref()
308                    .map(|m| m.span)
309                    .unwrap_or(field_name_meta.span);
310                ctx.require_table(rel_table, rel_span)?;
311                let new_prefix = if prefix.is_empty() {
312                    field_name.to_string()
313                } else {
314                    format!("{}_{}", prefix, field_name)
315                };
316
317                if let Some(rel_fields) = &rel.fields {
318                    add_flat_fields_for_select(ctx, st, rel_table, &new_prefix, rel_fields)?;
319                }
320            }
321            Some(FieldDef::Count(_)) => {
322                // COUNT subquery result
323                let flat_field_name = if prefix.is_empty() {
324                    field_name.to_string()
325                } else {
326                    format!("{}_{}", prefix, field_name)
327                };
328                st.field(&flat_field_name, "i64");
329            }
330        }
331    }
332    Ok(())
333}
334
335fn generate_raw_sql_result_struct(struct_name: &str, returns: &Returns, scope: &mut Scope) {
336    let mut st = Struct::new(struct_name);
337    st.vis("pub");
338    st.derive("Debug");
339    st.derive("Clone");
340    st.derive("Facet");
341    st.attr("facet(crate = dibs_runtime::facet)");
342
343    for (field_name_meta, param_type) in &returns.fields {
344        let field_name = field_name_meta.value.as_str();
345        let rust_ty = param_type_to_rust(param_type);
346        st.field(format!("pub {}", field_name), &rust_ty);
347    }
348
349    scope.push_struct(st);
350}
351
352fn generate_result_struct(
353    ctx: &CodegenContext,
354    select: &Select,
355    name_meta: &Meta<String>,
356    struct_name: &str,
357    table: &Meta<dibs_sql::TableName>,
358    scope: &mut Scope,
359) -> Result<(), QError> {
360    let mut st = Struct::new(struct_name);
361    st.vis("pub");
362    st.derive("Debug");
363    st.derive("Clone");
364    st.derive("Facet");
365    st.attr("facet(crate = dibs_runtime::facet)");
366
367    // Regular query - use select fields
368    let parent_prefix = &name_meta.value;
369    let table_name = table.value.as_str();
370    // Resolve the table once up front so a missing/empty schema is reported
371    // against the `from` clause rather than the first selected column.
372    ctx.require_table(table_name, table.span)?;
373
374    if let Some(select_fields) = &select.fields {
375        for (field_name_meta, field_def) in &select_fields.fields {
376            let field_name = field_name_meta.value.as_str();
377            match field_def {
378                None => {
379                    // Simple column
380                    let rust_ty =
381                        ctx.column_type_at(table_name, field_name, field_name_meta.span)?;
382                    st.field(format!("pub {}", field_name), &rust_ty);
383                }
384                Some(FieldDef::Rel(rel)) => {
385                    let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
386                    let ty = if rel.first.is_some() {
387                        format!("Option<{}>", nested_name)
388                    } else {
389                        format!("Vec<{}>", nested_name)
390                    };
391                    st.field(format!("pub {}", field_name), &ty);
392                }
393                Some(FieldDef::Count(_)) => {
394                    st.field(format!("pub {}", field_name), "i64");
395                }
396            }
397        }
398    }
399
400    scope.push_struct(st);
401
402    // Generate nested structs for relations (recursively)
403    if let Some(select_fields) = &select.fields {
404        generate_nested_structs(ctx, parent_prefix, select_fields, scope)?;
405    }
406    Ok(())
407}
408
409/// Recursively generate structs for nested relations.
410///
411/// `parent_prefix` is used to namespace the struct names to avoid collisions
412/// when multiple queries have relations with the same field name.
413fn generate_nested_structs(
414    ctx: &CodegenContext,
415    parent_prefix: &str,
416    select_fields: &SelectFields,
417    scope: &mut Scope,
418) -> Result<(), QError> {
419    for (field_name_meta, field_def) in &select_fields.fields {
420        if let Some(FieldDef::Rel(rel)) = field_def {
421            let field_name = field_name_meta.value.as_str();
422            let nested_name = format!("{}{}", parent_prefix, to_pascal_case(field_name));
423            let rel_table = rel.table_name().unwrap_or(field_name);
424            let rel_span = rel
425                .from
426                .as_ref()
427                .map(|m| m.span)
428                .unwrap_or(field_name_meta.span);
429            ctx.require_table(rel_table, rel_span)?;
430
431            let mut nested_st = Struct::new(&nested_name);
432            nested_st.vis("pub");
433            nested_st.derive("Debug");
434            nested_st.derive("Clone");
435            nested_st.derive("Facet");
436            nested_st.attr("facet(crate = dibs_runtime::facet)");
437
438            if let Some(rel_fields) = &rel.fields {
439                for (rel_field_name_meta, rel_field_def) in &rel_fields.fields {
440                    let rel_field_name = rel_field_name_meta.value.as_str();
441                    match rel_field_def {
442                        None => {
443                            // Simple column
444                            let rust_ty = ctx.column_type_at(
445                                rel_table,
446                                rel_field_name,
447                                rel_field_name_meta.span,
448                            )?;
449                            nested_st.field(format!("pub {}", rel_field_name), &rust_ty);
450                        }
451                        Some(FieldDef::Rel(nested_rel)) => {
452                            // Nested relation field - namespace with current struct name
453                            let nested_rel_name =
454                                format!("{}{}", nested_name, to_pascal_case(rel_field_name));
455                            let ty = if nested_rel.first.is_some() {
456                                format!("Option<{}>", nested_rel_name)
457                            } else {
458                                format!("Vec<{}>", nested_rel_name)
459                            };
460                            nested_st.field(format!("pub {}", rel_field_name), &ty);
461                        }
462                        Some(FieldDef::Count(_)) => {
463                            nested_st.field(format!("pub {}", rel_field_name), "i64");
464                        }
465                    }
466                }
467            }
468
469            scope.push_struct(nested_st);
470
471            // Recursively generate structs for nested relations
472            if let Some(rel_fields) = &rel.fields {
473                generate_nested_structs(ctx, &nested_name, rel_fields, scope)?;
474            }
475        }
476    }
477    Ok(())
478}
479
480fn generate_select_function(
481    ctx: &CodegenContext,
482    name_meta: &Meta<String>,
483    query: &Select,
484    struct_name: &str,
485    scope: &mut Scope,
486) -> Result<(), QError> {
487    let name = &name_meta.value;
488    let fn_name = to_snake_case(name);
489
490    let return_ty = if query.first.is_some() {
491        format!("Result<Option<{}>, QueryError>", struct_name)
492    } else {
493        format!("Result<Vec<{}>, QueryError>", struct_name)
494    };
495
496    let mut func = Function::new(&fn_name);
497    if let Some(doc) = &name_meta.doc {
498        let doc_str = doc.join("\n");
499        func.doc(&doc_str);
500    }
501    func.vis("pub");
502    func.set_async(true);
503    // Generated query fns take one arg per bound param; wide tables legitimately
504    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
505    func.attr("allow(clippy::too_many_arguments)");
506    func.generic("C");
507    func.arg("client", "&C");
508    // Allow clone_on_copy since we generate .clone() calls on parent IDs that might be Copy types
509    func.attr("allow(clippy::clone_on_copy)");
510
511    if let Some(params) = &query.params {
512        for (param_name_meta, param_type) in &params.params {
513            let param_name = &param_name_meta.value;
514            let rust_ty = param_type_to_rust(param_type);
515            func.arg(param_name, format!("&{}", rust_ty));
516        }
517    }
518
519    func.ret(&return_ty);
520    func.bound("C", "tokio_postgres::GenericClient");
521
522    // Generate function body
523    let body = if let Some(raw_sql_meta) = &query.sql {
524        block_to_string(&generate_raw_query_body(query, &raw_sql_meta.value))
525    } else {
526        generate_query_body(ctx, query, struct_name)?
527    };
528    func.line(wrap_with_trace_err(&body, &fn_name));
529
530    scope.push_fn(func);
531    Ok(())
532}
533
534/// Generate query body for all queries (with or without JOINs).
535///
536/// For queries without relations: use `from_row()` directly into the result struct.
537/// For queries with relations: deserialize into flat row struct, then transform.
538fn generate_query_body(
539    ctx: &CodegenContext,
540    query: &Select,
541    struct_name: &str,
542) -> Result<String, QError> {
543    let sqlgen_ctx = ctx.sqlgen_ctx();
544    let generated = match crate::sqlgen::generate_select_sql(&sqlgen_ctx, query) {
545        Ok(g) => g,
546        Err(e) => {
547            panic!("SELECT SQL generation failed: {}", e);
548        }
549    };
550
551    let mut block = Block::new("");
552
553    // SQL constant
554    block.line(format!("const SQL: &str = r#\"{}\"#;", generated.sql));
555    block.line("");
556
557    // Build params array - filter out literal placeholders
558    let params: Vec<_> = generated
559        .param_order
560        .iter()
561        .filter(|p| !p.as_str().starts_with("__literal_"))
562        .collect();
563
564    if params.is_empty() {
565        block.line("let rows = client.query(SQL, &[]).await?;");
566    } else {
567        let params_str = params
568            .iter()
569            .map(|p| p.as_str())
570            .collect::<Vec<_>>()
571            .join(", ");
572        block.line(format!(
573            "let rows = client.query(SQL, &[{}]).await?;",
574            params_str
575        ));
576    }
577
578    // If no relations, use from_row() directly into the result struct
579    if !query.has_relations() {
580        if query.first.is_some() {
581            let mut match_block = Block::new("match rows.into_iter().next()");
582            match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
583            match_block.line("None => Ok(None),");
584            block.push_block(match_block);
585        } else {
586            block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
587        }
588        return Ok(block_to_string(&block));
589    }
590
591    // For queries with relations, deserialize into flat row struct then transform
592    let query_name = struct_name.strip_suffix("Result").unwrap_or(struct_name);
593    let flat_struct_name = format!("{}Row", query_name);
594
595    block.line("");
596    block.line("// Deserialize all rows into flat structs using facet reflection");
597    block.line(format!(
598        "let flat_rows: Vec<{flat_struct_name}> = rows.iter().map(from_row).collect::<Result<Vec<_>, _>>()?;"
599    ));
600    block.line("");
601
602    // Generate the transformation from flat rows to nested result
603    let Some(select_fields) = &query.fields else {
604        // No fields - shouldn't happen for queries with relations
605        block.line("Ok(vec![])".to_string());
606        return Ok(block_to_string(&block));
607    };
608
609    let root_table = query
610        .from
611        .as_ref()
612        .map(|m| m.value.as_str())
613        .unwrap_or("unknown");
614    let is_first = query.is_first();
615
616    block.line(generate_flat_to_nested_transform(
617        ctx,
618        select_fields,
619        struct_name,
620        root_table,
621        is_first,
622    )?);
623
624    Ok(block_to_string(&block))
625}
626
627/// Generate code to transform flat rows into nested result structs.
628fn generate_flat_to_nested_transform(
629    ctx: &CodegenContext,
630    select_fields: &SelectFields,
631    struct_name: &str,
632    root_table: &str,
633    is_first: bool,
634) -> Result<String, QError> {
635    let mut block = Block::new("");
636
637    // Find the ID column for grouping (typically "id")
638    let id_column = select_fields
639        .id_column()
640        .map(|c| c.to_string())
641        .unwrap_or_else(|| "id".to_string());
642
643    let id_type = ctx
644        .column_type(root_table, &id_column)
645        .unwrap_or_else(|| "i64".to_string());
646
647    // Determine if we need grouping (Vec relations) or simple mapping (Option relations only)
648    if select_fields.has_vec_relations() {
649        // Group by parent ID for Vec relations
650        block.line("// Group flat rows by parent ID and assemble nested structs");
651        block.line(format!(
652            "let mut grouped: std::collections::HashMap<{id_type}, {struct_name}> = std::collections::HashMap::new();"
653        ));
654
655        // Track seen relation IDs to avoid duplicates from JOINs
656        generate_seen_id_declarations(&mut block, ctx, select_fields, &id_type, "")?;
657
658        block.line("");
659
660        let mut for_block = Block::new("for flat_row in flat_rows");
661        for_block.line(format!("let parent_id = flat_row.{id_column}.clone();"));
662        for_block.line("");
663
664        // Get or create the entry
665        let mut entry_block = Block::new(format!(
666            "let entry = grouped.entry(parent_id.clone()).or_insert_with(|| {struct_name}"
667        ));
668
669        // Add root columns
670        for (field_name_meta, field_def) in &select_fields.fields {
671            let field_name = field_name_meta.value.as_str();
672            match field_def {
673                None => {
674                    entry_block.line(format!("{field_name}: flat_row.{field_name}.clone(),"));
675                }
676                Some(FieldDef::Rel(rel)) => {
677                    if rel.is_first() {
678                        entry_block.line(format!("{field_name}: None,"));
679                    } else {
680                        entry_block.line(format!("{field_name}: Vec::new(),"));
681                    }
682                }
683                Some(FieldDef::Count(_)) => {
684                    entry_block.line(format!("{field_name}: flat_row.{field_name},"));
685                }
686            }
687        }
688        entry_block.after(");");
689        for_block.push_block(entry_block);
690        for_block.line("");
691
692        // Add relations
693        let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
694        generate_relation_assembly(
695            &mut for_block,
696            ctx,
697            select_fields,
698            parent_prefix,
699            "",
700            &id_type,
701        )?;
702
703        block.push_block(for_block);
704        block.line("");
705
706        if is_first {
707            block.line("Ok(grouped.into_values().next())");
708        } else {
709            block.line("Ok(grouped.into_values().collect())");
710        }
711    } else {
712        // Option-only relations - each row becomes one result
713        block.line("// Transform flat rows into nested structs (Option relations only)");
714
715        let mut map_block = Block::new(
716            "let results: Result<Vec<_>, QueryError> = flat_rows.into_iter().map(|flat_row| {",
717        );
718
719        let mut result_block = Block::new(format!("Ok({struct_name}"));
720        let parent_prefix = struct_name.strip_suffix("Result").unwrap_or(struct_name);
721
722        for (field_name_meta, field_def) in &select_fields.fields {
723            let field_name = field_name_meta.value.as_str();
724            match field_def {
725                None => {
726                    result_block.line(format!("{field_name}: flat_row.{field_name},"));
727                }
728                Some(FieldDef::Rel(rel)) => {
729                    if rel.is_first() {
730                        // Option relation - check if first column is Some
731                        if let Some(rel_fields) = &rel.fields {
732                            let rel_table = rel.table_name().unwrap_or(field_name);
733                            let first_col = rel_fields
734                                .first_column()
735                                .map(|c| c.as_str())
736                                .unwrap_or("id");
737                            let first_alias = format!("{field_name}_{first_col}");
738                            let nested_struct =
739                                format!("{}{}", parent_prefix, to_pascal_case(field_name));
740
741                            let mut map_inner = Block::new(format!(
742                                "{field_name}: flat_row.{first_alias}.as_ref().map(|_| {nested_struct}"
743                            ));
744
745                            for (inner_field_meta, inner_def) in &rel_fields.fields {
746                                let inner_name = inner_field_meta.value.as_str();
747                                if inner_def.is_none() {
748                                    let alias = format!("{field_name}_{inner_name}");
749                                    let rust_ty = ctx.column_type_at(
750                                        rel_table,
751                                        inner_name,
752                                        inner_field_meta.span,
753                                    )?;
754
755                                    // Unwrap the Option from LEFT JOIN
756                                    if rust_ty.starts_with("Option<") {
757                                        map_inner.line(format!(
758                                            "{inner_name}: flat_row.{alias}.clone(),"
759                                        ));
760                                    } else {
761                                        map_inner.line(format!(
762                                            "{inner_name}: flat_row.{alias}.clone().expect(\"non-null column from LEFT JOIN\"),"
763                                        ));
764                                    }
765                                }
766                            }
767
768                            map_inner.after("),");
769                            result_block.push_block(map_inner);
770                        }
771                    } else {
772                        // Vec relation in option-only assembly - shouldn't happen
773                        result_block.line(format!("{field_name}: Vec::new(),"));
774                    }
775                }
776                Some(FieldDef::Count(_)) => {
777                    result_block.line(format!("{field_name}: flat_row.{field_name},"));
778                }
779            }
780        }
781
782        result_block.after(")");
783        map_block.push_block(result_block);
784        map_block.after("}).collect();");
785        block.push_block(map_block);
786        block.line("");
787
788        if is_first {
789            block.line("results.map(|mut v| v.pop())");
790        } else {
791            block.line("results");
792        }
793    }
794
795    Ok(block_to_string(&block))
796}
797
798/// Generate declarations for tracking seen relation IDs (for deduplication).
799fn generate_seen_id_declarations(
800    block: &mut Block,
801    ctx: &CodegenContext,
802    select_fields: &SelectFields,
803    parent_id_type: &str,
804    prefix: &str,
805) -> Result<(), QError> {
806    for (field_name_meta, field_def) in &select_fields.fields {
807        if let Some(FieldDef::Rel(rel)) = field_def {
808            let field_name = field_name_meta.value.as_str();
809            if !rel.is_first() {
810                // Vec relation needs deduplication
811                if let Some(rel_fields) = &rel.fields {
812                    let rel_table = rel.table_name().unwrap_or(field_name);
813                    let id_col = rel_fields.id_column().map(|c| c.as_str()).unwrap_or("id");
814                    let id_type = ctx
815                        .column_type(rel_table, id_col)
816                        .unwrap_or_else(|| "i64".to_string());
817
818                    let set_name = if prefix.is_empty() {
819                        format!("seen_{field_name}")
820                    } else {
821                        format!("seen_{prefix}_{field_name}")
822                    };
823
824                    block.line(format!(
825                        "let mut {set_name}: std::collections::HashSet<({parent_id_type}, {id_type})> = std::collections::HashSet::new();"
826                    ));
827
828                    // Recurse for nested Vec relations
829                    let new_prefix = if prefix.is_empty() {
830                        field_name.to_string()
831                    } else {
832                        format!("{prefix}_{field_name}")
833                    };
834
835                    // For nested relations, the parent ID is now this relation's ID
836                    generate_seen_id_declarations(block, ctx, rel_fields, &id_type, &new_prefix)?;
837                }
838            }
839        }
840    }
841    Ok(())
842}
843
844/// Generate code to assemble relations from flat row data.
845fn generate_relation_assembly(
846    for_block: &mut Block,
847    ctx: &CodegenContext,
848    select_fields: &SelectFields,
849    parent_prefix: &str,
850    flat_prefix: &str,
851    _parent_id_type: &str,
852) -> Result<(), QError> {
853    for (field_name_meta, field_def) in &select_fields.fields {
854        if let Some(FieldDef::Rel(rel)) = field_def {
855            let field_name = field_name_meta.value.as_str();
856            let rel_table = rel.table_name().unwrap_or(field_name);
857            let nested_struct = format!("{}{}", parent_prefix, to_pascal_case(field_name));
858
859            let flat_field_prefix = if flat_prefix.is_empty() {
860                field_name.to_string()
861            } else {
862                format!("{flat_prefix}_{field_name}")
863            };
864
865            if let Some(rel_fields) = &rel.fields {
866                let first_col = rel_fields
867                    .first_column()
868                    .map(|c| c.as_str())
869                    .unwrap_or("id");
870                let id_col = rel_fields
871                    .id_column()
872                    .map(|c| c.as_str())
873                    .unwrap_or(first_col);
874                let id_alias = format!("{flat_field_prefix}_{id_col}");
875
876                if rel.is_first() {
877                    // Option relation
878                    for_block.line(format!("// Populate {field_name} (Option relation)"));
879
880                    let mut if_block = Block::new(format!(
881                        "if entry.{field_name}.is_none() && flat_row.{id_alias}.is_some()"
882                    ));
883
884                    let mut some_block =
885                        Block::new(format!("entry.{field_name} = Some({nested_struct}"));
886                    generate_relation_fields(
887                        &mut some_block,
888                        ctx,
889                        rel_fields,
890                        rel_table,
891                        &flat_field_prefix,
892                    )?;
893                    some_block.after(");");
894                    if_block.push_block(some_block);
895
896                    for_block.push_block(if_block);
897                    for_block.line("");
898                } else {
899                    // Vec relation with deduplication
900                    let set_name = if flat_prefix.is_empty() {
901                        format!("seen_{field_name}")
902                    } else {
903                        format!("seen_{flat_prefix}_{field_name}")
904                    };
905
906                    for_block.line(format!("// Append to {field_name} (Vec relation)"));
907
908                    let mut if_block =
909                        Block::new(format!("if let Some(ref rel_id) = flat_row.{id_alias}"));
910                    if_block.line("let key = (parent_id.clone(), rel_id.clone());".to_string());
911
912                    let mut if_insert = Block::new(format!("if {set_name}.insert(key)"));
913                    let mut push_block =
914                        Block::new(format!("entry.{field_name}.push({nested_struct}"));
915                    generate_relation_fields(
916                        &mut push_block,
917                        ctx,
918                        rel_fields,
919                        rel_table,
920                        &flat_field_prefix,
921                    )?;
922                    push_block.after(");");
923                    if_insert.push_block(push_block);
924
925                    if_block.push_block(if_insert);
926                    for_block.push_block(if_block);
927                    for_block.line("");
928                }
929            }
930        }
931    }
932    Ok(())
933}
934
935/// Generate field assignments for a relation struct.
936fn generate_relation_fields(
937    block: &mut Block,
938    ctx: &CodegenContext,
939    select_fields: &SelectFields,
940    table_name: &str,
941    flat_prefix: &str,
942) -> Result<(), QError> {
943    for (field_name_meta, field_def) in &select_fields.fields {
944        let field_name = field_name_meta.value.as_str();
945        let alias = format!("{flat_prefix}_{field_name}");
946
947        match field_def {
948            None => {
949                let rust_ty = ctx.column_type_at(table_name, field_name, field_name_meta.span)?;
950
951                // Flat struct has Option<T> for relation columns due to LEFT JOIN
952                // Need to unwrap unless the original type was already Option
953                if rust_ty.starts_with("Option<") {
954                    block.line(format!("{field_name}: flat_row.{alias}.clone(),"));
955                } else {
956                    block.line(format!(
957                        "{field_name}: flat_row.{alias}.clone().expect(\"non-null from LEFT JOIN\"),"
958                    ));
959                }
960            }
961            Some(FieldDef::Rel(rel)) => {
962                if rel.is_first() {
963                    block.line(format!(
964                        "{field_name}: None, // TODO: nested Option relation"
965                    ));
966                } else {
967                    block.line(format!(
968                        "{field_name}: Vec::new(), // TODO: nested Vec relation"
969                    ));
970                }
971            }
972            Some(FieldDef::Count(_)) => {
973                block.line(format!("{field_name}: flat_row.{alias},"));
974            }
975        }
976    }
977    Ok(())
978}
979
980fn generate_raw_query_body(query: &Select, raw_sql: &str) -> Block {
981    let cleaned: String = raw_sql
982        .lines()
983        .map(|l| l.trim())
984        .collect::<Vec<_>>()
985        .join("\n");
986
987    let mut block = Block::new("");
988
989    // SQL constant
990    block.line(format!("const SQL: &str = r#\"{}\"#;", cleaned.trim()));
991    block.line("");
992
993    // Query execution
994    if let Some(params) = &query.params {
995        let param_names: Vec<&str> = params.iter().map(|(meta, _)| meta.value.as_str()).collect();
996        if !param_names.is_empty() {
997            let params_str = param_names.join(", ");
998            block.line(format!(
999                "let rows = client.query(SQL, &[{}]).await?;",
1000                params_str
1001            ));
1002        } else {
1003            block.line("let rows = client.query(SQL, &[]).await?;");
1004        }
1005    } else {
1006        block.line("let rows = client.query(SQL, &[]).await?;");
1007    }
1008
1009    // Result processing
1010    if query.first.is_some() {
1011        let mut match_block = Block::new("match rows.into_iter().next()");
1012        match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
1013        match_block.line("None => Ok(None),");
1014        block.push_block(match_block);
1015    } else {
1016        block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
1017    }
1018
1019    block
1020}
1021
1022fn param_type_to_rust(ty: &dibs_query_schema::ParamType) -> String {
1023    use dibs_query_schema::ParamType;
1024    match ty {
1025        ParamType::String => "String".to_string(),
1026        ParamType::Int => "i64".to_string(),
1027        ParamType::Float => "f64".to_string(),
1028        ParamType::Bool => "bool".to_string(),
1029        ParamType::Uuid => "Uuid".to_string(),
1030        ParamType::Decimal => "Decimal".to_string(),
1031        ParamType::Timestamp => "Timestamp".to_string(),
1032        ParamType::Bytes => "Vec<u8>".to_string(),
1033        // JSONB params travel over the wire as JSON-encoded text and
1034        // get cast `::jsonb` at the binding site (see sqlgen). The
1035        // caller-facing type is therefore plain `String` — the same
1036        // shape they'd produce from `facet_json::to_string`, an axum
1037        // body, or a webhook delivery.
1038        ParamType::Jsonb => "String".to_string(),
1039        ParamType::Optional(inner_vec) => {
1040            if let Some(inner) = inner_vec.first() {
1041                format!("Option<{}>", param_type_to_rust(inner))
1042            } else {
1043                "Option<String>".to_string()
1044            }
1045        }
1046    }
1047}
1048
1049/// Helper to format a Block to a String.
1050fn block_to_string(block: &Block) -> String {
1051    let mut output = String::new();
1052    let mut formatter = codegen::Formatter::new(&mut output);
1053    block.fmt(&mut formatter).expect("formatting failed");
1054    output
1055}
1056
1057fn to_pascal_case(s: &str) -> String {
1058    let mut result = String::new();
1059    let mut capitalize_next = true;
1060
1061    for c in s.chars() {
1062        if c == '_' {
1063            capitalize_next = true;
1064        } else if capitalize_next {
1065            result.push(c.to_ascii_uppercase());
1066            capitalize_next = false;
1067        } else {
1068            result.push(c);
1069        }
1070    }
1071
1072    result
1073}
1074
1075fn to_snake_case(s: &str) -> String {
1076    let mut result = String::new();
1077
1078    for (i, c) in s.chars().enumerate() {
1079        if c.is_uppercase() {
1080            if i > 0 {
1081                result.push('_');
1082            }
1083            result.push(c.to_ascii_lowercase());
1084        } else {
1085            result.push(c);
1086        }
1087    }
1088
1089    result
1090}
1091
1092// ============================================================================
1093// Mutation code generation
1094// ============================================================================
1095
1096fn generate_insert_code(
1097    _ctx: &CodegenContext,
1098    name_meta: &Meta<String>,
1099    insert: &Insert,
1100    scope: &mut Scope,
1101) -> Result<(), QError> {
1102    let name = &name_meta.value;
1103    let fn_name = to_snake_case(name);
1104    let generated = crate::sqlgen::generate_insert_sql(insert);
1105
1106    // Generate result struct if RETURNING is used
1107    let has_returning = insert.returning.is_some();
1108    let return_ty = if !has_returning {
1109        "Result<u64, QueryError>".to_string()
1110    } else {
1111        let struct_name = format!("{}Result", name);
1112        if let Some(returning) = &insert.returning {
1113            generate_mutation_result_struct(
1114                _ctx,
1115                &struct_name,
1116                insert.into.value.as_str(),
1117                insert.into.span,
1118                returning,
1119                scope,
1120            )?;
1121        }
1122        format!("Result<Option<{}>, QueryError>", struct_name)
1123    };
1124
1125    let mut func = Function::new(&fn_name);
1126    if let Some(doc) = &name_meta.doc {
1127        let doc_str = doc.join("\n");
1128        func.doc(&doc_str);
1129    }
1130    func.vis("pub");
1131    func.set_async(true);
1132    // Generated query fns take one arg per bound param; wide tables legitimately
1133    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
1134    func.attr("allow(clippy::too_many_arguments)");
1135    func.generic("C");
1136    func.arg("client", "&C");
1137
1138    if let Some(params) = &insert.params {
1139        for (param_name_meta, param_type) in &params.params {
1140            let param_name = param_name_meta.value.as_str();
1141            let rust_ty = param_type_to_rust(param_type);
1142            func.arg(param_name, format!("&{}", rust_ty));
1143        }
1144    }
1145
1146    func.ret(&return_ty);
1147    func.bound("C", "tokio_postgres::GenericClient");
1148
1149    let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1150    func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1151
1152    scope.push_fn(func);
1153    Ok(())
1154}
1155
1156fn generate_upsert_code(
1157    _ctx: &CodegenContext,
1158    name_meta: &Meta<String>,
1159    upsert: &Upsert,
1160    scope: &mut Scope,
1161) -> Result<(), QError> {
1162    let name = &name_meta.value;
1163    let fn_name = to_snake_case(name);
1164    let generated = crate::sqlgen::generate_upsert_sql(upsert);
1165
1166    let has_returning = upsert.returning.is_some();
1167    let return_ty = if !has_returning {
1168        "Result<u64, QueryError>".to_string()
1169    } else {
1170        let struct_name = format!("{}Result", name);
1171        if let Some(returning) = &upsert.returning {
1172            generate_mutation_result_struct(
1173                _ctx,
1174                &struct_name,
1175                upsert.into.value.as_str(),
1176                upsert.into.span,
1177                returning,
1178                scope,
1179            )?;
1180        }
1181        format!("Result<Option<{}>, QueryError>", struct_name)
1182    };
1183
1184    let mut func = Function::new(&fn_name);
1185    if let Some(doc) = &name_meta.doc {
1186        let doc_str = doc.join("\n");
1187        func.doc(&doc_str);
1188    }
1189    func.vis("pub");
1190    func.set_async(true);
1191    // Generated query fns take one arg per bound param; wide tables legitimately
1192    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
1193    func.attr("allow(clippy::too_many_arguments)");
1194    func.generic("C");
1195    func.arg("client", "&C");
1196
1197    if let Some(params) = &upsert.params {
1198        for (param_name_meta, param_type) in &params.params {
1199            let param_name = param_name_meta.value.as_str();
1200            let rust_ty = param_type_to_rust(param_type);
1201            func.arg(param_name, format!("&{}", rust_ty));
1202        }
1203    }
1204
1205    func.ret(&return_ty);
1206    func.bound("C", "tokio_postgres::GenericClient");
1207
1208    let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1209    func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1210
1211    scope.push_fn(func);
1212    Ok(())
1213}
1214
1215fn generate_insert_many_code(
1216    ctx: &CodegenContext,
1217    name_meta: &Meta<String>,
1218    insert: &InsertMany,
1219    scope: &mut Scope,
1220) -> Result<(), QError> {
1221    let name = &name_meta.value;
1222    let fn_name = to_snake_case(name);
1223    let generated = crate::sqlgen::generate_insert_many_sql(insert);
1224
1225    // Generate params struct
1226    let params_struct_name = format!("{}Params", name);
1227    if let Some(params) = &insert.params {
1228        generate_bulk_params_struct(
1229            ctx,
1230            &params_struct_name,
1231            insert.into.value.as_str(),
1232            params,
1233            scope,
1234        );
1235    }
1236
1237    // Generate result struct if RETURNING is used
1238    let has_returning = insert.returning.is_some();
1239    let return_ty = if !has_returning {
1240        "Result<u64, QueryError>".to_string()
1241    } else {
1242        let struct_name = format!("{}Result", name);
1243        if let Some(returning) = &insert.returning {
1244            generate_mutation_result_struct(
1245                ctx,
1246                &struct_name,
1247                insert.into.value.as_str(),
1248                insert.into.span,
1249                returning,
1250                scope,
1251            )?;
1252        }
1253        format!("Result<Vec<{}>, QueryError>", struct_name)
1254    };
1255
1256    let mut func = Function::new(&fn_name);
1257    if let Some(doc) = &name_meta.doc {
1258        let doc_str = doc.join("\n");
1259        func.doc(&doc_str);
1260    }
1261    func.vis("pub");
1262    func.set_async(true);
1263    // Generated query fns take one arg per bound param; wide tables legitimately
1264    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
1265    func.attr("allow(clippy::too_many_arguments)");
1266    func.generic("C");
1267    func.arg("client", "&C");
1268    func.arg("items", format!("&[{}]", params_struct_name));
1269
1270    func.ret(&return_ty);
1271    func.bound("C", "tokio_postgres::GenericClient");
1272
1273    let body = generate_bulk_mutation_body(&generated.sql, insert.params.as_ref(), !has_returning);
1274    func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1275
1276    scope.push_fn(func);
1277    Ok(())
1278}
1279
1280fn generate_upsert_many_code(
1281    ctx: &CodegenContext,
1282    name_meta: &Meta<String>,
1283    upsert: &UpsertMany,
1284    scope: &mut Scope,
1285) -> Result<(), QError> {
1286    let name = &name_meta.value;
1287    let fn_name = to_snake_case(name);
1288    let generated = crate::sqlgen::generate_upsert_many_sql(upsert);
1289
1290    // Generate params struct
1291    let params_struct_name = format!("{}Params", name);
1292    if let Some(params) = &upsert.params {
1293        generate_bulk_params_struct(
1294            ctx,
1295            &params_struct_name,
1296            upsert.into.value.as_str(),
1297            params,
1298            scope,
1299        );
1300    }
1301
1302    // Generate result struct if RETURNING is used
1303    let has_returning = upsert.returning.is_some();
1304    let return_ty = if !has_returning {
1305        "Result<u64, QueryError>".to_string()
1306    } else {
1307        let struct_name = format!("{}Result", name);
1308        if let Some(returning) = &upsert.returning {
1309            generate_mutation_result_struct(
1310                ctx,
1311                &struct_name,
1312                upsert.into.value.as_str(),
1313                upsert.into.span,
1314                returning,
1315                scope,
1316            )?;
1317        }
1318        format!("Result<Vec<{}>, QueryError>", struct_name)
1319    };
1320
1321    let mut func = Function::new(&fn_name);
1322    if let Some(doc) = &name_meta.doc {
1323        let doc_str = doc.join("\n");
1324        func.doc(&doc_str);
1325    }
1326    func.vis("pub");
1327    func.set_async(true);
1328    // Generated query fns take one arg per bound param; wide tables legitimately
1329    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
1330    func.attr("allow(clippy::too_many_arguments)");
1331    func.generic("C");
1332    func.arg("client", "&C");
1333    func.arg("items", format!("&[{}]", params_struct_name));
1334
1335    func.ret(&return_ty);
1336    func.bound("C", "tokio_postgres::GenericClient");
1337
1338    let body = generate_bulk_mutation_body(&generated.sql, upsert.params.as_ref(), !has_returning);
1339    func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1340
1341    scope.push_fn(func);
1342    Ok(())
1343}
1344
1345/// Generate a params struct for bulk operations.
1346fn generate_bulk_params_struct(
1347    ctx: &CodegenContext,
1348    struct_name: &str,
1349    table: &str,
1350    params: &Params,
1351    scope: &mut Scope,
1352) {
1353    let mut st = Struct::new(struct_name);
1354    st.vis("pub");
1355    st.derive("Debug");
1356    st.derive("Clone");
1357
1358    for (param_name_meta, param_type) in &params.params {
1359        let param_name = param_name_meta.value.as_str();
1360        let rust_ty = ctx
1361            .column_type(table, param_name)
1362            .unwrap_or_else(|| param_type_to_rust(param_type));
1363        st.field(format!("pub {}", param_name), &rust_ty);
1364    }
1365
1366    scope.push_struct(st);
1367}
1368
1369/// Generate body for bulk mutation (INSERT MANY / UPSERT MANY).
1370fn generate_bulk_mutation_body(sql: &str, params: Option<&Params>, execute_only: bool) -> Block {
1371    let mut block = Block::new("");
1372
1373    // SQL constant
1374    block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
1375    block.line("");
1376
1377    // Convert slice of structs to parallel arrays
1378    if let Some(params) = params {
1379        block.line("// Convert items to parallel arrays for UNNEST");
1380        for (param_name_meta, param_type) in &params.params {
1381            let param_name = param_name_meta.value.as_str();
1382            let rust_ty = param_type_to_rust(param_type);
1383            block.line(format!(
1384                "let {}_arr: Vec<{}> = items.iter().map(|i| i.{}.clone()).collect();",
1385                param_name, rust_ty, param_name
1386            ));
1387        }
1388        block.line("");
1389
1390        // Build the params reference array
1391        let param_refs: Vec<String> = params
1392            .params
1393            .keys()
1394            .map(|p| format!("&{}_arr", p.value))
1395            .collect();
1396
1397        if execute_only {
1398            // No RETURNING - use execute
1399            block.line(format!(
1400                "let affected = client.execute(SQL, &[{}]).await?;",
1401                param_refs.join(", ")
1402            ));
1403            block.line("Ok(affected)");
1404        } else {
1405            // Has RETURNING - use query
1406            block.line(format!(
1407                "let rows = client.query(SQL, &[{}]).await?;",
1408                param_refs.join(", ")
1409            ));
1410            block.line("rows.iter().map(|row| Ok(from_row(row)?)).collect()");
1411        }
1412    }
1413
1414    block
1415}
1416
1417fn generate_update_code(
1418    ctx: &CodegenContext,
1419    name_meta: &Meta<String>,
1420    update: &Update,
1421    scope: &mut Scope,
1422) -> Result<(), QError> {
1423    let name = &name_meta.value;
1424    let fn_name = to_snake_case(name);
1425    let sqlgen_ctx = ctx.sqlgen_ctx();
1426    let generated = crate::sqlgen::generate_update_sql(&sqlgen_ctx, update)?;
1427
1428    let has_returning = update.returning.is_some();
1429    let return_ty = if !has_returning {
1430        "Result<u64, QueryError>".to_string()
1431    } else {
1432        let struct_name = format!("{}Result", name);
1433        if let Some(returning) = &update.returning {
1434            generate_mutation_result_struct(
1435                ctx,
1436                &struct_name,
1437                update.table.value.as_str(),
1438                update.table.span,
1439                returning,
1440                scope,
1441            )?;
1442        }
1443        format!("Result<Option<{}>, QueryError>", struct_name)
1444    };
1445
1446    let mut func = Function::new(&fn_name);
1447    if let Some(doc) = &name_meta.doc {
1448        let doc_str = doc.join("\n");
1449        func.doc(&doc_str);
1450    }
1451    func.vis("pub");
1452    func.set_async(true);
1453    // Generated query fns take one arg per bound param; wide tables legitimately
1454    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
1455    func.attr("allow(clippy::too_many_arguments)");
1456    func.generic("C");
1457    func.arg("client", "&C");
1458
1459    if let Some(params) = &update.params {
1460        for (param_name_meta, param_type) in &params.params {
1461            let param_name = param_name_meta.value.as_str();
1462            let rust_ty = param_type_to_rust(param_type);
1463            func.arg(param_name, format!("&{}", rust_ty));
1464        }
1465    }
1466
1467    func.ret(&return_ty);
1468    func.bound("C", "tokio_postgres::GenericClient");
1469
1470    let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1471    func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1472
1473    scope.push_fn(func);
1474    Ok(())
1475}
1476
1477fn generate_delete_code(
1478    ctx: &CodegenContext,
1479    name_meta: &Meta<String>,
1480    delete: &Delete,
1481    scope: &mut Scope,
1482) -> Result<(), QError> {
1483    let name = &name_meta.value;
1484    let fn_name = to_snake_case(name);
1485    let sqlgen_ctx = ctx.sqlgen_ctx();
1486    let generated = crate::sqlgen::generate_delete_sql(&sqlgen_ctx, delete)?;
1487
1488    let has_returning = delete.returning.is_some();
1489    let return_ty = if !has_returning {
1490        "Result<u64, QueryError>".to_string()
1491    } else {
1492        let struct_name = format!("{}Result", name);
1493        if let Some(returning) = &delete.returning {
1494            generate_mutation_result_struct(
1495                ctx,
1496                &struct_name,
1497                delete.from.value.as_str(),
1498                delete.from.span,
1499                returning,
1500                scope,
1501            )?;
1502        }
1503        format!("Result<Option<{}>, QueryError>", struct_name)
1504    };
1505
1506    let mut func = Function::new(&fn_name);
1507    if let Some(doc) = &name_meta.doc {
1508        let doc_str = doc.join("\n");
1509        func.doc(&doc_str);
1510    }
1511    func.vis("pub");
1512    func.set_async(true);
1513    // Generated query fns take one arg per bound param; wide tables legitimately
1514    // exceed clippy's threshold. Harmless (no warning) on narrow queries.
1515    func.attr("allow(clippy::too_many_arguments)");
1516    func.generic("C");
1517    func.arg("client", "&C");
1518
1519    if let Some(params) = &delete.params {
1520        for (param_name_meta, param_type) in &params.params {
1521            let param_name = param_name_meta.value.as_str();
1522            let rust_ty = param_type_to_rust(param_type);
1523            func.arg(param_name, format!("&{}", rust_ty));
1524        }
1525    }
1526
1527    func.ret(&return_ty);
1528    func.bound("C", "tokio_postgres::GenericClient");
1529
1530    let body = generate_mutation_body(&generated.sql, &generated.params, !has_returning);
1531    func.line(wrap_with_trace_err(&block_to_string(&body), &fn_name));
1532
1533    scope.push_fn(func);
1534    Ok(())
1535}
1536
1537fn generate_mutation_result_struct(
1538    ctx: &CodegenContext,
1539    struct_name: &str,
1540    table: &str,
1541    table_span: Span,
1542    returning: &Returning,
1543    scope: &mut Scope,
1544) -> Result<(), QError> {
1545    // Resolve the table once so a missing/empty schema is reported against the
1546    // mutation's table clause rather than the first RETURNING column.
1547    ctx.require_table(table, table_span)?;
1548
1549    let mut st = Struct::new(struct_name);
1550    st.vis("pub");
1551    st.derive("Debug");
1552    st.derive("Clone");
1553    st.derive("Facet");
1554    st.attr("facet(crate = dibs_runtime::facet)");
1555
1556    for (col_name_meta, _) in &returning.columns {
1557        let col_name = col_name_meta.value.as_str();
1558        let rust_ty = ctx.column_type_at(table, col_name, col_name_meta.span)?;
1559        st.field(format!("pub {col_name}"), &rust_ty);
1560    }
1561
1562    scope.push_struct(st);
1563    Ok(())
1564}
1565
1566fn generate_mutation_body(
1567    sql: &str,
1568    param_order: &[dibs_sql::ParamName],
1569    execute_only: bool,
1570) -> Block {
1571    let mut block = Block::new("");
1572
1573    // SQL constant
1574    block.line(format!("const SQL: &str = r#\"{}\"#;", sql));
1575    block.line("");
1576
1577    let params: Vec<_> = param_order
1578        .iter()
1579        .filter(|p| !p.as_str().starts_with("__literal_"))
1580        .collect();
1581
1582    if execute_only {
1583        // No RETURNING - use execute
1584        if params.is_empty() {
1585            block.line("let affected = client.execute(SQL, &[]).await?;");
1586        } else {
1587            let params_str = params
1588                .iter()
1589                .map(|p| p.as_str())
1590                .collect::<Vec<_>>()
1591                .join(", ");
1592            block.line(format!(
1593                "let affected = client.execute(SQL, &[{}]).await?;",
1594                params_str
1595            ));
1596        }
1597        block.line("Ok(affected)");
1598    } else {
1599        // Has RETURNING - use query
1600        if params.is_empty() {
1601            block.line("let rows = client.query(SQL, &[]).await?;");
1602        } else {
1603            let params_str = params
1604                .iter()
1605                .map(|p| p.as_str())
1606                .collect::<Vec<_>>()
1607                .join(", ");
1608            block.line(format!(
1609                "let rows = client.query(SQL, &[{}]).await?;",
1610                params_str
1611            ));
1612        }
1613        let mut match_block = Block::new("match rows.into_iter().next()");
1614        match_block.line("Some(row) => Ok(Some(from_row(&row)?)),");
1615        match_block.line("None => Ok(None),");
1616        block.push_block(match_block);
1617    }
1618
1619    block
1620}
1621
1622#[cfg(test)]
1623mod tests;