1use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ast::StorageStrategy;
5use nautilus_schema::ir::{FieldIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr};
6use serde::Serialize;
7use std::collections::{HashMap, HashSet};
8use tera::{Context, Tera};
9
10use crate::type_helpers::{
11 field_to_rust_avg_type, field_to_rust_base_type, field_to_rust_sum_type, field_to_rust_type,
12};
13
14pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
15 let mut tera = Tera::default();
16 tera.add_raw_templates(vec![
17 (
18 "columns_struct.tera",
19 include_str!("../templates/rust/columns_struct.tera"),
20 ),
21 (
22 "column_impl.tera",
23 include_str!("../templates/rust/column_impl.tera"),
24 ),
25 ("create.tera", include_str!("../templates/rust/create.tera")),
26 (
27 "create_many.tera",
28 include_str!("../templates/rust/create_many.tera"),
29 ),
30 (
31 "delegate.tera",
32 include_str!("../templates/rust/delegate.tera"),
33 ),
34 ("delete.tera", include_str!("../templates/rust/delete.tera")),
35 ("enum.tera", include_str!("../templates/rust/enum.tera")),
36 (
37 "find_many.tera",
38 include_str!("../templates/rust/find_many.tera"),
39 ),
40 (
41 "from_row_impl.tera",
42 include_str!("../templates/rust/from_row_impl.tera"),
43 ),
44 (
45 "model_file.tera",
46 include_str!("../templates/rust/model_file.tera"),
47 ),
48 (
49 "model_struct.tera",
50 include_str!("../templates/rust/model_struct.tera"),
51 ),
52 ("update.tera", include_str!("../templates/rust/update.tera")),
53 (
54 "composite_type.tera",
55 include_str!("../templates/rust/composite_type.tera"),
56 ),
57 ])
58 .expect("embedded Rust templates must parse");
59 tera
60});
61
62fn render(template: &str, ctx: &Context) -> String {
63 TEMPLATES
64 .render(template, ctx)
65 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
66}
67
68#[derive(Debug, Clone, Serialize)]
76struct FieldContext {
77 name: String,
78 logical_name: String,
79 db_name: String,
80 rust_type: String,
81 base_rust_type: String,
82 column_type: String,
83 read_hint_expr: String,
84 variant_name: String,
85 is_array: bool,
86 index: usize,
87 is_pk: bool,
88 is_optional: bool,
91 is_updated_at: bool,
93 is_computed: bool,
95}
96
97#[derive(Debug, Clone, Serialize)]
98struct AggregateFieldContext {
99 name: String,
100 logical_name: String,
101 rust_type: String,
102 avg_rust_type: String,
103 sum_rust_type: String,
104 variant_name: String,
105}
106
107#[derive(Debug, Clone, Serialize)]
110struct PkFieldContext {
111 name: String,
113 db_name: String,
115}
116
117#[derive(Debug, Clone, Serialize)]
118struct RelationContext {
119 field_name: String,
120 target_model: String,
121 target_table: String,
122 is_array: bool,
123 fields: Vec<String>,
124 references: Vec<String>,
125 fields_db: Vec<String>,
126 references_db: Vec<String>,
127 target_scalar_fields: Vec<FieldContext>,
128}
129
130fn resolve_inverse_relation_fields(
131 source_model_name: &str,
132 relation_name: Option<&str>,
133 target_model: &ModelIr,
134) -> (Vec<String>, Vec<String>) {
135 let inverse = target_model.relation_fields().find(|field| {
136 if let ResolvedFieldType::Relation(inv_rel) = &field.field_type {
137 if inv_rel.target_model != source_model_name {
138 return false;
139 }
140
141 match (relation_name, inv_rel.name.as_deref()) {
142 (Some(expected), Some(actual)) => actual == expected,
143 (Some(_), None) => false,
144 (None, Some(_)) => false,
145 (None, None) => true,
146 }
147 } else {
148 false
149 }
150 });
151
152 if let Some(inverse_field) = inverse {
153 if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type {
154 return (inv_rel.references.clone(), inv_rel.fields.clone());
155 }
156 }
157
158 (vec![], vec![])
159}
160
161fn field_read_hint_expr(field: &FieldIr) -> String {
162 if field.is_array && field.storage_strategy == Some(StorageStrategy::Json) {
163 return "Some(crate::ValueHint::Json)".to_string();
164 }
165
166 match &field.field_type {
167 ResolvedFieldType::Scalar(ScalarType::Decimal { .. }) => {
168 "Some(crate::ValueHint::Decimal)".to_string()
169 }
170 ResolvedFieldType::Scalar(ScalarType::DateTime) => {
171 "Some(crate::ValueHint::DateTime)".to_string()
172 }
173 ResolvedFieldType::Scalar(ScalarType::Json | ScalarType::Jsonb) => {
174 "Some(crate::ValueHint::Json)".to_string()
175 }
176 ResolvedFieldType::Scalar(ScalarType::Uuid) => "Some(crate::ValueHint::Uuid)".to_string(),
177 ResolvedFieldType::CompositeType { .. }
178 if field.storage_strategy == Some(StorageStrategy::Json) =>
179 {
180 "Some(crate::ValueHint::Json)".to_string()
181 }
182 _ => "None".to_string(),
183 }
184}
185
186pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
191 let mut context = Context::new();
192
193 context.insert("model_name", &model.logical_name);
194 context.insert("table_name", &model.db_name);
195 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
196 context.insert("columns_name", &format!("{}Columns", model.logical_name));
197 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
198 context.insert("create_name", &format!("{}Create", model.logical_name));
199 context.insert(
200 "create_many_name",
201 &format!("{}CreateMany", model.logical_name),
202 );
203 context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
204 context.insert("update_name", &format!("{}Update", model.logical_name));
205 context.insert("delete_name", &format!("{}Delete", model.logical_name));
206
207 let pk_field_names = model.primary_key.fields();
208 context.insert("primary_key_fields", &pk_field_names);
209
210 let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
211 .iter()
212 .filter_map(|logical| {
213 model
214 .scalar_fields()
215 .find(|f| f.logical_name.as_str() == *logical)
216 .map(|f| PkFieldContext {
217 name: f.logical_name.to_snake_case(),
218 db_name: f.db_name.clone(),
219 })
220 })
221 .collect();
222 context.insert("pk_fields_with_db", &pk_fields_with_db);
223
224 let mut enum_imports = HashSet::new();
225 let mut composite_type_imports = HashSet::new();
226
227 let mut scalar_fields: Vec<FieldContext> = Vec::new();
228 let mut create_fields: Vec<FieldContext> = Vec::new();
229 let mut updated_at_fields: Vec<FieldContext> = Vec::new();
230 let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
231 let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
232
233 for (idx, field) in model.scalar_fields().enumerate() {
234 match &field.field_type {
235 ResolvedFieldType::Enum { enum_name } => {
236 if ir.enums.contains_key(enum_name) {
237 enum_imports.insert(enum_name.clone());
238 }
239 }
240 ResolvedFieldType::CompositeType { type_name } => {
241 if ir.composite_types.contains_key(type_name) {
242 composite_type_imports.insert(type_name.clone());
243 }
244 }
245 _ => {}
246 }
247
248 let column_type = match &field.field_type {
249 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
250 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
251 _ => String::new(),
252 };
253 let is_pk = pk_field_names.contains(&field.logical_name.as_str());
254 let base_rust_type = field_to_rust_base_type(field);
255
256 let field_ctx = FieldContext {
257 name: field.logical_name.to_snake_case(),
258 logical_name: field.logical_name.clone(),
259 db_name: field.db_name.clone(),
260 rust_type: field_to_rust_type(field),
261 base_rust_type: base_rust_type.clone(),
262 column_type,
263 read_hint_expr: field_read_hint_expr(field),
264 variant_name: field.logical_name.to_pascal_case(),
265 is_array: field.is_array,
266 index: idx,
267 is_pk,
268 is_optional: !field.is_required && !field.is_array,
269 is_updated_at: field.is_updated_at,
270 is_computed: field.computed.is_some(),
271 };
272
273 create_fields.push(field_ctx.clone());
274
275 if field.is_updated_at {
276 updated_at_fields.push(field_ctx.clone());
277 }
278
279 scalar_fields.push(field_ctx);
280
281 let is_numeric = matches!(
282 &field.field_type,
283 ResolvedFieldType::Scalar(ScalarType::Int)
284 | ResolvedFieldType::Scalar(ScalarType::BigInt)
285 | ResolvedFieldType::Scalar(ScalarType::Float)
286 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
287 );
288 if is_numeric {
289 numeric_fields.push(AggregateFieldContext {
290 name: field.logical_name.to_snake_case(),
291 logical_name: field.logical_name.clone(),
292 rust_type: base_rust_type.clone(),
293 avg_rust_type: field_to_rust_avg_type(field),
294 sum_rust_type: field_to_rust_sum_type(field),
295 variant_name: field.logical_name.to_pascal_case(),
296 });
297 }
298
299 let is_non_orderable = matches!(
300 &field.field_type,
301 ResolvedFieldType::Scalar(ScalarType::Boolean)
302 | ResolvedFieldType::Scalar(ScalarType::Json)
303 | ResolvedFieldType::Scalar(ScalarType::Bytes)
304 );
305 if !is_non_orderable {
306 orderable_fields.push(AggregateFieldContext {
307 name: field.logical_name.to_snake_case(),
308 logical_name: field.logical_name.clone(),
309 rust_type: base_rust_type,
310 avg_rust_type: String::new(),
311 sum_rust_type: String::new(),
312 variant_name: field.logical_name.to_pascal_case(),
313 });
314 }
315 }
316
317 let mut relation_imports = HashSet::new();
318 for field in model.relation_fields() {
319 if let ResolvedFieldType::Relation(rel) = &field.field_type {
320 relation_imports.insert(rel.target_model.clone());
321 }
322 }
323
324 context.insert("has_enums", &!enum_imports.is_empty());
325 context.insert(
326 "enum_imports",
327 &enum_imports.into_iter().collect::<Vec<_>>(),
328 );
329 context.insert("has_relations", &!relation_imports.is_empty());
330 context.insert(
331 "relation_imports",
332 &relation_imports.into_iter().collect::<Vec<_>>(),
333 );
334 context.insert("has_composite_types", &!composite_type_imports.is_empty());
335 context.insert(
336 "composite_type_imports",
337 &composite_type_imports.into_iter().collect::<Vec<_>>(),
338 );
339
340 let relation_fields: Vec<FieldContext> = model
341 .relation_fields()
342 .map(|field| FieldContext {
343 name: field.logical_name.to_snake_case(),
344 logical_name: field.logical_name.clone(),
345 db_name: field.db_name.clone(),
346 rust_type: field_to_rust_type(field),
347 base_rust_type: field_to_rust_base_type(field),
348 column_type: String::new(),
349 read_hint_expr: "None".to_string(),
350 variant_name: field.logical_name.to_pascal_case(),
351 is_array: field.is_array,
352 index: 0,
353 is_pk: false,
354 is_optional: true,
355 is_updated_at: false,
356 is_computed: false,
357 })
358 .collect();
359
360 let relations: Vec<RelationContext> = model
361 .relation_fields()
362 .filter_map(|field| {
363 if let ResolvedFieldType::Relation(rel) = &field.field_type {
364 if let Some(target_model) = ir.models.get(&rel.target_model) {
365 let target_pk_names = target_model.primary_key.fields();
366 let target_scalar_fields: Vec<FieldContext> = target_model
367 .scalar_fields()
368 .enumerate()
369 .map(|(idx, f)| {
370 let column_type = match &f.field_type {
371 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
372 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
373 _ => String::new(),
374 };
375 let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
376 FieldContext {
377 name: f.logical_name.to_snake_case(),
378 logical_name: f.logical_name.clone(),
379 db_name: f.db_name.clone(),
380 rust_type: field_to_rust_type(f),
381 base_rust_type: field_to_rust_base_type(f),
382 column_type,
383 read_hint_expr: field_read_hint_expr(f),
384 variant_name: f.logical_name.to_pascal_case(),
385 is_array: f.is_array,
386 index: idx,
387 is_pk: f_is_pk,
388 is_optional: !f.is_required && !f.is_array,
389 is_updated_at: f.is_updated_at,
390 is_computed: f.computed.is_some(),
391 }
392 })
393 .collect();
394
395 let (fields, references) = if rel.fields.is_empty() {
396 resolve_inverse_relation_fields(
397 &model.logical_name,
398 rel.name.as_deref(),
399 target_model,
400 )
401 } else {
402 (rel.fields.clone(), rel.references.clone())
403 };
404
405 let fields_db: Vec<String> = fields
406 .iter()
407 .filter_map(|logical_name| {
408 model
409 .fields
410 .iter()
411 .find(|f| &f.logical_name == logical_name)
412 .map(|f| f.db_name.clone())
413 })
414 .collect();
415
416 let references_db: Vec<String> = references
417 .iter()
418 .filter_map(|logical_name| {
419 target_model
420 .fields
421 .iter()
422 .find(|f| &f.logical_name == logical_name)
423 .map(|f| f.db_name.clone())
424 })
425 .collect();
426
427 Some(RelationContext {
428 field_name: field.logical_name.to_snake_case(),
429 target_model: rel.target_model.clone(),
430 target_table: target_model.db_name.clone(),
431 is_array: field.is_array,
432 fields,
433 references,
434 fields_db,
435 references_db,
436 target_scalar_fields,
437 })
438 } else {
439 None
440 }
441 } else {
442 None
443 }
444 })
445 .collect();
446
447 context.insert("scalar_fields", &scalar_fields);
448 context.insert("relation_fields", &relation_fields);
449 context.insert("relations", &relations);
450 context.insert("create_fields", &create_fields);
451 context.insert("updated_at_fields", &updated_at_fields);
452 context.insert("all_scalar_fields", &scalar_fields);
453 context.insert("numeric_fields", &numeric_fields);
454 context.insert("orderable_fields", &orderable_fields);
455 context.insert("has_numeric_fields", &!numeric_fields.is_empty());
456 context.insert("has_orderable_fields", &!orderable_fields.is_empty());
457 context.insert("is_async", &is_async);
458
459 render("model_file.tera", &context)
460}
461
462pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
466 let mut generated = HashMap::new();
467
468 for (model_name, model_ir) in &ir.models {
469 let code = generate_model(model_ir, ir, is_async);
470 generated.insert(model_name.clone(), code);
471 }
472
473 generated
474}