Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5/// Returns true if any variant of the enum has data fields.
6/// These enums cannot be represented as flat integer enums in bindings.
7pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8    enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11/// Returns true if any variant of the enum has a sanitized field.
12///
13/// A sanitized field means the extractor could not resolve the field's concrete type
14/// (e.g. a `dyn Stream` or another unsized / non-serde type). When this is true the
15/// core enum does NOT implement `Serialize`, `Deserialize`, or `Default`, so the
16/// binding must not try to forward those impls.
17fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
18    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
19}
20
21/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
22///
23/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
24/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
25/// serializes it to JSON, and deserializes into the core Rust type via serde.
26///
27/// When any variant field is sanitized (its type could not be resolved — e.g. contains
28/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
29/// based constructor and trait impls are omitted. The constructor body uses `todo!()`
30/// so the generated code still compiles while making it clear the conversion is not
31/// available at runtime.
32pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
33    let name = &enum_def.name;
34    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
35    let has_sanitized = enum_has_sanitized_fields(enum_def);
36    let mut out = String::with_capacity(512);
37
38    writeln!(out, "#[derive(Clone)]").ok();
39    writeln!(out, "#[pyclass(frozen)]").ok();
40    writeln!(out, "pub struct {name} {{").ok();
41    writeln!(out, "    pub(crate) inner: {core_path},").ok();
42    writeln!(out, "}}").ok();
43    writeln!(out).ok();
44
45    writeln!(out, "#[pymethods]").ok();
46    writeln!(out, "impl {name} {{").ok();
47    writeln!(out, "    #[new]").ok();
48    if has_sanitized {
49        // The core type cannot be serde round-tripped (contains non-Serialize variants).
50        // Emit a stub constructor that compiles but panics at runtime with a clear message.
51        writeln!(
52            out,
53            "    fn new(_py: Python<'_>, _value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
54        )
55        .ok();
56        writeln!(
57            out,
58            "        Err(pyo3::exceptions::PyNotImplementedError::new_err(\"{name} cannot be constructed from Python: the core type contains non-serializable variants\"))"
59        )
60        .ok();
61    } else {
62        writeln!(
63            out,
64            "    fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
65        )
66        .ok();
67        writeln!(out, "        let json_mod = py.import(\"json\")?;").ok();
68        writeln!(
69            out,
70            "        let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
71        )
72        .ok();
73        writeln!(out, "        let inner: {core_path} = serde_json::from_str(&json_str)").ok();
74        writeln!(
75            out,
76            "            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
77        )
78        .ok();
79        writeln!(out, "        Ok(Self {{ inner }})").ok();
80    }
81    writeln!(out, "    }}").ok();
82    writeln!(out, "}}").ok();
83    writeln!(out).ok();
84
85    // From binding → core
86    writeln!(out, "impl From<{name}> for {core_path} {{").ok();
87    writeln!(out, "    fn from(val: {name}) -> Self {{ val.inner }}").ok();
88    writeln!(out, "}}").ok();
89    writeln!(out).ok();
90
91    // From core → binding
92    writeln!(out, "impl From<{core_path}> for {name} {{").ok();
93    writeln!(out, "    fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
94    writeln!(out, "}}").ok();
95
96    if !has_sanitized {
97        writeln!(out).ok();
98
99        // Serialize: forward to inner so parent structs that derive serde::Serialize compile.
100        writeln!(out, "impl serde::Serialize for {name} {{").ok();
101        writeln!(
102            out,
103            "    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
104        )
105        .ok();
106        writeln!(out, "        self.inner.serialize(serializer)").ok();
107        writeln!(out, "    }}").ok();
108        writeln!(out, "}}").ok();
109        writeln!(out).ok();
110
111        // Default: forward to inner's Default so parent structs that derive Default compile.
112        writeln!(out, "impl Default for {name} {{").ok();
113        writeln!(
114            out,
115            "    fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
116        )
117        .ok();
118        writeln!(out, "}}").ok();
119        writeln!(out).ok();
120
121        // Deserialize: forward to inner so parent structs that derive serde::Deserialize compile.
122        writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
123        writeln!(
124            out,
125            "    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
126        )
127        .ok();
128        writeln!(out, "        let inner = {core_path}::deserialize(deserializer)?;").ok();
129        writeln!(out, "        Ok(Self {{ inner }})").ok();
130        writeln!(out, "    }}").ok();
131        writeln!(out, "}}").ok();
132    }
133
134    out
135}
136
137/// Python keywords and builtins that cannot be used as variant identifiers in PyO3 enums.
138/// When a variant name matches one of these, a `#[pyo3(name = "...")]` rename attribute
139/// is emitted so the Rust identifier remains unchanged while Python sees a safe name.
140const PYTHON_KEYWORDS: &[&str] = &[
141    "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
142    "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
143    "del", "global", "nonlocal", "lambda", "assert", "type",
144];
145
146/// Generate an enum.
147pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
148    // All enums are generated as unit-variant-only in the binding layer.
149    // Data variants are flattened to unit variants; the From/Into conversions
150    // handle the lossy mapping (discarding / providing defaults for field data).
151    let mut out = String::with_capacity(512);
152    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
153    // Binding enums always derive Default, Serialize, and Deserialize.
154    // Default: enables using unwrap_or_default() in constructors.
155    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
156    derives.push("Default");
157    derives.push("serde::Serialize");
158    derives.push("serde::Deserialize");
159    if !derives.is_empty() {
160        writeln!(out, "#[derive({})]", derives.join(", ")).ok();
161    }
162    for attr in cfg.enum_attrs {
163        writeln!(out, "#[{attr}]").ok();
164    }
165    // Detect PyO3 context so we can rename Python keyword variants via #[pyo3(name = "...")].
166    // The Rust identifier stays unchanged; only the Python-exposed attribute name gets the suffix.
167    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
168    writeln!(out, "pub enum {} {{", enum_def.name).ok();
169    for (idx, variant) in enum_def.variants.iter().enumerate() {
170        if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
171            writeln!(out, "    #[pyo3(name = \"{}_\")]", variant.name).ok();
172        }
173        // Mark the first variant as #[default] so derive(Default) works
174        if idx == 0 {
175            writeln!(out, "    #[default]").ok();
176        }
177        writeln!(out, "    {} = {idx},", variant.name).ok();
178    }
179    writeln!(out, "}}").ok();
180
181    out
182}