Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3
4/// Returns true if any variant of the enum has data fields.
5/// These enums cannot be represented as flat integer enums in bindings.
6pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
7    enum_def.variants.iter().any(|v| !v.fields.is_empty())
8}
9
10/// Returns true if any variant of the enum has a sanitized field.
11///
12/// A sanitized field means the extractor could not resolve the field's concrete type
13/// (e.g. a tuple like `Vec<(String, String)>` that has no direct IR representation).
14/// When this is true the `#[new]` constructor that round-trips via serde/JSON cannot
15/// be generated, because the Python-dict → JSON → core deserialization path would not
16/// produce a valid value for the sanitized field. The forwarding trait impls
17/// (`Default`, `Serialize`, `Deserialize`) are still generated unconditionally since
18/// the wrapper struct always delegates to the core type.
19fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
20    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
21}
22
23/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
24///
25/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
26/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
27/// serializes it to JSON, and deserializes into the core Rust type via serde.
28///
29/// When any variant field is sanitized (its type could not be resolved — e.g. contains
30/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
31/// based `#[new]` constructor is omitted. The type is still useful as a return value
32/// from Rust (passed back via From impls). The forwarding impls for Default, Serialize,
33/// and Deserialize are always generated regardless of sanitized fields, because the
34/// wrapper struct always delegates to the core type which implements those traits.
35pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
36    let name = &enum_def.name;
37    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
38    let has_sanitized = enum_has_sanitized_fields(enum_def);
39    let string_methods_content = crate::template_env::render(
40        "generators/enums/enum_string_methods.jinja",
41        minijinja::context! {
42            name => name,
43            value_expr => "&self.inner",
44        },
45    );
46
47    let mut variant_accessors = String::new();
48    write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
49
50    let mut serde_tag_content = String::new();
51    if let Some(tag_field) = &enum_def.serde_tag {
52        write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
53    }
54
55    crate::template_env::render(
56        "generators/enums/pyo3_data_enum.jinja",
57        minijinja::context! {
58            name => name,
59            core_path => core_path,
60            has_sanitized => has_sanitized,
61            string_methods_content => string_methods_content,
62            variant_accessors_content => variant_accessors,
63            serde_tag_content => serde_tag_content,
64        },
65    )
66}
67
68/// Convert a Rust PascalCase variant name to `UPPER_SNAKE_CASE` for PyO3 `#[pyo3(name = "...")]`.
69///
70/// Handles acronym-style names where 2+ leading uppercase characters are followed only by
71/// lowercase letters (e.g. `RDFa` → `RDFA` instead of `RD_FA`). For Python-keyword variants
72/// whose Rust identifier was appended with `_` (e.g. `None_`), the screaming form preserves
73/// the trailing underscore (`NONE_`) so `setattr`-based aliases in `options.py` continue to
74/// work correctly.
75fn to_pyo3_screaming(name: &str) -> String {
76    use heck::ToShoutySnakeCase;
77    let chars: Vec<char> = name.chars().collect();
78    let upper_prefix_len = chars.iter().take_while(|c| c.is_uppercase()).count();
79    // Acronym: 2+ leading uppercase chars with only lowercase (or empty) remainder
80    if upper_prefix_len >= 2 && chars[upper_prefix_len..].iter().all(|c| c.is_lowercase() || *c == '_') {
81        name.to_ascii_uppercase()
82    } else {
83        name.to_shouty_snake_case()
84    }
85}
86
87/// Generate an enum.
88pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
89    // All enums are generated as unit-variant-only in the binding layer.
90    // Data variants are flattened to unit variants; the From/Into conversions
91    // handle the lossy mapping (discarding / providing defaults for field data).
92    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
93    // Binding enums always derive Default, Serialize, and Deserialize.
94    // Default: enables using unwrap_or_default() in constructors.
95    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
96    derives.push("Default");
97    derives.push("serde::Serialize");
98    derives.push("serde::Deserialize");
99
100    // Detect PyO3 context so we can rename all variants via #[pyo3(name = "UPPER_SNAKE_CASE")].
101    // PEP 8 mandates UPPER_SNAKE_CASE for enum members; pyclass variants must carry this
102    // rename so Python callers see `BatchStatus.VALIDATING` rather than `BatchStatus.Validating`.
103    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
104
105    // Determine which variant carries #[default].
106    // Prefer the variant that has is_default=true in the source (mirrors the Rust core's
107    // #[default] attribute); fall back to the first variant when none is explicitly marked.
108    let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
109
110    let variants: Vec<_> = enum_def
111        .variants
112        .iter()
113        .enumerate()
114        .map(|(idx, v)| {
115            // In pyo3 context every variant gets #[pyo3(name = "UPPER_SNAKE_CASE")] so the
116            // Python-exposed name is PEP 8-compliant. For Python-keyword variants the
117            // Rust identifier is already escaped (None → None_) so we produce "NONE_" as
118            // the screaming form of that escaped name — callers use BatchStatus.NONE.
119            let pyo3_name = if is_pyo3 {
120                to_pyo3_screaming(&v.name)
121            } else {
122                String::new()
123            };
124            minijinja::context! {
125                name => v.name.clone(),
126                idx => idx,
127                is_default => idx == default_idx,
128                has_pyo3_rename => is_pyo3,
129                pyo3_name => pyo3_name,
130                serde_rename => v.serde_rename.clone().unwrap_or_default(),
131            }
132        })
133        .collect();
134
135    let string_methods = if is_pyo3 {
136        crate::template_env::render(
137            "generators/enums/enum_string_methods.jinja",
138            minijinja::context! {
139                name => enum_def.name,
140                value_expr => "self",
141            },
142        )
143    } else {
144        String::new()
145    };
146
147    crate::template_env::render(
148        "generators/enums/enum_definition.jinja",
149        minijinja::context! {
150            enum_name => enum_def.name,
151            derives => derives.join(", "),
152            serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
153            enum_attrs => cfg.enum_attrs.to_vec(),
154            variants => variants,
155            is_pyo3 => is_pyo3,
156            string_methods => string_methods,
157        },
158    )
159}
160
161/// Rust keywords that cannot be used as bare identifiers in function names.
162const RUST_KEYWORDS: &[&str] = &[
163    "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
164    "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
165    "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
166    "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
167];
168
169/// Generate variant accessor properties for a data enum.
170/// For single-tuple variants with a Named inner type, returns the typed binding struct directly.
171/// For all other variants, returns the variant data as a Python dict, or None if not active.
172pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
173    use alef_core::ir::TypeRef;
174    
175
176    for variant in &enum_def.variants {
177        let variant_name_lower = crate::naming::pascal_to_snake(&variant.name);
178        let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
179            format!("r#{}", variant_name_lower)
180        } else {
181            variant_name_lower.clone()
182        };
183
184        if variant.fields.len() == 1 {
185            let field = &variant.fields[0];
186            let is_tuple_field = field
187                .name
188                .strip_prefix('_')
189                .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
190            if is_tuple_field {
191                if let TypeRef::Named(inner_type_name) = &field.ty {
192                    let variant_pascal = &variant.name;
193                    let clone_expr = if field.is_boxed {
194                        "(**data).clone().into()".to_string()
195                    } else {
196                        "data.clone().into()".to_string()
197                    };
198                    out.push('\n');
199                    out.push_str("    #[getter]\n");
200                    out.push_str(&crate::template_env::render(
201                        "generators/enums/getter_accessor.jinja",
202                        minijinja::context! {
203                            fn_name => &fn_name,
204                            inner_type_name => inner_type_name,
205                        },
206                    ));
207                    out.push_str("        match &self.inner {\n");
208                    out.push_str(&crate::template_env::render(
209                        "generators/enums/match_variant.jinja",
210                        minijinja::context! {
211                            core_path => &core_path,
212                            variant_pascal => variant_pascal,
213                            clone_expr => &clone_expr,
214                        },
215                    ));
216                    out.push_str("            _ => None,\n");
217                    out.push_str("        }\n");
218                    out.push_str("    }\n");
219                    continue;
220                }
221            }
222        }
223
224        out.push('\n');
225        out.push_str("    #[getter]\n");
226        out.push_str(&crate::template_env::render(
227            "generators/enums/py_dict_getter.jinja",
228            minijinja::context! {
229                fn_name => &fn_name,
230            },
231        ));
232        out.push_str("        let json = serde_json::to_value(&self.inner)\n");
233        out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
234        let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
235        out.push_str(&crate::template_env::render(
236            "generators/enums/tag_field_check.jinja",
237            minijinja::context! {
238                tag_field => tag_field,
239            },
240        ));
241        out.push_str("        let tag_value = json.get(tag_field)\n");
242        out.push_str("            .and_then(|v| v.as_str())\n");
243        out.push_str("            .unwrap_or(\"\");\n");
244        out.push_str(&crate::template_env::render(
245            "generators/enums/variant_tag_match.jinja",
246            minijinja::context! {
247                variant_name_lower => &variant_name_lower,
248            },
249        ));
250        out.push_str("            return Ok(None);\n");
251        out.push_str("        }\n");
252        out.push_str("        let json_str = json.to_string();\n");
253        out.push_str("        let json_mod = py.import(\"json\")?;\n");
254        out.push_str("        let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
255        out.push_str("        Ok(Some(py_dict.unbind()))\n");
256        out.push_str("    }\n");
257    }
258}
259
260pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
261    let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
262        format!("r#{tag_field}")
263    } else {
264        tag_field.to_string()
265    };
266    out.push('\n');
267    out.push_str("    #[getter]\n");
268    out.push_str(&crate::template_env::render(
269        "generators/enums/tag_getter_header.jinja",
270        minijinja::context! {
271            fn_name => &fn_name,
272        },
273    ));
274    out.push_str("        let json = serde_json::to_value(&self.inner)\n");
275    out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
276    out.push_str(&crate::template_env::render(
277        "generators/enums/json_get_field.jinja",
278        minijinja::context! {
279            tag_field => tag_field,
280        },
281    ));
282    out.push_str("            .and_then(|v| v.as_str())\n");
283    out.push_str("            .map(String::from)\n");
284    out.push_str(&crate::template_env::render(
285        "generators/enums/json_get_error.jinja",
286        minijinja::context! {
287            tag_field => tag_field,
288        },
289    ));
290    out.push_str("    }\n");
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::generators::AsyncPattern;
297    use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
298
299    fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
300        EnumVariant {
301            name: name.to_string(),
302            fields,
303            doc: String::new(),
304            is_default: false,
305            serde_rename: None,
306            is_tuple: false,
307        }
308    }
309
310    fn field(name: &str) -> FieldDef {
311        FieldDef {
312            name: name.to_string(),
313            ty: TypeRef::String,
314            optional: false,
315            default: None,
316            doc: String::new(),
317            sanitized: false,
318            is_boxed: false,
319            type_rust_path: None,
320            cfg: None,
321            typed_default: None,
322            core_wrapper: CoreWrapper::None,
323            vec_inner_core_wrapper: CoreWrapper::None,
324            newtype_wrapper: None,
325            serde_rename: None,
326            serde_flatten: false,
327            binding_excluded: false,
328            binding_exclusion_reason: None,
329        }
330    }
331
332    fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
333        EnumDef {
334            name: name.to_string(),
335            rust_path: format!("crate::{name}"),
336            original_rust_path: String::new(),
337            variants,
338            doc: String::new(),
339            cfg: None,
340            is_copy: false,
341            has_serde: true,
342            serde_tag: None,
343            serde_untagged: false,
344            serde_rename_all: None,
345            binding_excluded: false,
346            binding_exclusion_reason: None,
347        }
348    }
349
350    #[test]
351    fn gen_pyo3_data_enum_emits_string_methods() {
352        let generated = gen_pyo3_data_enum(
353            &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
354            "core",
355        );
356
357        assert!(
358            generated.contains("fn __str__(&self) -> PyResult<String>"),
359            "{generated}"
360        );
361        assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
362        assert!(
363            generated.contains("fn __repr__(&self) -> PyResult<String>"),
364            "{generated}"
365        );
366    }
367
368    #[test]
369    fn gen_pyo3_unit_enum_emits_string_methods() {
370        let cfg = RustBindingConfig {
371            struct_attrs: &[],
372            field_attrs: &[],
373            struct_derives: &[],
374            method_block_attr: None,
375            constructor_attr: "",
376            static_attr: None,
377            function_attr: "",
378            enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
379            enum_derives: &["Clone", "PartialEq"],
380            needs_signature: false,
381            signature_prefix: "",
382            signature_suffix: "",
383            core_import: "core",
384            async_pattern: AsyncPattern::None,
385            has_serde: true,
386            type_name_prefix: "",
387            option_duration_on_defaults: false,
388            opaque_type_names: &[],
389            skip_impl_constructor: false,
390            cast_uints_to_i32: false,
391            cast_large_ints_to_f64: false,
392            named_non_opaque_params_by_ref: false,
393            lossy_skip_types: &[],
394            serializable_opaque_type_names: &[],
395            never_skip_cfg_field_names: &[],
396        };
397        let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
398
399        assert!(
400            generated.contains("fn __str__(&self) -> PyResult<String>"),
401            "{generated}"
402        );
403        assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
404    }
405}