Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5/// Returns true if any variant of the enum has data fields.
6/// These enums cannot be represented as flat integer enums in bindings.
7pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8    enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11/// Returns true if any variant of the enum has a sanitized field.
12///
13/// A sanitized field means the extractor could not resolve the field's concrete type
14/// (e.g. a tuple like `Vec<(String, String)>` that has no direct IR representation).
15/// When this is true the `#[new]` constructor that round-trips via serde/JSON cannot
16/// be generated, because the Python-dict → JSON → core deserialization path would not
17/// produce a valid value for the sanitized field. The forwarding trait impls
18/// (`Default`, `Serialize`, `Deserialize`) are still generated unconditionally since
19/// the wrapper struct always delegates to the core type.
20fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
25///
26/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
27/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
28/// serializes it to JSON, and deserializes into the core Rust type via serde.
29///
30/// When any variant field is sanitized (its type could not be resolved — e.g. contains
31/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
32/// based `#[new]` constructor is omitted. The type is still useful as a return value
33/// from Rust (passed back via From impls). The forwarding impls for Default, Serialize,
34/// and Deserialize are always generated regardless of sanitized fields, because the
35/// wrapper struct always delegates to the core type which implements those traits.
36pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37    let name = &enum_def.name;
38    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39    let has_sanitized = enum_has_sanitized_fields(enum_def);
40    let mut out = String::with_capacity(512);
41
42    writeln!(out, "#[derive(Clone)]").ok();
43    writeln!(out, "#[pyclass(frozen)]").ok();
44    writeln!(out, "pub struct {name} {{").ok();
45    writeln!(out, "    pub(crate) inner: {core_path},").ok();
46    writeln!(out, "}}").ok();
47    writeln!(out).ok();
48
49    writeln!(out, "#[pymethods]").ok();
50    writeln!(out, "impl {name} {{").ok();
51    if has_sanitized {
52        // The core type cannot be serde round-tripped from a Python dict (contains
53        // non-representable variant fields). Omit the #[new] constructor — the type
54        // is still useful as a return value from Rust passed back via From impls.
55        writeln!(out, "}}").ok();
56    } else {
57        writeln!(out, "    #[new]").ok();
58        writeln!(
59            out,
60            "    fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
61        )
62        .ok();
63        writeln!(out, "        let json_mod = py.import(\"json\")?;").ok();
64        writeln!(
65            out,
66            "        let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
67        )
68        .ok();
69        writeln!(out, "        let inner: {core_path} = serde_json::from_str(&json_str)").ok();
70        writeln!(
71            out,
72            "            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
73        )
74        .ok();
75        writeln!(out, "        Ok(Self {{ inner }})").ok();
76        writeln!(out, "    }}").ok();
77        writeln!(out, "}}").ok();
78    }
79    writeln!(out).ok();
80
81    // From binding → core
82    writeln!(out, "impl From<{name}> for {core_path} {{").ok();
83    writeln!(out, "    fn from(val: {name}) -> Self {{ val.inner }}").ok();
84    writeln!(out, "}}").ok();
85    writeln!(out).ok();
86
87    // From core → binding
88    writeln!(out, "impl From<{core_path}> for {name} {{").ok();
89    writeln!(out, "    fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
90    writeln!(out, "}}").ok();
91    writeln!(out).ok();
92
93    // Serialize: forward to inner so parent structs that derive serde::Serialize compile.
94    // Always generated — the wrapper delegates to the core type which always implements Serialize.
95    writeln!(out, "impl serde::Serialize for {name} {{").ok();
96    writeln!(
97        out,
98        "    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
99    )
100    .ok();
101    writeln!(out, "        self.inner.serialize(serializer)").ok();
102    writeln!(out, "    }}").ok();
103    writeln!(out, "}}").ok();
104    writeln!(out).ok();
105
106    // Default: forward to inner's Default so parent structs that derive Default compile.
107    // Always generated — the wrapper delegates to the core type which always implements Default.
108    writeln!(out, "impl Default for {name} {{").ok();
109    writeln!(
110        out,
111        "    fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
112    )
113    .ok();
114    writeln!(out, "}}").ok();
115    writeln!(out).ok();
116
117    // Deserialize: forward to inner so parent structs that derive serde::Deserialize compile.
118    // Always generated — the wrapper delegates to the core type which always implements Deserialize.
119    writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
120    writeln!(
121        out,
122        "    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
123    )
124    .ok();
125    writeln!(out, "        let inner = {core_path}::deserialize(deserializer)?;").ok();
126    writeln!(out, "        Ok(Self {{ inner }})").ok();
127    writeln!(out, "    }}").ok();
128    writeln!(out, "}}").ok();
129
130    out
131}
132
133/// Python keywords and builtins that cannot be used as variant identifiers in PyO3 enums.
134/// When a variant name matches one of these, a `#[pyo3(name = "...")]` rename attribute
135/// is emitted so the Rust identifier remains unchanged while Python sees a safe name.
136const PYTHON_KEYWORDS: &[&str] = &[
137    "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
138    "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
139    "del", "global", "nonlocal", "lambda", "assert", "type",
140];
141
142/// Generate an enum.
143pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
144    // All enums are generated as unit-variant-only in the binding layer.
145    // Data variants are flattened to unit variants; the From/Into conversions
146    // handle the lossy mapping (discarding / providing defaults for field data).
147    let mut out = String::with_capacity(512);
148    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
149    // Binding enums always derive Default, Serialize, and Deserialize.
150    // Default: enables using unwrap_or_default() in constructors.
151    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
152    derives.push("Default");
153    derives.push("serde::Serialize");
154    derives.push("serde::Deserialize");
155    if !derives.is_empty() {
156        writeln!(out, "#[derive({})]", derives.join(", ")).ok();
157    }
158    for attr in cfg.enum_attrs {
159        writeln!(out, "#[{attr}]").ok();
160    }
161    // Detect PyO3 context so we can rename Python keyword variants via #[pyo3(name = "...")].
162    // The Rust identifier stays unchanged; only the Python-exposed attribute name gets the suffix.
163    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
164    writeln!(out, "pub enum {} {{", enum_def.name).ok();
165    for (idx, variant) in enum_def.variants.iter().enumerate() {
166        if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
167            writeln!(out, "    #[pyo3(name = \"{}_\")]", variant.name).ok();
168        }
169        // Mark the first variant as #[default] so derive(Default) works
170        if idx == 0 {
171            writeln!(out, "    #[default]").ok();
172        }
173        writeln!(out, "    {} = {idx},", variant.name).ok();
174    }
175    writeln!(out, "}}").ok();
176
177    out
178}