1use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ir::{CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::python::type_mapper::{
10 field_to_python_type, get_base_python_type, get_default_value, get_filter_operators_for_field,
11 is_auto_generated,
12};
13
14pub static PYTHON_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
16 let mut tera = Tera::default();
17 tera.add_raw_templates(vec![
18 (
19 "composite_types.py.tera",
20 include_str!("../../templates/python/composite_types.py.tera"),
21 ),
22 (
23 "model_file.py.tera",
24 include_str!("../../templates/python/model_file.py.tera"),
25 ),
26 (
27 "input_types.py.tera",
28 include_str!("../../templates/python/input_types.py.tera"),
29 ),
30 (
31 "enums.py.tera",
32 include_str!("../../templates/python/enums.py.tera"),
33 ),
34 (
35 "client.py.tera",
36 include_str!("../../templates/python/client.py.tera"),
37 ),
38 (
39 "package_init.py.tera",
40 include_str!("../../templates/python/package_init.py.tera"),
41 ),
42 (
43 "models_init.py.tera",
44 include_str!("../../templates/python/models_init.py.tera"),
45 ),
46 (
47 "enums_init.py.tera",
48 include_str!("../../templates/python/enums_init.py.tera"),
49 ),
50 (
51 "errors_init.py.tera",
52 include_str!("../../templates/python/errors_init.py.tera"),
53 ),
54 (
55 "internal_init.py.tera",
56 include_str!("../../templates/python/internal_init.py.tera"),
57 ),
58 (
59 "transaction_init.py.tera",
60 include_str!("../../templates/python/transaction_init.py.tera"),
61 ),
62 ])
63 .expect("embedded Python templates must parse");
64 tera
65});
66
67fn render(template: &str, ctx: &Context) -> String {
68 PYTHON_TEMPLATES
69 .render(template, ctx)
70 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
71}
72
73#[derive(Debug, Clone, Serialize)]
81struct PythonFieldContext {
82 name: String,
83 logical_name: String,
84 db_name: String,
85 python_type: String,
86 base_type: String,
87 is_optional: bool,
88 is_array: bool,
89 is_enum: bool,
90 has_default: bool,
91 default: String,
92 index: usize,
93}
94
95#[derive(Debug, Clone, Serialize)]
96struct PythonRelationContext {
97 field_name: String,
98 target_model: String,
99 target_table: String,
100 is_array: bool,
101 fields: Vec<String>,
102 references: Vec<String>,
103 fields_db: Vec<String>,
104 references_db: Vec<String>,
105}
106
107#[derive(Debug, Clone, Serialize)]
108struct FilterOperatorContext {
109 suffix: String,
110 python_type: String,
111}
112
113#[derive(Debug, Clone, Serialize)]
114struct WhereInputFieldContext {
115 name: String,
116 python_type: String,
117 operators: Vec<FilterOperatorContext>,
118}
119
120#[derive(Debug, Clone, Serialize)]
121struct CreateInputFieldContext {
122 name: String,
123 python_type: String,
124 is_required: bool,
125}
126
127#[derive(Debug, Clone, Serialize)]
128struct UpdateInputFieldContext {
129 name: String,
130 python_type: String,
131}
132
133#[derive(Debug, Clone, Serialize)]
134struct OrderByFieldContext {
135 name: String,
136}
137
138#[derive(Debug, Clone, Serialize)]
139struct IncludeFieldContext {
140 name: String,
141 target_model: String,
142 target_snake: String,
144 is_array: bool,
146}
147
148#[derive(Debug, Clone, Serialize)]
150struct AggregateFieldContext {
151 name: String,
152 python_type: String,
153}
154
155pub fn generate_python_model(
161 model: &ModelIr,
162 ir: &SchemaIr,
163 is_async: bool,
164 recursive_type_depth: usize,
165) -> (String, String) {
166 let mut context = Context::new();
167
168 context.insert("model_name", &model.logical_name);
170 context.insert("snake_name", &model.logical_name.to_snake_case());
171 context.insert("table_name", &model.db_name);
172 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
173 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
174 context.insert("create_name", &format!("{}Create", model.logical_name));
175 context.insert(
176 "create_many_name",
177 &format!("{}CreateMany", model.logical_name),
178 );
179 context.insert("update_name", &format!("{}Update", model.logical_name));
180 context.insert("delete_name", &format!("{}Delete", model.logical_name));
181
182 let pk_field_names = model.primary_key.fields();
184 context.insert("primary_key_fields", &pk_field_names);
185
186 let mut enum_imports = HashSet::new();
190 let mut composite_type_imports = HashSet::new();
191 let mut has_datetime = false;
192 let mut has_uuid = false;
193 let mut has_decimal = false;
194 let mut has_dict = false;
195
196 let mut scalar_fields: Vec<PythonFieldContext> = Vec::new();
197 let mut create_fields: Vec<PythonFieldContext> = Vec::new();
198 let mut where_input_fields: Vec<WhereInputFieldContext> = Vec::new();
199 let mut create_input_fields: Vec<CreateInputFieldContext> = Vec::new();
200 let mut update_input_fields: Vec<UpdateInputFieldContext> = Vec::new();
201 let mut order_by_fields: Vec<OrderByFieldContext> = Vec::new();
202 let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
203 let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
204 let mut updated_at_field_names: Vec<String> = Vec::new();
205
206 for (idx, field) in model.scalar_fields().enumerate() {
207 use nautilus_schema::ir::ScalarType;
208
209 match &field.field_type {
211 ResolvedFieldType::Enum { enum_name } => {
212 if ir.enums.contains_key(enum_name) {
213 enum_imports.insert(enum_name.clone());
214 }
215 }
216 ResolvedFieldType::CompositeType { type_name } => {
217 if ir.composite_types.contains_key(type_name) {
218 composite_type_imports.insert(type_name.clone());
219 }
220 }
221 ResolvedFieldType::Scalar(scalar) => match scalar {
222 ScalarType::DateTime => has_datetime = true,
223 ScalarType::Uuid => has_uuid = true,
224 ScalarType::Decimal { .. } => has_decimal = true,
225 ScalarType::Json => has_dict = true,
226 _ => {}
227 },
228 _ => {}
229 }
230
231 let python_type = field_to_python_type(field, &ir.enums);
233 let base_type = match &field.field_type {
234 ResolvedFieldType::Scalar(s) => {
235 crate::python::type_mapper::scalar_to_python_type(s).to_string()
236 }
237 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
238 _ => "Any".to_string(),
239 };
240 let base_python_type = get_base_python_type(field, &ir.enums);
241 let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
242 let auto_generated = is_auto_generated(field);
243
244 let mut default_val = get_default_value(field);
246 if let Some(ref def) = default_val {
247 if let ResolvedFieldType::Enum { enum_name } = &field.field_type {
248 if !def.contains('.') && !def.contains('(') && def != "None" {
249 default_val = Some(format!("{}.{}", enum_name, def));
250 }
251 }
252 }
253
254 let field_ctx = PythonFieldContext {
256 name: field.logical_name.to_snake_case(),
257 logical_name: field.logical_name.clone(),
258 db_name: field.db_name.clone(),
259 python_type: python_type.clone(),
260 base_type,
261 is_optional: !field.is_required,
262 is_array: field.is_array,
263 is_enum,
264 has_default: default_val.is_some(),
265 default: default_val.unwrap_or_default(),
266 index: idx,
267 };
268
269 if !auto_generated {
271 create_fields.push(field_ctx.clone());
272 }
273
274 scalar_fields.push(field_ctx);
275
276 if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
278 let operators = get_filter_operators_for_field(field, &ir.enums);
279 where_input_fields.push(WhereInputFieldContext {
280 name: field.logical_name.clone(),
281 python_type: base_python_type.clone(),
282 operators: operators
283 .into_iter()
284 .map(|op| FilterOperatorContext {
285 suffix: op.suffix,
286 python_type: op.type_name,
287 })
288 .collect(),
289 });
290 }
291
292 if !auto_generated {
294 let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
295 {
296 "dict".to_string()
297 } else {
298 base_python_type.clone()
299 };
300 let typed = if field.is_array {
301 format!("List[{}]", input_base)
302 } else {
303 input_base
304 };
305 create_input_fields.push(CreateInputFieldContext {
306 name: field.logical_name.clone(),
307 python_type: typed,
308 is_required: field.is_required
309 && field.default_value.is_none()
310 && !field.is_updated_at,
311 });
312 }
313
314 let is_auto_pk = auto_generated && pk_field_names.contains(&field.logical_name.as_str());
316 if !is_auto_pk {
317 let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
318 {
319 "dict".to_string()
320 } else {
321 base_python_type.clone()
322 };
323 let typed = if field.is_array {
324 format!("List[{}]", input_base)
325 } else {
326 input_base
327 };
328 update_input_fields.push(UpdateInputFieldContext {
329 name: field.logical_name.clone(),
330 python_type: typed,
331 });
332 }
333
334 order_by_fields.push(OrderByFieldContext {
336 name: field.logical_name.clone(),
337 });
338
339 let is_numeric = matches!(
341 &field.field_type,
342 ResolvedFieldType::Scalar(ScalarType::Int)
343 | ResolvedFieldType::Scalar(ScalarType::BigInt)
344 | ResolvedFieldType::Scalar(ScalarType::Float)
345 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
346 );
347 if is_numeric {
348 numeric_fields.push(AggregateFieldContext {
349 name: field.logical_name.clone(),
350 python_type: base_python_type.clone(),
351 });
352 }
353
354 let is_non_orderable = matches!(
355 &field.field_type,
356 ResolvedFieldType::Scalar(ScalarType::Boolean)
357 | ResolvedFieldType::Scalar(ScalarType::Json)
358 | ResolvedFieldType::Scalar(ScalarType::Bytes)
359 );
360 if !is_non_orderable {
361 orderable_fields.push(AggregateFieldContext {
362 name: field.logical_name.clone(),
363 python_type: base_python_type,
364 });
365 }
366
367 if field.is_updated_at {
369 updated_at_field_names.push(field.logical_name.clone());
370 }
371 }
372
373 let mut relation_imports = HashSet::new();
375 for field in model.relation_fields() {
376 if let ResolvedFieldType::Relation(rel) = &field.field_type {
377 relation_imports.insert(rel.target_model.clone());
378 }
379 }
380
381 context.insert("has_datetime", &has_datetime);
382 context.insert("has_uuid", &has_uuid);
383 context.insert("has_decimal", &has_decimal);
384 context.insert("has_dict", &has_dict);
385 context.insert("has_enums", &!enum_imports.is_empty());
386 context.insert(
387 "enum_imports",
388 &enum_imports.into_iter().collect::<Vec<_>>(),
389 );
390 context.insert("has_composite_types", &!composite_type_imports.is_empty());
391 context.insert(
392 "composite_type_imports",
393 &composite_type_imports.into_iter().collect::<Vec<_>>(),
394 );
395 context.insert("has_relations", &!relation_imports.is_empty());
396 context.insert(
397 "relation_imports",
398 &relation_imports.into_iter().collect::<Vec<_>>(),
399 );
400
401 let relation_fields: Vec<PythonFieldContext> = model
403 .relation_fields()
404 .enumerate()
405 .map(|(idx, field)| {
406 let python_type = field_to_python_type(field, &ir.enums);
407 let default_val = if field.is_array {
408 "Field(default_factory=list)".to_string()
409 } else {
410 "None".to_string()
411 };
412
413 PythonFieldContext {
414 name: field.logical_name.to_snake_case(),
415 logical_name: field.logical_name.clone(),
416 db_name: field.db_name.clone(),
417 python_type: python_type.clone(),
418 base_type: String::new(),
419 is_optional: true,
420 is_array: field.is_array,
421 is_enum: false,
422 has_default: true,
423 default: default_val,
424 index: idx,
425 }
426 })
427 .collect();
428
429 let relations: Vec<PythonRelationContext> = model
431 .relation_fields()
432 .filter_map(|field| {
433 if let ResolvedFieldType::Relation(rel) = &field.field_type {
434 if let Some(target_model) = ir.models.get(&rel.target_model) {
435 let (fields, references) = if rel.fields.is_empty() {
436 let inverse = target_model.relation_fields().find(|f| {
438 if let ResolvedFieldType::Relation(inv_rel) = &f.field_type {
439 inv_rel.target_model == model.logical_name
440 } else {
441 false
442 }
443 });
444
445 if let Some(inverse_field) = inverse {
446 if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type
447 {
448 (inv_rel.references.clone(), inv_rel.fields.clone())
449 } else {
450 (vec![], vec![])
451 }
452 } else {
453 (vec![], vec![])
454 }
455 } else {
456 (rel.fields.clone(), rel.references.clone())
457 };
458
459 let fields_db: Vec<String> = fields
460 .iter()
461 .filter_map(|logical_name| {
462 model
463 .fields
464 .iter()
465 .find(|f| &f.logical_name == logical_name)
466 .map(|f| f.db_name.clone())
467 })
468 .collect();
469
470 let references_db: Vec<String> = references
471 .iter()
472 .filter_map(|logical_name| {
473 target_model
474 .fields
475 .iter()
476 .find(|f| &f.logical_name == logical_name)
477 .map(|f| f.db_name.clone())
478 })
479 .collect();
480
481 Some(PythonRelationContext {
482 field_name: field.logical_name.to_snake_case(),
483 target_model: rel.target_model.clone(),
484 target_table: target_model.db_name.clone(),
485 is_array: field.is_array,
486 fields,
487 references,
488 fields_db,
489 references_db,
490 })
491 } else {
492 None
493 }
494 } else {
495 None
496 }
497 })
498 .collect();
499
500 let include_fields: Vec<IncludeFieldContext> = model
502 .relation_fields()
503 .filter_map(|field| {
504 if let ResolvedFieldType::Relation(rel) = &field.field_type {
505 Some(IncludeFieldContext {
506 name: field.logical_name.to_snake_case(),
507 target_model: rel.target_model.clone(),
508 target_snake: rel.target_model.to_snake_case(),
509 is_array: field.is_array,
510 })
511 } else {
512 None
513 }
514 })
515 .collect();
516
517 let has_numeric_fields = !numeric_fields.is_empty();
518 let has_orderable_fields = !orderable_fields.is_empty();
519
520 let needs_typeddict = !where_input_fields.is_empty()
521 || !create_input_fields.is_empty()
522 || !update_input_fields.is_empty();
523
524 context.insert("needs_typeddict", &needs_typeddict);
525 context.insert("where_input_fields", &where_input_fields);
526 context.insert("create_input_fields", &create_input_fields);
527 context.insert("update_input_fields", &update_input_fields);
528 context.insert("updated_at_fields", &updated_at_field_names);
529 context.insert("order_by_fields", &order_by_fields);
530 context.insert("include_fields", &include_fields);
531 context.insert("has_includes", &!include_fields.is_empty());
532 context.insert("numeric_fields", &numeric_fields);
533 context.insert("orderable_fields", &orderable_fields);
534 context.insert("has_numeric_fields", &has_numeric_fields);
535 context.insert("has_orderable_fields", &has_orderable_fields);
536
537 context.insert("scalar_fields", &scalar_fields);
538 context.insert("relation_fields", &relation_fields);
539 context.insert("create_fields", &create_fields);
540 context.insert("relations", &relations);
541 context.insert("is_async", &is_async);
542 context.insert("recursive_type_depth", &recursive_type_depth);
543
544 let model_code = render("model_file.py.tera", &context);
546
547 (
548 format!("{}.py", model.logical_name.to_snake_case()),
549 model_code,
550 )
551}
552
553pub fn generate_all_python_models(
558 ir: &SchemaIr,
559 is_async: bool,
560 recursive_type_depth: usize,
561) -> Vec<(String, String)> {
562 ir.models
563 .values()
564 .map(|model| generate_python_model(model, ir, is_async, recursive_type_depth))
565 .collect()
566}
567
568pub fn generate_python_composite_types(
572 composite_types: &HashMap<String, CompositeTypeIr>,
573) -> Option<String> {
574 if composite_types.is_empty() {
575 return None;
576 }
577
578 #[derive(Serialize)]
579 struct CompositeFieldCtx {
580 name: String,
581 python_type: String,
582 }
583
584 #[derive(Serialize)]
585 struct CompositeTypeCtx {
586 name: String,
587 fields: Vec<CompositeFieldCtx>,
588 }
589
590 let mut type_list: Vec<CompositeTypeCtx> = composite_types
591 .values()
592 .map(|ct| {
593 let fields = ct
594 .fields
595 .iter()
596 .map(|f| {
597 let base = match &f.field_type {
598 ResolvedFieldType::Scalar(s) => {
599 crate::python::type_mapper::scalar_to_python_type(s).to_string()
600 }
601 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
602 ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
603 ResolvedFieldType::Relation(_) => "Any".to_string(),
604 };
605 let python_type = if f.is_array {
606 format!("List[{}]", base)
607 } else if !f.is_required {
608 format!("Optional[{}]", base)
609 } else {
610 base
611 };
612 CompositeFieldCtx {
613 name: f.logical_name.to_snake_case(),
614 python_type,
615 }
616 })
617 .collect();
618 CompositeTypeCtx {
619 name: ct.logical_name.clone(),
620 fields,
621 }
622 })
623 .collect();
624 type_list.sort_by(|a, b| a.name.cmp(&b.name));
625
626 let mut context = Context::new();
627 context.insert("composite_types", &type_list);
628
629 Some(render("composite_types.py.tera", &context))
630}
631
632pub fn generate_python_enums(enums: &HashMap<String, EnumIr>) -> String {
634 let mut context = Context::new();
635
636 #[derive(Serialize)]
637 struct EnumContext {
638 name: String,
639 variants: Vec<String>,
640 }
641
642 let enum_contexts: Vec<EnumContext> = enums
643 .values()
644 .map(|e| EnumContext {
645 name: e.logical_name.clone(),
646 variants: e.variants.clone(),
647 })
648 .collect();
649
650 context.insert("enums", &enum_contexts);
651
652 render("enums.py.tera", &context)
653}
654
655pub fn generate_python_client(
660 models: &HashMap<String, ModelIr>,
661 schema_path: &str,
662 is_async: bool,
663) -> String {
664 let mut context = Context::new();
665
666 #[derive(Serialize)]
667 struct ModelContext {
668 snake_name: String,
669 delegate_name: String,
670 }
671
672 let mut model_contexts: Vec<ModelContext> = models
673 .values()
674 .map(|m| ModelContext {
675 snake_name: m.logical_name.to_snake_case(),
676 delegate_name: format!("{}Delegate", m.logical_name),
677 })
678 .collect();
679 model_contexts.sort_by(|a, b| a.snake_name.cmp(&b.snake_name));
680
681 context.insert("models", &model_contexts);
682 context.insert("schema_path", schema_path);
683 context.insert("is_async", &is_async);
684
685 render("client.py.tera", &context)
686}
687
688pub fn generate_package_init(has_enums: bool) -> String {
690 let mut context = Context::new();
691 context.insert("has_enums", &has_enums);
692
693 render("package_init.py.tera", &context)
694}
695
696pub fn generate_models_init(models: &[(String, String)]) -> String {
698 let mut context = Context::new();
699
700 let mut model_modules: Vec<String> = models
701 .iter()
702 .map(|(file_name, _)| file_name.trim_end_matches(".py").to_string())
703 .collect();
704 model_modules.sort();
705
706 let mut model_classes: Vec<String> = model_modules.iter().map(|m| m.to_pascal_case()).collect();
707 model_classes.sort();
708
709 context.insert("model_modules", &model_modules);
710 context.insert("model_classes", &model_classes);
711
712 render("models_init.py.tera", &context)
713}
714
715pub fn generate_enums_init(has_enums: bool) -> String {
717 let mut context = Context::new();
718 context.insert("has_enums", &has_enums);
719
720 render("enums_init.py.tera", &context)
721}
722
723pub fn generate_errors_init() -> &'static str {
727 include_str!("../../templates/python/errors_init.py.tera")
728}
729
730pub fn generate_internal_init() -> &'static str {
734 include_str!("../../templates/python/internal_init.py.tera")
735}
736
737pub fn generate_transaction_init() -> &'static str {
743 include_str!("../../templates/python/transaction_init.py.tera")
744}
745
746pub fn python_runtime_files() -> Vec<(&'static str, &'static str)> {
749 vec![
750 (
751 "_errors.py",
752 include_str!("../../templates/python/runtime/_errors.py"),
753 ),
754 (
755 "_protocol.py",
756 include_str!("../../templates/python/runtime/_protocol.py"),
757 ),
758 (
759 "_engine.py",
760 include_str!("../../templates/python/runtime/_engine.py"),
761 ),
762 (
763 "_client.py",
764 include_str!("../../templates/python/runtime/_client.py"),
765 ),
766 (
767 "_descriptors.py",
768 include_str!("../../templates/python/runtime/_descriptors.py"),
769 ),
770 (
771 "_transaction.py",
772 include_str!("../../templates/python/runtime/_transaction.py"),
773 ),
774 ]
775}