1use heck::ToSnakeCase;
4use nautilus_schema::ir::{ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::type_helpers::{field_to_rust_type, is_auto_generated};
10
11pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
12 let mut tera = Tera::default();
13 tera.add_raw_templates(vec![
14 (
15 "columns_struct.tera",
16 include_str!("../templates/rust/columns_struct.tera"),
17 ),
18 (
19 "column_impl.tera",
20 include_str!("../templates/rust/column_impl.tera"),
21 ),
22 ("create.tera", include_str!("../templates/rust/create.tera")),
23 (
24 "create_many.tera",
25 include_str!("../templates/rust/create_many.tera"),
26 ),
27 (
28 "delegate.tera",
29 include_str!("../templates/rust/delegate.tera"),
30 ),
31 ("delete.tera", include_str!("../templates/rust/delete.tera")),
32 ("enum.tera", include_str!("../templates/rust/enum.tera")),
33 (
34 "find_many.tera",
35 include_str!("../templates/rust/find_many.tera"),
36 ),
37 (
38 "from_row_impl.tera",
39 include_str!("../templates/rust/from_row_impl.tera"),
40 ),
41 (
42 "model_file.tera",
43 include_str!("../templates/rust/model_file.tera"),
44 ),
45 (
46 "model_struct.tera",
47 include_str!("../templates/rust/model_struct.tera"),
48 ),
49 ("update.tera", include_str!("../templates/rust/update.tera")),
50 (
51 "composite_type.tera",
52 include_str!("../templates/rust/composite_type.tera"),
53 ),
54 ])
55 .expect("embedded Rust templates must parse");
56 tera
57});
58
59fn render(template: &str, ctx: &Context) -> String {
60 TEMPLATES
61 .render(template, ctx)
62 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
63}
64
65#[derive(Debug, Clone, Serialize)]
73struct FieldContext {
74 name: String,
75 db_name: String,
76 rust_type: String,
77 column_type: String,
78 is_array: bool,
79 index: usize,
80 is_pk: bool,
81 is_optional: bool,
84 is_updated_at: bool,
86 is_computed: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
93struct PkFieldContext {
94 name: String,
96 db_name: String,
98}
99
100#[derive(Debug, Clone, Serialize)]
101struct RelationContext {
102 field_name: String,
103 target_model: String,
104 target_table: String,
105 is_array: bool,
106 fields: Vec<String>,
107 references: Vec<String>,
108 fields_db: Vec<String>,
109 references_db: Vec<String>,
110 target_scalar_fields: Vec<FieldContext>,
111}
112
113pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
118 let mut context = Context::new();
119
120 context.insert("model_name", &model.logical_name);
121 context.insert("table_name", &model.db_name);
122 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
123 context.insert("columns_name", &format!("{}Columns", model.logical_name));
124 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
125 context.insert("create_name", &format!("{}Create", model.logical_name));
126 context.insert(
127 "create_many_name",
128 &format!("{}CreateMany", model.logical_name),
129 );
130 context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
131 context.insert("update_name", &format!("{}Update", model.logical_name));
132 context.insert("delete_name", &format!("{}Delete", model.logical_name));
133
134 let pk_field_names = model.primary_key.fields();
135 context.insert("primary_key_fields", &pk_field_names);
136
137 let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
138 .iter()
139 .filter_map(|logical| {
140 model
141 .scalar_fields()
142 .find(|f| f.logical_name.as_str() == *logical)
143 .map(|f| PkFieldContext {
144 name: f.logical_name.to_snake_case(),
145 db_name: f.db_name.clone(),
146 })
147 })
148 .collect();
149 context.insert("pk_fields_with_db", &pk_fields_with_db);
150
151 let mut enum_imports = HashSet::new();
153 let mut composite_type_imports = HashSet::new();
154
155 let mut scalar_fields: Vec<FieldContext> = Vec::new();
156 let mut create_fields: Vec<FieldContext> = Vec::new();
157 let mut updated_at_fields: Vec<FieldContext> = Vec::new();
158
159 for (idx, field) in model.scalar_fields().enumerate() {
160 match &field.field_type {
162 ResolvedFieldType::Enum { enum_name } => {
163 if ir.enums.contains_key(enum_name) {
164 enum_imports.insert(enum_name.clone());
165 }
166 }
167 ResolvedFieldType::CompositeType { type_name } => {
168 if ir.composite_types.contains_key(type_name) {
169 composite_type_imports.insert(type_name.clone());
170 }
171 }
172 _ => {}
173 }
174
175 let column_type = match &field.field_type {
177 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
178 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
179 _ => String::new(),
180 };
181 let is_pk = pk_field_names.contains(&field.logical_name.as_str());
182 let auto_generated = is_auto_generated(field);
183
184 let field_ctx = FieldContext {
185 name: field.logical_name.to_snake_case(),
186 db_name: field.db_name.clone(),
187 rust_type: field_to_rust_type(field),
188 column_type,
189 is_array: field.is_array,
190 index: idx,
191 is_pk,
192 is_optional: !field.is_required && !field.is_array,
193 is_updated_at: field.is_updated_at,
194 is_computed: field.computed.is_some(),
195 };
196
197 if !auto_generated {
199 create_fields.push(field_ctx.clone());
200 }
201
202 if field.is_updated_at {
204 updated_at_fields.push(field_ctx.clone());
205 }
206
207 scalar_fields.push(field_ctx);
208 }
209
210 let mut relation_imports = HashSet::new();
212 for field in model.relation_fields() {
213 if let ResolvedFieldType::Relation(rel) = &field.field_type {
214 relation_imports.insert(rel.target_model.clone());
215 }
216 }
217
218 context.insert("has_enums", &!enum_imports.is_empty());
219 context.insert(
220 "enum_imports",
221 &enum_imports.into_iter().collect::<Vec<_>>(),
222 );
223 context.insert("has_relations", &!relation_imports.is_empty());
224 context.insert(
225 "relation_imports",
226 &relation_imports.into_iter().collect::<Vec<_>>(),
227 );
228 context.insert("has_composite_types", &!composite_type_imports.is_empty());
229 context.insert(
230 "composite_type_imports",
231 &composite_type_imports.into_iter().collect::<Vec<_>>(),
232 );
233
234 let relation_fields: Vec<FieldContext> = model
235 .relation_fields()
236 .map(|field| FieldContext {
237 name: field.logical_name.to_snake_case(),
238 db_name: field.db_name.clone(),
239 rust_type: field_to_rust_type(field),
240 column_type: String::new(),
241 is_array: field.is_array,
242 index: 0,
243 is_pk: false,
244 is_optional: true,
245 is_updated_at: false,
246 is_computed: false,
247 })
248 .collect();
249
250 let relations: Vec<RelationContext> = model
251 .relation_fields()
252 .filter_map(|field| {
253 if let ResolvedFieldType::Relation(rel) = &field.field_type {
254 if let Some(target_model) = ir.models.get(&rel.target_model) {
255 let target_pk_names = target_model.primary_key.fields();
256 let target_scalar_fields: Vec<FieldContext> = target_model
257 .scalar_fields()
258 .enumerate()
259 .map(|(idx, f)| {
260 let column_type = match &f.field_type {
261 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
262 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
263 _ => String::new(),
264 };
265 let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
266 FieldContext {
267 name: f.logical_name.to_snake_case(),
268 db_name: f.db_name.clone(),
269 rust_type: field_to_rust_type(f),
270 column_type,
271 is_array: f.is_array,
272 index: idx,
273 is_pk: f_is_pk,
274 is_optional: !f.is_required && !f.is_array,
275 is_updated_at: f.is_updated_at,
276 is_computed: f.computed.is_some(),
277 }
278 })
279 .collect();
280
281 let (fields, references) = if rel.fields.is_empty() {
282 let inverse = target_model.relation_fields().find(|f| {
283 if let ResolvedFieldType::Relation(inv_rel) = &f.field_type {
284 inv_rel.target_model == model.logical_name
285 } else {
286 false
287 }
288 });
289
290 if let Some(inverse_field) = inverse {
291 if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type
292 {
293 (inv_rel.references.clone(), inv_rel.fields.clone())
295 } else {
296 (vec![], vec![])
297 }
298 } else {
299 (vec![], vec![])
300 }
301 } else {
302 (rel.fields.clone(), rel.references.clone())
303 };
304
305 let fields_db: Vec<String> = fields
306 .iter()
307 .filter_map(|logical_name| {
308 model
309 .fields
310 .iter()
311 .find(|f| &f.logical_name == logical_name)
312 .map(|f| f.db_name.clone())
313 })
314 .collect();
315
316 let references_db: Vec<String> = references
317 .iter()
318 .filter_map(|logical_name| {
319 target_model
320 .fields
321 .iter()
322 .find(|f| &f.logical_name == logical_name)
323 .map(|f| f.db_name.clone())
324 })
325 .collect();
326
327 Some(RelationContext {
328 field_name: field.logical_name.to_snake_case(),
329 target_model: rel.target_model.clone(),
330 target_table: target_model.db_name.clone(),
331 is_array: field.is_array,
332 fields,
333 references,
334 fields_db,
335 references_db,
336 target_scalar_fields,
337 })
338 } else {
339 None
340 }
341 } else {
342 None
343 }
344 })
345 .collect();
346
347 context.insert("scalar_fields", &scalar_fields);
348 context.insert("relation_fields", &relation_fields);
349 context.insert("relations", &relations);
350 context.insert("create_fields", &create_fields);
351 context.insert("updated_at_fields", &updated_at_fields);
352 context.insert("all_scalar_fields", &scalar_fields);
353 context.insert("is_async", &is_async);
354
355 render("model_file.tera", &context)
356}
357
358pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
362 let mut generated = HashMap::new();
363
364 for (model_name, model_ir) in &ir.models {
365 let code = generate_model(model_ir, ir, is_async);
366 generated.insert(model_name.clone(), code);
367 }
368
369 generated
370}