Skip to main content

nautilus_codegen/
type_helpers.rs

1//! Type mapping helpers for code generation.
2
3use nautilus_schema::ir::DefaultValue;
4use nautilus_schema::ir::{FieldIr, ResolvedFieldType, ScalarType};
5
6/// Convert a `ScalarType` to its Rust type string.
7pub(crate) fn scalar_to_rust_type(scalar: &ScalarType) -> String {
8    scalar.rust_type().to_string()
9}
10
11/// Get the base Rust type for a field without optional wrappers.
12pub(crate) fn field_to_rust_base_type(field: &FieldIr) -> String {
13    let base_type = match &field.field_type {
14        ResolvedFieldType::Scalar(scalar) => scalar_to_rust_type(scalar),
15        ResolvedFieldType::Enum { enum_name } => enum_name.clone(),
16        ResolvedFieldType::CompositeType { type_name } => type_name.clone(),
17        ResolvedFieldType::Relation(rel) => rel.target_model.clone(),
18    };
19
20    if field.is_array && !matches!(field.field_type, ResolvedFieldType::Relation(_)) {
21        format!("Vec<{}>", base_type)
22    } else {
23        base_type
24    }
25}
26
27/// Get the Rust type for a field, including Option wrapper if nullable.
28pub fn field_to_rust_type(field: &FieldIr) -> String {
29    let base_type = field_to_rust_base_type(field);
30
31    if matches!(field.field_type, ResolvedFieldType::Relation(_)) {
32        return if field.is_array {
33            format!("Vec<{}>", base_type)
34        } else {
35            format!("Option<Box<{}>>", base_type)
36        };
37    }
38
39    if !field.is_required && !field.is_array {
40        format!("Option<{}>", base_type)
41    } else {
42        base_type
43    }
44}
45
46/// Get the Rust type used by `SUM()` outputs for a numeric field.
47pub(crate) fn field_to_rust_sum_type(field: &FieldIr) -> String {
48    match &field.field_type {
49        ResolvedFieldType::Scalar(ScalarType::Int | ScalarType::BigInt) => "i64".to_string(),
50        ResolvedFieldType::Scalar(ScalarType::Float) => "f64".to_string(),
51        ResolvedFieldType::Scalar(ScalarType::Decimal { .. }) => {
52            "rust_decimal::Decimal".to_string()
53        }
54        _ => field_to_rust_base_type(field),
55    }
56}
57
58/// Get the Rust type used by `AVG()` outputs for a numeric field.
59pub(crate) fn field_to_rust_avg_type(field: &FieldIr) -> String {
60    match &field.field_type {
61        ResolvedFieldType::Scalar(ScalarType::Decimal { .. }) => {
62            "rust_decimal::Decimal".to_string()
63        }
64        _ => "f64".to_string(),
65    }
66}
67
68/// Check if a field should be auto-generated (excluded from create builders).
69pub fn is_auto_generated(field: &FieldIr) -> bool {
70    if field.computed.is_some() {
71        return true;
72    }
73    if let Some(default) = &field.default_value {
74        matches!(
75            default,
76            DefaultValue::Function(func) if func.name == "autoincrement" || func.name == "uuid"
77        )
78    } else {
79        false
80    }
81}