1use heck::{ToLowerCamelCase, ToSnakeCase};
4use nautilus_schema::ir::{
5 CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr,
6};
7use serde::Serialize;
8use std::collections::{HashMap, HashSet};
9use tera::{Context, Tera};
10
11use crate::js::type_mapper::{
12 field_to_ts_type, get_base_ts_type, get_filter_operators_for_field, get_ts_default_value,
13 is_auto_generated, scalar_to_ts_type,
14};
15
16pub static JS_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
18 let mut tera = Tera::default();
19 tera.add_raw_templates(vec![
20 (
21 "model.js.tera",
22 include_str!("../../templates/js/model.js.tera"),
23 ),
24 (
25 "model.d.ts.tera",
26 include_str!("../../templates/js/model.d.ts.tera"),
27 ),
28 (
29 "enums.js.tera",
30 include_str!("../../templates/js/enums.js.tera"),
31 ),
32 (
33 "enums.d.ts.tera",
34 include_str!("../../templates/js/enums.d.ts.tera"),
35 ),
36 (
37 "client.js.tera",
38 include_str!("../../templates/js/client.js.tera"),
39 ),
40 (
41 "client.d.ts.tera",
42 include_str!("../../templates/js/client.d.ts.tera"),
43 ),
44 (
45 "models_index.js.tera",
46 include_str!("../../templates/js/models_index.js.tera"),
47 ),
48 (
49 "models_index.d.ts.tera",
50 include_str!("../../templates/js/models_index.d.ts.tera"),
51 ),
52 (
53 "composite_types.d.ts.tera",
54 include_str!("../../templates/js/composite_types.d.ts.tera"),
55 ),
56 ])
57 .expect("embedded JS templates must parse");
58 tera
59});
60
61fn render(template: &str, ctx: &Context) -> String {
62 JS_TEMPLATES
63 .render(template, ctx)
64 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
65}
66
67#[derive(Debug, Clone, Serialize)]
68struct JsFieldContext {
69 name: String,
71 logical_name: String,
73 db_name: String,
75 ts_type: String,
77 base_type: String,
79 is_optional: bool,
80 is_array: bool,
81 is_enum: bool,
82 has_default: bool,
83 default: String,
84 is_pk: bool,
85 index: usize,
86}
87
88#[derive(Debug, Clone, Serialize)]
89struct JsFilterOperatorContext {
90 suffix: String,
91 ts_type: String,
92}
93
94#[derive(Debug, Clone, Serialize)]
95struct JsWhereInputFieldContext {
96 name: String,
97 base_type: String,
99 ts_type: String,
100 operators: Vec<JsFilterOperatorContext>,
101}
102
103#[derive(Debug, Clone, Serialize)]
104struct JsCreateInputFieldContext {
105 name: String,
106 ts_type: String,
107 is_required: bool,
108}
109
110#[derive(Debug, Clone, Serialize)]
111struct JsUpdateInputFieldContext {
112 name: String,
113 ts_type: String,
114}
115
116#[derive(Debug, Clone, Serialize)]
117struct JsOrderByFieldContext {
118 name: String,
119}
120
121#[derive(Debug, Clone, Serialize)]
122struct JsIncludeFieldContext {
123 name: String,
124 target_model: String,
125 target_snake: String,
126 target_camel: String,
128 is_array: bool,
129}
130
131#[derive(Debug, Clone, Serialize)]
132struct JsAggregateFieldContext {
133 name: String,
134 ts_type: String,
135}
136
137pub fn generate_js_model(model: &ModelIr, ir: &SchemaIr) -> ((String, String), (String, String)) {
141 let mut context = Context::new();
142
143 context.insert("model_name", &model.logical_name);
144 context.insert("snake_name", &model.logical_name.to_snake_case());
145 context.insert("table_name", &model.db_name);
146 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
147
148 let pk_field_names = model.primary_key.fields();
149 context.insert("primary_key_fields", &pk_field_names);
150
151 let mut enum_imports: HashSet<String> = HashSet::new();
152 let mut composite_type_imports: HashSet<String> = HashSet::new();
153
154 let mut scalar_fields: Vec<JsFieldContext> = Vec::new();
155 let mut where_input_fields: Vec<JsWhereInputFieldContext> = Vec::new();
156 let mut create_input_fields: Vec<JsCreateInputFieldContext> = Vec::new();
157 let mut update_input_fields: Vec<JsUpdateInputFieldContext> = Vec::new();
158 let mut order_by_fields: Vec<JsOrderByFieldContext> = Vec::new();
159 let mut numeric_fields: Vec<JsAggregateFieldContext> = Vec::new();
160 let mut orderable_fields: Vec<JsAggregateFieldContext> = Vec::new();
161
162 for (idx, field) in model.scalar_fields().enumerate() {
163 match &field.field_type {
164 ResolvedFieldType::Enum { enum_name } => {
165 if ir.enums.contains_key(enum_name) {
166 enum_imports.insert(enum_name.clone());
167 }
168 }
169 ResolvedFieldType::CompositeType { type_name } => {
170 if ir.composite_types.contains_key(type_name) {
171 composite_type_imports.insert(type_name.clone());
172 }
173 }
174 _ => {}
175 }
176
177 let ts_type = field_to_ts_type(field, &ir.enums);
178 let base_type = get_base_ts_type(field, &ir.enums);
179 let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
180 let auto_generated = is_auto_generated(field);
181 let default_val = get_ts_default_value(field);
182 let is_pk = pk_field_names.contains(&field.logical_name.as_str());
183
184 scalar_fields.push(JsFieldContext {
185 name: field.logical_name.clone(),
186 logical_name: field.logical_name.clone(),
187 db_name: field.db_name.clone(),
188 ts_type: ts_type.clone(),
189 base_type: base_type.clone(),
190 is_optional: !field.is_required,
191 is_array: field.is_array,
192 is_enum,
193 has_default: default_val.is_some(),
194 default: default_val.unwrap_or_default(),
195 is_pk,
196 index: idx,
197 });
198
199 if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
200 let operators = get_filter_operators_for_field(field, &ir.enums);
201 where_input_fields.push(JsWhereInputFieldContext {
202 name: field.logical_name.clone(),
203 base_type: base_type.clone(),
204 ts_type: ts_type.clone(),
205 operators: operators
206 .into_iter()
207 .map(|op| JsFilterOperatorContext {
208 suffix: op.suffix,
209 ts_type: op.type_name,
210 })
211 .collect(),
212 });
213 }
214
215 if !auto_generated {
216 let input_base = base_type.clone();
217 let typed = if field.is_array {
218 format!("{}[]", input_base)
219 } else {
220 input_base
221 };
222 create_input_fields.push(JsCreateInputFieldContext {
223 name: field.logical_name.clone(),
224 ts_type: typed,
225 is_required: field.is_required
226 && field.default_value.is_none()
227 && !field.is_updated_at,
228 });
229 }
230
231 let is_auto_pk = auto_generated
232 && pk_field_names.contains(&field.logical_name.as_str())
233 && matches!(
234 field.field_type,
235 ResolvedFieldType::Scalar(ScalarType::Int)
236 | ResolvedFieldType::Scalar(ScalarType::BigInt)
237 );
238 if !is_auto_pk {
239 let input_base = base_type.clone();
240 let typed = if field.is_array {
241 format!("{}[]", input_base)
242 } else {
243 format!("{} | null", input_base)
244 };
245 update_input_fields.push(JsUpdateInputFieldContext {
246 name: field.logical_name.clone(),
247 ts_type: typed,
248 });
249 }
250
251 order_by_fields.push(JsOrderByFieldContext {
252 name: field.logical_name.clone(),
253 });
254
255 let is_numeric = matches!(
256 &field.field_type,
257 ResolvedFieldType::Scalar(ScalarType::Int)
258 | ResolvedFieldType::Scalar(ScalarType::BigInt)
259 | ResolvedFieldType::Scalar(ScalarType::Float)
260 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
261 );
262 if is_numeric {
263 let agg_type = if let ResolvedFieldType::Scalar(s) = &field.field_type {
264 scalar_to_ts_type(s).to_string()
265 } else {
266 unreachable!()
267 };
268 numeric_fields.push(JsAggregateFieldContext {
269 name: field.logical_name.clone(),
270 ts_type: agg_type,
271 });
272 }
273
274 let is_non_orderable = matches!(
275 &field.field_type,
276 ResolvedFieldType::Scalar(ScalarType::Boolean)
277 | ResolvedFieldType::Scalar(ScalarType::Json)
278 | ResolvedFieldType::Scalar(ScalarType::Bytes)
279 );
280 if !is_non_orderable {
281 orderable_fields.push(JsAggregateFieldContext {
282 name: field.logical_name.clone(),
283 ts_type: base_type,
284 });
285 }
286 }
287
288 let relation_fields: Vec<JsFieldContext> = model
289 .relation_fields()
290 .enumerate()
291 .map(|(idx, field)| {
292 let ts_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
293 if field.is_array {
294 format!("{}Model[]", rel.target_model)
295 } else {
296 format!("{}Model | null", rel.target_model)
297 }
298 } else {
299 "unknown".to_string()
300 };
301 let base_type = if let ResolvedFieldType::Relation(rel) = &field.field_type {
302 format!("{}Model", rel.target_model)
303 } else {
304 "unknown".to_string()
305 };
306
307 JsFieldContext {
308 name: field.logical_name.clone(),
309 logical_name: field.logical_name.clone(),
310 db_name: field.db_name.clone(),
311 ts_type,
312 base_type,
313 is_optional: true,
314 is_array: field.is_array,
315 is_enum: false,
316 has_default: true,
317 default: if field.is_array {
318 "[]".to_string()
319 } else {
320 "null".to_string()
321 },
322 is_pk: false,
323 index: idx,
324 }
325 })
326 .collect();
327
328 let include_fields: Vec<JsIncludeFieldContext> = model
329 .relation_fields()
330 .filter_map(|field| {
331 if let ResolvedFieldType::Relation(rel) = &field.field_type {
332 Some(JsIncludeFieldContext {
333 name: field.logical_name.clone(),
334 target_model: rel.target_model.clone(),
335 target_snake: rel.target_model.to_snake_case(),
336 target_camel: rel.target_model.to_lower_camel_case(),
337 is_array: field.is_array,
338 })
339 } else {
340 None
341 }
342 })
343 .collect();
344
345 let has_numeric_fields = !numeric_fields.is_empty();
346 let has_includes = !include_fields.is_empty();
347 let has_enums = !enum_imports.is_empty();
348
349 context.insert("scalar_fields", &scalar_fields);
350 context.insert("relation_fields", &relation_fields);
351 context.insert("where_input_fields", &where_input_fields);
352 context.insert("create_input_fields", &create_input_fields);
353 context.insert("update_input_fields", &update_input_fields);
354 context.insert("order_by_fields", &order_by_fields);
355 context.insert("include_fields", &include_fields);
356 context.insert("has_includes", &has_includes);
357 context.insert("numeric_fields", &numeric_fields);
358 context.insert("orderable_fields", &orderable_fields);
359 context.insert("has_numeric_fields", &has_numeric_fields);
360 context.insert("has_enums", &has_enums);
361 context.insert(
362 "enum_imports",
363 &enum_imports.into_iter().collect::<Vec<_>>(),
364 );
365 context.insert("has_composite_types", &!composite_type_imports.is_empty());
366 context.insert(
367 "composite_type_imports",
368 &composite_type_imports.into_iter().collect::<Vec<_>>(),
369 );
370
371 let snake = model.logical_name.to_snake_case();
372 let js_code = render("model.js.tera", &context);
373 let dts_code = render("model.d.ts.tera", &context);
374
375 (
376 (format!("{}.js", snake), js_code),
377 (format!("{}.d.ts", snake), dts_code),
378 )
379}
380
381#[allow(clippy::type_complexity)]
385pub fn generate_all_js_models(ir: &SchemaIr) -> (Vec<(String, String)>, Vec<(String, String)>) {
386 let pairs: Vec<((String, String), (String, String))> = ir
387 .models
388 .values()
389 .map(|model| generate_js_model(model, ir))
390 .collect();
391
392 let mut js_models: Vec<(String, String)> = pairs.iter().map(|(js, _)| js.clone()).collect();
393 let mut dts_models: Vec<(String, String)> = pairs.iter().map(|(_, dts)| dts.clone()).collect();
394
395 js_models.sort_by(|a, b| a.0.cmp(&b.0));
396 dts_models.sort_by(|a, b| a.0.cmp(&b.0));
397
398 (js_models, dts_models)
399}
400
401pub fn generate_js_composite_types(
405 composite_types: &HashMap<String, CompositeTypeIr>,
406) -> Option<String> {
407 if composite_types.is_empty() {
408 return None;
409 }
410
411 #[derive(Serialize)]
412 struct CompositeFieldCtx {
413 name: String,
414 ts_type: String,
415 }
416
417 #[derive(Serialize)]
418 struct CompositeTypeCtx {
419 name: String,
420 fields: Vec<CompositeFieldCtx>,
421 }
422
423 let mut type_list: Vec<CompositeTypeCtx> = composite_types
424 .values()
425 .map(|ct| {
426 let fields = ct
427 .fields
428 .iter()
429 .map(|f| {
430 let base = match &f.field_type {
431 ResolvedFieldType::Scalar(s) => scalar_to_ts_type(s).to_string(),
432 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
433 ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
434 ResolvedFieldType::Relation(_) => "unknown".to_string(),
435 };
436 let ts_type = if f.is_array {
437 format!("{}[]", base)
438 } else if !f.is_required {
439 format!("{} | null", base)
440 } else {
441 base
442 };
443 CompositeFieldCtx {
444 name: f.logical_name.clone(),
445 ts_type,
446 }
447 })
448 .collect();
449 CompositeTypeCtx {
450 name: ct.logical_name.clone(),
451 fields,
452 }
453 })
454 .collect();
455 type_list.sort_by(|a, b| a.name.cmp(&b.name));
456
457 let mut context = Context::new();
458 context.insert("composite_types", &type_list);
459
460 Some(render("composite_types.d.ts.tera", &context))
461}
462
463pub fn generate_js_enums(enums: &HashMap<String, EnumIr>) -> (String, String) {
467 #[derive(Serialize)]
468 struct EnumCtx {
469 name: String,
470 variants: Vec<String>,
471 }
472
473 let mut enum_list: Vec<EnumCtx> = enums
474 .values()
475 .map(|e| EnumCtx {
476 name: e.logical_name.clone(),
477 variants: e.variants.clone(),
478 })
479 .collect();
480 enum_list.sort_by(|a, b| a.name.cmp(&b.name));
481
482 let mut context = Context::new();
483 context.insert("enums", &enum_list);
484 let js_code = render("enums.js.tera", &context);
485 let dts_code = render("enums.d.ts.tera", &context);
486 (js_code, dts_code)
487}
488
489pub fn generate_js_client(
493 models: &HashMap<String, ModelIr>,
494 schema_path: &str,
495) -> (String, String) {
496 #[derive(Serialize)]
497 struct ModelCtx {
498 camel_name: String,
500 snake_name: String,
502 delegate_name: String,
504 }
505
506 let mut model_list: Vec<ModelCtx> = models
507 .values()
508 .map(|m| ModelCtx {
509 camel_name: m.logical_name.to_lower_camel_case(),
510 snake_name: m.logical_name.to_snake_case(),
511 delegate_name: format!("{}Delegate", m.logical_name),
512 })
513 .collect();
514 model_list.sort_by(|a, b| a.camel_name.cmp(&b.camel_name));
515
516 let mut context = Context::new();
517 context.insert("models", &model_list);
518 context.insert("schema_path", schema_path);
519 let js_code = render("client.js.tera", &context);
520 let dts_code = render("client.d.ts.tera", &context);
521 (js_code, dts_code)
522}
523
524pub fn generate_js_models_index(js_models: &[(String, String)]) -> (String, String) {
528 let mut modules: Vec<String> = js_models
529 .iter()
530 .map(|(file_name, _)| file_name.trim_end_matches(".js").to_string())
531 .collect();
532 modules.sort();
533
534 let mut context = Context::new();
535 context.insert("model_modules", &modules);
536 let js_code = render("models_index.js.tera", &context);
537 let dts_code = render("models_index.d.ts.tera", &context);
538 (js_code, dts_code)
539}
540
541pub fn js_runtime_files() -> Vec<(&'static str, &'static str)> {
544 vec![
545 (
546 "_errors.js",
547 include_str!("../../templates/js/runtime/_errors.js"),
548 ),
549 (
550 "_errors.d.ts",
551 include_str!("../../templates/js/runtime/_errors.d.ts"),
552 ),
553 (
554 "_protocol.js",
555 include_str!("../../templates/js/runtime/_protocol.js"),
556 ),
557 (
558 "_protocol.d.ts",
559 include_str!("../../templates/js/runtime/_protocol.d.ts"),
560 ),
561 (
562 "_engine.js",
563 include_str!("../../templates/js/runtime/_engine.js"),
564 ),
565 (
566 "_engine.d.ts",
567 include_str!("../../templates/js/runtime/_engine.d.ts"),
568 ),
569 (
570 "_client.js",
571 include_str!("../../templates/js/runtime/_client.js"),
572 ),
573 (
574 "_client.d.ts",
575 include_str!("../../templates/js/runtime/_client.d.ts"),
576 ),
577 (
578 "_transaction.js",
579 include_str!("../../templates/js/runtime/_transaction.js"),
580 ),
581 (
582 "_transaction.d.ts",
583 include_str!("../../templates/js/runtime/_transaction.d.ts"),
584 ),
585 ]
586}