Skip to main content

nautilus_codegen/
generator.rs

1//! Code generator for Nautilus models, delegates, and builders.
2
3use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ast::StorageStrategy;
5use nautilus_schema::ir::{FieldIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr};
6use serde::Serialize;
7use std::collections::{HashMap, HashSet};
8use tera::{Context, Tera};
9
10use crate::type_helpers::{
11    field_to_rust_avg_type, field_to_rust_base_type, field_to_rust_sum_type, field_to_rust_type,
12};
13
14pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
15    let mut tera = Tera::default();
16    tera.add_raw_templates(vec![
17        (
18            "columns_struct.tera",
19            include_str!("../templates/rust/columns_struct.tera"),
20        ),
21        (
22            "column_impl.tera",
23            include_str!("../templates/rust/column_impl.tera"),
24        ),
25        ("create.tera", include_str!("../templates/rust/create.tera")),
26        (
27            "create_many.tera",
28            include_str!("../templates/rust/create_many.tera"),
29        ),
30        (
31            "delegate.tera",
32            include_str!("../templates/rust/delegate.tera"),
33        ),
34        ("delete.tera", include_str!("../templates/rust/delete.tera")),
35        ("enum.tera", include_str!("../templates/rust/enum.tera")),
36        (
37            "find_many.tera",
38            include_str!("../templates/rust/find_many.tera"),
39        ),
40        (
41            "from_row_impl.tera",
42            include_str!("../templates/rust/from_row_impl.tera"),
43        ),
44        (
45            "model_file.tera",
46            include_str!("../templates/rust/model_file.tera"),
47        ),
48        (
49            "model_struct.tera",
50            include_str!("../templates/rust/model_struct.tera"),
51        ),
52        ("update.tera", include_str!("../templates/rust/update.tera")),
53        (
54            "composite_type.tera",
55            include_str!("../templates/rust/composite_type.tera"),
56        ),
57    ])
58    .expect("embedded Rust templates must parse");
59    tera
60});
61
62fn render(template: &str, ctx: &Context) -> String {
63    TEMPLATES
64        .render(template, ctx)
65        .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
66}
67
68/// Template context for a single model field in the Rust codegen backend.
69///
70/// This struct is intentionally separate from [`PythonFieldContext`] in
71/// `python/generator.rs`: the two backends expose different template
72/// variables (Rust needs `rust_type` / `column_type`; Python needs
73/// `python_type` / `base_type` / `is_enum` / `has_default` / `default`) and
74/// are expected to evolve independently.
75#[derive(Debug, Clone, Serialize)]
76struct FieldContext {
77    name: String,
78    logical_name: String,
79    db_name: String,
80    rust_type: String,
81    base_rust_type: String,
82    column_type: String,
83    read_hint_expr: String,
84    variant_name: String,
85    is_array: bool,
86    index: usize,
87    is_pk: bool,
88    /// `true` when the field maps to an `Option<T>` Rust type
89    /// (i.e. the schema field is not required and is not a relation).
90    is_optional: bool,
91    /// `true` when the field has `@updatedAt` — auto-defaults to `now()` if not provided.
92    is_updated_at: bool,
93    /// `true` when the field is a `@computed` generated column (read-only from client side).
94    is_computed: bool,
95}
96
97#[derive(Debug, Clone, Serialize)]
98struct AggregateFieldContext {
99    name: String,
100    logical_name: String,
101    rust_type: String,
102    avg_rust_type: String,
103    sum_rust_type: String,
104    variant_name: String,
105}
106
107/// Serialisable (logical_name, db_name) pair for primary-key fields.
108/// Used in templates to generate cursor predicate slices.
109#[derive(Debug, Clone, Serialize)]
110struct PkFieldContext {
111    /// Snake-case logical name — used as the cursor map key in generated code.
112    name: String,
113    /// Database column name — used to build the `table__db_col` column reference.
114    db_name: String,
115}
116
117#[derive(Debug, Clone, Serialize)]
118struct RelationContext {
119    field_name: String,
120    target_model: String,
121    target_table: String,
122    is_array: bool,
123    fields: Vec<String>,
124    references: Vec<String>,
125    fields_db: Vec<String>,
126    references_db: Vec<String>,
127    target_scalar_fields: Vec<FieldContext>,
128}
129
130fn resolve_inverse_relation_fields(
131    source_model_name: &str,
132    relation_name: Option<&str>,
133    target_model: &ModelIr,
134) -> (Vec<String>, Vec<String>) {
135    let inverse = target_model.relation_fields().find(|field| {
136        if let ResolvedFieldType::Relation(inv_rel) = &field.field_type {
137            if inv_rel.target_model != source_model_name {
138                return false;
139            }
140
141            match (relation_name, inv_rel.name.as_deref()) {
142                (Some(expected), Some(actual)) => actual == expected,
143                (Some(_), None) => false,
144                (None, Some(_)) => false,
145                (None, None) => true,
146            }
147        } else {
148            false
149        }
150    });
151
152    if let Some(inverse_field) = inverse {
153        if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type {
154            return (inv_rel.references.clone(), inv_rel.fields.clone());
155        }
156    }
157
158    (vec![], vec![])
159}
160
161fn field_read_hint_expr(field: &FieldIr) -> String {
162    if field.is_array && field.storage_strategy == Some(StorageStrategy::Json) {
163        return "Some(crate::ValueHint::Json)".to_string();
164    }
165
166    match &field.field_type {
167        ResolvedFieldType::Scalar(ScalarType::Decimal { .. }) => {
168            "Some(crate::ValueHint::Decimal)".to_string()
169        }
170        ResolvedFieldType::Scalar(ScalarType::DateTime) => {
171            "Some(crate::ValueHint::DateTime)".to_string()
172        }
173        ResolvedFieldType::Scalar(ScalarType::Json | ScalarType::Jsonb) => {
174            "Some(crate::ValueHint::Json)".to_string()
175        }
176        ResolvedFieldType::Scalar(ScalarType::Uuid) => "Some(crate::ValueHint::Uuid)".to_string(),
177        ResolvedFieldType::CompositeType { .. }
178            if field.storage_strategy == Some(StorageStrategy::Json) =>
179        {
180            "Some(crate::ValueHint::Json)".to_string()
181        }
182        _ => "None".to_string(),
183    }
184}
185
186/// Generate complete code for a model (struct, impls, delegate, builders).
187///
188/// `is_async` determines whether the generated delegate methods and internal
189/// builders use `async fn`/`.await` (`true`) or blocking sync wrappers (`false`).
190pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
191    let mut context = Context::new();
192
193    context.insert("model_name", &model.logical_name);
194    context.insert("table_name", &model.db_name);
195    context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
196    context.insert("columns_name", &format!("{}Columns", model.logical_name));
197    context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
198    context.insert("create_name", &format!("{}Create", model.logical_name));
199    context.insert(
200        "create_many_name",
201        &format!("{}CreateMany", model.logical_name),
202    );
203    context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
204    context.insert("update_name", &format!("{}Update", model.logical_name));
205    context.insert("delete_name", &format!("{}Delete", model.logical_name));
206
207    let pk_field_names = model.primary_key.fields();
208    context.insert("primary_key_fields", &pk_field_names);
209
210    let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
211        .iter()
212        .filter_map(|logical| {
213            model
214                .scalar_fields()
215                .find(|f| f.logical_name.as_str() == *logical)
216                .map(|f| PkFieldContext {
217                    name: f.logical_name.to_snake_case(),
218                    db_name: f.db_name.clone(),
219                })
220        })
221        .collect();
222    context.insert("pk_fields_with_db", &pk_fields_with_db);
223
224    let mut enum_imports = HashSet::new();
225    let mut composite_type_imports = HashSet::new();
226
227    let mut scalar_fields: Vec<FieldContext> = Vec::new();
228    let mut create_fields: Vec<FieldContext> = Vec::new();
229    let mut updated_at_fields: Vec<FieldContext> = Vec::new();
230    let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
231    let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
232
233    for (idx, field) in model.scalar_fields().enumerate() {
234        match &field.field_type {
235            ResolvedFieldType::Enum { enum_name } => {
236                if ir.enums.contains_key(enum_name) {
237                    enum_imports.insert(enum_name.clone());
238                }
239            }
240            ResolvedFieldType::CompositeType { type_name } => {
241                if ir.composite_types.contains_key(type_name) {
242                    composite_type_imports.insert(type_name.clone());
243                }
244            }
245            _ => {}
246        }
247
248        let column_type = match &field.field_type {
249            ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
250            ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
251            _ => String::new(),
252        };
253        let is_pk = pk_field_names.contains(&field.logical_name.as_str());
254        let base_rust_type = field_to_rust_base_type(field);
255
256        let field_ctx = FieldContext {
257            name: field.logical_name.to_snake_case(),
258            logical_name: field.logical_name.clone(),
259            db_name: field.db_name.clone(),
260            rust_type: field_to_rust_type(field),
261            base_rust_type: base_rust_type.clone(),
262            column_type,
263            read_hint_expr: field_read_hint_expr(field),
264            variant_name: field.logical_name.to_pascal_case(),
265            is_array: field.is_array,
266            index: idx,
267            is_pk,
268            is_optional: !field.is_required && !field.is_array,
269            is_updated_at: field.is_updated_at,
270            is_computed: field.computed.is_some(),
271        };
272
273        create_fields.push(field_ctx.clone());
274
275        if field.is_updated_at {
276            updated_at_fields.push(field_ctx.clone());
277        }
278
279        scalar_fields.push(field_ctx);
280
281        let is_numeric = matches!(
282            &field.field_type,
283            ResolvedFieldType::Scalar(ScalarType::Int)
284                | ResolvedFieldType::Scalar(ScalarType::BigInt)
285                | ResolvedFieldType::Scalar(ScalarType::Float)
286                | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
287        );
288        if is_numeric {
289            numeric_fields.push(AggregateFieldContext {
290                name: field.logical_name.to_snake_case(),
291                logical_name: field.logical_name.clone(),
292                rust_type: base_rust_type.clone(),
293                avg_rust_type: field_to_rust_avg_type(field),
294                sum_rust_type: field_to_rust_sum_type(field),
295                variant_name: field.logical_name.to_pascal_case(),
296            });
297        }
298
299        let is_non_orderable = matches!(
300            &field.field_type,
301            ResolvedFieldType::Scalar(ScalarType::Boolean)
302                | ResolvedFieldType::Scalar(ScalarType::Json)
303                | ResolvedFieldType::Scalar(ScalarType::Bytes)
304        );
305        if !is_non_orderable {
306            orderable_fields.push(AggregateFieldContext {
307                name: field.logical_name.to_snake_case(),
308                logical_name: field.logical_name.clone(),
309                rust_type: base_rust_type,
310                avg_rust_type: String::new(),
311                sum_rust_type: String::new(),
312                variant_name: field.logical_name.to_pascal_case(),
313            });
314        }
315    }
316
317    let mut relation_imports = HashSet::new();
318    for field in model.relation_fields() {
319        if let ResolvedFieldType::Relation(rel) = &field.field_type {
320            relation_imports.insert(rel.target_model.clone());
321        }
322    }
323
324    context.insert("has_enums", &!enum_imports.is_empty());
325    context.insert(
326        "enum_imports",
327        &enum_imports.into_iter().collect::<Vec<_>>(),
328    );
329    context.insert("has_relations", &!relation_imports.is_empty());
330    context.insert(
331        "relation_imports",
332        &relation_imports.into_iter().collect::<Vec<_>>(),
333    );
334    context.insert("has_composite_types", &!composite_type_imports.is_empty());
335    context.insert(
336        "composite_type_imports",
337        &composite_type_imports.into_iter().collect::<Vec<_>>(),
338    );
339
340    let relation_fields: Vec<FieldContext> = model
341        .relation_fields()
342        .map(|field| FieldContext {
343            name: field.logical_name.to_snake_case(),
344            logical_name: field.logical_name.clone(),
345            db_name: field.db_name.clone(),
346            rust_type: field_to_rust_type(field),
347            base_rust_type: field_to_rust_base_type(field),
348            column_type: String::new(),
349            read_hint_expr: "None".to_string(),
350            variant_name: field.logical_name.to_pascal_case(),
351            is_array: field.is_array,
352            index: 0,
353            is_pk: false,
354            is_optional: true,
355            is_updated_at: false,
356            is_computed: false,
357        })
358        .collect();
359
360    let relations: Vec<RelationContext> = model
361        .relation_fields()
362        .filter_map(|field| {
363            if let ResolvedFieldType::Relation(rel) = &field.field_type {
364                if let Some(target_model) = ir.models.get(&rel.target_model) {
365                    let target_pk_names = target_model.primary_key.fields();
366                    let target_scalar_fields: Vec<FieldContext> = target_model
367                        .scalar_fields()
368                        .enumerate()
369                        .map(|(idx, f)| {
370                            let column_type = match &f.field_type {
371                                ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
372                                ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
373                                _ => String::new(),
374                            };
375                            let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
376                            FieldContext {
377                                name: f.logical_name.to_snake_case(),
378                                logical_name: f.logical_name.clone(),
379                                db_name: f.db_name.clone(),
380                                rust_type: field_to_rust_type(f),
381                                base_rust_type: field_to_rust_base_type(f),
382                                column_type,
383                                read_hint_expr: field_read_hint_expr(f),
384                                variant_name: f.logical_name.to_pascal_case(),
385                                is_array: f.is_array,
386                                index: idx,
387                                is_pk: f_is_pk,
388                                is_optional: !f.is_required && !f.is_array,
389                                is_updated_at: f.is_updated_at,
390                                is_computed: f.computed.is_some(),
391                            }
392                        })
393                        .collect();
394
395                    let (fields, references) = if rel.fields.is_empty() {
396                        resolve_inverse_relation_fields(
397                            &model.logical_name,
398                            rel.name.as_deref(),
399                            target_model,
400                        )
401                    } else {
402                        (rel.fields.clone(), rel.references.clone())
403                    };
404
405                    let fields_db: Vec<String> = fields
406                        .iter()
407                        .filter_map(|logical_name| {
408                            model
409                                .fields
410                                .iter()
411                                .find(|f| &f.logical_name == logical_name)
412                                .map(|f| f.db_name.clone())
413                        })
414                        .collect();
415
416                    let references_db: Vec<String> = references
417                        .iter()
418                        .filter_map(|logical_name| {
419                            target_model
420                                .fields
421                                .iter()
422                                .find(|f| &f.logical_name == logical_name)
423                                .map(|f| f.db_name.clone())
424                        })
425                        .collect();
426
427                    Some(RelationContext {
428                        field_name: field.logical_name.to_snake_case(),
429                        target_model: rel.target_model.clone(),
430                        target_table: target_model.db_name.clone(),
431                        is_array: field.is_array,
432                        fields,
433                        references,
434                        fields_db,
435                        references_db,
436                        target_scalar_fields,
437                    })
438                } else {
439                    None
440                }
441            } else {
442                None
443            }
444        })
445        .collect();
446
447    context.insert("scalar_fields", &scalar_fields);
448    context.insert("relation_fields", &relation_fields);
449    context.insert("relations", &relations);
450    context.insert("create_fields", &create_fields);
451    context.insert("updated_at_fields", &updated_at_fields);
452    context.insert("all_scalar_fields", &scalar_fields);
453    context.insert("numeric_fields", &numeric_fields);
454    context.insert("orderable_fields", &orderable_fields);
455    context.insert("has_numeric_fields", &!numeric_fields.is_empty());
456    context.insert("has_orderable_fields", &!orderable_fields.is_empty());
457    context.insert("is_async", &is_async);
458
459    render("model_file.tera", &context)
460}
461
462/// Generate all models from a schema IR.
463///
464/// `is_async` is forwarded to every [`generate_model`] call.
465pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
466    let mut generated = HashMap::new();
467
468    for (model_name, model_ir) in &ir.models {
469        let code = generate_model(model_ir, ir, is_async);
470        generated.insert(model_name.clone(), code);
471    }
472
473    generated
474}