Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5/// Returns true if any variant of the enum has data fields.
6/// These enums cannot be represented as flat integer enums in bindings.
7pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8    enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
12///
13/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
14/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
15/// serializes it to JSON, and deserializes into the core Rust type via serde.
16pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
17    let name = &enum_def.name;
18    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
19    let mut out = String::with_capacity(512);
20
21    writeln!(out, "#[derive(Clone)]").ok();
22    writeln!(out, "#[pyclass(frozen)]").ok();
23    writeln!(out, "pub struct {name} {{").ok();
24    writeln!(out, "    pub(crate) inner: {core_path},").ok();
25    writeln!(out, "}}").ok();
26    writeln!(out).ok();
27
28    writeln!(out, "#[pymethods]").ok();
29    writeln!(out, "impl {name} {{").ok();
30    writeln!(out, "    #[new]").ok();
31    writeln!(
32        out,
33        "    fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
34    )
35    .ok();
36    writeln!(out, "        let json_mod = py.import(\"json\")?;").ok();
37    writeln!(
38        out,
39        "        let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
40    )
41    .ok();
42    writeln!(out, "        let inner: {core_path} = serde_json::from_str(&json_str)").ok();
43    writeln!(
44        out,
45        "            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
46    )
47    .ok();
48    writeln!(out, "        Ok(Self {{ inner }})").ok();
49    writeln!(out, "    }}").ok();
50    writeln!(out, "}}").ok();
51    writeln!(out).ok();
52
53    // From binding → core
54    writeln!(out, "impl From<{name}> for {core_path} {{").ok();
55    writeln!(out, "    fn from(val: {name}) -> Self {{ val.inner }}").ok();
56    writeln!(out, "}}").ok();
57    writeln!(out).ok();
58
59    // From core → binding
60    writeln!(out, "impl From<{core_path}> for {name} {{").ok();
61    writeln!(out, "    fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
62    writeln!(out, "}}").ok();
63    writeln!(out).ok();
64
65    // Serialize: forward to inner so parent structs that derive serde::Serialize compile.
66    writeln!(out, "impl serde::Serialize for {name} {{").ok();
67    writeln!(
68        out,
69        "    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
70    )
71    .ok();
72    writeln!(out, "        self.inner.serialize(serializer)").ok();
73    writeln!(out, "    }}").ok();
74    writeln!(out, "}}").ok();
75    writeln!(out).ok();
76
77    // Default: forward to inner's Default so parent structs that derive Default compile.
78    writeln!(out, "impl Default for {name} {{").ok();
79    writeln!(
80        out,
81        "    fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
82    )
83    .ok();
84    writeln!(out, "}}").ok();
85    writeln!(out).ok();
86
87    // Deserialize: forward to inner so parent structs that derive serde::Deserialize compile.
88    writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
89    writeln!(
90        out,
91        "    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
92    )
93    .ok();
94    writeln!(out, "        let inner = {core_path}::deserialize(deserializer)?;").ok();
95    writeln!(out, "        Ok(Self {{ inner }})").ok();
96    writeln!(out, "    }}").ok();
97    writeln!(out, "}}").ok();
98
99    out
100}
101
102/// Python keywords and builtins that cannot be used as variant identifiers in PyO3 enums.
103/// When a variant name matches one of these, a `#[pyo3(name = "...")]` rename attribute
104/// is emitted so the Rust identifier remains unchanged while Python sees a safe name.
105const PYTHON_KEYWORDS: &[&str] = &[
106    "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
107    "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
108    "del", "global", "nonlocal", "lambda", "assert", "type",
109];
110
111/// Generate an enum.
112pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
113    // All enums are generated as unit-variant-only in the binding layer.
114    // Data variants are flattened to unit variants; the From/Into conversions
115    // handle the lossy mapping (discarding / providing defaults for field data).
116    let mut out = String::with_capacity(512);
117    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
118    // Binding enums always derive Default, Serialize, and Deserialize.
119    // Default: enables using unwrap_or_default() in constructors.
120    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
121    derives.push("Default");
122    derives.push("serde::Serialize");
123    derives.push("serde::Deserialize");
124    if !derives.is_empty() {
125        writeln!(out, "#[derive({})]", derives.join(", ")).ok();
126    }
127    for attr in cfg.enum_attrs {
128        writeln!(out, "#[{attr}]").ok();
129    }
130    // Detect PyO3 context so we can rename Python keyword variants via #[pyo3(name = "...")].
131    // The Rust identifier stays unchanged; only the Python-exposed attribute name gets the suffix.
132    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
133    writeln!(out, "pub enum {} {{", enum_def.name).ok();
134    for (idx, variant) in enum_def.variants.iter().enumerate() {
135        if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
136            writeln!(out, "    #[pyo3(name = \"{}_\")]", variant.name).ok();
137        }
138        // Mark the first variant as #[default] so derive(Default) works
139        if idx == 0 {
140            writeln!(out, "    #[default]").ok();
141        }
142        writeln!(out, "    {} = {idx},", variant.name).ok();
143    }
144    writeln!(out, "}}").ok();
145
146    out
147}