Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4
5/// Returns true if any variant of the enum has data fields.
6/// These enums cannot be represented as flat integer enums in bindings.
7pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8    enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11/// Returns true if any variant of the enum has a sanitized field.
12///
13/// A sanitized field means the extractor could not resolve the field's concrete type
14/// (e.g. a tuple like `Vec<(String, String)>` that has no direct IR representation).
15/// When this is true the `#[new]` constructor that round-trips via serde/JSON cannot
16/// be generated, because the Python-dict → JSON → core deserialization path would not
17/// produce a valid value for the sanitized field. The forwarding trait impls
18/// (`Default`, `Serialize`, `Deserialize`) are still generated unconditionally since
19/// the wrapper struct always delegates to the core type.
20fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
25///
26/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
27/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
28/// serializes it to JSON, and deserializes into the core Rust type via serde.
29///
30/// When any variant field is sanitized (its type could not be resolved — e.g. contains
31/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
32/// based `#[new]` constructor is omitted. The type is still useful as a return value
33/// from Rust (passed back via From impls). The forwarding impls for Default, Serialize,
34/// and Deserialize are always generated regardless of sanitized fields, because the
35/// wrapper struct always delegates to the core type which implements those traits.
36pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37    let name = &enum_def.name;
38    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39    let has_sanitized = enum_has_sanitized_fields(enum_def);
40    let string_methods_content = crate::template_env::render(
41        "generators/enums/enum_string_methods.jinja",
42        minijinja::context! {
43            name => name,
44            value_expr => "&self.inner",
45        },
46    );
47
48    let mut variant_accessors = String::new();
49    write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
50
51    let mut serde_tag_content = String::new();
52    if let Some(tag_field) = &enum_def.serde_tag {
53        write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
54    }
55
56    crate::template_env::render(
57        "generators/enums/pyo3_data_enum.jinja",
58        minijinja::context! {
59            name => name,
60            core_path => core_path,
61            has_sanitized => has_sanitized,
62            string_methods_content => string_methods_content,
63            variant_accessors_content => variant_accessors,
64            serde_tag_content => serde_tag_content,
65        },
66    )
67}
68
69/// Generate an enum.
70pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
71    // All enums are generated as unit-variant-only in the binding layer.
72    // Data variants are flattened to unit variants; the From/Into conversions
73    // handle the lossy mapping (discarding / providing defaults for field data).
74    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
75    // Binding enums always derive Default, Serialize, and Deserialize.
76    // Default: enables using unwrap_or_default() in constructors.
77    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
78    derives.push("Default");
79    derives.push("serde::Serialize");
80    derives.push("serde::Deserialize");
81
82    // Detect PyO3 context so we can rename Python keyword variants via #[pyo3(name = "...")].
83    // The Rust identifier stays unchanged; only the Python-exposed attribute name gets the suffix.
84    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
85
86    // Determine which variant carries #[default].
87    // Prefer the variant that has is_default=true in the source (mirrors the Rust core's
88    // #[default] attribute); fall back to the first variant when none is explicitly marked.
89    let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
90
91    let variants: Vec<_> = enum_def
92        .variants
93        .iter()
94        .enumerate()
95        .map(|(idx, v)| {
96            minijinja::context! {
97                name => v.name.clone(),
98                idx => idx,
99                is_default => idx == default_idx,
100                has_pyo3_rename => is_pyo3 && PYTHON_KEYWORDS.contains(&v.name.as_str()),
101            }
102        })
103        .collect();
104
105    let string_methods = if is_pyo3 {
106        crate::template_env::render(
107            "generators/enums/enum_string_methods.jinja",
108            minijinja::context! {
109                name => enum_def.name,
110                value_expr => "self",
111            },
112        )
113    } else {
114        String::new()
115    };
116
117    crate::template_env::render(
118        "generators/enums/enum_definition.jinja",
119        minijinja::context! {
120            enum_name => enum_def.name,
121            derives => derives.join(", "),
122            serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
123            enum_attrs => cfg.enum_attrs.to_vec(),
124            variants => variants,
125            is_pyo3 => is_pyo3,
126            string_methods => string_methods,
127        },
128    )
129}
130
131/// Rust keywords that cannot be used as bare identifiers in function names.
132const RUST_KEYWORDS: &[&str] = &[
133    "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
134    "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
135    "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
136    "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
137];
138
139/// Generate variant accessor properties for a data enum.
140/// For single-tuple variants with a Named inner type, returns the typed binding struct directly.
141/// For all other variants, returns the variant data as a Python dict, or None if not active.
142pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
143    use alef_core::ir::TypeRef;
144    use heck::ToSnakeCase;
145
146    for variant in &enum_def.variants {
147        let variant_name_lower = variant.name.to_snake_case();
148        let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
149            format!("r#{}", variant_name_lower)
150        } else {
151            variant_name_lower.clone()
152        };
153
154        if variant.fields.len() == 1 {
155            let field = &variant.fields[0];
156            let is_tuple_field = field
157                .name
158                .strip_prefix('_')
159                .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
160            if is_tuple_field {
161                if let TypeRef::Named(inner_type_name) = &field.ty {
162                    let variant_pascal = &variant.name;
163                    let clone_expr = if field.is_boxed {
164                        "(**data).clone().into()".to_string()
165                    } else {
166                        "data.clone().into()".to_string()
167                    };
168                    out.push('\n');
169                    out.push_str("    #[getter]\n");
170                    out.push_str(&crate::template_env::render(
171                        "generators/enums/getter_accessor.jinja",
172                        minijinja::context! {
173                            fn_name => &fn_name,
174                            inner_type_name => inner_type_name,
175                        },
176                    ));
177                    out.push_str("        match &self.inner {\n");
178                    out.push_str(&crate::template_env::render(
179                        "generators/enums/match_variant.jinja",
180                        minijinja::context! {
181                            core_path => &core_path,
182                            variant_pascal => variant_pascal,
183                            clone_expr => &clone_expr,
184                        },
185                    ));
186                    out.push_str("            _ => None,\n");
187                    out.push_str("        }\n");
188                    out.push_str("    }\n");
189                    continue;
190                }
191            }
192        }
193
194        out.push('\n');
195        out.push_str("    #[getter]\n");
196        out.push_str(&crate::template_env::render(
197            "generators/enums/py_dict_getter.jinja",
198            minijinja::context! {
199                fn_name => &fn_name,
200            },
201        ));
202        out.push_str("        let json = serde_json::to_value(&self.inner)\n");
203        out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
204        let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
205        out.push_str(&crate::template_env::render(
206            "generators/enums/tag_field_check.jinja",
207            minijinja::context! {
208                tag_field => tag_field,
209            },
210        ));
211        out.push_str("        let tag_value = json.get(tag_field)\n");
212        out.push_str("            .and_then(|v| v.as_str())\n");
213        out.push_str("            .unwrap_or(\"\");\n");
214        out.push_str(&crate::template_env::render(
215            "generators/enums/variant_tag_match.jinja",
216            minijinja::context! {
217                variant_name_lower => &variant_name_lower,
218            },
219        ));
220        out.push_str("            return Ok(None);\n");
221        out.push_str("        }\n");
222        out.push_str("        let json_str = json.to_string();\n");
223        out.push_str("        let json_mod = py.import(\"json\")?;\n");
224        out.push_str("        let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
225        out.push_str("        Ok(Some(py_dict.unbind()))\n");
226        out.push_str("    }\n");
227    }
228}
229
230pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
231    let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
232        format!("r#{tag_field}")
233    } else {
234        tag_field.to_string()
235    };
236    out.push('\n');
237    out.push_str("    #[getter]\n");
238    out.push_str(&crate::template_env::render(
239        "generators/enums/tag_getter_header.jinja",
240        minijinja::context! {
241            fn_name => &fn_name,
242        },
243    ));
244    out.push_str("        let json = serde_json::to_value(&self.inner)\n");
245    out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
246    out.push_str(&crate::template_env::render(
247        "generators/enums/json_get_field.jinja",
248        minijinja::context! {
249            tag_field => tag_field,
250        },
251    ));
252    out.push_str("            .and_then(|v| v.as_str())\n");
253    out.push_str("            .map(String::from)\n");
254    out.push_str(&crate::template_env::render(
255        "generators/enums/json_get_error.jinja",
256        minijinja::context! {
257            tag_field => tag_field,
258        },
259    ));
260    out.push_str("    }\n");
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::generators::AsyncPattern;
267    use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
268
269    fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
270        EnumVariant {
271            name: name.to_string(),
272            fields,
273            doc: String::new(),
274            is_default: false,
275            serde_rename: None,
276            is_tuple: false,
277        }
278    }
279
280    fn field(name: &str) -> FieldDef {
281        FieldDef {
282            name: name.to_string(),
283            ty: TypeRef::String,
284            optional: false,
285            default: None,
286            doc: String::new(),
287            sanitized: false,
288            is_boxed: false,
289            type_rust_path: None,
290            cfg: None,
291            typed_default: None,
292            core_wrapper: CoreWrapper::None,
293            vec_inner_core_wrapper: CoreWrapper::None,
294            newtype_wrapper: None,
295        }
296    }
297
298    fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
299        EnumDef {
300            name: name.to_string(),
301            rust_path: format!("crate::{name}"),
302            original_rust_path: String::new(),
303            variants,
304            doc: String::new(),
305            cfg: None,
306            is_copy: false,
307            has_serde: true,
308            serde_tag: None,
309            serde_rename_all: None,
310        }
311    }
312
313    #[test]
314    fn gen_pyo3_data_enum_emits_string_methods() {
315        let generated = gen_pyo3_data_enum(
316            &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
317            "core",
318        );
319
320        assert!(
321            generated.contains("fn __str__(&self) -> PyResult<String>"),
322            "{generated}"
323        );
324        assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
325        assert!(
326            generated.contains("fn __repr__(&self) -> PyResult<String>"),
327            "{generated}"
328        );
329    }
330
331    #[test]
332    fn gen_pyo3_unit_enum_emits_string_methods() {
333        let cfg = RustBindingConfig {
334            struct_attrs: &[],
335            field_attrs: &[],
336            struct_derives: &[],
337            method_block_attr: None,
338            constructor_attr: "",
339            static_attr: None,
340            function_attr: "",
341            enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
342            enum_derives: &["Clone", "PartialEq"],
343            needs_signature: false,
344            signature_prefix: "",
345            signature_suffix: "",
346            core_import: "core",
347            async_pattern: AsyncPattern::None,
348            has_serde: true,
349            type_name_prefix: "",
350            option_duration_on_defaults: false,
351            opaque_type_names: &[],
352            skip_impl_constructor: false,
353            cast_uints_to_i32: false,
354            cast_large_ints_to_f64: false,
355            named_non_opaque_params_by_ref: false,
356            lossy_skip_types: &[],
357        };
358        let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
359
360        assert!(
361            generated.contains("fn __str__(&self) -> PyResult<String>"),
362            "{generated}"
363        );
364        assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
365    }
366}