Skip to main content

nautilus_codegen/
generator.rs

1//! Code generator for Nautilus models, delegates, and builders.
2
3use heck::ToSnakeCase;
4use nautilus_schema::ir::{ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::type_helpers::{field_to_rust_type, is_auto_generated};
10
11pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
12    let mut tera = Tera::default();
13    tera.add_raw_templates(vec![
14        (
15            "columns_struct.tera",
16            include_str!("../templates/rust/columns_struct.tera"),
17        ),
18        (
19            "column_impl.tera",
20            include_str!("../templates/rust/column_impl.tera"),
21        ),
22        ("create.tera", include_str!("../templates/rust/create.tera")),
23        (
24            "create_many.tera",
25            include_str!("../templates/rust/create_many.tera"),
26        ),
27        (
28            "delegate.tera",
29            include_str!("../templates/rust/delegate.tera"),
30        ),
31        ("delete.tera", include_str!("../templates/rust/delete.tera")),
32        ("enum.tera", include_str!("../templates/rust/enum.tera")),
33        (
34            "find_many.tera",
35            include_str!("../templates/rust/find_many.tera"),
36        ),
37        (
38            "from_row_impl.tera",
39            include_str!("../templates/rust/from_row_impl.tera"),
40        ),
41        (
42            "model_file.tera",
43            include_str!("../templates/rust/model_file.tera"),
44        ),
45        (
46            "model_struct.tera",
47            include_str!("../templates/rust/model_struct.tera"),
48        ),
49        ("update.tera", include_str!("../templates/rust/update.tera")),
50        (
51            "composite_type.tera",
52            include_str!("../templates/rust/composite_type.tera"),
53        ),
54    ])
55    .expect("embedded Rust templates must parse");
56    tera
57});
58
59fn render(template: &str, ctx: &Context) -> String {
60    TEMPLATES
61        .render(template, ctx)
62        .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
63}
64
65/// Template context for a single model field in the Rust codegen backend.
66///
67/// This struct is intentionally separate from [`PythonFieldContext`] in
68/// `python/generator.rs`: the two backends expose different template
69/// variables (Rust needs `rust_type` / `column_type`; Python needs
70/// `python_type` / `base_type` / `is_enum` / `has_default` / `default`) and
71/// are expected to evolve independently.
72#[derive(Debug, Clone, Serialize)]
73struct FieldContext {
74    name: String,
75    db_name: String,
76    rust_type: String,
77    column_type: String,
78    is_array: bool,
79    index: usize,
80    is_pk: bool,
81    /// `true` when the field maps to an `Option<T>` Rust type
82    /// (i.e. the schema field is not required and is not a relation).
83    is_optional: bool,
84    /// `true` when the field has `@updatedAt` — auto-defaults to `now()` if not provided.
85    is_updated_at: bool,
86    /// `true` when the field is a `@computed` generated column (read-only from client side).
87    is_computed: bool,
88}
89
90/// Serialisable (logical_name, db_name) pair for primary-key fields.
91/// Used in templates to generate cursor predicate slices.
92#[derive(Debug, Clone, Serialize)]
93struct PkFieldContext {
94    /// Snake-case logical name — used as the cursor map key in generated code.
95    name: String,
96    /// Database column name — used to build the `table__db_col` column reference.
97    db_name: String,
98}
99
100#[derive(Debug, Clone, Serialize)]
101struct RelationContext {
102    field_name: String,
103    target_model: String,
104    target_table: String,
105    is_array: bool,
106    fields: Vec<String>,
107    references: Vec<String>,
108    fields_db: Vec<String>,
109    references_db: Vec<String>,
110    target_scalar_fields: Vec<FieldContext>,
111}
112
113/// Generate complete code for a model (struct, impls, delegate, builders).
114///
115/// `is_async` determines whether the generated delegate methods and internal
116/// builders use `async fn`/`.await` (`true`) or blocking sync wrappers (`false`).
117pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
118    let mut context = Context::new();
119
120    context.insert("model_name", &model.logical_name);
121    context.insert("table_name", &model.db_name);
122    context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
123    context.insert("columns_name", &format!("{}Columns", model.logical_name));
124    context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
125    context.insert("create_name", &format!("{}Create", model.logical_name));
126    context.insert(
127        "create_many_name",
128        &format!("{}CreateMany", model.logical_name),
129    );
130    context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
131    context.insert("update_name", &format!("{}Update", model.logical_name));
132    context.insert("delete_name", &format!("{}Delete", model.logical_name));
133
134    let pk_field_names = model.primary_key.fields();
135    context.insert("primary_key_fields", &pk_field_names);
136
137    let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
138        .iter()
139        .filter_map(|logical| {
140            model
141                .scalar_fields()
142                .find(|f| f.logical_name.as_str() == *logical)
143                .map(|f| PkFieldContext {
144                    name: f.logical_name.to_snake_case(),
145                    db_name: f.db_name.clone(),
146                })
147        })
148        .collect();
149    context.insert("pk_fields_with_db", &pk_fields_with_db);
150
151    // --- Single-pass scalar field context building ---
152    let mut enum_imports = HashSet::new();
153    let mut composite_type_imports = HashSet::new();
154
155    let mut scalar_fields: Vec<FieldContext> = Vec::new();
156    let mut create_fields: Vec<FieldContext> = Vec::new();
157    let mut updated_at_fields: Vec<FieldContext> = Vec::new();
158
159    for (idx, field) in model.scalar_fields().enumerate() {
160        // --- Import tracking ---
161        match &field.field_type {
162            ResolvedFieldType::Enum { enum_name } => {
163                if ir.enums.contains_key(enum_name) {
164                    enum_imports.insert(enum_name.clone());
165                }
166            }
167            ResolvedFieldType::CompositeType { type_name } => {
168                if ir.composite_types.contains_key(type_name) {
169                    composite_type_imports.insert(type_name.clone());
170                }
171            }
172            _ => {}
173        }
174
175        // --- Computed values ---
176        let column_type = match &field.field_type {
177            ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
178            ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
179            _ => String::new(),
180        };
181        let is_pk = pk_field_names.contains(&field.logical_name.as_str());
182        let auto_generated = is_auto_generated(field);
183
184        let field_ctx = FieldContext {
185            name: field.logical_name.to_snake_case(),
186            db_name: field.db_name.clone(),
187            rust_type: field_to_rust_type(field),
188            column_type,
189            is_array: field.is_array,
190            index: idx,
191            is_pk,
192            is_optional: !field.is_required && !field.is_array,
193            is_updated_at: field.is_updated_at,
194            is_computed: field.computed.is_some(),
195        };
196
197        // Create fields: exclude auto-generated
198        if !auto_generated {
199            create_fields.push(field_ctx.clone());
200        }
201
202        // Updated-at fields
203        if field.is_updated_at {
204            updated_at_fields.push(field_ctx.clone());
205        }
206
207        scalar_fields.push(field_ctx);
208    }
209
210    // --- Relation imports ---
211    let mut relation_imports = HashSet::new();
212    for field in model.relation_fields() {
213        if let ResolvedFieldType::Relation(rel) = &field.field_type {
214            relation_imports.insert(rel.target_model.clone());
215        }
216    }
217
218    context.insert("has_enums", &!enum_imports.is_empty());
219    context.insert(
220        "enum_imports",
221        &enum_imports.into_iter().collect::<Vec<_>>(),
222    );
223    context.insert("has_relations", &!relation_imports.is_empty());
224    context.insert(
225        "relation_imports",
226        &relation_imports.into_iter().collect::<Vec<_>>(),
227    );
228    context.insert("has_composite_types", &!composite_type_imports.is_empty());
229    context.insert(
230        "composite_type_imports",
231        &composite_type_imports.into_iter().collect::<Vec<_>>(),
232    );
233
234    let relation_fields: Vec<FieldContext> = model
235        .relation_fields()
236        .map(|field| FieldContext {
237            name: field.logical_name.to_snake_case(),
238            db_name: field.db_name.clone(),
239            rust_type: field_to_rust_type(field),
240            column_type: String::new(),
241            is_array: field.is_array,
242            index: 0,
243            is_pk: false,
244            is_optional: true,
245            is_updated_at: false,
246            is_computed: false,
247        })
248        .collect();
249
250    let relations: Vec<RelationContext> = model
251        .relation_fields()
252        .filter_map(|field| {
253            if let ResolvedFieldType::Relation(rel) = &field.field_type {
254                if let Some(target_model) = ir.models.get(&rel.target_model) {
255                    let target_pk_names = target_model.primary_key.fields();
256                    let target_scalar_fields: Vec<FieldContext> = target_model
257                        .scalar_fields()
258                        .enumerate()
259                        .map(|(idx, f)| {
260                            let column_type = match &f.field_type {
261                                ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
262                                ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
263                                _ => String::new(),
264                            };
265                            let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
266                            FieldContext {
267                                name: f.logical_name.to_snake_case(),
268                                db_name: f.db_name.clone(),
269                                rust_type: field_to_rust_type(f),
270                                column_type,
271                                is_array: f.is_array,
272                                index: idx,
273                                is_pk: f_is_pk,
274                                is_optional: !f.is_required && !f.is_array,
275                                is_updated_at: f.is_updated_at,
276                                is_computed: f.computed.is_some(),
277                            }
278                        })
279                        .collect();
280
281                    let (fields, references) = if rel.fields.is_empty() {
282                        let inverse = target_model.relation_fields().find(|f| {
283                            if let ResolvedFieldType::Relation(inv_rel) = &f.field_type {
284                                inv_rel.target_model == model.logical_name
285                            } else {
286                                false
287                            }
288                        });
289
290                        if let Some(inverse_field) = inverse {
291                            if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type
292                            {
293                                // Swap fields and references for the many-side
294                                (inv_rel.references.clone(), inv_rel.fields.clone())
295                            } else {
296                                (vec![], vec![])
297                            }
298                        } else {
299                            (vec![], vec![])
300                        }
301                    } else {
302                        (rel.fields.clone(), rel.references.clone())
303                    };
304
305                    let fields_db: Vec<String> = fields
306                        .iter()
307                        .filter_map(|logical_name| {
308                            model
309                                .fields
310                                .iter()
311                                .find(|f| &f.logical_name == logical_name)
312                                .map(|f| f.db_name.clone())
313                        })
314                        .collect();
315
316                    let references_db: Vec<String> = references
317                        .iter()
318                        .filter_map(|logical_name| {
319                            target_model
320                                .fields
321                                .iter()
322                                .find(|f| &f.logical_name == logical_name)
323                                .map(|f| f.db_name.clone())
324                        })
325                        .collect();
326
327                    Some(RelationContext {
328                        field_name: field.logical_name.to_snake_case(),
329                        target_model: rel.target_model.clone(),
330                        target_table: target_model.db_name.clone(),
331                        is_array: field.is_array,
332                        fields,
333                        references,
334                        fields_db,
335                        references_db,
336                        target_scalar_fields,
337                    })
338                } else {
339                    None
340                }
341            } else {
342                None
343            }
344        })
345        .collect();
346
347    context.insert("scalar_fields", &scalar_fields);
348    context.insert("relation_fields", &relation_fields);
349    context.insert("relations", &relations);
350    context.insert("create_fields", &create_fields);
351    context.insert("updated_at_fields", &updated_at_fields);
352    context.insert("all_scalar_fields", &scalar_fields);
353    context.insert("is_async", &is_async);
354
355    render("model_file.tera", &context)
356}
357
358/// Generate all models from a schema IR.
359///
360/// `is_async` is forwarded to every [`generate_model`] call.
361pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
362    let mut generated = HashMap::new();
363
364    for (model_name, model_ir) in &ir.models {
365        let code = generate_model(model_ir, ir, is_async);
366        generated.insert(model_name.clone(), code);
367    }
368
369    generated
370}