1use heck::ToSnakeCase;
4use nautilus_schema::ir::{ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::type_helpers::field_to_rust_type;
10
11pub static TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
12 let mut tera = Tera::default();
13 tera.add_raw_templates(vec![
14 (
15 "columns_struct.tera",
16 include_str!("../templates/rust/columns_struct.tera"),
17 ),
18 (
19 "column_impl.tera",
20 include_str!("../templates/rust/column_impl.tera"),
21 ),
22 ("create.tera", include_str!("../templates/rust/create.tera")),
23 (
24 "create_many.tera",
25 include_str!("../templates/rust/create_many.tera"),
26 ),
27 (
28 "delegate.tera",
29 include_str!("../templates/rust/delegate.tera"),
30 ),
31 ("delete.tera", include_str!("../templates/rust/delete.tera")),
32 ("enum.tera", include_str!("../templates/rust/enum.tera")),
33 (
34 "find_many.tera",
35 include_str!("../templates/rust/find_many.tera"),
36 ),
37 (
38 "from_row_impl.tera",
39 include_str!("../templates/rust/from_row_impl.tera"),
40 ),
41 (
42 "model_file.tera",
43 include_str!("../templates/rust/model_file.tera"),
44 ),
45 (
46 "model_struct.tera",
47 include_str!("../templates/rust/model_struct.tera"),
48 ),
49 ("update.tera", include_str!("../templates/rust/update.tera")),
50 (
51 "composite_type.tera",
52 include_str!("../templates/rust/composite_type.tera"),
53 ),
54 ])
55 .expect("embedded Rust templates must parse");
56 tera
57});
58
59fn render(template: &str, ctx: &Context) -> String {
60 TEMPLATES
61 .render(template, ctx)
62 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
63}
64
65#[derive(Debug, Clone, Serialize)]
73struct FieldContext {
74 name: String,
75 db_name: String,
76 rust_type: String,
77 column_type: String,
78 is_array: bool,
79 index: usize,
80 is_pk: bool,
81 is_optional: bool,
84 is_updated_at: bool,
86 is_computed: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
93struct PkFieldContext {
94 name: String,
96 db_name: String,
98}
99
100#[derive(Debug, Clone, Serialize)]
101struct RelationContext {
102 field_name: String,
103 target_model: String,
104 target_table: String,
105 is_array: bool,
106 fields: Vec<String>,
107 references: Vec<String>,
108 fields_db: Vec<String>,
109 references_db: Vec<String>,
110 target_scalar_fields: Vec<FieldContext>,
111}
112
113pub fn generate_model(model: &ModelIr, ir: &SchemaIr, is_async: bool) -> String {
118 let mut context = Context::new();
119
120 context.insert("model_name", &model.logical_name);
121 context.insert("table_name", &model.db_name);
122 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
123 context.insert("columns_name", &format!("{}Columns", model.logical_name));
124 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
125 context.insert("create_name", &format!("{}Create", model.logical_name));
126 context.insert(
127 "create_many_name",
128 &format!("{}CreateMany", model.logical_name),
129 );
130 context.insert("entry_name", &format!("{}CreateEntry", model.logical_name));
131 context.insert("update_name", &format!("{}Update", model.logical_name));
132 context.insert("delete_name", &format!("{}Delete", model.logical_name));
133
134 let pk_field_names = model.primary_key.fields();
135 context.insert("primary_key_fields", &pk_field_names);
136
137 let pk_fields_with_db: Vec<PkFieldContext> = pk_field_names
138 .iter()
139 .filter_map(|logical| {
140 model
141 .scalar_fields()
142 .find(|f| f.logical_name.as_str() == *logical)
143 .map(|f| PkFieldContext {
144 name: f.logical_name.to_snake_case(),
145 db_name: f.db_name.clone(),
146 })
147 })
148 .collect();
149 context.insert("pk_fields_with_db", &pk_fields_with_db);
150
151 let mut enum_imports = HashSet::new();
153 let mut composite_type_imports = HashSet::new();
154
155 let mut scalar_fields: Vec<FieldContext> = Vec::new();
156 let mut create_fields: Vec<FieldContext> = Vec::new();
157 let mut updated_at_fields: Vec<FieldContext> = Vec::new();
158
159 for (idx, field) in model.scalar_fields().enumerate() {
160 match &field.field_type {
162 ResolvedFieldType::Enum { enum_name } => {
163 if ir.enums.contains_key(enum_name) {
164 enum_imports.insert(enum_name.clone());
165 }
166 }
167 ResolvedFieldType::CompositeType { type_name } => {
168 if ir.composite_types.contains_key(type_name) {
169 composite_type_imports.insert(type_name.clone());
170 }
171 }
172 _ => {}
173 }
174
175 let column_type = match &field.field_type {
177 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
178 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
179 _ => String::new(),
180 };
181 let is_pk = pk_field_names.contains(&field.logical_name.as_str());
182
183 let field_ctx = FieldContext {
184 name: field.logical_name.to_snake_case(),
185 db_name: field.db_name.clone(),
186 rust_type: field_to_rust_type(field),
187 column_type,
188 is_array: field.is_array,
189 index: idx,
190 is_pk,
191 is_optional: !field.is_required && !field.is_array,
192 is_updated_at: field.is_updated_at,
193 is_computed: field.computed.is_some(),
194 };
195
196 create_fields.push(field_ctx.clone());
197
198 if field.is_updated_at {
200 updated_at_fields.push(field_ctx.clone());
201 }
202
203 scalar_fields.push(field_ctx);
204 }
205
206 let mut relation_imports = HashSet::new();
208 for field in model.relation_fields() {
209 if let ResolvedFieldType::Relation(rel) = &field.field_type {
210 relation_imports.insert(rel.target_model.clone());
211 }
212 }
213
214 context.insert("has_enums", &!enum_imports.is_empty());
215 context.insert(
216 "enum_imports",
217 &enum_imports.into_iter().collect::<Vec<_>>(),
218 );
219 context.insert("has_relations", &!relation_imports.is_empty());
220 context.insert(
221 "relation_imports",
222 &relation_imports.into_iter().collect::<Vec<_>>(),
223 );
224 context.insert("has_composite_types", &!composite_type_imports.is_empty());
225 context.insert(
226 "composite_type_imports",
227 &composite_type_imports.into_iter().collect::<Vec<_>>(),
228 );
229
230 let relation_fields: Vec<FieldContext> = model
231 .relation_fields()
232 .map(|field| FieldContext {
233 name: field.logical_name.to_snake_case(),
234 db_name: field.db_name.clone(),
235 rust_type: field_to_rust_type(field),
236 column_type: String::new(),
237 is_array: field.is_array,
238 index: 0,
239 is_pk: false,
240 is_optional: true,
241 is_updated_at: false,
242 is_computed: false,
243 })
244 .collect();
245
246 let relations: Vec<RelationContext> = model
247 .relation_fields()
248 .filter_map(|field| {
249 if let ResolvedFieldType::Relation(rel) = &field.field_type {
250 if let Some(target_model) = ir.models.get(&rel.target_model) {
251 let target_pk_names = target_model.primary_key.fields();
252 let target_scalar_fields: Vec<FieldContext> = target_model
253 .scalar_fields()
254 .enumerate()
255 .map(|(idx, f)| {
256 let column_type = match &f.field_type {
257 ResolvedFieldType::Scalar(scalar) => scalar.rust_type().to_string(),
258 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
259 _ => String::new(),
260 };
261 let f_is_pk = target_pk_names.contains(&f.logical_name.as_str());
262 FieldContext {
263 name: f.logical_name.to_snake_case(),
264 db_name: f.db_name.clone(),
265 rust_type: field_to_rust_type(f),
266 column_type,
267 is_array: f.is_array,
268 index: idx,
269 is_pk: f_is_pk,
270 is_optional: !f.is_required && !f.is_array,
271 is_updated_at: f.is_updated_at,
272 is_computed: f.computed.is_some(),
273 }
274 })
275 .collect();
276
277 let (fields, references) = if rel.fields.is_empty() {
278 let inverse = target_model.relation_fields().find(|f| {
279 if let ResolvedFieldType::Relation(inv_rel) = &f.field_type {
280 inv_rel.target_model == model.logical_name
281 } else {
282 false
283 }
284 });
285
286 if let Some(inverse_field) = inverse {
287 if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type
288 {
289 (inv_rel.references.clone(), inv_rel.fields.clone())
291 } else {
292 (vec![], vec![])
293 }
294 } else {
295 (vec![], vec![])
296 }
297 } else {
298 (rel.fields.clone(), rel.references.clone())
299 };
300
301 let fields_db: Vec<String> = fields
302 .iter()
303 .filter_map(|logical_name| {
304 model
305 .fields
306 .iter()
307 .find(|f| &f.logical_name == logical_name)
308 .map(|f| f.db_name.clone())
309 })
310 .collect();
311
312 let references_db: Vec<String> = references
313 .iter()
314 .filter_map(|logical_name| {
315 target_model
316 .fields
317 .iter()
318 .find(|f| &f.logical_name == logical_name)
319 .map(|f| f.db_name.clone())
320 })
321 .collect();
322
323 Some(RelationContext {
324 field_name: field.logical_name.to_snake_case(),
325 target_model: rel.target_model.clone(),
326 target_table: target_model.db_name.clone(),
327 is_array: field.is_array,
328 fields,
329 references,
330 fields_db,
331 references_db,
332 target_scalar_fields,
333 })
334 } else {
335 None
336 }
337 } else {
338 None
339 }
340 })
341 .collect();
342
343 context.insert("scalar_fields", &scalar_fields);
344 context.insert("relation_fields", &relation_fields);
345 context.insert("relations", &relations);
346 context.insert("create_fields", &create_fields);
347 context.insert("updated_at_fields", &updated_at_fields);
348 context.insert("all_scalar_fields", &scalar_fields);
349 context.insert("is_async", &is_async);
350
351 render("model_file.tera", &context)
352}
353
354pub fn generate_all_models(ir: &SchemaIr, is_async: bool) -> HashMap<String, String> {
358 let mut generated = HashMap::new();
359
360 for (model_name, model_ir) in &ir.models {
361 let code = generate_model(model_ir, ir, is_async);
362 generated.insert(model_name.clone(), code);
363 }
364
365 generated
366}