1use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ast::StorageStrategy;
5use nautilus_schema::ir::{FieldIr, ModelIr, ResolvedFieldType, ScalarType, SchemaIr};
6use serde::Serialize;
7use std::collections::{HashMap, HashSet};
8use tera::{Context, Tera};
9
10use crate::type_helpers::{
11 field_to_rust_avg_type, field_to_rust_base_type, field_to_rust_sum_type, field_to_rust_type,
12};
13
14pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
15 let mut tera = Tera::default();
16 tera.add_raw_templates(vec![
17 (
18 "columns_struct.tera",
19 include_str!("../templates/rust/columns_struct.tera"),
20 ),
21 (
22 "column_impl.tera",
23 include_str!("../templates/rust/column_impl.tera"),
24 ),
25 ("create.tera", include_str!("../templates/rust/create.tera")),
26 (
27 "create_many.tera",
28 include_str!("../templates/rust/create_many.tera"),
29 ),
30 (
31 "delegate.tera",
32 include_str!("../templates/rust/delegate.tera"),
33 ),
34 ("delete.tera", include_str!("../templates/rust/delete.tera")),
35 ("enum.tera", include_str!("../templates/rust/enum.tera")),
36 (
37 "find_many.tera",
38 include_str!("../templates/rust/find_many.tera"),
39 ),
40 (
41 "from_row_impl.tera",
42 include_str!("../templates/rust/from_row_impl.tera"),
43 ),
44 (
45 "model_file.tera",
46 include_str!("../templates/rust/model_file.tera"),
47 ),
48 ("lib_rs.tera", include_str!("../templates/rust/lib_rs.tera")),
49 (
50 "model_struct.tera",
51 include_str!("../templates/rust/model_struct.tera"),
52 ),
53 ("update.tera", include_str!("../templates/rust/update.tera")),
54 (
55 "composite_type.tera",
56 include_str!("../templates/rust/composite_type.tera"),
57 ),
58 ])
59 .expect("embedded Rust templates must parse");
60 tera
61});
62
63fn render(template: &str, ctx: &Context) -> String {
64 TEMPLATES
65 .render(template, ctx)
66 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
67}
68
69#[derive(Debug, Clone, Serialize)]
77struct FieldContext {
78 name: String,
79 logical_name: String,
80 db_name: String,
81 rust_type: String,
82 base_rust_type: String,
83 column_type: String,
84 read_hint_expr: String,
85 variant_name: String,
86 is_array: bool,
87 index: usize,
88 is_pk: bool,
89 is_optional: bool,
92 is_updated_at: bool,
94 is_computed: bool,
96}
97
98#[derive(Debug, Clone, Serialize)]
99struct AggregateFieldContext {
100 name: String,
101 logical_name: String,
102 rust_type: String,
103 avg_rust_type: String,
104 sum_rust_type: String,
105 variant_name: String,
106}
107
108#[derive(Debug, Clone, Serialize)]
111struct PkFieldContext {
112 name: String,
114 db_name: String,
116}
117
118#[derive(Debug, Clone, Serialize)]
119struct RelationContext {
120 field_name: String,
121 target_model: String,
122 target_table: String,
123 is_array: bool,
124 fields: Vec<String>,
125 references: Vec<String>,
126 fields_db: Vec<String>,
127 references_db: Vec<String>,
128 target_scalar_fields: Vec<FieldContext>,
129}
130
131fn resolve_inverse_relation_fields(
132 source_model_name: &str,
133 relation_name: Option<&str>,
134 target_model: &ModelIr,
135) -> (Vec<String>, Vec<String>) {
136 let inverse = target_model.relation_fields().find(|field| {
137 if let ResolvedFieldType::Relation(inv_rel) = &field.field_type {
138 if inv_rel.target_model != source_model_name {
139 return false;
140 }
141
142 match (relation_name, inv_rel.name.as_deref()) {
143 (Some(expected), Some(actual)) => actual == expected,
144 (Some(_), None) => false,
145 (None, Some(_)) => false,
146 (None, None) => true,
147 }
148 } else {
149 false
150 }
151 });
152
153 let Some(inverse_field) = inverse else {
154 return (vec![], vec![]);
155 };
156 let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type else {
157 return (vec![], vec![]);
158 };
159
160 (inv_rel.references.clone(), inv_rel.fields.clone())
161}
162
163fn field_read_hint_expr(field: &FieldIr) -> String {
164 if field.is_array && field.storage_strategy == Some(StorageStrategy::Json) {
165 return "Some(crate::ValueHint::Json)".to_string();
166 }
167
168 match &field.field_type {
169 ResolvedFieldType::Scalar(ScalarType::Decimal { .. }) => {
170 "Some(crate::ValueHint::Decimal)".to_string()
171 }
172 ResolvedFieldType::Scalar(ScalarType::DateTime) => {
173 "Some(crate::ValueHint::DateTime)".to_string()
174 }
175 ResolvedFieldType::Scalar(ScalarType::Json | ScalarType::Jsonb) => {
176 "Some(crate::ValueHint::Json)".to_string()
177 }
178 ResolvedFieldType::Scalar(ScalarType::Uuid) => "Some(crate::ValueHint::Uuid)".to_string(),
179 ResolvedFieldType::CompositeType { .. }
180 if field.storage_strategy == Some(StorageStrategy::Json) =>
181 {
182 "Some(crate::ValueHint::Json)".to_string()
183 }
184 _ => "None".to_string(),
185 }
186}
187
188pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
193 let mut context = Context::new();
194
195 context.insert("model_name", &model.logical_name);
196 context.insert("table_name", &model.db_name);
197 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
198 context.insert("columns_name", &format!("{}Columns", model.logical_name));
199 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
200 context.insert("create_name", &format!("{}Create", model.logical_name));
201 context.insert(
202 "create_many_name",
203 &format!("{}CreateMany", model.logical_name),
204 );
205 context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
206 context.insert("update_name", &format!("{}Update", model.logical_name));
207 context.insert("delete_name", &format!("{}Delete", model.logical_name));
208
209 let pk_field_names = model.primary_key.fields();
210 context.insert("primary_key_fields", &pk_field_names);
211
212 let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
213 .iter()
214 .filter_map(|logical| {
215 model
216 .scalar_fields()
217 .find(|f| f.logical_name.as_str() == *logical)
218 .map(|f| PkFieldContext {
219 name: f.logical_name.to_snake_case(),
220 db_name: f.db_name.clone(),
221 })
222 })
223 .collect();
224 context.insert("pk_fields_with_db", &pk_fields_with_db);
225
226 let mut enum_imports = HashSet::new();
227 let mut composite_type_imports = HashSet::new();
228
229 let mut scalar_fields: Vec<FieldContext> = Vec::new();
230 let mut create_fields: Vec<FieldContext> = Vec::new();
231 let mut updated_at_fields: Vec<FieldContext> = Vec::new();
232 let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
233 let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
234
235 for (idx, field) in model.scalar_fields().enumerate() {
236 match &field.field_type {
237 ResolvedFieldType::Enum { enum_name } => {
238 if ir.enums.contains_key(enum_name) {
239 enum_imports.insert(enum_name.clone());
240 }
241 }
242 ResolvedFieldType::CompositeType { type_name } => {
243 if ir.composite_types.contains_key(type_name) {
244 composite_type_imports.insert(type_name.clone());
245 }
246 }
247 _ => {}
248 }
249
250 let column_type = match &field.field_type {
251 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
252 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
253 _ => String::new(),
254 };
255 let is_pk = pk_field_names.contains(&field.logical_name.as_str());
256 let base_rust_type = field_to_rust_base_type(field);
257
258 let field_ctx = FieldContext {
259 name: field.logical_name.to_snake_case(),
260 logical_name: field.logical_name.clone(),
261 db_name: field.db_name.clone(),
262 rust_type: field_to_rust_type(field),
263 base_rust_type: base_rust_type.clone(),
264 column_type,
265 read_hint_expr: field_read_hint_expr(field),
266 variant_name: field.logical_name.to_pascal_case(),
267 is_array: field.is_array,
268 index: idx,
269 is_pk,
270 is_optional: !field.is_required && !field.is_array,
271 is_updated_at: field.is_updated_at,
272 is_computed: field.computed.is_some(),
273 };
274
275 create_fields.push(field_ctx.clone());
276
277 if field.is_updated_at {
278 updated_at_fields.push(field_ctx.clone());
279 }
280
281 scalar_fields.push(field_ctx);
282
283 let is_numeric = matches!(
284 &field.field_type,
285 ResolvedFieldType::Scalar(ScalarType::Int)
286 | ResolvedFieldType::Scalar(ScalarType::BigInt)
287 | ResolvedFieldType::Scalar(ScalarType::Float)
288 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
289 );
290 if is_numeric {
291 numeric_fields.push(AggregateFieldContext {
292 name: field.logical_name.to_snake_case(),
293 logical_name: field.logical_name.clone(),
294 rust_type: base_rust_type.clone(),
295 avg_rust_type: field_to_rust_avg_type(field),
296 sum_rust_type: field_to_rust_sum_type(field),
297 variant_name: field.logical_name.to_pascal_case(),
298 });
299 }
300
301 let is_non_orderable = matches!(
302 &field.field_type,
303 ResolvedFieldType::Scalar(ScalarType::Boolean)
304 | ResolvedFieldType::Scalar(ScalarType::Json)
305 | ResolvedFieldType::Scalar(ScalarType::Bytes)
306 );
307 if !is_non_orderable {
308 orderable_fields.push(AggregateFieldContext {
309 name: field.logical_name.to_snake_case(),
310 logical_name: field.logical_name.clone(),
311 rust_type: base_rust_type,
312 avg_rust_type: String::new(),
313 sum_rust_type: String::new(),
314 variant_name: field.logical_name.to_pascal_case(),
315 });
316 }
317 }
318
319 let mut relation_imports = HashSet::new();
320 for field in model.relation_fields() {
321 if let ResolvedFieldType::Relation(rel) = &field.field_type {
322 relation_imports.insert(rel.target_model.clone());
323 }
324 }
325
326 context.insert("has_enums", &!enum_imports.is_empty());
327 context.insert(
328 "enum_imports",
329 &enum_imports.into_iter().collect::<Vec<_>>(),
330 );
331 context.insert("has_relations", &!relation_imports.is_empty());
332 context.insert(
333 "relation_imports",
334 &relation_imports.into_iter().collect::<Vec<_>>(),
335 );
336 context.insert("has_composite_types", &!composite_type_imports.is_empty());
337 context.insert(
338 "composite_type_imports",
339 &composite_type_imports.into_iter().collect::<Vec<_>>(),
340 );
341
342 let relation_fields: Vec<FieldContext> = model
343 .relation_fields()
344 .map(|field| FieldContext {
345 name: field.logical_name.to_snake_case(),
346 logical_name: field.logical_name.clone(),
347 db_name: field.db_name.clone(),
348 rust_type: field_to_rust_type(field),
349 base_rust_type: field_to_rust_base_type(field),
350 column_type: String::new(),
351 read_hint_expr: "None".to_string(),
352 variant_name: field.logical_name.to_pascal_case(),
353 is_array: field.is_array,
354 index: 0,
355 is_pk: false,
356 is_optional: true,
357 is_updated_at: false,
358 is_computed: false,
359 })
360 .collect();
361
362 let relations: Vec<RelationContext> = model
363 .relation_fields()
364 .filter_map(|field| {
365 let ResolvedFieldType::Relation(rel) = &field.field_type else {
366 return None;
367 };
368 let target_model = ir.models.get(&rel.target_model)?;
369
370 let target_pk_names = target_model.primary_key.fields();
371 let target_scalar_fields: Vec<FieldContext> = target_model
372 .scalar_fields()
373 .enumerate()
374 .map(|(idx, f)| {
375 let column_type = match &f.field_type {
376 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
377 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
378 _ => String::new(),
379 };
380 let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
381 FieldContext {
382 name: f.logical_name.to_snake_case(),
383 logical_name: f.logical_name.clone(),
384 db_name: f.db_name.clone(),
385 rust_type: field_to_rust_type(f),
386 base_rust_type: field_to_rust_base_type(f),
387 column_type,
388 read_hint_expr: field_read_hint_expr(f),
389 variant_name: f.logical_name.to_pascal_case(),
390 is_array: f.is_array,
391 index: idx,
392 is_pk: f_is_pk,
393 is_optional: !f.is_required && !f.is_array,
394 is_updated_at: f.is_updated_at,
395 is_computed: f.computed.is_some(),
396 }
397 })
398 .collect();
399
400 let (fields, references) = if rel.fields.is_empty() {
401 resolve_inverse_relation_fields(
402 &model.logical_name,
403 rel.name.as_deref(),
404 target_model,
405 )
406 } else {
407 (rel.fields.clone(), rel.references.clone())
408 };
409
410 let fields_db: Vec<String> = fields
411 .iter()
412 .filter_map(|logical_name| {
413 model
414 .fields
415 .iter()
416 .find(|f| &f.logical_name == logical_name)
417 .map(|f| f.db_name.clone())
418 })
419 .collect();
420
421 let references_db: Vec<String> = references
422 .iter()
423 .filter_map(|logical_name| {
424 target_model
425 .fields
426 .iter()
427 .find(|f| &f.logical_name == logical_name)
428 .map(|f| f.db_name.clone())
429 })
430 .collect();
431
432 Some(RelationContext {
433 field_name: field.logical_name.to_snake_case(),
434 target_model: rel.target_model.clone(),
435 target_table: target_model.db_name.clone(),
436 is_array: field.is_array,
437 fields,
438 references,
439 fields_db,
440 references_db,
441 target_scalar_fields,
442 })
443 })
444 .collect();
445
446 context.insert("scalar_fields", &scalar_fields);
447 context.insert("relation_fields", &relation_fields);
448 context.insert("relations", &relations);
449 context.insert("create_fields", &create_fields);
450 context.insert("updated_at_fields", &updated_at_fields);
451 context.insert("all_scalar_fields", &scalar_fields);
452 context.insert("numeric_fields", &numeric_fields);
453 context.insert("orderable_fields", &orderable_fields);
454 context.insert("has_numeric_fields", &!numeric_fields.is_empty());
455 context.insert("has_orderable_fields", &!orderable_fields.is_empty());
456 context.insert("is_async", &is_async);
457
458 render("model_file.tera", &context)
459}
460
461pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
465 let mut generated = HashMap::new();
466
467 for (model_name, model_ir) in &ir.models {
468 let code = generate_model(model_ir, ir, is_async);
469 generated.insert(model_name.clone(), code);
470 }
471
472 generated
473}