Skip to main content

nautilus_codegen/js/
generator.rs

1//! JavaScript/TypeScript code generator for Nautilus models, delegates, and input types.
2
3use heck::{ToLowerCamelCase, ToSnakeCase};
4use nautilus_schema::ir::{
5    CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr,
6};
7use serde::Serialize;
8use std::collections::{HashMap, HashSet};
9use tera::{Context, Tera};
10
11use crate::js::type_mapper::{
12    field_to_ts_type, get_base_ts_type, get_filter_operators_for_field, get_ts_default_value,
13    is_auto_generated, scalar_to_ts_type,
14};
15
16/// JS/TS template registry — loaded once at first use.
17pub static JS_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
18    let mut tera = Tera::default();
19    tera.add_raw_templates(vec![
20        (
21            "model.js.tera",
22            include_str!("../../templates/js/model.js.tera"),
23        ),
24        (
25            "model.d.ts.tera",
26            include_str!("../../templates/js/model.d.ts.tera"),
27        ),
28        (
29            "enums.js.tera",
30            include_str!("../../templates/js/enums.js.tera"),
31        ),
32        (
33            "enums.d.ts.tera",
34            include_str!("../../templates/js/enums.d.ts.tera"),
35        ),
36        (
37            "client.js.tera",
38            include_str!("../../templates/js/client.js.tera"),
39        ),
40        (
41            "client.d.ts.tera",
42            include_str!("../../templates/js/client.d.ts.tera"),
43        ),
44        (
45            "models_index.js.tera",
46            include_str!("../../templates/js/models_index.js.tera"),
47        ),
48        (
49            "models_index.d.ts.tera",
50            include_str!("../../templates/js/models_index.d.ts.tera"),
51        ),
52        (
53            "composite_types.d.ts.tera",
54            include_str!("../../templates/js/composite_types.d.ts.tera"),
55        ),
56    ])
57    .expect("embedded JS templates must parse");
58    tera
59});
60
61fn render(template: &str, ctx: &Context) -> String {
62    JS_TEMPLATES
63        .render(template, ctx)
64        .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
65}
66
67#[derive(Debug, Clone, Serialize)]
68struct JsFieldContext {
69    /// Logical JS field name (camelCase, same as schema logical name).
70    name: String,
71    /// Logical name from the schema IR (may differ from `name` after `@map`).
72    logical_name: String,
73    /// Database column name.
74    db_name: String,
75    /// Full TypeScript type, e.g. `string | null`, `number[]`.
76    ts_type: String,
77    /// Inner base type without wrappers, e.g. `string`, `number`, `Date`.
78    base_type: String,
79    is_optional: bool,
80    is_array: bool,
81    is_enum: bool,
82    has_default: bool,
83    default: String,
84    index: usize,
85}
86
87#[derive(Debug, Clone, Serialize)]
88struct JsFilterOperatorContext {
89    suffix: String,
90    ts_type: String,
91}
92
93#[derive(Debug, Clone, Serialize)]
94struct JsWhereInputFieldContext {
95    name: String,
96    /// Base TS type used by the template to pick the right filter interface.
97    base_type: String,
98    ts_type: String,
99    operators: Vec<JsFilterOperatorContext>,
100}
101
102#[derive(Debug, Clone, Serialize)]
103struct JsCreateInputFieldContext {
104    name: String,
105    ts_type: String,
106    is_required: bool,
107}
108
109#[derive(Debug, Clone, Serialize)]
110struct JsUpdateInputFieldContext {
111    name: String,
112    ts_type: String,
113}
114
115#[derive(Debug, Clone, Serialize)]
116struct JsOrderByFieldContext {
117    name: String,
118}
119
120#[derive(Debug, Clone, Serialize)]
121struct JsIncludeFieldContext {
122    name: String,
123    target_model: String,
124    /// camelCase — property name on the generated Nautilus class.
125    target_camel: String,
126    is_array: bool,
127}
128
129#[derive(Debug, Clone, Serialize)]
130struct JsAggregateFieldContext {
131    name: String,
132    ts_type: String,
133}
134
135/// Generate JavaScript + declaration code for a single model.
136///
137/// Returns `((js_filename, js_code), (dts_filename, dts_code))`.
138pub fn generate_js_model(model: &ModelIr, ir: &SchemaIr) -> ((String, String), (String, String)) {
139    let mut context = Context::new();
140
141    // Basic model info.
142    context.insert("model_name", &model.logical_name);
143    context.insert("snake_name", &model.logical_name.to_snake_case());
144    context.insert("table_name", &model.db_name);
145    context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
146
147    // Primary key fields.
148    let pk_field_names = model.primary_key.fields();
149    context.insert("primary_key_fields", &pk_field_names);
150
151    // --- Single-pass scalar field context building ---
152    let mut enum_imports: HashSet<String> = HashSet::new();
153    let mut composite_type_imports: HashSet<String> = HashSet::new();
154
155    let mut scalar_fields: Vec<JsFieldContext> = Vec::new();
156    let mut where_input_fields: Vec<JsWhereInputFieldContext> = Vec::new();
157    let mut create_input_fields: Vec<JsCreateInputFieldContext> = Vec::new();
158    let mut update_input_fields: Vec<JsUpdateInputFieldContext> = Vec::new();
159    let mut order_by_fields: Vec<JsOrderByFieldContext> = Vec::new();
160    let mut numeric_fields: Vec<JsAggregateFieldContext> = Vec::new();
161    let mut orderable_fields: Vec<JsAggregateFieldContext> = Vec::new();
162    let mut updated_at_field_names: Vec<String> = Vec::new();
163
164    for (idx, field) in model.scalar_fields().enumerate() {
165        // --- Import tracking ---
166        match &field.field_type {
167            ResolvedFieldType::Enum { enum_name } => {
168                if ir.enums.contains_key(enum_name) {
169                    enum_imports.insert(enum_name.clone());
170                }
171            }
172            ResolvedFieldType::CompositeType { type_name } => {
173                if ir.composite_types.contains_key(type_name) {
174                    composite_type_imports.insert(type_name.clone());
175                }
176            }
177            _ => {}
178        }
179
180        // --- Computed values (shared across contexts) ---
181        let ts_type = field_to_ts_type(field, &ir.enums);
182        let base_type = get_base_ts_type(field, &ir.enums);
183        let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
184        let auto_generated = is_auto_generated(field);
185        let default_val = get_ts_default_value(field);
186
187        // --- Build JsFieldContext ---
188        scalar_fields.push(JsFieldContext {
189            name: field.logical_name.clone(),
190            logical_name: field.logical_name.clone(),
191            db_name: field.db_name.clone(),
192            ts_type: ts_type.clone(),
193            base_type: base_type.clone(),
194            is_optional: !field.is_required,
195            is_array: field.is_array,
196            is_enum,
197            has_default: default_val.is_some(),
198            default: default_val.unwrap_or_default(),
199            index: idx,
200        });
201
202        // --- WhereInput ---
203        if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
204            let operators = get_filter_operators_for_field(field, &ir.enums);
205            where_input_fields.push(JsWhereInputFieldContext {
206                name: field.logical_name.clone(),
207                base_type: base_type.clone(),
208                ts_type: ts_type.clone(),
209                operators: operators
210                    .into_iter()
211                    .map(|op| JsFilterOperatorContext {
212                        suffix: op.suffix,
213                        ts_type: op.type_name,
214                    })
215                    .collect(),
216            });
217        }
218
219        // --- CreateInput ---
220        if !auto_generated {
221            let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
222            {
223                "object".to_string()
224            } else {
225                base_type.clone()
226            };
227            let typed = if field.is_array {
228                format!("{}[]", input_base)
229            } else {
230                input_base
231            };
232            create_input_fields.push(JsCreateInputFieldContext {
233                name: field.logical_name.clone(),
234                ts_type: typed,
235                is_required: field.is_required
236                    && field.default_value.is_none()
237                    && !field.is_updated_at,
238            });
239        }
240
241        // --- UpdateInput ---
242        let is_auto_pk = auto_generated
243            && pk_field_names.contains(&field.logical_name.as_str())
244            && matches!(
245                field.field_type,
246                ResolvedFieldType::Scalar(ScalarType::Int)
247                    | ResolvedFieldType::Scalar(ScalarType::BigInt)
248            );
249        if !is_auto_pk {
250            let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
251            {
252                "object".to_string()
253            } else {
254                base_type.clone()
255            };
256            let typed = if field.is_array {
257                format!("{}[]", input_base)
258            } else {
259                format!("{} | null", input_base)
260            };
261            update_input_fields.push(JsUpdateInputFieldContext {
262                name: field.logical_name.clone(),
263                ts_type: typed,
264            });
265        }
266
267        // --- OrderBy ---
268        order_by_fields.push(JsOrderByFieldContext {
269            name: field.logical_name.clone(),
270        });
271
272        // --- Aggregate fields ---
273        let is_numeric = matches!(
274            &field.field_type,
275            ResolvedFieldType::Scalar(ScalarType::Int)
276                | ResolvedFieldType::Scalar(ScalarType::BigInt)
277                | ResolvedFieldType::Scalar(ScalarType::Float)
278                | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
279        );
280        if is_numeric {
281            let agg_type = if let ResolvedFieldType::Scalar(s) = &field.field_type {
282                scalar_to_ts_type(s).to_string()
283            } else {
284                unreachable!()
285            };
286            numeric_fields.push(JsAggregateFieldContext {
287                name: field.logical_name.clone(),
288                ts_type: agg_type,
289            });
290        }
291
292        let is_non_orderable = matches!(
293            &field.field_type,
294            ResolvedFieldType::Scalar(ScalarType::Boolean)
295                | ResolvedFieldType::Scalar(ScalarType::Json)
296                | ResolvedFieldType::Scalar(ScalarType::Bytes)
297        );
298        if !is_non_orderable {
299            orderable_fields.push(JsAggregateFieldContext {
300                name: field.logical_name.clone(),
301                ts_type: base_type,
302            });
303        }
304
305        // --- Updated-at ---
306        if field.is_updated_at {
307            updated_at_field_names.push(field.logical_name.clone());
308        }
309    }
310
311    // --- Relation fields (separate field set) ---
312    let relation_fields: Vec<JsFieldContext> = model
313        .relation_fields()
314        .enumerate()
315        .map(|(idx, field)| {
316            let ts_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
317                if field.is_array {
318                    format!("{}Model[]", rel.target_model)
319                } else {
320                    format!("{}Model | null", rel.target_model)
321                }
322            } else {
323                "unknown".to_string()
324            };
325            let base_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
326                format!("{}Model", rel.target_model)
327            } else {
328                "unknown".to_string()
329            };
330
331            JsFieldContext {
332                name: field.logical_name.clone(),
333                logical_name: field.logical_name.clone(),
334                db_name: field.db_name.clone(),
335                ts_type,
336                base_type,
337                is_optional: true,
338                is_array: field.is_array,
339                is_enum: false,
340                has_default: true,
341                default: if field.is_array {
342                    "[]".to_string()
343                } else {
344                    "null".to_string()
345                },
346                index: idx,
347            }
348        })
349        .collect();
350
351    let include_fields: Vec<JsIncludeFieldContext> = model
352        .relation_fields()
353        .filter_map(|field| {
354            if let ResolvedFieldType::Relation(rel) = &field.field_type {
355                Some(JsIncludeFieldContext {
356                    name: field.logical_name.clone(),
357                    target_model: rel.target_model.clone(),
358                    target_camel: rel.target_model.to_lower_camel_case(),
359                    is_array: field.is_array,
360                })
361            } else {
362                None
363            }
364        })
365        .collect();
366
367    let has_numeric_fields = !numeric_fields.is_empty();
368    let has_includes = !include_fields.is_empty();
369    let has_enums = !enum_imports.is_empty();
370
371    context.insert("scalar_fields", &scalar_fields);
372    context.insert("relation_fields", &relation_fields);
373    context.insert("where_input_fields", &where_input_fields);
374    context.insert("create_input_fields", &create_input_fields);
375    context.insert("update_input_fields", &update_input_fields);
376    context.insert("updated_at_fields", &updated_at_field_names);
377    context.insert("order_by_fields", &order_by_fields);
378    context.insert("include_fields", &include_fields);
379    context.insert("has_includes", &has_includes);
380    context.insert("numeric_fields", &numeric_fields);
381    context.insert("orderable_fields", &orderable_fields);
382    context.insert("has_numeric_fields", &has_numeric_fields);
383    context.insert("has_enums", &has_enums);
384    context.insert(
385        "enum_imports",
386        &enum_imports.into_iter().collect::<Vec<_>>(),
387    );
388    context.insert("has_composite_types", &!composite_type_imports.is_empty());
389    context.insert(
390        "composite_type_imports",
391        &composite_type_imports.into_iter().collect::<Vec<_>>(),
392    );
393
394    let snake = model.logical_name.to_snake_case();
395    let js_code = render("model.js.tera", &context);
396    let dts_code = render("model.d.ts.tera", &context);
397
398    (
399        (format!("{}.js", snake), js_code),
400        (format!("{}.d.ts", snake), dts_code),
401    )
402}
403
404/// Generate JavaScript + declaration code for all models in the schema.
405///
406/// Returns `(js_models, dts_models)`, each sorted by filename.
407#[allow(clippy::type_complexity)]
408pub fn generate_all_js_models(ir: &SchemaIr) -> (Vec<(String, String)>, Vec<(String, String)>) {
409    let pairs: Vec<((String, String), (String, String))> = ir
410        .models
411        .values()
412        .map(|model| generate_js_model(model, ir))
413        .collect();
414
415    let mut js_models: Vec<(String, String)> = pairs.iter().map(|(js, _)| js.clone()).collect();
416    let mut dts_models: Vec<(String, String)> = pairs.iter().map(|(_, dts)| dts.clone()).collect();
417
418    js_models.sort_by(|a, b| a.0.cmp(&b.0));
419    dts_models.sort_by(|a, b| a.0.cmp(&b.0));
420
421    (js_models, dts_models)
422}
423
424/// Generate `types.d.ts` — TypeScript interfaces for all composite types.
425///
426/// Returns `None` when there are no composite types.
427pub fn generate_js_composite_types(
428    composite_types: &HashMap<String, CompositeTypeIr>,
429) -> Option<String> {
430    if composite_types.is_empty() {
431        return None;
432    }
433
434    #[derive(Serialize)]
435    struct CompositeFieldCtx {
436        name: String,
437        ts_type: String,
438    }
439
440    #[derive(Serialize)]
441    struct CompositeTypeCtx {
442        name: String,
443        fields: Vec<CompositeFieldCtx>,
444    }
445
446    let mut type_list: Vec<CompositeTypeCtx> = composite_types
447        .values()
448        .map(|ct| {
449            let fields = ct
450                .fields
451                .iter()
452                .map(|f| {
453                    let base = match &f.field_type {
454                        ResolvedFieldType::Scalar(s) => scalar_to_ts_type(s).to_string(),
455                        ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
456                        ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
457                        ResolvedFieldType::Relation(_) => "unknown".to_string(),
458                    };
459                    let ts_type = if f.is_array {
460                        format!("{}[]", base)
461                    } else if !f.is_required {
462                        format!("{} | null", base)
463                    } else {
464                        base
465                    };
466                    CompositeFieldCtx {
467                        name: f.logical_name.clone(),
468                        ts_type,
469                    }
470                })
471                .collect();
472            CompositeTypeCtx {
473                name: ct.logical_name.clone(),
474                fields,
475            }
476        })
477        .collect();
478    type_list.sort_by(|a, b| a.name.cmp(&b.name));
479
480    let mut context = Context::new();
481    context.insert("composite_types", &type_list);
482
483    Some(render("composite_types.d.ts.tera", &context))
484}
485
486/// Generate `enums.js` + `enums.d.ts` for all enum definitions.
487///
488/// Returns `(js_code, dts_code)`.
489pub fn generate_js_enums(enums: &HashMap<String, EnumIr>) -> (String, String) {
490    #[derive(Serialize)]
491    struct EnumCtx {
492        name: String,
493        variants: Vec<String>,
494    }
495
496    let mut enum_list: Vec<EnumCtx> = enums
497        .values()
498        .map(|e| EnumCtx {
499            name: e.logical_name.clone(),
500            variants: e.variants.clone(),
501        })
502        .collect();
503    enum_list.sort_by(|a, b| a.name.cmp(&b.name));
504
505    let mut context = Context::new();
506    context.insert("enums", &enum_list);
507    let js_code = render("enums.js.tera", &context);
508    let dts_code = render("enums.d.ts.tera", &context);
509    (js_code, dts_code)
510}
511
512/// Generate `index.js` + `index.d.ts` — the typed `Nautilus` class with model delegates.
513///
514/// Returns `(js_code, dts_code)`.
515pub fn generate_js_client(
516    models: &HashMap<String, ModelIr>,
517    schema_path: &str,
518) -> (String, String) {
519    #[derive(Serialize)]
520    struct ModelCtx {
521        /// camelCase — property name on `Nautilus`, e.g. `user`.
522        camel_name: String,
523        /// snake_case — import file name, e.g. `user`.
524        snake_name: String,
525        /// PascalCase + "Delegate", e.g. `UserDelegate`.
526        delegate_name: String,
527    }
528
529    let mut model_list: Vec<ModelCtx> = models
530        .values()
531        .map(|m| ModelCtx {
532            camel_name: m.logical_name.to_lower_camel_case(),
533            snake_name: m.logical_name.to_snake_case(),
534            delegate_name: format!("{}Delegate", m.logical_name),
535        })
536        .collect();
537    model_list.sort_by(|a, b| a.camel_name.cmp(&b.camel_name));
538
539    let mut context = Context::new();
540    context.insert("models", &model_list);
541    context.insert("schema_path", schema_path);
542    let js_code = render("client.js.tera", &context);
543    let dts_code = render("client.d.ts.tera", &context);
544    (js_code, dts_code)
545}
546
547/// Generate `models/index.js` + `models/index.d.ts` — barrel re-exports for all model files.
548///
549/// `js_models` contains the `.js` model filenames. Returns `(js_code, dts_code)`.
550pub fn generate_js_models_index(js_models: &[(String, String)]) -> (String, String) {
551    let mut modules: Vec<String> = js_models
552        .iter()
553        .map(|(file_name, _)| file_name.trim_end_matches(".js").to_string())
554        .collect();
555    modules.sort();
556
557    let mut context = Context::new();
558    context.insert("model_modules", &modules);
559    let js_code = render("models_index.js.tera", &context);
560    let dts_code = render("models_index.d.ts.tera", &context);
561    (js_code, dts_code)
562}
563
564/// Static JavaScript + declaration runtime files embedded at compile time.
565/// Returns `Vec<(filename, content)>` containing both `.js` and `.d.ts` pairs.
566pub fn js_runtime_files() -> Vec<(&'static str, &'static str)> {
567    vec![
568        (
569            "_errors.js",
570            include_str!("../../templates/js/runtime/_errors.js"),
571        ),
572        (
573            "_errors.d.ts",
574            include_str!("../../templates/js/runtime/_errors.d.ts"),
575        ),
576        (
577            "_protocol.js",
578            include_str!("../../templates/js/runtime/_protocol.js"),
579        ),
580        (
581            "_protocol.d.ts",
582            include_str!("../../templates/js/runtime/_protocol.d.ts"),
583        ),
584        (
585            "_engine.js",
586            include_str!("../../templates/js/runtime/_engine.js"),
587        ),
588        (
589            "_engine.d.ts",
590            include_str!("../../templates/js/runtime/_engine.d.ts"),
591        ),
592        (
593            "_client.js",
594            include_str!("../../templates/js/runtime/_client.js"),
595        ),
596        (
597            "_client.d.ts",
598            include_str!("../../templates/js/runtime/_client.d.ts"),
599        ),
600        (
601            "_transaction.js",
602            include_str!("../../templates/js/runtime/_transaction.js"),
603        ),
604        (
605            "_transaction.d.ts",
606            include_str!("../../templates/js/runtime/_transaction.d.ts"),
607        ),
608    ]
609}