1use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ir::{CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::python::type_mapper::{
10 field_to_python_type, get_base_python_type, get_default_value, get_filter_operators_for_field,
11 is_auto_generated,
12};
13
14pub static PYTHON_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
16 let mut tera = Tera::default();
17 tera.add_raw_templates(vec![
18 (
19 "composite_types.py.tera",
20 include_str!("../../templates/python/composite_types.py.tera"),
21 ),
22 (
23 "model_file.py.tera",
24 include_str!("../../templates/python/model_file.py.tera"),
25 ),
26 (
27 "input_types.py.tera",
28 include_str!("../../templates/python/input_types.py.tera"),
29 ),
30 (
31 "enums.py.tera",
32 include_str!("../../templates/python/enums.py.tera"),
33 ),
34 (
35 "client.py.tera",
36 include_str!("../../templates/python/client.py.tera"),
37 ),
38 (
39 "package_init.py.tera",
40 include_str!("../../templates/python/package_init.py.tera"),
41 ),
42 (
43 "models_init.py.tera",
44 include_str!("../../templates/python/models_init.py.tera"),
45 ),
46 (
47 "enums_init.py.tera",
48 include_str!("../../templates/python/enums_init.py.tera"),
49 ),
50 (
51 "errors_init.py.tera",
52 include_str!("../../templates/python/errors_init.py.tera"),
53 ),
54 (
55 "internal_init.py.tera",
56 include_str!("../../templates/python/internal_init.py.tera"),
57 ),
58 (
59 "transaction_init.py.tera",
60 include_str!("../../templates/python/transaction_init.py.tera"),
61 ),
62 ])
63 .expect("embedded Python templates must parse");
64 tera
65});
66
67fn render(template: &str, ctx: &Context) -> String {
68 PYTHON_TEMPLATES
69 .render(template, ctx)
70 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
71}
72
73#[derive(Debug, Clone, Serialize)]
81struct PythonFieldContext {
82 name: String,
83 logical_name: String,
84 db_name: String,
85 python_type: String,
86 model_python_type: String,
87 base_type: String,
88 is_optional: bool,
89 is_array: bool,
90 is_enum: bool,
91 has_default: bool,
92 default: String,
93 model_has_default: bool,
94 model_default: String,
95 is_pk: bool,
96 index: usize,
97}
98
99#[derive(Debug, Clone, Serialize)]
100struct PythonRelationContext {
101 field_name: String,
102 target_model: String,
103 target_table: String,
104 is_array: bool,
105 fields: Vec<String>,
106 references: Vec<String>,
107 fields_db: Vec<String>,
108 references_db: Vec<String>,
109}
110
111fn resolve_inverse_relation_fields(
112 source_model_name: &str,
113 relation_name: Option<&str>,
114 target_model: &ModelIr,
115) -> (Vec<String>, Vec<String>) {
116 let inverse = target_model.relation_fields().find(|field| {
117 if let ResolvedFieldType::Relation(inv_rel) = &field.field_type {
118 if inv_rel.target_model != source_model_name {
119 return false;
120 }
121
122 match (relation_name, inv_rel.name.as_deref()) {
123 (Some(expected), Some(actual)) => actual == expected,
124 (Some(_), None) => false,
125 (None, Some(_)) => false,
126 (None, None) => true,
127 }
128 } else {
129 false
130 }
131 });
132
133 if let Some(inverse_field) = inverse {
134 if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type {
135 return (inv_rel.references.clone(), inv_rel.fields.clone());
136 }
137 }
138
139 (vec![], vec![])
140}
141
142#[derive(Debug, Clone, Serialize)]
143struct FilterOperatorContext {
144 suffix: String,
145 python_type: String,
146}
147
148#[derive(Debug, Clone, Serialize)]
149struct WhereInputFieldContext {
150 name: String,
151 python_type: String,
152 operators: Vec<FilterOperatorContext>,
153}
154
155#[derive(Debug, Clone, Serialize)]
156struct CreateInputFieldContext {
157 name: String,
158 python_type: String,
159 is_required: bool,
160}
161
162#[derive(Debug, Clone, Serialize)]
163struct UpdateInputFieldContext {
164 name: String,
165 python_type: String,
166}
167
168#[derive(Debug, Clone, Serialize)]
169struct OrderByFieldContext {
170 name: String,
171}
172
173#[derive(Debug, Clone, Serialize)]
174struct IncludeFieldContext {
175 name: String,
176 logical_name: String,
177 target_model: String,
178 target_snake: String,
180 is_array: bool,
182}
183
184#[derive(Debug, Clone, Serialize)]
186struct AggregateFieldContext {
187 name: String,
188 python_type: String,
189}
190
191fn optional_output_python_type(python_type: &str) -> String {
192 if python_type.starts_with("Optional[") {
193 python_type.to_string()
194 } else {
195 format!("Optional[{}]", python_type)
196 }
197}
198
199pub fn generate_python_model(
205 model: &ModelIr,
206 ir: &SchemaIr,
207 is_async: bool,
208 recursive_type_depth: usize,
209) -> (String, String) {
210 let mut context = Context::new();
211
212 context.insert("model_name", &model.logical_name);
213 context.insert("snake_name", &model.logical_name.to_snake_case());
214 context.insert("table_name", &model.db_name);
215 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
216 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
217 context.insert("create_name", &format!("{}Create", model.logical_name));
218 context.insert(
219 "create_many_name",
220 &format!("{}CreateMany", model.logical_name),
221 );
222 context.insert("update_name", &format!("{}Update", model.logical_name));
223 context.insert("delete_name", &format!("{}Delete", model.logical_name));
224
225 let pk_field_names = model.primary_key.fields();
226 context.insert("primary_key_fields", &pk_field_names);
227
228 let mut enum_imports = HashSet::new();
229 let mut composite_type_imports = HashSet::new();
230 let mut has_datetime = false;
231 let mut has_uuid = false;
232 let mut has_decimal = false;
233 let mut has_dict = false;
234
235 let mut scalar_fields: Vec<PythonFieldContext> = Vec::new();
236 let mut create_fields: Vec<PythonFieldContext> = Vec::new();
237 let mut where_input_fields: Vec<WhereInputFieldContext> = Vec::new();
238 let mut create_input_fields: Vec<CreateInputFieldContext> = Vec::new();
239 let mut update_input_fields: Vec<UpdateInputFieldContext> = Vec::new();
240 let mut order_by_fields: Vec<OrderByFieldContext> = Vec::new();
241 let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
242 let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
243
244 for (idx, field) in model.scalar_fields().enumerate() {
245 use nautilus_schema::ir::ScalarType;
246
247 match &field.field_type {
248 ResolvedFieldType::Enum { enum_name } => {
249 if ir.enums.contains_key(enum_name) {
250 enum_imports.insert(enum_name.clone());
251 }
252 }
253 ResolvedFieldType::CompositeType { type_name } => {
254 if ir.composite_types.contains_key(type_name) {
255 composite_type_imports.insert(type_name.clone());
256 }
257 }
258 ResolvedFieldType::Scalar(scalar) => match scalar {
259 ScalarType::DateTime => has_datetime = true,
260 ScalarType::Uuid => has_uuid = true,
261 ScalarType::Decimal { .. } => has_decimal = true,
262 ScalarType::Json => has_dict = true,
263 _ => {}
264 },
265 _ => {}
266 }
267
268 let python_type = field_to_python_type(field, &ir.enums);
269 let base_type = match &field.field_type {
270 ResolvedFieldType::Scalar(s) => {
271 crate::python::type_mapper::scalar_to_python_type(s).to_string()
272 }
273 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
274 _ => "Any".to_string(),
275 };
276 let base_python_type = get_base_python_type(field, &ir.enums);
277 let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
278 let auto_generated = is_auto_generated(field);
279 let is_pk = pk_field_names.contains(&field.logical_name.as_str());
280
281 let mut default_val = get_default_value(field);
283 if let Some(ref def) = default_val {
284 if let ResolvedFieldType::Enum { enum_name } = &field.field_type {
285 if !def.contains('.') && !def.contains('(') && def != "None" {
286 default_val = Some(format!("{}.{}", enum_name, def));
287 }
288 }
289 }
290
291 let model_python_type = if is_pk {
292 python_type.clone()
293 } else {
294 optional_output_python_type(&python_type)
295 };
296 let model_has_default = if is_pk { default_val.is_some() } else { true };
297 let model_default = if is_pk {
298 default_val.clone().unwrap_or_default()
299 } else {
300 "None".to_string()
301 };
302
303 let field_ctx = PythonFieldContext {
304 name: field.logical_name.to_snake_case(),
305 logical_name: field.logical_name.clone(),
306 db_name: field.db_name.clone(),
307 python_type: python_type.clone(),
308 model_python_type,
309 base_type,
310 is_optional: !field.is_required,
311 is_array: field.is_array,
312 is_enum,
313 has_default: default_val.is_some(),
314 default: default_val.unwrap_or_default(),
315 model_has_default,
316 model_default,
317 is_pk,
318 index: idx,
319 };
320
321 create_fields.push(field_ctx.clone());
322
323 scalar_fields.push(field_ctx);
324
325 if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
326 let operators = get_filter_operators_for_field(field, &ir.enums);
327 where_input_fields.push(WhereInputFieldContext {
328 name: field.logical_name.clone(),
329 python_type: base_python_type.clone(),
330 operators: operators
331 .into_iter()
332 .map(|op| FilterOperatorContext {
333 suffix: op.suffix,
334 python_type: op.type_name,
335 })
336 .collect(),
337 });
338 }
339
340 {
341 let input_base = base_python_type.clone();
342 let typed = if field.is_array {
343 format!("List[{}]", input_base)
344 } else {
345 input_base
346 };
347 create_input_fields.push(CreateInputFieldContext {
348 name: field.logical_name.clone(),
349 python_type: typed,
350 is_required: field.is_required
351 && field.default_value.is_none()
352 && !field.is_updated_at
353 && field.computed.is_none(),
354 });
355 }
356
357 let is_auto_pk = auto_generated && pk_field_names.contains(&field.logical_name.as_str());
358 if !is_auto_pk {
359 let input_base = base_python_type.clone();
360 let typed = if field.is_array {
361 format!("List[{}]", input_base)
362 } else {
363 input_base
364 };
365 update_input_fields.push(UpdateInputFieldContext {
366 name: field.logical_name.clone(),
367 python_type: typed,
368 });
369 }
370
371 order_by_fields.push(OrderByFieldContext {
372 name: field.logical_name.clone(),
373 });
374
375 let is_numeric = matches!(
376 &field.field_type,
377 ResolvedFieldType::Scalar(ScalarType::Int)
378 | ResolvedFieldType::Scalar(ScalarType::BigInt)
379 | ResolvedFieldType::Scalar(ScalarType::Float)
380 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
381 );
382 if is_numeric {
383 numeric_fields.push(AggregateFieldContext {
384 name: field.logical_name.clone(),
385 python_type: base_python_type.clone(),
386 });
387 }
388
389 let is_non_orderable = matches!(
390 &field.field_type,
391 ResolvedFieldType::Scalar(ScalarType::Boolean)
392 | ResolvedFieldType::Scalar(ScalarType::Json)
393 | ResolvedFieldType::Scalar(ScalarType::Bytes)
394 );
395 if !is_non_orderable {
396 orderable_fields.push(AggregateFieldContext {
397 name: field.logical_name.clone(),
398 python_type: base_python_type,
399 });
400 }
401 }
402
403 let mut relation_imports = HashSet::new();
404 for field in model.relation_fields() {
405 if let ResolvedFieldType::Relation(rel) = &field.field_type {
406 relation_imports.insert(rel.target_model.clone());
407 }
408 }
409
410 context.insert("has_datetime", &has_datetime);
411 context.insert("has_uuid", &has_uuid);
412 context.insert("has_decimal", &has_decimal);
413 context.insert("has_dict", &has_dict);
414 context.insert("has_enums", &!enum_imports.is_empty());
415 context.insert(
416 "enum_imports",
417 &enum_imports.into_iter().collect::<Vec<_>>(),
418 );
419 context.insert("has_composite_types", &!composite_type_imports.is_empty());
420 context.insert(
421 "composite_type_imports",
422 &composite_type_imports.into_iter().collect::<Vec<_>>(),
423 );
424 context.insert("has_relations", &!relation_imports.is_empty());
425 context.insert(
426 "relation_imports",
427 &relation_imports.into_iter().collect::<Vec<_>>(),
428 );
429
430 let relation_fields: Vec<PythonFieldContext> = model
431 .relation_fields()
432 .enumerate()
433 .map(|(idx, field)| {
434 let python_type = field_to_python_type(field, &ir.enums);
435 let default_val = if field.is_array {
436 "Field(default_factory=list)".to_string()
437 } else {
438 "None".to_string()
439 };
440
441 PythonFieldContext {
442 name: field.logical_name.to_snake_case(),
443 logical_name: field.logical_name.clone(),
444 db_name: field.db_name.clone(),
445 python_type: python_type.clone(),
446 model_python_type: python_type.clone(),
447 base_type: String::new(),
448 is_optional: true,
449 is_array: field.is_array,
450 is_enum: false,
451 has_default: true,
452 default: default_val,
453 model_has_default: true,
454 model_default: "None".to_string(),
455 is_pk: false,
456 index: idx,
457 }
458 })
459 .collect();
460
461 let relations: Vec<PythonRelationContext> = model
462 .relation_fields()
463 .filter_map(|field| {
464 if let ResolvedFieldType::Relation(rel) = &field.field_type {
465 if let Some(target_model) = ir.models.get(&rel.target_model) {
466 let (fields, references) = if rel.fields.is_empty() {
467 resolve_inverse_relation_fields(
468 &model.logical_name,
469 rel.name.as_deref(),
470 target_model,
471 )
472 } else {
473 (rel.fields.clone(), rel.references.clone())
474 };
475
476 let fields_db: Vec<String> = fields
477 .iter()
478 .filter_map(|logical_name| {
479 model
480 .fields
481 .iter()
482 .find(|f| &f.logical_name == logical_name)
483 .map(|f| f.db_name.clone())
484 })
485 .collect();
486
487 let references_db: Vec<String> = references
488 .iter()
489 .filter_map(|logical_name| {
490 target_model
491 .fields
492 .iter()
493 .find(|f| &f.logical_name == logical_name)
494 .map(|f| f.db_name.clone())
495 })
496 .collect();
497
498 Some(PythonRelationContext {
499 field_name: field.logical_name.to_snake_case(),
500 target_model: rel.target_model.clone(),
501 target_table: target_model.db_name.clone(),
502 is_array: field.is_array,
503 fields,
504 references,
505 fields_db,
506 references_db,
507 })
508 } else {
509 None
510 }
511 } else {
512 None
513 }
514 })
515 .collect();
516
517 let include_fields: Vec<IncludeFieldContext> = model
518 .relation_fields()
519 .filter_map(|field| {
520 if let ResolvedFieldType::Relation(rel) = &field.field_type {
521 Some(IncludeFieldContext {
522 name: field.logical_name.to_snake_case(),
523 logical_name: field.logical_name.clone(),
524 target_model: rel.target_model.clone(),
525 target_snake: rel.target_model.to_snake_case(),
526 is_array: field.is_array,
527 })
528 } else {
529 None
530 }
531 })
532 .collect();
533
534 let has_numeric_fields = !numeric_fields.is_empty();
535 let has_orderable_fields = !orderable_fields.is_empty();
536
537 let needs_typeddict = !where_input_fields.is_empty()
538 || !create_input_fields.is_empty()
539 || !update_input_fields.is_empty();
540
541 context.insert("needs_typeddict", &needs_typeddict);
542 context.insert("where_input_fields", &where_input_fields);
543 context.insert("create_input_fields", &create_input_fields);
544 context.insert("update_input_fields", &update_input_fields);
545 context.insert("order_by_fields", &order_by_fields);
546 context.insert("include_fields", &include_fields);
547 context.insert("has_includes", &!include_fields.is_empty());
548 context.insert("numeric_fields", &numeric_fields);
549 context.insert("orderable_fields", &orderable_fields);
550 context.insert("has_numeric_fields", &has_numeric_fields);
551 context.insert("has_orderable_fields", &has_orderable_fields);
552
553 context.insert("scalar_fields", &scalar_fields);
554 context.insert("relation_fields", &relation_fields);
555 context.insert("create_fields", &create_fields);
556 context.insert("relations", &relations);
557 context.insert("is_async", &is_async);
558 context.insert("recursive_type_depth", &recursive_type_depth);
559
560 let model_code = render("model_file.py.tera", &context);
561
562 (
563 format!("{}.py", model.logical_name.to_snake_case()),
564 model_code,
565 )
566}
567
568pub fn generate_all_python_models(
573 ir: &SchemaIr,
574 is_async: bool,
575 recursive_type_depth: usize,
576) -> Vec<(String, String)> {
577 ir.models
578 .values()
579 .map(|model| generate_python_model(model, ir, is_async, recursive_type_depth))
580 .collect()
581}
582
583pub fn generate_python_composite_types(
587 composite_types: &HashMap<String, CompositeTypeIr>,
588) -> Option<String> {
589 if composite_types.is_empty() {
590 return None;
591 }
592
593 #[derive(Serialize)]
594 struct CompositeFieldCtx {
595 name: String,
596 python_type: String,
597 }
598
599 #[derive(Serialize)]
600 struct CompositeTypeCtx {
601 name: String,
602 fields: Vec<CompositeFieldCtx>,
603 }
604
605 let mut type_list: Vec<CompositeTypeCtx> = composite_types
606 .values()
607 .map(|ct| {
608 let fields = ct
609 .fields
610 .iter()
611 .map(|f| {
612 let base = match &f.field_type {
613 ResolvedFieldType::Scalar(s) => {
614 crate::python::type_mapper::scalar_to_python_type(s).to_string()
615 }
616 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
617 ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
618 ResolvedFieldType::Relation(_) => "Any".to_string(),
619 };
620 let python_type = if f.is_array {
621 format!("List[{}]", base)
622 } else if !f.is_required {
623 format!("Optional[{}]", base)
624 } else {
625 base
626 };
627 CompositeFieldCtx {
628 name: f.logical_name.to_snake_case(),
629 python_type,
630 }
631 })
632 .collect();
633 CompositeTypeCtx {
634 name: ct.logical_name.clone(),
635 fields,
636 }
637 })
638 .collect();
639 type_list.sort_by(|a, b| a.name.cmp(&b.name));
640
641 let mut context = Context::new();
642 context.insert("composite_types", &type_list);
643
644 Some(render("composite_types.py.tera", &context))
645}
646
647pub fn generate_python_enums(enums: &HashMap<String, EnumIr>) -> String {
649 let mut context = Context::new();
650
651 #[derive(Serialize)]
652 struct EnumContext {
653 name: String,
654 variants: Vec<String>,
655 }
656
657 let enum_contexts: Vec<EnumContext> = enums
658 .values()
659 .map(|e| EnumContext {
660 name: e.logical_name.clone(),
661 variants: e.variants.clone(),
662 })
663 .collect();
664
665 context.insert("enums", &enum_contexts);
666
667 render("enums.py.tera", &context)
668}
669
670pub fn generate_python_client(
675 models: &HashMap<String, ModelIr>,
676 schema_path: &str,
677 is_async: bool,
678) -> String {
679 let mut context = Context::new();
680
681 #[derive(Serialize)]
682 struct ModelContext {
683 snake_name: String,
684 delegate_name: String,
685 }
686
687 let mut model_contexts: Vec<ModelContext> = models
688 .values()
689 .map(|m| ModelContext {
690 snake_name: m.logical_name.to_snake_case(),
691 delegate_name: format!("{}Delegate", m.logical_name),
692 })
693 .collect();
694 model_contexts.sort_by(|a, b| a.snake_name.cmp(&b.snake_name));
695
696 context.insert("models", &model_contexts);
697 context.insert("schema_path", schema_path);
698 context.insert("is_async", &is_async);
699
700 render("client.py.tera", &context)
701}
702
703pub fn generate_package_init(has_enums: bool) -> String {
705 let mut context = Context::new();
706 context.insert("has_enums", &has_enums);
707
708 render("package_init.py.tera", &context)
709}
710
711pub fn generate_models_init(models: &[(String, String)]) -> String {
713 let mut context = Context::new();
714
715 let mut model_modules: Vec<String> = models
716 .iter()
717 .map(|(file_name, _)| file_name.trim_end_matches(".py").to_string())
718 .collect();
719 model_modules.sort();
720
721 let mut model_classes: Vec<String> = model_modules.iter().map(|m| m.to_pascal_case()).collect();
722 model_classes.sort();
723
724 context.insert("model_modules", &model_modules);
725 context.insert("model_classes", &model_classes);
726
727 render("models_init.py.tera", &context)
728}
729
730pub fn generate_enums_init(has_enums: bool) -> String {
732 let mut context = Context::new();
733 context.insert("has_enums", &has_enums);
734
735 render("enums_init.py.tera", &context)
736}
737
738pub fn generate_errors_init() -> &'static str {
742 include_str!("../../templates/python/errors_init.py.tera")
743}
744
745pub fn generate_internal_init() -> &'static str {
749 include_str!("../../templates/python/internal_init.py.tera")
750}
751
752pub fn generate_transaction_init() -> &'static str {
758 include_str!("../../templates/python/transaction_init.py.tera")
759}
760
761pub fn python_runtime_files() -> Vec<(&'static str, &'static str)> {
764 vec![
765 (
766 "_errors.py",
767 include_str!("../../templates/python/runtime/_errors.py"),
768 ),
769 (
770 "_protocol.py",
771 include_str!("../../templates/python/runtime/_protocol.py"),
772 ),
773 (
774 "_engine.py",
775 include_str!("../../templates/python/runtime/_engine.py"),
776 ),
777 (
778 "_client.py",
779 include_str!("../../templates/python/runtime/_client.py"),
780 ),
781 (
782 "_descriptors.py",
783 include_str!("../../templates/python/runtime/_descriptors.py"),
784 ),
785 (
786 "_transaction.py",
787 include_str!("../../templates/python/runtime/_transaction.py"),
788 ),
789 ]
790}