Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4use std::fmt::Write;
5
6/// Returns true if any variant of the enum has data fields.
7/// These enums cannot be represented as flat integer enums in bindings.
8pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
9    enum_def.variants.iter().any(|v| !v.fields.is_empty())
10}
11
12/// Returns true if any variant of the enum has a sanitized field.
13///
14/// A sanitized field means the extractor could not resolve the field's concrete type
15/// (e.g. a tuple like `Vec<(String, String)>` that has no direct IR representation).
16/// When this is true the `#[new]` constructor that round-trips via serde/JSON cannot
17/// be generated, because the Python-dict → JSON → core deserialization path would not
18/// produce a valid value for the sanitized field. The forwarding trait impls
19/// (`Default`, `Serialize`, `Deserialize`) are still generated unconditionally since
20/// the wrapper struct always delegates to the core type.
21fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
22    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
23}
24
25/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
26///
27/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
28/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
29/// serializes it to JSON, and deserializes into the core Rust type via serde.
30///
31/// When any variant field is sanitized (its type could not be resolved — e.g. contains
32/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
33/// based `#[new]` constructor is omitted. The type is still useful as a return value
34/// from Rust (passed back via From impls). The forwarding impls for Default, Serialize,
35/// and Deserialize are always generated regardless of sanitized fields, because the
36/// wrapper struct always delegates to the core type which implements those traits.
37pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
38    let name = &enum_def.name;
39    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
40    let has_sanitized = enum_has_sanitized_fields(enum_def);
41    let mut out = String::with_capacity(512);
42
43    writeln!(out, "#[derive(Clone)]").ok();
44    writeln!(out, "#[pyclass(frozen)]").ok();
45    writeln!(out, "pub struct {name} {{").ok();
46    writeln!(out, "    pub(crate) inner: {core_path},").ok();
47    writeln!(out, "}}").ok();
48    writeln!(out).ok();
49
50    writeln!(out, "#[pymethods]").ok();
51    writeln!(out, "impl {name} {{").ok();
52    if has_sanitized {
53        // The core type cannot be serde round-tripped from a Python dict (contains
54        // non-representable variant fields). Omit the #[new] constructor — the type
55        // is still useful as a return value from Rust passed back via From impls.
56        write_pyo3_enum_string_methods(&mut out, name, "&self.inner");
57        write_pyo3_variant_accessors(&mut out, enum_def, &core_path);
58        if let Some(tag_field) = &enum_def.serde_tag {
59            write_pyo3_serde_tag_getter(&mut out, tag_field);
60        }
61        writeln!(out, "}}").ok();
62    } else {
63        writeln!(out, "    #[new]").ok();
64        writeln!(
65            out,
66            "    fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyAny>) -> PyResult<Self> {{"
67        )
68        .ok();
69        writeln!(
70            out,
71            "        // Accept either a Python dict (full tagged-union shape) or a string"
72        )
73        .ok();
74        writeln!(
75            out,
76            "        // (the unit variant name). Strings are wrapped in `\"...\"` so serde_json"
77        )
78        .ok();
79        writeln!(
80            out,
81            "        // can deserialize into a unit-variant of the tagged enum."
82        )
83        .ok();
84        writeln!(
85            out,
86            "        let json_str: String = if let Ok(s) = value.extract::<String>() {{"
87        )
88        .ok();
89        writeln!(
90            out,
91            "            serde_json::to_string(&s).map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?"
92        )
93        .ok();
94        writeln!(out, "        }} else {{").ok();
95        writeln!(out, "            let json_mod = py.import(\"json\")?;").ok();
96        writeln!(
97            out,
98            "            json_mod.call_method1(\"dumps\", (value,))?.extract()?"
99        )
100        .ok();
101        writeln!(out, "        }};").ok();
102        writeln!(out, "        let inner: {core_path} = serde_json::from_str(&json_str)").ok();
103        writeln!(
104            out,
105            "            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
106        )
107        .ok();
108        writeln!(out, "        Ok(Self {{ inner }})").ok();
109        writeln!(out, "    }}").ok();
110        write_pyo3_enum_string_methods(&mut out, name, "&self.inner");
111        write_pyo3_variant_accessors(&mut out, enum_def, &core_path);
112        if let Some(tag_field) = &enum_def.serde_tag {
113            write_pyo3_serde_tag_getter(&mut out, tag_field);
114        }
115        writeln!(out, "}}").ok();
116    }
117    writeln!(out).ok();
118
119    // From binding → core
120    writeln!(out, "impl From<{name}> for {core_path} {{").ok();
121    writeln!(out, "    fn from(val: {name}) -> Self {{ val.inner }}").ok();
122    writeln!(out, "}}").ok();
123    writeln!(out).ok();
124
125    // From core → binding
126    writeln!(out, "impl From<{core_path}> for {name} {{").ok();
127    writeln!(out, "    fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
128    writeln!(out, "}}").ok();
129    writeln!(out).ok();
130
131    // Serialize: forward to inner so parent structs that derive serde::Serialize compile.
132    // Always generated — the wrapper delegates to the core type which always implements Serialize.
133    writeln!(out, "impl serde::Serialize for {name} {{").ok();
134    writeln!(
135        out,
136        "    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
137    )
138    .ok();
139    writeln!(out, "        self.inner.serialize(serializer)").ok();
140    writeln!(out, "    }}").ok();
141    writeln!(out, "}}").ok();
142    writeln!(out).ok();
143
144    // Default: forward to inner's Default so parent structs that derive Default compile.
145    // Always generated — the wrapper delegates to the core type which always implements Default.
146    writeln!(out, "impl Default for {name} {{").ok();
147    writeln!(
148        out,
149        "    fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
150    )
151    .ok();
152    writeln!(out, "}}").ok();
153    writeln!(out).ok();
154
155    // Deserialize: forward to inner so parent structs that derive serde::Deserialize compile.
156    // Always generated — the wrapper delegates to the core type which always implements Deserialize.
157    writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
158    writeln!(
159        out,
160        "    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
161    )
162    .ok();
163    writeln!(out, "        let inner = {core_path}::deserialize(deserializer)?;").ok();
164    writeln!(out, "        Ok(Self {{ inner }})").ok();
165    writeln!(out, "    }}").ok();
166    writeln!(out, "}}").ok();
167
168    out
169}
170
171/// Generate an enum.
172pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
173    // All enums are generated as unit-variant-only in the binding layer.
174    // Data variants are flattened to unit variants; the From/Into conversions
175    // handle the lossy mapping (discarding / providing defaults for field data).
176    let mut out = String::with_capacity(512);
177    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
178    // Binding enums always derive Default, Serialize, and Deserialize.
179    // Default: enables using unwrap_or_default() in constructors.
180    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
181    derives.push("Default");
182    derives.push("serde::Serialize");
183    derives.push("serde::Deserialize");
184    if !derives.is_empty() {
185        writeln!(out, "#[derive({})]", derives.join(", ")).ok();
186    }
187    if let Some(rename_all) = &enum_def.serde_rename_all {
188        writeln!(out, "#[serde(rename_all = \"{rename_all}\")]").ok();
189    }
190    for attr in cfg.enum_attrs {
191        writeln!(out, "#[{attr}]").ok();
192    }
193    // Detect PyO3 context so we can rename Python keyword variants via #[pyo3(name = "...")].
194    // The Rust identifier stays unchanged; only the Python-exposed attribute name gets the suffix.
195    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
196    writeln!(out, "pub enum {} {{", enum_def.name).ok();
197    // Determine which variant carries #[default].
198    // Prefer the variant that has is_default=true in the source (mirrors the Rust core's
199    // #[default] attribute); fall back to the first variant when none is explicitly marked.
200    let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
201    for (idx, variant) in enum_def.variants.iter().enumerate() {
202        if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
203            writeln!(out, "    #[pyo3(name = \"{}_\")]", variant.name).ok();
204        }
205        // Mark the default variant as #[default] so derive(Default) works
206        if idx == default_idx {
207            writeln!(out, "    #[default]").ok();
208        }
209        writeln!(out, "    {} = {idx},", variant.name).ok();
210    }
211    writeln!(out, "}}").ok();
212    if is_pyo3 {
213        writeln!(out).ok();
214        writeln!(out, "#[pymethods]").ok();
215        writeln!(out, "impl {} {{", enum_def.name).ok();
216        write_pyo3_enum_string_methods(&mut out, &enum_def.name, "self");
217        writeln!(out, "}}").ok();
218    }
219
220    out
221}
222
223/// Rust keywords that cannot be used as bare identifiers in function names.
224const RUST_KEYWORDS: &[&str] = &[
225    "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
226    "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
227    "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
228    "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
229];
230
231/// Generate variant accessor properties for a data enum.
232/// For single-tuple variants with a Named inner type, returns the typed binding struct directly.
233/// For all other variants, returns the variant data as a Python dict, or None if not active.
234fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
235    use alef_core::ir::TypeRef;
236    use heck::ToSnakeCase;
237
238    for variant in &enum_def.variants {
239        let variant_name_lower = variant.name.to_snake_case();
240        // Use raw identifier syntax if variant name is a Rust keyword
241        let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
242            format!("r#{}", variant_name_lower)
243        } else {
244            variant_name_lower.clone()
245        };
246
247        // For single-tuple variants with a Named inner type, return the typed binding struct.
248        if variant.fields.len() == 1 {
249            let field = &variant.fields[0];
250            let is_tuple_field = field
251                .name
252                .strip_prefix('_')
253                .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
254            if is_tuple_field {
255                if let TypeRef::Named(inner_type_name) = &field.ty {
256                    let variant_pascal = &variant.name;
257                    writeln!(out).ok();
258                    writeln!(out, "    #[getter]").ok();
259                    writeln!(out, "    fn {fn_name}(&self) -> Option<{inner_type_name}> {{").ok();
260                    writeln!(out, "        match &self.inner {{").ok();
261                    // Use `.into()` to avoid ambiguity when the binding type has an inherent
262                    // `from()` method that would shadow the `From` trait impl.
263                    // For boxed variants, double-deref: &Box<T> → Box<T> → T, then clone.
264                    let clone_expr = if field.is_boxed {
265                        "(**data).clone().into()".to_string()
266                    } else {
267                        "data.clone().into()".to_string()
268                    };
269                    writeln!(
270                        out,
271                        "            {core_path}::{variant_pascal}(data) => Some({clone_expr}),"
272                    )
273                    .ok();
274                    writeln!(out, "            _ => None,").ok();
275                    writeln!(out, "        }}").ok();
276                    writeln!(out, "    }}").ok();
277                    continue;
278                }
279            }
280        }
281
282        // Fall back to dict-based approach for multi-field or non-Named variants.
283        writeln!(out).ok();
284        writeln!(out, "    #[getter]").ok();
285        writeln!(
286            out,
287            "    fn {fn_name}(&self, py: Python<'_>) -> PyResult<Option<pyo3::Py<pyo3::types::PyDict>>> {{"
288        )
289        .ok();
290        writeln!(out, "        // Serialize to JSON first").ok();
291        writeln!(out, "        let json = serde_json::to_value(&self.inner)").ok();
292        writeln!(
293            out,
294            "            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;"
295        )
296        .ok();
297        writeln!(out, "        // Check the tag field to see if this variant is active").ok();
298        writeln!(
299            out,
300            "        let tag_field = \"{}\";",
301            enum_def.serde_tag.as_ref().unwrap_or(&"tag".to_string())
302        )
303        .ok();
304        writeln!(out, "        let tag_value = json.get(tag_field)").ok();
305        writeln!(out, "            .and_then(|v| v.as_str())").ok();
306        writeln!(out, "            .unwrap_or(\"\");").ok();
307        writeln!(out, "        if tag_value != \"{}\" {{", variant_name_lower).ok();
308        writeln!(out, "            return Ok(None);").ok();
309        writeln!(out, "        }}").ok();
310        writeln!(out, "        // Create a Python dict from the JSON").ok();
311        writeln!(out, "        let json_str = json.to_string();").ok();
312        writeln!(out, "        let json_mod = py.import(\"json\")?;").ok();
313        writeln!(
314            out,
315            "        let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;"
316        )
317        .ok();
318        writeln!(out, "        Ok(Some(py_dict.unbind()))").ok();
319        writeln!(out, "    }}").ok();
320    }
321}
322
323fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
324    // Use raw identifier syntax if tag_field is a Rust keyword (e.g. "type" → r#type).
325    // pyo3 exposes the getter without the r# prefix, so the Python attribute name stays correct.
326    let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
327        format!("r#{tag_field}")
328    } else {
329        tag_field.to_string()
330    };
331    writeln!(out).ok();
332    writeln!(out, "    #[getter]").ok();
333    writeln!(out, "    fn {fn_name}(&self) -> pyo3::PyResult<String> {{").ok();
334    writeln!(out, "        let json = serde_json::to_value(&self.inner)").ok();
335    writeln!(
336        out,
337        "            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;"
338    )
339    .ok();
340    writeln!(out, "        json.get(\"{tag_field}\")").ok();
341    writeln!(out, "            .and_then(|v| v.as_str())").ok();
342    writeln!(out, "            .map(String::from)").ok();
343    writeln!(
344        out,
345        "            .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err(\"{tag_field} not found in serialized enum\"))"
346    )
347    .ok();
348    writeln!(out, "    }}").ok();
349}
350
351fn write_pyo3_enum_string_methods(out: &mut String, name: &str, value_expr: &str) {
352    writeln!(out).ok();
353    writeln!(out, "    fn __str__(&self) -> PyResult<String> {{").ok();
354    writeln!(
355        out,
356        "        serde_json::to_value({value_expr})\n            .map(|value| match value {{\n                serde_json::Value::String(value) => value,\n                other => other.to_string(),\n            }})\n            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(\"Failed to serialize {name}: {{e}}\")))"
357    )
358    .ok();
359    writeln!(out, "    }}").ok();
360    writeln!(out).ok();
361    writeln!(out, "    fn __repr__(&self) -> PyResult<String> {{").ok();
362    writeln!(out, "        self.__str__()").ok();
363    writeln!(out, "    }}").ok();
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::generators::AsyncPattern;
370    use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
371
372    fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
373        EnumVariant {
374            name: name.to_string(),
375            fields,
376            doc: String::new(),
377            is_default: false,
378            serde_rename: None,
379            is_tuple: false,
380        }
381    }
382
383    fn field(name: &str) -> FieldDef {
384        FieldDef {
385            name: name.to_string(),
386            ty: TypeRef::String,
387            optional: false,
388            default: None,
389            doc: String::new(),
390            sanitized: false,
391            is_boxed: false,
392            type_rust_path: None,
393            cfg: None,
394            typed_default: None,
395            core_wrapper: CoreWrapper::None,
396            vec_inner_core_wrapper: CoreWrapper::None,
397            newtype_wrapper: None,
398        }
399    }
400
401    fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
402        EnumDef {
403            name: name.to_string(),
404            rust_path: format!("crate::{name}"),
405            original_rust_path: String::new(),
406            variants,
407            doc: String::new(),
408            cfg: None,
409            is_copy: false,
410            has_serde: true,
411            serde_tag: None,
412            serde_rename_all: None,
413        }
414    }
415
416    #[test]
417    fn gen_pyo3_data_enum_emits_string_methods() {
418        let generated = gen_pyo3_data_enum(
419            &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
420            "core",
421        );
422
423        assert!(
424            generated.contains("fn __str__(&self) -> PyResult<String>"),
425            "{generated}"
426        );
427        assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
428        assert!(
429            generated.contains("fn __repr__(&self) -> PyResult<String>"),
430            "{generated}"
431        );
432    }
433
434    #[test]
435    fn gen_pyo3_unit_enum_emits_string_methods() {
436        let cfg = RustBindingConfig {
437            struct_attrs: &[],
438            field_attrs: &[],
439            struct_derives: &[],
440            method_block_attr: None,
441            constructor_attr: "",
442            static_attr: None,
443            function_attr: "",
444            enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
445            enum_derives: &["Clone", "PartialEq"],
446            needs_signature: false,
447            signature_prefix: "",
448            signature_suffix: "",
449            core_import: "core",
450            async_pattern: AsyncPattern::None,
451            has_serde: true,
452            type_name_prefix: "",
453            option_duration_on_defaults: false,
454            opaque_type_names: &[],
455            skip_impl_constructor: false,
456            cast_uints_to_i32: false,
457            cast_large_ints_to_f64: false,
458            named_non_opaque_params_by_ref: false,
459            lossy_skip_types: &[],
460        };
461        let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
462
463        assert!(
464            generated.contains("fn __str__(&self) -> PyResult<String>"),
465            "{generated}"
466        );
467        assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
468    }
469}