Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4
5/// Returns true if any variant of the enum has data fields.
6/// These enums cannot be represented as flat integer enums in bindings.
7pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8    enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11/// Returns true if any variant of the enum has a sanitized field.
12///
13/// A sanitized field means the extractor could not resolve the field's concrete type
14/// (e.g. a tuple like `Vec<(String, String)>` that has no direct IR representation).
15/// When this is true the `#[new]` constructor that round-trips via serde/JSON cannot
16/// be generated, because the Python-dict → JSON → core deserialization path would not
17/// produce a valid value for the sanitized field. The forwarding trait impls
18/// (`Default`, `Serialize`, `Deserialize`) are still generated unconditionally since
19/// the wrapper struct always delegates to the core type.
20fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
25///
26/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
27/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
28/// serializes it to JSON, and deserializes into the core Rust type via serde.
29///
30/// When any variant field is sanitized (its type could not be resolved — e.g. contains
31/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
32/// based `#[new]` constructor is omitted. The type is still useful as a return value
33/// from Rust (passed back via From impls). The forwarding impls for Default, Serialize,
34/// and Deserialize are always generated regardless of sanitized fields, because the
35/// wrapper struct always delegates to the core type which implements those traits.
36pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37    let name = &enum_def.name;
38    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39    let has_sanitized = enum_has_sanitized_fields(enum_def);
40    let string_methods_content = crate::template_env::render(
41        "generators/enums/enum_string_methods.jinja",
42        minijinja::context! {
43            name => name,
44            value_expr => "&self.inner",
45        },
46    );
47
48    let mut variant_accessors = String::new();
49    write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
50
51    let mut serde_tag_content = String::new();
52    if let Some(tag_field) = &enum_def.serde_tag {
53        write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
54    }
55
56    crate::template_env::render(
57        "generators/enums/pyo3_data_enum.jinja",
58        minijinja::context! {
59            name => name,
60            core_path => core_path,
61            has_sanitized => has_sanitized,
62            string_methods_content => string_methods_content,
63            variant_accessors_content => variant_accessors,
64            serde_tag_content => serde_tag_content,
65        },
66    )
67}
68
69/// Generate an enum.
70pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
71    // All enums are generated as unit-variant-only in the binding layer.
72    // Data variants are flattened to unit variants; the From/Into conversions
73    // handle the lossy mapping (discarding / providing defaults for field data).
74    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
75    // Binding enums always derive Default, Serialize, and Deserialize.
76    // Default: enables using unwrap_or_default() in constructors.
77    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
78    derives.push("Default");
79    derives.push("serde::Serialize");
80    derives.push("serde::Deserialize");
81
82    // Detect PyO3 context so we can rename Python keyword variants via #[pyo3(name = "...")].
83    // The Rust identifier stays unchanged; only the Python-exposed attribute name gets the suffix.
84    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
85
86    // Determine which variant carries #[default].
87    // Prefer the variant that has is_default=true in the source (mirrors the Rust core's
88    // #[default] attribute); fall back to the first variant when none is explicitly marked.
89    let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
90
91    let variants: Vec<_> = enum_def
92        .variants
93        .iter()
94        .enumerate()
95        .map(|(idx, v)| {
96            minijinja::context! {
97                name => v.name.clone(),
98                idx => idx,
99                is_default => idx == default_idx,
100                has_pyo3_rename => is_pyo3 && PYTHON_KEYWORDS.contains(&v.name.as_str()),
101                serde_rename => v.serde_rename.clone().unwrap_or_default(),
102            }
103        })
104        .collect();
105
106    let string_methods = if is_pyo3 {
107        crate::template_env::render(
108            "generators/enums/enum_string_methods.jinja",
109            minijinja::context! {
110                name => enum_def.name,
111                value_expr => "self",
112            },
113        )
114    } else {
115        String::new()
116    };
117
118    crate::template_env::render(
119        "generators/enums/enum_definition.jinja",
120        minijinja::context! {
121            enum_name => enum_def.name,
122            derives => derives.join(", "),
123            serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
124            enum_attrs => cfg.enum_attrs.to_vec(),
125            variants => variants,
126            is_pyo3 => is_pyo3,
127            string_methods => string_methods,
128        },
129    )
130}
131
132/// Rust keywords that cannot be used as bare identifiers in function names.
133const RUST_KEYWORDS: &[&str] = &[
134    "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
135    "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
136    "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
137    "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
138];
139
140/// Generate variant accessor properties for a data enum.
141/// For single-tuple variants with a Named inner type, returns the typed binding struct directly.
142/// For all other variants, returns the variant data as a Python dict, or None if not active.
143pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
144    use alef_core::ir::TypeRef;
145    use heck::ToSnakeCase;
146
147    for variant in &enum_def.variants {
148        let variant_name_lower = variant.name.to_snake_case();
149        let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
150            format!("r#{}", variant_name_lower)
151        } else {
152            variant_name_lower.clone()
153        };
154
155        if variant.fields.len() == 1 {
156            let field = &variant.fields[0];
157            let is_tuple_field = field
158                .name
159                .strip_prefix('_')
160                .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
161            if is_tuple_field {
162                if let TypeRef::Named(inner_type_name) = &field.ty {
163                    let variant_pascal = &variant.name;
164                    let clone_expr = if field.is_boxed {
165                        "(**data).clone().into()".to_string()
166                    } else {
167                        "data.clone().into()".to_string()
168                    };
169                    out.push('\n');
170                    out.push_str("    #[getter]\n");
171                    out.push_str(&crate::template_env::render(
172                        "generators/enums/getter_accessor.jinja",
173                        minijinja::context! {
174                            fn_name => &fn_name,
175                            inner_type_name => inner_type_name,
176                        },
177                    ));
178                    out.push_str("        match &self.inner {\n");
179                    out.push_str(&crate::template_env::render(
180                        "generators/enums/match_variant.jinja",
181                        minijinja::context! {
182                            core_path => &core_path,
183                            variant_pascal => variant_pascal,
184                            clone_expr => &clone_expr,
185                        },
186                    ));
187                    out.push_str("            _ => None,\n");
188                    out.push_str("        }\n");
189                    out.push_str("    }\n");
190                    continue;
191                }
192            }
193        }
194
195        out.push('\n');
196        out.push_str("    #[getter]\n");
197        out.push_str(&crate::template_env::render(
198            "generators/enums/py_dict_getter.jinja",
199            minijinja::context! {
200                fn_name => &fn_name,
201            },
202        ));
203        out.push_str("        let json = serde_json::to_value(&self.inner)\n");
204        out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
205        let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
206        out.push_str(&crate::template_env::render(
207            "generators/enums/tag_field_check.jinja",
208            minijinja::context! {
209                tag_field => tag_field,
210            },
211        ));
212        out.push_str("        let tag_value = json.get(tag_field)\n");
213        out.push_str("            .and_then(|v| v.as_str())\n");
214        out.push_str("            .unwrap_or(\"\");\n");
215        out.push_str(&crate::template_env::render(
216            "generators/enums/variant_tag_match.jinja",
217            minijinja::context! {
218                variant_name_lower => &variant_name_lower,
219            },
220        ));
221        out.push_str("            return Ok(None);\n");
222        out.push_str("        }\n");
223        out.push_str("        let json_str = json.to_string();\n");
224        out.push_str("        let json_mod = py.import(\"json\")?;\n");
225        out.push_str("        let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
226        out.push_str("        Ok(Some(py_dict.unbind()))\n");
227        out.push_str("    }\n");
228    }
229}
230
231pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
232    let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
233        format!("r#{tag_field}")
234    } else {
235        tag_field.to_string()
236    };
237    out.push('\n');
238    out.push_str("    #[getter]\n");
239    out.push_str(&crate::template_env::render(
240        "generators/enums/tag_getter_header.jinja",
241        minijinja::context! {
242            fn_name => &fn_name,
243        },
244    ));
245    out.push_str("        let json = serde_json::to_value(&self.inner)\n");
246    out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
247    out.push_str(&crate::template_env::render(
248        "generators/enums/json_get_field.jinja",
249        minijinja::context! {
250            tag_field => tag_field,
251        },
252    ));
253    out.push_str("            .and_then(|v| v.as_str())\n");
254    out.push_str("            .map(String::from)\n");
255    out.push_str(&crate::template_env::render(
256        "generators/enums/json_get_error.jinja",
257        minijinja::context! {
258            tag_field => tag_field,
259        },
260    ));
261    out.push_str("    }\n");
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::generators::AsyncPattern;
268    use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
269
270    fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
271        EnumVariant {
272            name: name.to_string(),
273            fields,
274            doc: String::new(),
275            is_default: false,
276            serde_rename: None,
277            is_tuple: false,
278        }
279    }
280
281    fn field(name: &str) -> FieldDef {
282        FieldDef {
283            name: name.to_string(),
284            ty: TypeRef::String,
285            optional: false,
286            default: None,
287            doc: String::new(),
288            sanitized: false,
289            is_boxed: false,
290            type_rust_path: None,
291            cfg: None,
292            typed_default: None,
293            core_wrapper: CoreWrapper::None,
294            vec_inner_core_wrapper: CoreWrapper::None,
295            newtype_wrapper: None,
296            serde_rename: None,
297            serde_flatten: false,
298        }
299    }
300
301    fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
302        EnumDef {
303            name: name.to_string(),
304            rust_path: format!("crate::{name}"),
305            original_rust_path: String::new(),
306            variants,
307            doc: String::new(),
308            cfg: None,
309            is_copy: false,
310            has_serde: true,
311            serde_tag: None,
312            serde_untagged: false,
313            serde_rename_all: None,
314        }
315    }
316
317    #[test]
318    fn gen_pyo3_data_enum_emits_string_methods() {
319        let generated = gen_pyo3_data_enum(
320            &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
321            "core",
322        );
323
324        assert!(
325            generated.contains("fn __str__(&self) -> PyResult<String>"),
326            "{generated}"
327        );
328        assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
329        assert!(
330            generated.contains("fn __repr__(&self) -> PyResult<String>"),
331            "{generated}"
332        );
333    }
334
335    #[test]
336    fn gen_pyo3_unit_enum_emits_string_methods() {
337        let cfg = RustBindingConfig {
338            struct_attrs: &[],
339            field_attrs: &[],
340            struct_derives: &[],
341            method_block_attr: None,
342            constructor_attr: "",
343            static_attr: None,
344            function_attr: "",
345            enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
346            enum_derives: &["Clone", "PartialEq"],
347            needs_signature: false,
348            signature_prefix: "",
349            signature_suffix: "",
350            core_import: "core",
351            async_pattern: AsyncPattern::None,
352            has_serde: true,
353            type_name_prefix: "",
354            option_duration_on_defaults: false,
355            opaque_type_names: &[],
356            skip_impl_constructor: false,
357            cast_uints_to_i32: false,
358            cast_large_ints_to_f64: false,
359            named_non_opaque_params_by_ref: false,
360            lossy_skip_types: &[],
361            serializable_opaque_type_names: &[],
362        };
363        let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
364
365        assert!(
366            generated.contains("fn __str__(&self) -> PyResult<String>"),
367            "{generated}"
368        );
369        assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
370    }
371}