1use heck::{ToLowerCamelCase, ToSnakeCase};
4use nautilus_schema::ir::{
5 CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr,
6};
7use serde::Serialize;
8use std::collections::{HashMap, HashSet};
9use tera::{Context, Tera};
10
11use crate::js::type_mapper::{
12 field_to_ts_type, get_base_ts_type, get_filter_operators_for_field, get_ts_default_value,
13 is_auto_generated, scalar_to_ts_type,
14};
15
16pub static JS_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
18 let mut tera = Tera::default();
19 tera.add_raw_templates(vec![
20 (
21 "model.js.tera",
22 include_str!("../../templates/js/model.js.tera"),
23 ),
24 (
25 "model.d.ts.tera",
26 include_str!("../../templates/js/model.d.ts.tera"),
27 ),
28 (
29 "enums.js.tera",
30 include_str!("../../templates/js/enums.js.tera"),
31 ),
32 (
33 "enums.d.ts.tera",
34 include_str!("../../templates/js/enums.d.ts.tera"),
35 ),
36 (
37 "client.js.tera",
38 include_str!("../../templates/js/client.js.tera"),
39 ),
40 (
41 "client.d.ts.tera",
42 include_str!("../../templates/js/client.d.ts.tera"),
43 ),
44 (
45 "models_index.js.tera",
46 include_str!("../../templates/js/models_index.js.tera"),
47 ),
48 (
49 "models_index.d.ts.tera",
50 include_str!("../../templates/js/models_index.d.ts.tera"),
51 ),
52 (
53 "composite_types.d.ts.tera",
54 include_str!("../../templates/js/composite_types.d.ts.tera"),
55 ),
56 ])
57 .expect("embedded JS templates must parse");
58 tera
59});
60
61fn render(template: &str, ctx: &Context) -> String {
62 JS_TEMPLATES
63 .render(template, ctx)
64 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
65}
66
67#[derive(Debug, Clone, Serialize)]
68struct JsFieldContext {
69 name: String,
71 logical_name: String,
73 db_name: String,
75 ts_type: String,
77 base_type: String,
79 is_optional: bool,
80 is_array: bool,
81 is_enum: bool,
82 has_default: bool,
83 default: String,
84 index: usize,
85}
86
87#[derive(Debug, Clone, Serialize)]
88struct JsFilterOperatorContext {
89 suffix: String,
90 ts_type: String,
91}
92
93#[derive(Debug, Clone, Serialize)]
94struct JsWhereInputFieldContext {
95 name: String,
96 base_type: String,
98 ts_type: String,
99 operators: Vec<JsFilterOperatorContext>,
100}
101
102#[derive(Debug, Clone, Serialize)]
103struct JsCreateInputFieldContext {
104 name: String,
105 ts_type: String,
106 is_required: bool,
107}
108
109#[derive(Debug, Clone, Serialize)]
110struct JsUpdateInputFieldContext {
111 name: String,
112 ts_type: String,
113}
114
115#[derive(Debug, Clone, Serialize)]
116struct JsOrderByFieldContext {
117 name: String,
118}
119
120#[derive(Debug, Clone, Serialize)]
121struct JsIncludeFieldContext {
122 name: String,
123 target_model: String,
124 target_camel: String,
126 is_array: bool,
127}
128
129#[derive(Debug, Clone, Serialize)]
130struct JsAggregateFieldContext {
131 name: String,
132 ts_type: String,
133}
134
135pub fn generate_js_model(model: &ModelIr, ir: &SchemaIr) -> ((String, String), (String, String)) {
139 let mut context = Context::new();
140
141 context.insert("model_name", &model.logical_name);
143 context.insert("snake_name", &model.logical_name.to_snake_case());
144 context.insert("table_name", &model.db_name);
145 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
146
147 let pk_field_names = model.primary_key.fields();
149 context.insert("primary_key_fields", &pk_field_names);
150
151 let mut enum_imports: HashSet<String> = HashSet::new();
153 let mut composite_type_imports: HashSet<String> = HashSet::new();
154
155 let mut scalar_fields: Vec<JsFieldContext> = Vec::new();
156 let mut where_input_fields: Vec<JsWhereInputFieldContext> = Vec::new();
157 let mut create_input_fields: Vec<JsCreateInputFieldContext> = Vec::new();
158 let mut update_input_fields: Vec<JsUpdateInputFieldContext> = Vec::new();
159 let mut order_by_fields: Vec<JsOrderByFieldContext> = Vec::new();
160 let mut numeric_fields: Vec<JsAggregateFieldContext> = Vec::new();
161 let mut orderable_fields: Vec<JsAggregateFieldContext> = Vec::new();
162 let mut updated_at_field_names: Vec<String> = Vec::new();
163
164 for (idx, field) in model.scalar_fields().enumerate() {
165 match &field.field_type {
167 ResolvedFieldType::Enum { enum_name } => {
168 if ir.enums.contains_key(enum_name) {
169 enum_imports.insert(enum_name.clone());
170 }
171 }
172 ResolvedFieldType::CompositeType { type_name } => {
173 if ir.composite_types.contains_key(type_name) {
174 composite_type_imports.insert(type_name.clone());
175 }
176 }
177 _ => {}
178 }
179
180 let ts_type = field_to_ts_type(field, &ir.enums);
182 let base_type = get_base_ts_type(field, &ir.enums);
183 let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
184 let auto_generated = is_auto_generated(field);
185 let default_val = get_ts_default_value(field);
186
187 scalar_fields.push(JsFieldContext {
189 name: field.logical_name.clone(),
190 logical_name: field.logical_name.clone(),
191 db_name: field.db_name.clone(),
192 ts_type: ts_type.clone(),
193 base_type: base_type.clone(),
194 is_optional: !field.is_required,
195 is_array: field.is_array,
196 is_enum,
197 has_default: default_val.is_some(),
198 default: default_val.unwrap_or_default(),
199 index: idx,
200 });
201
202 if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
204 let operators = get_filter_operators_for_field(field, &ir.enums);
205 where_input_fields.push(JsWhereInputFieldContext {
206 name: field.logical_name.clone(),
207 base_type: base_type.clone(),
208 ts_type: ts_type.clone(),
209 operators: operators
210 .into_iter()
211 .map(|op| JsFilterOperatorContext {
212 suffix: op.suffix,
213 ts_type: op.type_name,
214 })
215 .collect(),
216 });
217 }
218
219 if !auto_generated {
221 let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
222 {
223 "object".to_string()
224 } else {
225 base_type.clone()
226 };
227 let typed = if field.is_array {
228 format!("{}[]", input_base)
229 } else {
230 input_base
231 };
232 create_input_fields.push(JsCreateInputFieldContext {
233 name: field.logical_name.clone(),
234 ts_type: typed,
235 is_required: field.is_required
236 && field.default_value.is_none()
237 && !field.is_updated_at,
238 });
239 }
240
241 let is_auto_pk = auto_generated
243 && pk_field_names.contains(&field.logical_name.as_str())
244 && matches!(
245 field.field_type,
246 ResolvedFieldType::Scalar(ScalarType::Int)
247 | ResolvedFieldType::Scalar(ScalarType::BigInt)
248 );
249 if !is_auto_pk {
250 let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
251 {
252 "object".to_string()
253 } else {
254 base_type.clone()
255 };
256 let typed = if field.is_array {
257 format!("{}[]", input_base)
258 } else {
259 format!("{} | null", input_base)
260 };
261 update_input_fields.push(JsUpdateInputFieldContext {
262 name: field.logical_name.clone(),
263 ts_type: typed,
264 });
265 }
266
267 order_by_fields.push(JsOrderByFieldContext {
269 name: field.logical_name.clone(),
270 });
271
272 let is_numeric = matches!(
274 &field.field_type,
275 ResolvedFieldType::Scalar(ScalarType::Int)
276 | ResolvedFieldType::Scalar(ScalarType::BigInt)
277 | ResolvedFieldType::Scalar(ScalarType::Float)
278 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
279 );
280 if is_numeric {
281 let agg_type = if let ResolvedFieldType::Scalar(s) = &field.field_type {
282 scalar_to_ts_type(s).to_string()
283 } else {
284 unreachable!()
285 };
286 numeric_fields.push(JsAggregateFieldContext {
287 name: field.logical_name.clone(),
288 ts_type: agg_type,
289 });
290 }
291
292 let is_non_orderable = matches!(
293 &field.field_type,
294 ResolvedFieldType::Scalar(ScalarType::Boolean)
295 | ResolvedFieldType::Scalar(ScalarType::Json)
296 | ResolvedFieldType::Scalar(ScalarType::Bytes)
297 );
298 if !is_non_orderable {
299 orderable_fields.push(JsAggregateFieldContext {
300 name: field.logical_name.clone(),
301 ts_type: base_type,
302 });
303 }
304
305 if field.is_updated_at {
307 updated_at_field_names.push(field.logical_name.clone());
308 }
309 }
310
311 let relation_fields: Vec<JsFieldContext> = model
313 .relation_fields()
314 .enumerate()
315 .map(|(idx, field)| {
316 let ts_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
317 if field.is_array {
318 format!("{}Model[]", rel.target_model)
319 } else {
320 format!("{}Model | null", rel.target_model)
321 }
322 } else {
323 "unknown".to_string()
324 };
325 let base_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
326 format!("{}Model", rel.target_model)
327 } else {
328 "unknown".to_string()
329 };
330
331 JsFieldContext {
332 name: field.logical_name.clone(),
333 logical_name: field.logical_name.clone(),
334 db_name: field.db_name.clone(),
335 ts_type,
336 base_type,
337 is_optional: true,
338 is_array: field.is_array,
339 is_enum: false,
340 has_default: true,
341 default: if field.is_array {
342 "[]".to_string()
343 } else {
344 "null".to_string()
345 },
346 index: idx,
347 }
348 })
349 .collect();
350
351 let include_fields: Vec<JsIncludeFieldContext> = model
352 .relation_fields()
353 .filter_map(|field| {
354 if let ResolvedFieldType::Relation(rel) = &field.field_type {
355 Some(JsIncludeFieldContext {
356 name: field.logical_name.clone(),
357 target_model: rel.target_model.clone(),
358 target_camel: rel.target_model.to_lower_camel_case(),
359 is_array: field.is_array,
360 })
361 } else {
362 None
363 }
364 })
365 .collect();
366
367 let has_numeric_fields = !numeric_fields.is_empty();
368 let has_includes = !include_fields.is_empty();
369 let has_enums = !enum_imports.is_empty();
370
371 context.insert("scalar_fields", &scalar_fields);
372 context.insert("relation_fields", &relation_fields);
373 context.insert("where_input_fields", &where_input_fields);
374 context.insert("create_input_fields", &create_input_fields);
375 context.insert("update_input_fields", &update_input_fields);
376 context.insert("updated_at_fields", &updated_at_field_names);
377 context.insert("order_by_fields", &order_by_fields);
378 context.insert("include_fields", &include_fields);
379 context.insert("has_includes", &has_includes);
380 context.insert("numeric_fields", &numeric_fields);
381 context.insert("orderable_fields", &orderable_fields);
382 context.insert("has_numeric_fields", &has_numeric_fields);
383 context.insert("has_enums", &has_enums);
384 context.insert(
385 "enum_imports",
386 &enum_imports.into_iter().collect::<Vec<_>>(),
387 );
388 context.insert("has_composite_types", &!composite_type_imports.is_empty());
389 context.insert(
390 "composite_type_imports",
391 &composite_type_imports.into_iter().collect::<Vec<_>>(),
392 );
393
394 let snake = model.logical_name.to_snake_case();
395 let js_code = render("model.js.tera", &context);
396 let dts_code = render("model.d.ts.tera", &context);
397
398 (
399 (format!("{}.js", snake), js_code),
400 (format!("{}.d.ts", snake), dts_code),
401 )
402}
403
404#[allow(clippy::type_complexity)]
408pub fn generate_all_js_models(ir: &SchemaIr) -> (Vec<(String, String)>, Vec<(String, String)>) {
409 let pairs: Vec<((String, String), (String, String))> = ir
410 .models
411 .values()
412 .map(|model| generate_js_model(model, ir))
413 .collect();
414
415 let mut js_models: Vec<(String, String)> = pairs.iter().map(|(js, _)| js.clone()).collect();
416 let mut dts_models: Vec<(String, String)> = pairs.iter().map(|(_, dts)| dts.clone()).collect();
417
418 js_models.sort_by(|a, b| a.0.cmp(&b.0));
419 dts_models.sort_by(|a, b| a.0.cmp(&b.0));
420
421 (js_models, dts_models)
422}
423
424pub fn generate_js_composite_types(
428 composite_types: &HashMap<String, CompositeTypeIr>,
429) -> Option<String> {
430 if composite_types.is_empty() {
431 return None;
432 }
433
434 #[derive(Serialize)]
435 struct CompositeFieldCtx {
436 name: String,
437 ts_type: String,
438 }
439
440 #[derive(Serialize)]
441 struct CompositeTypeCtx {
442 name: String,
443 fields: Vec<CompositeFieldCtx>,
444 }
445
446 let mut type_list: Vec<CompositeTypeCtx> = composite_types
447 .values()
448 .map(|ct| {
449 let fields = ct
450 .fields
451 .iter()
452 .map(|f| {
453 let base = match &f.field_type {
454 ResolvedFieldType::Scalar(s) => scalar_to_ts_type(s).to_string(),
455 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
456 ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
457 ResolvedFieldType::Relation(_) => "unknown".to_string(),
458 };
459 let ts_type = if f.is_array {
460 format!("{}[]", base)
461 } else if !f.is_required {
462 format!("{} | null", base)
463 } else {
464 base
465 };
466 CompositeFieldCtx {
467 name: f.logical_name.clone(),
468 ts_type,
469 }
470 })
471 .collect();
472 CompositeTypeCtx {
473 name: ct.logical_name.clone(),
474 fields,
475 }
476 })
477 .collect();
478 type_list.sort_by(|a, b| a.name.cmp(&b.name));
479
480 let mut context = Context::new();
481 context.insert("composite_types", &type_list);
482
483 Some(render("composite_types.d.ts.tera", &context))
484}
485
486pub fn generate_js_enums(enums: &HashMap<String, EnumIr>) -> (String, String) {
490 #[derive(Serialize)]
491 struct EnumCtx {
492 name: String,
493 variants: Vec<String>,
494 }
495
496 let mut enum_list: Vec<EnumCtx> = enums
497 .values()
498 .map(|e| EnumCtx {
499 name: e.logical_name.clone(),
500 variants: e.variants.clone(),
501 })
502 .collect();
503 enum_list.sort_by(|a, b| a.name.cmp(&b.name));
504
505 let mut context = Context::new();
506 context.insert("enums", &enum_list);
507 let js_code = render("enums.js.tera", &context);
508 let dts_code = render("enums.d.ts.tera", &context);
509 (js_code, dts_code)
510}
511
512pub fn generate_js_client(
516 models: &HashMap<String, ModelIr>,
517 schema_path: &str,
518) -> (String, String) {
519 #[derive(Serialize)]
520 struct ModelCtx {
521 camel_name: String,
523 snake_name: String,
525 delegate_name: String,
527 }
528
529 let mut model_list: Vec<ModelCtx> = models
530 .values()
531 .map(|m| ModelCtx {
532 camel_name: m.logical_name.to_lower_camel_case(),
533 snake_name: m.logical_name.to_snake_case(),
534 delegate_name: format!("{}Delegate", m.logical_name),
535 })
536 .collect();
537 model_list.sort_by(|a, b| a.camel_name.cmp(&b.camel_name));
538
539 let mut context = Context::new();
540 context.insert("models", &model_list);
541 context.insert("schema_path", schema_path);
542 let js_code = render("client.js.tera", &context);
543 let dts_code = render("client.d.ts.tera", &context);
544 (js_code, dts_code)
545}
546
547pub fn generate_js_models_index(js_models: &[(String, String)]) -> (String, String) {
551 let mut modules: Vec<String> = js_models
552 .iter()
553 .map(|(file_name, _)| file_name.trim_end_matches(".js").to_string())
554 .collect();
555 modules.sort();
556
557 let mut context = Context::new();
558 context.insert("model_modules", &modules);
559 let js_code = render("models_index.js.tera", &context);
560 let dts_code = render("models_index.d.ts.tera", &context);
561 (js_code, dts_code)
562}
563
564pub fn js_runtime_files() -> Vec<(&'static str, &'static str)> {
567 vec![
568 (
569 "_errors.js",
570 include_str!("../../templates/js/runtime/_errors.js"),
571 ),
572 (
573 "_errors.d.ts",
574 include_str!("../../templates/js/runtime/_errors.d.ts"),
575 ),
576 (
577 "_protocol.js",
578 include_str!("../../templates/js/runtime/_protocol.js"),
579 ),
580 (
581 "_protocol.d.ts",
582 include_str!("../../templates/js/runtime/_protocol.d.ts"),
583 ),
584 (
585 "_engine.js",
586 include_str!("../../templates/js/runtime/_engine.js"),
587 ),
588 (
589 "_engine.d.ts",
590 include_str!("../../templates/js/runtime/_engine.d.ts"),
591 ),
592 (
593 "_client.js",
594 include_str!("../../templates/js/runtime/_client.js"),
595 ),
596 (
597 "_client.d.ts",
598 include_str!("../../templates/js/runtime/_client.d.ts"),
599 ),
600 (
601 "_transaction.js",
602 include_str!("../../templates/js/runtime/_transaction.js"),
603 ),
604 (
605 "_transaction.d.ts",
606 include_str!("../../templates/js/runtime/_transaction.d.ts"),
607 ),
608 ]
609}