Skip to main content

nautilus_codegen/
generator.rs

1//! Code generator for Nautilus models, delegates, and builders.
2
3use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ast::StorageStrategy;
5use nautilus_schema::ir::{FieldIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr};
6use serde::Serialize;
7use std::collections::{HashMap, HashSet};
8use tera::{Context, Tera};
9
10use crate::type_helpers::{
11    field_to_rust_avg_type, field_to_rust_base_type, field_to_rust_sum_type, field_to_rust_type,
12};
13
14pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
15    let mut tera = Tera::default();
16    tera.add_raw_templates(vec![
17        (
18            "columns_struct.tera",
19            include_str!("../templates/rust/columns_struct.tera"),
20        ),
21        (
22            "column_impl.tera",
23            include_str!("../templates/rust/column_impl.tera"),
24        ),
25        ("create.tera", include_str!("../templates/rust/create.tera")),
26        (
27            "create_many.tera",
28            include_str!("../templates/rust/create_many.tera"),
29        ),
30        (
31            "delegate.tera",
32            include_str!("../templates/rust/delegate.tera"),
33        ),
34        ("delete.tera", include_str!("../templates/rust/delete.tera")),
35        ("enum.tera", include_str!("../templates/rust/enum.tera")),
36        (
37            "find_many.tera",
38            include_str!("../templates/rust/find_many.tera"),
39        ),
40        (
41            "from_row_impl.tera",
42            include_str!("../templates/rust/from_row_impl.tera"),
43        ),
44        (
45            "model_file.tera",
46            include_str!("../templates/rust/model_file.tera"),
47        ),
48        ("lib_rs.tera", include_str!("../templates/rust/lib_rs.tera")),
49        (
50            "model_struct.tera",
51            include_str!("../templates/rust/model_struct.tera"),
52        ),
53        ("update.tera", include_str!("../templates/rust/update.tera")),
54        (
55            "composite_type.tera",
56            include_str!("../templates/rust/composite_type.tera"),
57        ),
58    ])
59    .expect("embedded Rust templates must parse");
60    tera
61});
62
63fn render(template: &str, ctx: &Context) -> String {
64    TEMPLATES
65        .render(template, ctx)
66        .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
67}
68
69/// Template context for a single model field in the Rust codegen backend.
70///
71/// This struct is intentionally separate from [`PythonFieldContext`] in
72/// `python/generator.rs`: the two backends expose different template
73/// variables (Rust needs `rust_type` / `column_type`; Python needs
74/// `python_type` / `base_type` / `is_enum` / `has_default` / `default`) and
75/// are expected to evolve independently.
76#[derive(Debug, Clone, Serialize)]
77struct FieldContext {
78    name: String,
79    logical_name: String,
80    db_name: String,
81    rust_type: String,
82    base_rust_type: String,
83    column_type: String,
84    read_hint_expr: String,
85    variant_name: String,
86    is_array: bool,
87    index: usize,
88    is_pk: bool,
89    /// `true` when the field maps to an `Option<T>` Rust type
90    /// (i.e. the schema field is not required and is not a relation).
91    is_optional: bool,
92    /// `true` when the field has `@updatedAt` — auto-defaults to `now()` if not provided.
93    is_updated_at: bool,
94    /// `true` when the field is a `@computed` generated column (read-only from client side).
95    is_computed: bool,
96}
97
98#[derive(Debug, Clone, Serialize)]
99struct AggregateFieldContext {
100    name: String,
101    logical_name: String,
102    rust_type: String,
103    avg_rust_type: String,
104    sum_rust_type: String,
105    variant_name: String,
106}
107
108/// Serialisable (logical_name, db_name) pair for primary-key fields.
109/// Used in templates to generate cursor predicate slices.
110#[derive(Debug, Clone, Serialize)]
111struct PkFieldContext {
112    /// Snake-case logical name — used as the cursor map key in generated code.
113    name: String,
114    /// Database column name — used to build the `table__db_col` column reference.
115    db_name: String,
116}
117
118#[derive(Debug, Clone, Serialize)]
119struct RelationContext {
120    field_name: String,
121    target_model: String,
122    target_table: String,
123    is_array: bool,
124    fields: Vec<String>,
125    references: Vec<String>,
126    fields_db: Vec<String>,
127    references_db: Vec<String>,
128    target_scalar_fields: Vec<FieldContext>,
129}
130
131fn resolve_inverse_relation_fields(
132    source_model_name: &str,
133    relation_name: Option<&str>,
134    target_model: &ModelIr,
135) -> (Vec<String>, Vec<String>) {
136    let inverse = target_model.relation_fields().find(|field| {
137        if let ResolvedFieldType::Relation(inv_rel) = &field.field_type {
138            if inv_rel.target_model != source_model_name {
139                return false;
140            }
141
142            match (relation_name, inv_rel.name.as_deref()) {
143                (Some(expected), Some(actual)) => actual == expected,
144                (Some(_), None) => false,
145                (None, Some(_)) => false,
146                (None, None) => true,
147            }
148        } else {
149            false
150        }
151    });
152
153    let Some(inverse_field) = inverse else {
154        return (vec![], vec![]);
155    };
156    let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type else {
157        return (vec![], vec![]);
158    };
159
160    (inv_rel.references.clone(), inv_rel.fields.clone())
161}
162
163fn field_read_hint_expr(field: &FieldIr) -> String {
164    if field.is_array && field.storage_strategy == Some(StorageStrategy::Json) {
165        return "Some(crate::ValueHint::Json)".to_string();
166    }
167
168    match &field.field_type {
169        ResolvedFieldType::Scalar(ScalarType::Decimal { .. }) => {
170            "Some(crate::ValueHint::Decimal)".to_string()
171        }
172        ResolvedFieldType::Scalar(ScalarType::DateTime) => {
173            "Some(crate::ValueHint::DateTime)".to_string()
174        }
175        ResolvedFieldType::Scalar(ScalarType::Json | ScalarType::Jsonb) => {
176            "Some(crate::ValueHint::Json)".to_string()
177        }
178        ResolvedFieldType::Scalar(ScalarType::Uuid) => "Some(crate::ValueHint::Uuid)".to_string(),
179        ResolvedFieldType::CompositeType { .. }
180            if field.storage_strategy == Some(StorageStrategy::Json) =>
181        {
182            "Some(crate::ValueHint::Json)".to_string()
183        }
184        _ => "None".to_string(),
185    }
186}
187
188/// Generate complete code for a model (struct, impls, delegate, builders).
189///
190/// `is_async` determines whether the generated delegate methods and internal
191/// builders use `async fn`/`.await` (`true`) or blocking sync wrappers (`false`).
192pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
193    let mut context = Context::new();
194
195    context.insert("model_name", &model.logical_name);
196    context.insert("table_name", &model.db_name);
197    context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
198    context.insert("columns_name", &format!("{}Columns", model.logical_name));
199    context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
200    context.insert("create_name", &format!("{}Create", model.logical_name));
201    context.insert(
202        "create_many_name",
203        &format!("{}CreateMany", model.logical_name),
204    );
205    context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
206    context.insert("update_name", &format!("{}Update", model.logical_name));
207    context.insert("delete_name", &format!("{}Delete", model.logical_name));
208
209    let pk_field_names = model.primary_key.fields();
210    context.insert("primary_key_fields", &pk_field_names);
211
212    let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
213        .iter()
214        .filter_map(|logical| {
215            model
216                .scalar_fields()
217                .find(|f| f.logical_name.as_str() == *logical)
218                .map(|f| PkFieldContext {
219                    name: f.logical_name.to_snake_case(),
220                    db_name: f.db_name.clone(),
221                })
222        })
223        .collect();
224    context.insert("pk_fields_with_db", &pk_fields_with_db);
225
226    let mut enum_imports = HashSet::new();
227    let mut composite_type_imports = HashSet::new();
228
229    let mut scalar_fields: Vec<FieldContext> = Vec::new();
230    let mut create_fields: Vec<FieldContext> = Vec::new();
231    let mut updated_at_fields: Vec<FieldContext> = Vec::new();
232    let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
233    let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
234
235    for (idx, field) in model.scalar_fields().enumerate() {
236        match &field.field_type {
237            ResolvedFieldType::Enum { enum_name } => {
238                if ir.enums.contains_key(enum_name) {
239                    enum_imports.insert(enum_name.clone());
240                }
241            }
242            ResolvedFieldType::CompositeType { type_name } => {
243                if ir.composite_types.contains_key(type_name) {
244                    composite_type_imports.insert(type_name.clone());
245                }
246            }
247            _ => {}
248        }
249
250        let column_type = match &field.field_type {
251            ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
252            ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
253            _ => String::new(),
254        };
255        let is_pk = pk_field_names.contains(&field.logical_name.as_str());
256        let base_rust_type = field_to_rust_base_type(field);
257
258        let field_ctx = FieldContext {
259            name: field.logical_name.to_snake_case(),
260            logical_name: field.logical_name.clone(),
261            db_name: field.db_name.clone(),
262            rust_type: field_to_rust_type(field),
263            base_rust_type: base_rust_type.clone(),
264            column_type,
265            read_hint_expr: field_read_hint_expr(field),
266            variant_name: field.logical_name.to_pascal_case(),
267            is_array: field.is_array,
268            index: idx,
269            is_pk,
270            is_optional: !field.is_required && !field.is_array,
271            is_updated_at: field.is_updated_at,
272            is_computed: field.computed.is_some(),
273        };
274
275        create_fields.push(field_ctx.clone());
276
277        if field.is_updated_at {
278            updated_at_fields.push(field_ctx.clone());
279        }
280
281        scalar_fields.push(field_ctx);
282
283        let is_numeric = matches!(
284            &field.field_type,
285            ResolvedFieldType::Scalar(ScalarType::Int)
286                | ResolvedFieldType::Scalar(ScalarType::BigInt)
287                | ResolvedFieldType::Scalar(ScalarType::Float)
288                | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
289        );
290        if is_numeric {
291            numeric_fields.push(AggregateFieldContext {
292                name: field.logical_name.to_snake_case(),
293                logical_name: field.logical_name.clone(),
294                rust_type: base_rust_type.clone(),
295                avg_rust_type: field_to_rust_avg_type(field),
296                sum_rust_type: field_to_rust_sum_type(field),
297                variant_name: field.logical_name.to_pascal_case(),
298            });
299        }
300
301        let is_non_orderable = matches!(
302            &field.field_type,
303            ResolvedFieldType::Scalar(ScalarType::Boolean)
304                | ResolvedFieldType::Scalar(ScalarType::Json)
305                | ResolvedFieldType::Scalar(ScalarType::Bytes)
306        );
307        if !is_non_orderable {
308            orderable_fields.push(AggregateFieldContext {
309                name: field.logical_name.to_snake_case(),
310                logical_name: field.logical_name.clone(),
311                rust_type: base_rust_type,
312                avg_rust_type: String::new(),
313                sum_rust_type: String::new(),
314                variant_name: field.logical_name.to_pascal_case(),
315            });
316        }
317    }
318
319    let mut relation_imports = HashSet::new();
320    for field in model.relation_fields() {
321        if let ResolvedFieldType::Relation(rel) = &field.field_type {
322            relation_imports.insert(rel.target_model.clone());
323        }
324    }
325
326    context.insert("has_enums", &!enum_imports.is_empty());
327    context.insert(
328        "enum_imports",
329        &enum_imports.into_iter().collect::<Vec<_>>(),
330    );
331    context.insert("has_relations", &!relation_imports.is_empty());
332    context.insert(
333        "relation_imports",
334        &relation_imports.into_iter().collect::<Vec<_>>(),
335    );
336    context.insert("has_composite_types", &!composite_type_imports.is_empty());
337    context.insert(
338        "composite_type_imports",
339        &composite_type_imports.into_iter().collect::<Vec<_>>(),
340    );
341
342    let relation_fields: Vec<FieldContext> = model
343        .relation_fields()
344        .map(|field| FieldContext {
345            name: field.logical_name.to_snake_case(),
346            logical_name: field.logical_name.clone(),
347            db_name: field.db_name.clone(),
348            rust_type: field_to_rust_type(field),
349            base_rust_type: field_to_rust_base_type(field),
350            column_type: String::new(),
351            read_hint_expr: "None".to_string(),
352            variant_name: field.logical_name.to_pascal_case(),
353            is_array: field.is_array,
354            index: 0,
355            is_pk: false,
356            is_optional: true,
357            is_updated_at: false,
358            is_computed: false,
359        })
360        .collect();
361
362    let relations: Vec<RelationContext> = model
363        .relation_fields()
364        .filter_map(|field| {
365            let ResolvedFieldType::Relation(rel) = &field.field_type else {
366                return None;
367            };
368            let target_model = ir.models.get(&rel.target_model)?;
369
370            let target_pk_names = target_model.primary_key.fields();
371            let target_scalar_fields: Vec<FieldContext> = target_model
372                .scalar_fields()
373                .enumerate()
374                .map(|(idx, f)| {
375                    let column_type = match &f.field_type {
376                        ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
377                        ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
378                        _ => String::new(),
379                    };
380                    let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
381                    FieldContext {
382                        name: f.logical_name.to_snake_case(),
383                        logical_name: f.logical_name.clone(),
384                        db_name: f.db_name.clone(),
385                        rust_type: field_to_rust_type(f),
386                        base_rust_type: field_to_rust_base_type(f),
387                        column_type,
388                        read_hint_expr: field_read_hint_expr(f),
389                        variant_name: f.logical_name.to_pascal_case(),
390                        is_array: f.is_array,
391                        index: idx,
392                        is_pk: f_is_pk,
393                        is_optional: !f.is_required && !f.is_array,
394                        is_updated_at: f.is_updated_at,
395                        is_computed: f.computed.is_some(),
396                    }
397                })
398                .collect();
399
400            let (fields, references) = if rel.fields.is_empty() {
401                resolve_inverse_relation_fields(
402                    &model.logical_name,
403                    rel.name.as_deref(),
404                    target_model,
405                )
406            } else {
407                (rel.fields.clone(), rel.references.clone())
408            };
409
410            let fields_db: Vec<String> = fields
411                .iter()
412                .filter_map(|logical_name| {
413                    model
414                        .fields
415                        .iter()
416                        .find(|f| &f.logical_name == logical_name)
417                        .map(|f| f.db_name.clone())
418                })
419                .collect();
420
421            let references_db: Vec<String> = references
422                .iter()
423                .filter_map(|logical_name| {
424                    target_model
425                        .fields
426                        .iter()
427                        .find(|f| &f.logical_name == logical_name)
428                        .map(|f| f.db_name.clone())
429                })
430                .collect();
431
432            Some(RelationContext {
433                field_name: field.logical_name.to_snake_case(),
434                target_model: rel.target_model.clone(),
435                target_table: target_model.db_name.clone(),
436                is_array: field.is_array,
437                fields,
438                references,
439                fields_db,
440                references_db,
441                target_scalar_fields,
442            })
443        })
444        .collect();
445
446    context.insert("scalar_fields", &scalar_fields);
447    context.insert("relation_fields", &relation_fields);
448    context.insert("relations", &relations);
449    context.insert("create_fields", &create_fields);
450    context.insert("updated_at_fields", &updated_at_fields);
451    context.insert("all_scalar_fields", &scalar_fields);
452    context.insert("numeric_fields", &numeric_fields);
453    context.insert("orderable_fields", &orderable_fields);
454    context.insert("has_numeric_fields", &!numeric_fields.is_empty());
455    context.insert("has_orderable_fields", &!orderable_fields.is_empty());
456    context.insert("is_async", &is_async);
457
458    render("model_file.tera", &context)
459}
460
461/// Generate all models from a schema IR.
462///
463/// `is_async` is forwarded to every [`generate_model`] call.
464pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
465    let mut generated = HashMap::new();
466
467    for (model_name, model_ir) in &ir.models {
468        let code = generate_model(model_ir, ir, is_async);
469        generated.insert(model_name.clone(), code);
470    }
471
472    generated
473}