Skip to main content

nautilus_codegen/
generator.rs

1//! Code generator for Nautilus models, delegates, and builders.
2
3use heck::ToSnakeCase;
4use nautilus_schema::ir::{ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::type_helpers::field_to_rust_type;
10
11pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
12    let mut tera = Tera::default();
13    tera.add_raw_templates(vec![
14        (
15            "columns_struct.tera",
16            include_str!("../templates/rust/columns_struct.tera"),
17        ),
18        (
19            "column_impl.tera",
20            include_str!("../templates/rust/column_impl.tera"),
21        ),
22        ("create.tera", include_str!("../templates/rust/create.tera")),
23        (
24            "create_many.tera",
25            include_str!("../templates/rust/create_many.tera"),
26        ),
27        (
28            "delegate.tera",
29            include_str!("../templates/rust/delegate.tera"),
30        ),
31        ("delete.tera", include_str!("../templates/rust/delete.tera")),
32        ("enum.tera", include_str!("../templates/rust/enum.tera")),
33        (
34            "find_many.tera",
35            include_str!("../templates/rust/find_many.tera"),
36        ),
37        (
38            "from_row_impl.tera",
39            include_str!("../templates/rust/from_row_impl.tera"),
40        ),
41        (
42            "model_file.tera",
43            include_str!("../templates/rust/model_file.tera"),
44        ),
45        (
46            "model_struct.tera",
47            include_str!("../templates/rust/model_struct.tera"),
48        ),
49        ("update.tera", include_str!("../templates/rust/update.tera")),
50        (
51            "composite_type.tera",
52            include_str!("../templates/rust/composite_type.tera"),
53        ),
54    ])
55    .expect("embedded Rust templates must parse");
56    tera
57});
58
59fn render(template: &str, ctx: &Context) -> String {
60    TEMPLATES
61        .render(template, ctx)
62        .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
63}
64
65/// Template context for a single model field in the Rust codegen backend.
66///
67/// This struct is intentionally separate from [`PythonFieldContext`] in
68/// `python/generator.rs`: the two backends expose different template
69/// variables (Rust needs `rust_type` / `column_type`; Python needs
70/// `python_type` / `base_type` / `is_enum` / `has_default` / `default`) and
71/// are expected to evolve independently.
72#[derive(Debug, Clone, Serialize)]
73struct FieldContext {
74    name: String,
75    db_name: String,
76    rust_type: String,
77    column_type: String,
78    is_array: bool,
79    index: usize,
80    is_pk: bool,
81    /// `true` when the field maps to an `Option<T>` Rust type
82    /// (i.e. the schema field is not required and is not a relation).
83    is_optional: bool,
84    /// `true` when the field has `@updatedAt` — auto-defaults to `now()` if not provided.
85    is_updated_at: bool,
86    /// `true` when the field is a `@computed` generated column (read-only from client side).
87    is_computed: bool,
88}
89
90/// Serialisable (logical_name, db_name) pair for primary-key fields.
91/// Used in templates to generate cursor predicate slices.
92#[derive(Debug, Clone, Serialize)]
93struct PkFieldContext {
94    /// Snake-case logical name — used as the cursor map key in generated code.
95    name: String,
96    /// Database column name — used to build the `table__db_col` column reference.
97    db_name: String,
98}
99
100#[derive(Debug, Clone, Serialize)]
101struct RelationContext {
102    field_name: String,
103    target_model: String,
104    target_table: String,
105    is_array: bool,
106    fields: Vec<String>,
107    references: Vec<String>,
108    fields_db: Vec<String>,
109    references_db: Vec<String>,
110    target_scalar_fields: Vec<FieldContext>,
111}
112
113/// Generate complete code for a model (struct, impls, delegate, builders).
114///
115/// `is_async` determines whether the generated delegate methods and internal
116/// builders use `async fn`/`.await` (`true`) or blocking sync wrappers (`false`).
117pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
118    let mut context = Context::new();
119
120    context.insert("model_name", &model.logical_name);
121    context.insert("table_name", &model.db_name);
122    context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
123    context.insert("columns_name", &format!("{}Columns", model.logical_name));
124    context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
125    context.insert("create_name", &format!("{}Create", model.logical_name));
126    context.insert(
127        "create_many_name",
128        &format!("{}CreateMany", model.logical_name),
129    );
130    context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
131    context.insert("update_name", &format!("{}Update", model.logical_name));
132    context.insert("delete_name", &format!("{}Delete", model.logical_name));
133
134    let pk_field_names = model.primary_key.fields();
135    context.insert("primary_key_fields", &pk_field_names);
136
137    let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
138        .iter()
139        .filter_map(|logical| {
140            model
141                .scalar_fields()
142                .find(|f| f.logical_name.as_str() == *logical)
143                .map(|f| PkFieldContext {
144                    name: f.logical_name.to_snake_case(),
145                    db_name: f.db_name.clone(),
146                })
147        })
148        .collect();
149    context.insert("pk_fields_with_db", &pk_fields_with_db);
150
151    // --- Single-pass scalar field context building ---
152    let mut enum_imports = HashSet::new();
153    let mut composite_type_imports = HashSet::new();
154
155    let mut scalar_fields: Vec<FieldContext> = Vec::new();
156    let mut create_fields: Vec<FieldContext> = Vec::new();
157    let mut updated_at_fields: Vec<FieldContext> = Vec::new();
158
159    for (idx, field) in model.scalar_fields().enumerate() {
160        // --- Import tracking ---
161        match &field.field_type {
162            ResolvedFieldType::Enum { enum_name } => {
163                if ir.enums.contains_key(enum_name) {
164                    enum_imports.insert(enum_name.clone());
165                }
166            }
167            ResolvedFieldType::CompositeType { type_name } => {
168                if ir.composite_types.contains_key(type_name) {
169                    composite_type_imports.insert(type_name.clone());
170                }
171            }
172            _ => {}
173        }
174
175        // --- Computed values ---
176        let column_type = match &field.field_type {
177            ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
178            ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
179            _ => String::new(),
180        };
181        let is_pk = pk_field_names.contains(&field.logical_name.as_str());
182
183        let field_ctx = FieldContext {
184            name: field.logical_name.to_snake_case(),
185            db_name: field.db_name.clone(),
186            rust_type: field_to_rust_type(field),
187            column_type,
188            is_array: field.is_array,
189            index: idx,
190            is_pk,
191            is_optional: !field.is_required && !field.is_array,
192            is_updated_at: field.is_updated_at,
193            is_computed: field.computed.is_some(),
194        };
195
196        create_fields.push(field_ctx.clone());
197
198        // Updated-at fields
199        if field.is_updated_at {
200            updated_at_fields.push(field_ctx.clone());
201        }
202
203        scalar_fields.push(field_ctx);
204    }
205
206    // --- Relation imports ---
207    let mut relation_imports = HashSet::new();
208    for field in model.relation_fields() {
209        if let ResolvedFieldType::Relation(rel) = &field.field_type {
210            relation_imports.insert(rel.target_model.clone());
211        }
212    }
213
214    context.insert("has_enums", &!enum_imports.is_empty());
215    context.insert(
216        "enum_imports",
217        &enum_imports.into_iter().collect::<Vec<_>>(),
218    );
219    context.insert("has_relations", &!relation_imports.is_empty());
220    context.insert(
221        "relation_imports",
222        &relation_imports.into_iter().collect::<Vec<_>>(),
223    );
224    context.insert("has_composite_types", &!composite_type_imports.is_empty());
225    context.insert(
226        "composite_type_imports",
227        &composite_type_imports.into_iter().collect::<Vec<_>>(),
228    );
229
230    let relation_fields: Vec<FieldContext> = model
231        .relation_fields()
232        .map(|field| FieldContext {
233            name: field.logical_name.to_snake_case(),
234            db_name: field.db_name.clone(),
235            rust_type: field_to_rust_type(field),
236            column_type: String::new(),
237            is_array: field.is_array,
238            index: 0,
239            is_pk: false,
240            is_optional: true,
241            is_updated_at: false,
242            is_computed: false,
243        })
244        .collect();
245
246    let relations: Vec<RelationContext> = model
247        .relation_fields()
248        .filter_map(|field| {
249            if let ResolvedFieldType::Relation(rel) = &field.field_type {
250                if let Some(target_model) = ir.models.get(&rel.target_model) {
251                    let target_pk_names = target_model.primary_key.fields();
252                    let target_scalar_fields: Vec<FieldContext> = target_model
253                        .scalar_fields()
254                        .enumerate()
255                        .map(|(idx, f)| {
256                            let column_type = match &f.field_type {
257                                ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
258                                ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
259                                _ => String::new(),
260                            };
261                            let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
262                            FieldContext {
263                                name: f.logical_name.to_snake_case(),
264                                db_name: f.db_name.clone(),
265                                rust_type: field_to_rust_type(f),
266                                column_type,
267                                is_array: f.is_array,
268                                index: idx,
269                                is_pk: f_is_pk,
270                                is_optional: !f.is_required && !f.is_array,
271                                is_updated_at: f.is_updated_at,
272                                is_computed: f.computed.is_some(),
273                            }
274                        })
275                        .collect();
276
277                    let (fields, references) = if rel.fields.is_empty() {
278                        let inverse = target_model.relation_fields().find(|f| {
279                            if let ResolvedFieldType::Relation(inv_rel) = &f.field_type {
280                                inv_rel.target_model == model.logical_name
281                            } else {
282                                false
283                            }
284                        });
285
286                        if let Some(inverse_field) = inverse {
287                            if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type
288                            {
289                                // Swap fields and references for the many-side
290                                (inv_rel.references.clone(), inv_rel.fields.clone())
291                            } else {
292                                (vec![], vec![])
293                            }
294                        } else {
295                            (vec![], vec![])
296                        }
297                    } else {
298                        (rel.fields.clone(), rel.references.clone())
299                    };
300
301                    let fields_db: Vec<String> = fields
302                        .iter()
303                        .filter_map(|logical_name| {
304                            model
305                                .fields
306                                .iter()
307                                .find(|f| &f.logical_name == logical_name)
308                                .map(|f| f.db_name.clone())
309                        })
310                        .collect();
311
312                    let references_db: Vec<String> = references
313                        .iter()
314                        .filter_map(|logical_name| {
315                            target_model
316                                .fields
317                                .iter()
318                                .find(|f| &f.logical_name == logical_name)
319                                .map(|f| f.db_name.clone())
320                        })
321                        .collect();
322
323                    Some(RelationContext {
324                        field_name: field.logical_name.to_snake_case(),
325                        target_model: rel.target_model.clone(),
326                        target_table: target_model.db_name.clone(),
327                        is_array: field.is_array,
328                        fields,
329                        references,
330                        fields_db,
331                        references_db,
332                        target_scalar_fields,
333                    })
334                } else {
335                    None
336                }
337            } else {
338                None
339            }
340        })
341        .collect();
342
343    context.insert("scalar_fields", &scalar_fields);
344    context.insert("relation_fields", &relation_fields);
345    context.insert("relations", &relations);
346    context.insert("create_fields", &create_fields);
347    context.insert("updated_at_fields", &updated_at_fields);
348    context.insert("all_scalar_fields", &scalar_fields);
349    context.insert("is_async", &is_async);
350
351    render("model_file.tera", &context)
352}
353
354/// Generate all models from a schema IR.
355///
356/// `is_async` is forwarded to every [`generate_model`] call.
357pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
358    let mut generated = HashMap::new();
359
360    for (model_name, model_ir) in &ir.models {
361        let code = generate_model(model_ir, ir, is_async);
362        generated.insert(model_name.clone(), code);
363    }
364
365    generated
366}