1use heck::{ToPascalCase, ToSnakeCase};
4use nautilus_schema::ir::{CompositeTypeIr, EnumIr, ModelIr, ResolvedFieldType, SchemaIr};
5use serde::Serialize;
6use std::collections::{HashMap, HashSet};
7use tera::{Context, Tera};
8
9use crate::python::type_mapper::{
10 field_to_python_type, get_base_python_type, get_default_value, get_filter_operators_for_field,
11 is_auto_generated,
12};
13
14pub static PYTHON_TEMPLATES: std::sync::LazyLock<Tera> = std::sync::LazyLock::new(|| {
16 let mut tera = Tera::default();
17 tera.add_raw_templates(vec![
18 (
19 "composite_types.py.tera",
20 include_str!("../../templates/python/composite_types.py.tera"),
21 ),
22 (
23 "model_file.py.tera",
24 include_str!("../../templates/python/model_file.py.tera"),
25 ),
26 (
27 "input_types.py.tera",
28 include_str!("../../templates/python/input_types.py.tera"),
29 ),
30 (
31 "enums.py.tera",
32 include_str!("../../templates/python/enums.py.tera"),
33 ),
34 (
35 "client.py.tera",
36 include_str!("../../templates/python/client.py.tera"),
37 ),
38 (
39 "package_init.py.tera",
40 include_str!("../../templates/python/package_init.py.tera"),
41 ),
42 (
43 "models_init.py.tera",
44 include_str!("../../templates/python/models_init.py.tera"),
45 ),
46 (
47 "enums_init.py.tera",
48 include_str!("../../templates/python/enums_init.py.tera"),
49 ),
50 (
51 "errors_init.py.tera",
52 include_str!("../../templates/python/errors_init.py.tera"),
53 ),
54 (
55 "internal_init.py.tera",
56 include_str!("../../templates/python/internal_init.py.tera"),
57 ),
58 (
59 "transaction_init.py.tera",
60 include_str!("../../templates/python/transaction_init.py.tera"),
61 ),
62 ])
63 .expect("embedded Python templates must parse");
64 tera
65});
66
67fn render(template: &str, ctx: &Context) -> String {
68 PYTHON_TEMPLATES
69 .render(template, ctx)
70 .unwrap_or_else(|e| panic!("template rendering failed for '{}': {:?}", template, e))
71}
72
73#[derive(Debug, Clone, Serialize)]
81struct PythonFieldContext {
82 name: String,
83 logical_name: String,
84 db_name: String,
85 python_type: String,
86 base_type: String,
87 is_optional: bool,
88 is_array: bool,
89 is_enum: bool,
90 has_default: bool,
91 default: String,
92 index: usize,
93}
94
95#[derive(Debug, Clone, Serialize)]
96struct PythonRelationContext {
97 field_name: String,
98 target_model: String,
99 target_table: String,
100 is_array: bool,
101 fields: Vec<String>,
102 references: Vec<String>,
103 fields_db: Vec<String>,
104 references_db: Vec<String>,
105}
106
107#[derive(Debug, Clone, Serialize)]
108struct FilterOperatorContext {
109 suffix: String,
110 python_type: String,
111}
112
113#[derive(Debug, Clone, Serialize)]
114struct WhereInputFieldContext {
115 name: String,
116 python_type: String,
117 operators: Vec<FilterOperatorContext>,
118}
119
120#[derive(Debug, Clone, Serialize)]
121struct CreateInputFieldContext {
122 name: String,
123 python_type: String,
124 is_required: bool,
125}
126
127#[derive(Debug, Clone, Serialize)]
128struct UpdateInputFieldContext {
129 name: String,
130 python_type: String,
131}
132
133#[derive(Debug, Clone, Serialize)]
134struct OrderByFieldContext {
135 name: String,
136}
137
138#[derive(Debug, Clone, Serialize)]
139struct IncludeFieldContext {
140 name: String,
141 target_model: String,
142 target_snake: String,
144 is_array: bool,
146}
147
148#[derive(Debug, Clone, Serialize)]
150struct AggregateFieldContext {
151 name: String,
152 python_type: String,
153}
154
155pub fn generate_python_model(
161 model: &ModelIr,
162 ir: &SchemaIr,
163 is_async: bool,
164 recursive_type_depth: usize,
165) -> (String, String) {
166 let mut context = Context::new();
167
168 context.insert("model_name", &model.logical_name);
170 context.insert("snake_name", &model.logical_name.to_snake_case());
171 context.insert("table_name", &model.db_name);
172 context.insert("delegate_name", &format!("{}Delegate", model.logical_name));
173 context.insert("find_many_name", &format!("{}FindMany", model.logical_name));
174 context.insert("create_name", &format!("{}Create", model.logical_name));
175 context.insert(
176 "create_many_name",
177 &format!("{}CreateMany", model.logical_name),
178 );
179 context.insert("update_name", &format!("{}Update", model.logical_name));
180 context.insert("delete_name", &format!("{}Delete", model.logical_name));
181
182 let pk_field_names = model.primary_key.fields();
184 context.insert("primary_key_fields", &pk_field_names);
185
186 let mut enum_imports = HashSet::new();
190 let mut composite_type_imports = HashSet::new();
191 let mut has_datetime = false;
192 let mut has_uuid = false;
193 let mut has_decimal = false;
194 let mut has_dict = false;
195
196 let mut scalar_fields: Vec<PythonFieldContext> = Vec::new();
197 let mut create_fields: Vec<PythonFieldContext> = Vec::new();
198 let mut where_input_fields: Vec<WhereInputFieldContext> = Vec::new();
199 let mut create_input_fields: Vec<CreateInputFieldContext> = Vec::new();
200 let mut update_input_fields: Vec<UpdateInputFieldContext> = Vec::new();
201 let mut order_by_fields: Vec<OrderByFieldContext> = Vec::new();
202 let mut numeric_fields: Vec<AggregateFieldContext> = Vec::new();
203 let mut orderable_fields: Vec<AggregateFieldContext> = Vec::new();
204 let mut updated_at_field_names: Vec<String> = Vec::new();
205
206 for (idx, field) in model.scalar_fields().enumerate() {
207 use nautilus_schema::ir::ScalarType;
208
209 match &field.field_type {
211 ResolvedFieldType::Enum { enum_name } => {
212 if ir.enums.contains_key(enum_name) {
213 enum_imports.insert(enum_name.clone());
214 }
215 }
216 ResolvedFieldType::CompositeType { type_name } => {
217 if ir.composite_types.contains_key(type_name) {
218 composite_type_imports.insert(type_name.clone());
219 }
220 }
221 ResolvedFieldType::Scalar(scalar) => match scalar {
222 ScalarType::DateTime => has_datetime = true,
223 ScalarType::Uuid => has_uuid = true,
224 ScalarType::Decimal { .. } => has_decimal = true,
225 ScalarType::Json => has_dict = true,
226 _ => {}
227 },
228 _ => {}
229 }
230
231 let python_type = field_to_python_type(field, &ir.enums);
233 let base_type = match &field.field_type {
234 ResolvedFieldType::Scalar(s) => {
235 crate::python::type_mapper::scalar_to_python_type(s).to_string()
236 }
237 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
238 _ => "Any".to_string(),
239 };
240 let base_python_type = get_base_python_type(field, &ir.enums);
241 let is_enum = matches!(field.field_type, ResolvedFieldType::Enum { .. });
242 let auto_generated = is_auto_generated(field);
243
244 let mut default_val = get_default_value(field);
246 if let Some(ref def) = default_val {
247 if let ResolvedFieldType::Enum { enum_name } = &field.field_type {
248 if !def.contains('.') && !def.contains('(') && def != "None" {
249 default_val = Some(format!("{}.{}", enum_name, def));
250 }
251 }
252 }
253
254 let field_ctx = PythonFieldContext {
256 name: field.logical_name.to_snake_case(),
257 logical_name: field.logical_name.clone(),
258 db_name: field.db_name.clone(),
259 python_type: python_type.clone(),
260 base_type,
261 is_optional: !field.is_required,
262 is_array: field.is_array,
263 is_enum,
264 has_default: default_val.is_some(),
265 default: default_val.unwrap_or_default(),
266 index: idx,
267 };
268
269 create_fields.push(field_ctx.clone());
270
271 scalar_fields.push(field_ctx);
272
273 if !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
275 let operators = get_filter_operators_for_field(field, &ir.enums);
276 where_input_fields.push(WhereInputFieldContext {
277 name: field.logical_name.clone(),
278 python_type: base_python_type.clone(),
279 operators: operators
280 .into_iter()
281 .map(|op| FilterOperatorContext {
282 suffix: op.suffix,
283 python_type: op.type_name,
284 })
285 .collect(),
286 });
287 }
288
289 {
291 let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
292 {
293 "dict".to_string()
294 } else {
295 base_python_type.clone()
296 };
297 let typed = if field.is_array {
298 format!("List[{}]", input_base)
299 } else {
300 input_base
301 };
302 create_input_fields.push(CreateInputFieldContext {
303 name: field.logical_name.clone(),
304 python_type: typed,
305 is_required: field.is_required
306 && field.default_value.is_none()
307 && !field.is_updated_at
308 && field.computed.is_none(),
309 });
310 }
311
312 let is_auto_pk = auto_generated && pk_field_names.contains(&field.logical_name.as_str());
314 if !is_auto_pk {
315 let input_base = if matches!(field.field_type, ResolvedFieldType::CompositeType { .. })
316 {
317 "dict".to_string()
318 } else {
319 base_python_type.clone()
320 };
321 let typed = if field.is_array {
322 format!("List[{}]", input_base)
323 } else {
324 input_base
325 };
326 update_input_fields.push(UpdateInputFieldContext {
327 name: field.logical_name.clone(),
328 python_type: typed,
329 });
330 }
331
332 order_by_fields.push(OrderByFieldContext {
334 name: field.logical_name.clone(),
335 });
336
337 let is_numeric = matches!(
339 &field.field_type,
340 ResolvedFieldType::Scalar(ScalarType::Int)
341 | ResolvedFieldType::Scalar(ScalarType::BigInt)
342 | ResolvedFieldType::Scalar(ScalarType::Float)
343 | ResolvedFieldType::Scalar(ScalarType::Decimal { .. })
344 );
345 if is_numeric {
346 numeric_fields.push(AggregateFieldContext {
347 name: field.logical_name.clone(),
348 python_type: base_python_type.clone(),
349 });
350 }
351
352 let is_non_orderable = matches!(
353 &field.field_type,
354 ResolvedFieldType::Scalar(ScalarType::Boolean)
355 | ResolvedFieldType::Scalar(ScalarType::Json)
356 | ResolvedFieldType::Scalar(ScalarType::Bytes)
357 );
358 if !is_non_orderable {
359 orderable_fields.push(AggregateFieldContext {
360 name: field.logical_name.clone(),
361 python_type: base_python_type,
362 });
363 }
364
365 if field.is_updated_at {
367 updated_at_field_names.push(field.logical_name.clone());
368 }
369 }
370
371 let mut relation_imports = HashSet::new();
373 for field in model.relation_fields() {
374 if let ResolvedFieldType::Relation(rel) = &field.field_type {
375 relation_imports.insert(rel.target_model.clone());
376 }
377 }
378
379 context.insert("has_datetime", &has_datetime);
380 context.insert("has_uuid", &has_uuid);
381 context.insert("has_decimal", &has_decimal);
382 context.insert("has_dict", &has_dict);
383 context.insert("has_enums", &!enum_imports.is_empty());
384 context.insert(
385 "enum_imports",
386 &enum_imports.into_iter().collect::<Vec<_>>(),
387 );
388 context.insert("has_composite_types", &!composite_type_imports.is_empty());
389 context.insert(
390 "composite_type_imports",
391 &composite_type_imports.into_iter().collect::<Vec<_>>(),
392 );
393 context.insert("has_relations", &!relation_imports.is_empty());
394 context.insert(
395 "relation_imports",
396 &relation_imports.into_iter().collect::<Vec<_>>(),
397 );
398
399 let relation_fields: Vec<PythonFieldContext> = model
401 .relation_fields()
402 .enumerate()
403 .map(|(idx, field)| {
404 let python_type = field_to_python_type(field, &ir.enums);
405 let default_val = if field.is_array {
406 "Field(default_factory=list)".to_string()
407 } else {
408 "None".to_string()
409 };
410
411 PythonFieldContext {
412 name: field.logical_name.to_snake_case(),
413 logical_name: field.logical_name.clone(),
414 db_name: field.db_name.clone(),
415 python_type: python_type.clone(),
416 base_type: String::new(),
417 is_optional: true,
418 is_array: field.is_array,
419 is_enum: false,
420 has_default: true,
421 default: default_val,
422 index: idx,
423 }
424 })
425 .collect();
426
427 let relations: Vec<PythonRelationContext> = model
429 .relation_fields()
430 .filter_map(|field| {
431 if let ResolvedFieldType::Relation(rel) = &field.field_type {
432 if let Some(target_model) = ir.models.get(&rel.target_model) {
433 let (fields, references) = if rel.fields.is_empty() {
434 let inverse = target_model.relation_fields().find(|f| {
436 if let ResolvedFieldType::Relation(inv_rel) = &f.field_type {
437 inv_rel.target_model == model.logical_name
438 } else {
439 false
440 }
441 });
442
443 if let Some(inverse_field) = inverse {
444 if let ResolvedFieldType::Relation(inv_rel) = &inverse_field.field_type
445 {
446 (inv_rel.references.clone(), inv_rel.fields.clone())
447 } else {
448 (vec![], vec![])
449 }
450 } else {
451 (vec![], vec![])
452 }
453 } else {
454 (rel.fields.clone(), rel.references.clone())
455 };
456
457 let fields_db: Vec<String> = fields
458 .iter()
459 .filter_map(|logical_name| {
460 model
461 .fields
462 .iter()
463 .find(|f| &f.logical_name == logical_name)
464 .map(|f| f.db_name.clone())
465 })
466 .collect();
467
468 let references_db: Vec<String> = references
469 .iter()
470 .filter_map(|logical_name| {
471 target_model
472 .fields
473 .iter()
474 .find(|f| &f.logical_name == logical_name)
475 .map(|f| f.db_name.clone())
476 })
477 .collect();
478
479 Some(PythonRelationContext {
480 field_name: field.logical_name.to_snake_case(),
481 target_model: rel.target_model.clone(),
482 target_table: target_model.db_name.clone(),
483 is_array: field.is_array,
484 fields,
485 references,
486 fields_db,
487 references_db,
488 })
489 } else {
490 None
491 }
492 } else {
493 None
494 }
495 })
496 .collect();
497
498 let include_fields: Vec<IncludeFieldContext> = model
500 .relation_fields()
501 .filter_map(|field| {
502 if let ResolvedFieldType::Relation(rel) = &field.field_type {
503 Some(IncludeFieldContext {
504 name: field.logical_name.to_snake_case(),
505 target_model: rel.target_model.clone(),
506 target_snake: rel.target_model.to_snake_case(),
507 is_array: field.is_array,
508 })
509 } else {
510 None
511 }
512 })
513 .collect();
514
515 let has_numeric_fields = !numeric_fields.is_empty();
516 let has_orderable_fields = !orderable_fields.is_empty();
517
518 let needs_typeddict = !where_input_fields.is_empty()
519 || !create_input_fields.is_empty()
520 || !update_input_fields.is_empty();
521
522 context.insert("needs_typeddict", &needs_typeddict);
523 context.insert("where_input_fields", &where_input_fields);
524 context.insert("create_input_fields", &create_input_fields);
525 context.insert("update_input_fields", &update_input_fields);
526 context.insert("updated_at_fields", &updated_at_field_names);
527 context.insert("order_by_fields", &order_by_fields);
528 context.insert("include_fields", &include_fields);
529 context.insert("has_includes", &!include_fields.is_empty());
530 context.insert("numeric_fields", &numeric_fields);
531 context.insert("orderable_fields", &orderable_fields);
532 context.insert("has_numeric_fields", &has_numeric_fields);
533 context.insert("has_orderable_fields", &has_orderable_fields);
534
535 context.insert("scalar_fields", &scalar_fields);
536 context.insert("relation_fields", &relation_fields);
537 context.insert("create_fields", &create_fields);
538 context.insert("relations", &relations);
539 context.insert("is_async", &is_async);
540 context.insert("recursive_type_depth", &recursive_type_depth);
541
542 let model_code = render("model_file.py.tera", &context);
544
545 (
546 format!("{}.py", model.logical_name.to_snake_case()),
547 model_code,
548 )
549}
550
551pub fn generate_all_python_models(
556 ir: &SchemaIr,
557 is_async: bool,
558 recursive_type_depth: usize,
559) -> Vec<(String, String)> {
560 ir.models
561 .values()
562 .map(|model| generate_python_model(model, ir, is_async, recursive_type_depth))
563 .collect()
564}
565
566pub fn generate_python_composite_types(
570 composite_types: &HashMap<String, CompositeTypeIr>,
571) -> Option<String> {
572 if composite_types.is_empty() {
573 return None;
574 }
575
576 #[derive(Serialize)]
577 struct CompositeFieldCtx {
578 name: String,
579 python_type: String,
580 }
581
582 #[derive(Serialize)]
583 struct CompositeTypeCtx {
584 name: String,
585 fields: Vec<CompositeFieldCtx>,
586 }
587
588 let mut type_list: Vec<CompositeTypeCtx> = composite_types
589 .values()
590 .map(|ct| {
591 let fields = ct
592 .fields
593 .iter()
594 .map(|f| {
595 let base = match &f.field_type {
596 ResolvedFieldType::Scalar(s) => {
597 crate::python::type_mapper::scalar_to_python_type(s).to_string()
598 }
599 ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
600 ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
601 ResolvedFieldType::Relation(_) => "Any".to_string(),
602 };
603 let python_type = if f.is_array {
604 format!("List[{}]", base)
605 } else if !f.is_required {
606 format!("Optional[{}]", base)
607 } else {
608 base
609 };
610 CompositeFieldCtx {
611 name: f.logical_name.to_snake_case(),
612 python_type,
613 }
614 })
615 .collect();
616 CompositeTypeCtx {
617 name: ct.logical_name.clone(),
618 fields,
619 }
620 })
621 .collect();
622 type_list.sort_by(|a, b| a.name.cmp(&b.name));
623
624 let mut context = Context::new();
625 context.insert("composite_types", &type_list);
626
627 Some(render("composite_types.py.tera", &context))
628}
629
630pub fn generate_python_enums(enums: &HashMap<String, EnumIr>) -> String {
632 let mut context = Context::new();
633
634 #[derive(Serialize)]
635 struct EnumContext {
636 name: String,
637 variants: Vec<String>,
638 }
639
640 let enum_contexts: Vec<EnumContext> = enums
641 .values()
642 .map(|e| EnumContext {
643 name: e.logical_name.clone(),
644 variants: e.variants.clone(),
645 })
646 .collect();
647
648 context.insert("enums", &enum_contexts);
649
650 render("enums.py.tera", &context)
651}
652
653pub fn generate_python_client(
658 models: &HashMap<String, ModelIr>,
659 schema_path: &str,
660 is_async: bool,
661) -> String {
662 let mut context = Context::new();
663
664 #[derive(Serialize)]
665 struct ModelContext {
666 snake_name: String,
667 delegate_name: String,
668 }
669
670 let mut model_contexts: Vec<ModelContext> = models
671 .values()
672 .map(|m| ModelContext {
673 snake_name: m.logical_name.to_snake_case(),
674 delegate_name: format!("{}Delegate", m.logical_name),
675 })
676 .collect();
677 model_contexts.sort_by(|a, b| a.snake_name.cmp(&b.snake_name));
678
679 context.insert("models", &model_contexts);
680 context.insert("schema_path", schema_path);
681 context.insert("is_async", &is_async);
682
683 render("client.py.tera", &context)
684}
685
686pub fn generate_package_init(has_enums: bool) -> String {
688 let mut context = Context::new();
689 context.insert("has_enums", &has_enums);
690
691 render("package_init.py.tera", &context)
692}
693
694pub fn generate_models_init(models: &[(String, String)]) -> String {
696 let mut context = Context::new();
697
698 let mut model_modules: Vec<String> = models
699 .iter()
700 .map(|(file_name, _)| file_name.trim_end_matches(".py").to_string())
701 .collect();
702 model_modules.sort();
703
704 let mut model_classes: Vec<String> = model_modules.iter().map(|m| m.to_pascal_case()).collect();
705 model_classes.sort();
706
707 context.insert("model_modules", &model_modules);
708 context.insert("model_classes", &model_classes);
709
710 render("models_init.py.tera", &context)
711}
712
713pub fn generate_enums_init(has_enums: bool) -> String {
715 let mut context = Context::new();
716 context.insert("has_enums", &has_enums);
717
718 render("enums_init.py.tera", &context)
719}
720
721pub fn generate_errors_init() -> &'static str {
725 include_str!("../../templates/python/errors_init.py.tera")
726}
727
728pub fn generate_internal_init() -> &'static str {
732 include_str!("../../templates/python/internal_init.py.tera")
733}
734
735pub fn generate_transaction_init() -> &'static str {
741 include_str!("../../templates/python/transaction_init.py.tera")
742}
743
744pub fn python_runtime_files() -> Vec<(&'static str, &'static str)> {
747 vec![
748 (
749 "_errors.py",
750 include_str!("../../templates/python/runtime/_errors.py"),
751 ),
752 (
753 "_protocol.py",
754 include_str!("../../templates/python/runtime/_protocol.py"),
755 ),
756 (
757 "_engine.py",
758 include_str!("../../templates/python/runtime/_engine.py"),
759 ),
760 (
761 "_client.py",
762 include_str!("../../templates/python/runtime/_client.py"),
763 ),
764 (
765 "_descriptors.py",
766 include_str!("../../templates/python/runtime/_descriptors.py"),
767 ),
768 (
769 "_transaction.py",
770 include_str!("../../templates/python/runtime/_transaction.py"),
771 ),
772 ]
773}