Skip to main content

nautilus_codegen/js/
generator.rs

1//! JavaScript/TypeScript code generator for Nautilus models, delegates, and input types.
2
3use heck::{ToLowerCamelCase, ToSnakeCase};
4use nautilus_schema::ir::{
5    CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr,
6};
7use serde::Serialize;
8use std::collections::{HashMap, HashSet};
9use tera::{Context, Tera};
10
11use crate::js::type_mapper::{
12    field_to_ts_type, get_base_ts_type, get_filter_operators_for_field, get_ts_default_value,
13    is_auto_generated, scalar_to_ts_type,
14};
15
16/// JS/TS template registry — loaded once at first use.
17pub static JS_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
18    let mut tera = Tera::default();
19    tera.add_raw_templates(vec![
20        (
21            "model.js.tera",
22            include_str!("../../templates/js/model.js.tera"),
23        ),
24        (
25            "model.d.ts.tera",
26            include_str!("../../templates/js/model.d.ts.tera"),
27        ),
28        (
29            "enums.js.tera",
30            include_str!("../../templates/js/enums.js.tera"),
31        ),
32        (
33            "enums.d.ts.tera",
34            include_str!("../../templates/js/enums.d.ts.tera"),
35        ),
36        (
37            "client.js.tera",
38            include_str!("../../templates/js/client.js.tera"),
39        ),
40        (
41            "client.d.ts.tera",
42            include_str!("../../templates/js/client.d.ts.tera"),
43        ),
44        (
45            "models_index.js.tera",
46            include_str!("../../templates/js/models_index.js.tera"),
47        ),
48        (
49            "models_index.d.ts.tera",
50            include_str!("../../templates/js/models_index.d.ts.tera"),
51        ),
52        (
53            "composite_types.d.ts.tera",
54            include_str!("../../templates/js/composite_types.d.ts.tera"),
55        ),
56    ])
57    .expect("embedded JS templates must parse");
58    tera
59});
60
61fn render(template: &str, ctx: &Context) -> String {
62    JS_TEMPLATES
63        .render(template, ctx)
64        .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
65}
66
67#[derive(Debug, Clone, Serialize)]
68struct JsFieldContext {
69    /// Logical JS field name (camelCase, same as schema logical name).
70    name: String,
71    /// Logical name from the schema IR (may differ from `name` after `@map`).
72    logical_name: String,
73    /// Database column name.
74    db_name: String,
75    /// Full TypeScript type, e.g. `string | null`, `number[]`.
76    ts_type: String,
77    /// Inner base type without wrappers, e.g. `string`, `number`, `Date`.
78    base_type: String,
79    is_optional: bool,
80    is_array: bool,
81    is_enum: bool,
82    has_default: bool,
83    default: String,
84    is_pk: bool,
85    index: usize,
86}
87
88#[derive(Debug, Clone, Serialize)]
89struct JsFilterOperatorContext {
90    suffix: String,
91    ts_type: String,
92}
93
94#[derive(Debug, Clone, Serialize)]
95struct JsWhereInputFieldContext {
96    name: String,
97    /// Base TS type used by the template to pick the right filter interface.
98    base_type: String,
99    ts_type: String,
100    operators: Vec<JsFilterOperatorContext>,
101}
102
103#[derive(Debug, Clone, Serialize)]
104struct JsCreateInputFieldContext {
105    name: String,
106    ts_type: String,
107    is_required: bool,
108}
109
110#[derive(Debug, Clone, Serialize)]
111struct JsUpdateInputFieldContext {
112    name: String,
113    ts_type: String,
114}
115
116#[derive(Debug, Clone, Serialize)]
117struct JsOrderByFieldContext {
118    name: String,
119}
120
121#[derive(Debug, Clone, Serialize)]
122struct JsIncludeFieldContext {
123    name: String,
124    target_model: String,
125    target_snake: String,
126    /// camelCase — property name on the generated Nautilus class.
127    target_camel: String,
128    is_array: bool,
129}
130
131#[derive(Debug, Clone, Serialize)]
132struct JsAggregateFieldContext {
133    name: String,
134    ts_type: String,
135}
136
137/// Generate JavaScript + declaration code for a single model.
138///
139/// Returns `((js_filename, js_code), (dts_filename, dts_code))`.
140pub fn generate_js_model(model: &ModelIr, ir: &SchemaIr) -> ((String, String), (String, String)) {
141    let mut context = Context::new();
142
143    context.insert("model_name", &model.logical_name);
144    context.insert("snake_name", &model.logical_name.to_snake_case());
145    context.insert("table_name", &model.db_name);
146    context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
147
148    let pk_field_names = model.primary_key.fields();
149    context.insert("primary_key_fields", &pk_field_names);
150
151    let mut enum_imports: HashSet<String> = HashSet::new();
152    let mut composite_type_imports: HashSet<String> = HashSet::new();
153
154    let mut scalar_fields: Vec<JsFieldContext> = Vec::new();
155    let mut where_input_fields: Vec<JsWhereInputFieldContext> = Vec::new();
156    let mut create_input_fields: Vec<JsCreateInputFieldContext> = Vec::new();
157    let mut update_input_fields: Vec<JsUpdateInputFieldContext> = Vec::new();
158    let mut order_by_fields: Vec<JsOrderByFieldContext> = Vec::new();
159    let mut numeric_fields: Vec<JsAggregateFieldContext> = Vec::new();
160    let mut orderable_fields: Vec<JsAggregateFieldContext> = Vec::new();
161
162    for (idx, field) in model.scalar_fields().enumerate() {
163        match &field.field_type {
164            ResolvedFieldType::Enum { enum_name } => {
165                if ir.enums.contains_key(enum_name) {
166                    enum_imports.insert(enum_name.clone());
167                }
168            }
169            ResolvedFieldType::CompositeType { type_name } => {
170                if ir.composite_types.contains_key(type_name) {
171                    composite_type_imports.insert(type_name.clone());
172                }
173            }
174            _ => {}
175        }
176
177        let ts_type = field_to_ts_type(field, &ir.enums);
178        let base_type = get_base_ts_type(field, &ir.enums);
179        let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
180        let auto_generated = is_auto_generated(field);
181        let default_val = get_ts_default_value(field);
182        let is_pk = pk_field_names.contains(&field.logical_name.as_str());
183
184        scalar_fields.push(JsFieldContext {
185            name: field.logical_name.clone(),
186            logical_name: field.logical_name.clone(),
187            db_name: field.db_name.clone(),
188            ts_type: ts_type.clone(),
189            base_type: base_type.clone(),
190            is_optional: !field.is_required,
191            is_array: field.is_array,
192            is_enum,
193            has_default: default_val.is_some(),
194            default: default_val.unwrap_or_default(),
195            is_pk,
196            index: idx,
197        });
198
199        if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
200            let operators = get_filter_operators_for_field(field, &ir.enums);
201            where_input_fields.push(JsWhereInputFieldContext {
202                name: field.logical_name.clone(),
203                base_type: base_type.clone(),
204                ts_type: ts_type.clone(),
205                operators: operators
206                    .into_iter()
207                    .map(|op| JsFilterOperatorContext {
208                        suffix: op.suffix,
209                        ts_type: op.type_name,
210                    })
211                    .collect(),
212            });
213        }
214
215        if !auto_generated {
216            let input_base = base_type.clone();
217            let typed = if field.is_array {
218                format!("{}[]", input_base)
219            } else {
220                input_base
221            };
222            create_input_fields.push(JsCreateInputFieldContext {
223                name: field.logical_name.clone(),
224                ts_type: typed,
225                is_required: field.is_required
226                    && field.default_value.is_none()
227                    && !field.is_updated_at,
228            });
229        }
230
231        let is_auto_pk = auto_generated
232            && pk_field_names.contains(&field.logical_name.as_str())
233            && matches!(
234                field.field_type,
235                ResolvedFieldType::Scalar(ScalarType::Int)
236                    | ResolvedFieldType::Scalar(ScalarType::BigInt)
237            );
238        if !is_auto_pk {
239            let input_base = base_type.clone();
240            let typed = if field.is_array {
241                format!("{}[]", input_base)
242            } else {
243                format!("{} | null", input_base)
244            };
245            update_input_fields.push(JsUpdateInputFieldContext {
246                name: field.logical_name.clone(),
247                ts_type: typed,
248            });
249        }
250
251        order_by_fields.push(JsOrderByFieldContext {
252            name: field.logical_name.clone(),
253        });
254
255        let is_numeric = matches!(
256            &field.field_type,
257            ResolvedFieldType::Scalar(ScalarType::Int)
258                | ResolvedFieldType::Scalar(ScalarType::BigInt)
259                | ResolvedFieldType::Scalar(ScalarType::Float)
260                | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
261        );
262        if is_numeric {
263            let agg_type = if let ResolvedFieldType::Scalar(s) = &field.field_type {
264                scalar_to_ts_type(s).to_string()
265            } else {
266                unreachable!()
267            };
268            numeric_fields.push(JsAggregateFieldContext {
269                name: field.logical_name.clone(),
270                ts_type: agg_type,
271            });
272        }
273
274        let is_non_orderable = matches!(
275            &field.field_type,
276            ResolvedFieldType::Scalar(ScalarType::Boolean)
277                | ResolvedFieldType::Scalar(ScalarType::Json)
278                | ResolvedFieldType::Scalar(ScalarType::Bytes)
279        );
280        if !is_non_orderable {
281            orderable_fields.push(JsAggregateFieldContext {
282                name: field.logical_name.clone(),
283                ts_type: base_type,
284            });
285        }
286    }
287
288    let relation_fields: Vec<JsFieldContext> = model
289        .relation_fields()
290        .enumerate()
291        .map(|(idx, field)| {
292            let ts_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
293                if field.is_array {
294                    format!("{}Model[]", rel.target_model)
295                } else {
296                    format!("{}Model | null", rel.target_model)
297                }
298            } else {
299                "unknown".to_string()
300            };
301            let base_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
302                format!("{}Model", rel.target_model)
303            } else {
304                "unknown".to_string()
305            };
306
307            JsFieldContext {
308                name: field.logical_name.clone(),
309                logical_name: field.logical_name.clone(),
310                db_name: field.db_name.clone(),
311                ts_type,
312                base_type,
313                is_optional: true,
314                is_array: field.is_array,
315                is_enum: false,
316                has_default: true,
317                default: if field.is_array {
318                    "[]".to_string()
319                } else {
320                    "null".to_string()
321                },
322                is_pk: false,
323                index: idx,
324            }
325        })
326        .collect();
327
328    let include_fields: Vec<JsIncludeFieldContext> = model
329        .relation_fields()
330        .filter_map(|field| {
331            if let ResolvedFieldType::Relation(rel) = &field.field_type {
332                Some(JsIncludeFieldContext {
333                    name: field.logical_name.clone(),
334                    target_model: rel.target_model.clone(),
335                    target_snake: rel.target_model.to_snake_case(),
336                    target_camel: rel.target_model.to_lower_camel_case(),
337                    is_array: field.is_array,
338                })
339            } else {
340                None
341            }
342        })
343        .collect();
344
345    let has_numeric_fields = !numeric_fields.is_empty();
346    let has_includes = !include_fields.is_empty();
347    let has_enums = !enum_imports.is_empty();
348
349    context.insert("scalar_fields", &scalar_fields);
350    context.insert("relation_fields", &relation_fields);
351    context.insert("where_input_fields", &where_input_fields);
352    context.insert("create_input_fields", &create_input_fields);
353    context.insert("update_input_fields", &update_input_fields);
354    context.insert("order_by_fields", &order_by_fields);
355    context.insert("include_fields", &include_fields);
356    context.insert("has_includes", &has_includes);
357    context.insert("numeric_fields", &numeric_fields);
358    context.insert("orderable_fields", &orderable_fields);
359    context.insert("has_numeric_fields", &has_numeric_fields);
360    context.insert("has_enums", &has_enums);
361    context.insert(
362        "enum_imports",
363        &enum_imports.into_iter().collect::<Vec<_>>(),
364    );
365    context.insert("has_composite_types", &!composite_type_imports.is_empty());
366    context.insert(
367        "composite_type_imports",
368        &composite_type_imports.into_iter().collect::<Vec<_>>(),
369    );
370
371    let snake = model.logical_name.to_snake_case();
372    let js_code = render("model.js.tera", &context);
373    let dts_code = render("model.d.ts.tera", &context);
374
375    (
376        (format!("{}.js", snake), js_code),
377        (format!("{}.d.ts", snake), dts_code),
378    )
379}
380
381/// Generate JavaScript + declaration code for all models in the schema.
382///
383/// Returns `(js_models, dts_models)`, each sorted by filename.
384#[allow(clippy::type_complexity)]
385pub fn generate_all_js_models(ir: &SchemaIr) -> (Vec<(String, String)>, Vec<(String, String)>) {
386    let pairs: Vec<((String, String), (String, String))> = ir
387        .models
388        .values()
389        .map(|model| generate_js_model(model, ir))
390        .collect();
391
392    let mut js_models: Vec<(String, String)> = pairs.iter().map(|(js, _)| js.clone()).collect();
393    let mut dts_models: Vec<(String, String)> = pairs.iter().map(|(_, dts)| dts.clone()).collect();
394
395    js_models.sort_by(|a, b| a.0.cmp(&b.0));
396    dts_models.sort_by(|a, b| a.0.cmp(&b.0));
397
398    (js_models, dts_models)
399}
400
401/// Generate `types.d.ts` — TypeScript interfaces for all composite types.
402///
403/// Returns `None` when there are no composite types.
404pub fn generate_js_composite_types(
405    composite_types: &HashMap<String, CompositeTypeIr>,
406) -> Option<String> {
407    if composite_types.is_empty() {
408        return None;
409    }
410
411    #[derive(Serialize)]
412    struct CompositeFieldCtx {
413        name: String,
414        ts_type: String,
415    }
416
417    #[derive(Serialize)]
418    struct CompositeTypeCtx {
419        name: String,
420        fields: Vec<CompositeFieldCtx>,
421    }
422
423    let mut type_list: Vec<CompositeTypeCtx> = composite_types
424        .values()
425        .map(|ct| {
426            let fields = ct
427                .fields
428                .iter()
429                .map(|f| {
430                    let base = match &f.field_type {
431                        ResolvedFieldType::Scalar(s) => scalar_to_ts_type(s).to_string(),
432                        ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
433                        ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
434                        ResolvedFieldType::Relation(_) => "unknown".to_string(),
435                    };
436                    let ts_type = if f.is_array {
437                        format!("{}[]", base)
438                    } else if !f.is_required {
439                        format!("{} | null", base)
440                    } else {
441                        base
442                    };
443                    CompositeFieldCtx {
444                        name: f.logical_name.clone(),
445                        ts_type,
446                    }
447                })
448                .collect();
449            CompositeTypeCtx {
450                name: ct.logical_name.clone(),
451                fields,
452            }
453        })
454        .collect();
455    type_list.sort_by(|a, b| a.name.cmp(&b.name));
456
457    let mut context = Context::new();
458    context.insert("composite_types", &type_list);
459
460    Some(render("composite_types.d.ts.tera", &context))
461}
462
463/// Generate `enums.js` + `enums.d.ts` for all enum definitions.
464///
465/// Returns `(js_code, dts_code)`.
466pub fn generate_js_enums(enums: &HashMap<String, EnumIr>) -> (String, String) {
467    #[derive(Serialize)]
468    struct EnumCtx {
469        name: String,
470        variants: Vec<String>,
471    }
472
473    let mut enum_list: Vec<EnumCtx> = enums
474        .values()
475        .map(|e| EnumCtx {
476            name: e.logical_name.clone(),
477            variants: e.variants.clone(),
478        })
479        .collect();
480    enum_list.sort_by(|a, b| a.name.cmp(&b.name));
481
482    let mut context = Context::new();
483    context.insert("enums", &enum_list);
484    let js_code = render("enums.js.tera", &context);
485    let dts_code = render("enums.d.ts.tera", &context);
486    (js_code, dts_code)
487}
488
489/// Generate `index.js` + `index.d.ts` — the typed `Nautilus` class with model delegates.
490///
491/// Returns `(js_code, dts_code)`.
492pub fn generate_js_client(
493    models: &HashMap<String, ModelIr>,
494    schema_path: &str,
495) -> (String, String) {
496    #[derive(Serialize)]
497    struct ModelCtx {
498        /// camelCase — property name on `Nautilus`, e.g. `user`.
499        camel_name: String,
500        /// snake_case — import file name, e.g. `user`.
501        snake_name: String,
502        /// PascalCase + "Delegate", e.g. `UserDelegate`.
503        delegate_name: String,
504    }
505
506    let mut model_list: Vec<ModelCtx> = models
507        .values()
508        .map(|m| ModelCtx {
509            camel_name: m.logical_name.to_lower_camel_case(),
510            snake_name: m.logical_name.to_snake_case(),
511            delegate_name: format!("{}Delegate", m.logical_name),
512        })
513        .collect();
514    model_list.sort_by(|a, b| a.camel_name.cmp(&b.camel_name));
515
516    let mut context = Context::new();
517    context.insert("models", &model_list);
518    context.insert("schema_path", schema_path);
519    let js_code = render("client.js.tera", &context);
520    let dts_code = render("client.d.ts.tera", &context);
521    (js_code, dts_code)
522}
523
524/// Generate `models/index.js` + `models/index.d.ts` — barrel re-exports for all model files.
525///
526/// `js_models` contains the `.js` model filenames. Returns `(js_code, dts_code)`.
527pub fn generate_js_models_index(js_models: &[(String, String)]) -> (String, String) {
528    let mut modules: Vec<String> = js_models
529        .iter()
530        .map(|(file_name, _)| file_name.trim_end_matches(".js").to_string())
531        .collect();
532    modules.sort();
533
534    let mut context = Context::new();
535    context.insert("model_modules", &modules);
536    let js_code = render("models_index.js.tera", &context);
537    let dts_code = render("models_index.d.ts.tera", &context);
538    (js_code, dts_code)
539}
540
541/// Static JavaScript + declaration runtime files embedded at compile time.
542/// Returns `Vec<(filename, content)>` containing both `.js` and `.d.ts` pairs.
543pub fn js_runtime_files() -> Vec<(&'static str, &'static str)> {
544    vec![
545        (
546            "_errors.js",
547            include_str!("../../templates/js/runtime/_errors.js"),
548        ),
549        (
550            "_errors.d.ts",
551            include_str!("../../templates/js/runtime/_errors.d.ts"),
552        ),
553        (
554            "_protocol.js",
555            include_str!("../../templates/js/runtime/_protocol.js"),
556        ),
557        (
558            "_protocol.d.ts",
559            include_str!("../../templates/js/runtime/_protocol.d.ts"),
560        ),
561        (
562            "_engine.js",
563            include_str!("../../templates/js/runtime/_engine.js"),
564        ),
565        (
566            "_engine.d.ts",
567            include_str!("../../templates/js/runtime/_engine.d.ts"),
568        ),
569        (
570            "_client.js",
571            include_str!("../../templates/js/runtime/_client.js"),
572        ),
573        (
574            "_client.d.ts",
575            include_str!("../../templates/js/runtime/_client.d.ts"),
576        ),
577        (
578            "_transaction.js",
579            include_str!("../../templates/js/runtime/_transaction.js"),
580        ),
581        (
582            "_transaction.d.ts",
583            include_str!("../../templates/js/runtime/_transaction.d.ts"),
584        ),
585    ]
586}