Skip to main content

alef_codegen/generators/
enums.rs

1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3
4/// Returns true if any variant of the enum has data fields.
5/// These enums cannot be represented as flat integer enums in bindings.
6pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
7    enum_def.variants.iter().any(|v| !v.fields.is_empty())
8}
9
10/// Returns true if any variant of the enum has a sanitized field.
11///
12/// A sanitized field means the extractor could not resolve the field's concrete type
13/// (e.g. a tuple like `Vec<(String, String)>` that has no direct IR representation).
14/// When this is true the `#[new]` constructor that round-trips via serde/JSON cannot
15/// be generated, because the Python-dict → JSON → core deserialization path would not
16/// produce a valid value for the sanitized field. The forwarding trait impls
17/// (`Default`, `Serialize`, `Deserialize`) are still generated unconditionally since
18/// the wrapper struct always delegates to the core type.
19fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
20    enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
21}
22
23/// Generate a PyO3 data enum as a `#[pyclass]` struct wrapping the core type.
24///
25/// Data enums (tagged unions like `AuthConfig`) can't be flat int enums in PyO3.
26/// Instead, generate a frozen struct with `inner` that accepts a Python dict,
27/// serializes it to JSON, and deserializes into the core Rust type via serde.
28///
29/// When any variant field is sanitized (its type could not be resolved — e.g. contains
30/// `dyn Stream + Send` which is not `Serialize`/`Deserialize`/`Default`), the serde-
31/// based `#[new]` constructor is omitted. The type is still useful as a return value
32/// from Rust (passed back via From impls). The forwarding impls for Default, Serialize,
33/// and Deserialize are always generated regardless of sanitized fields, because the
34/// wrapper struct always delegates to the core type which implements those traits.
35pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
36    let name = &enum_def.name;
37    let core_path = crate::conversions::core_enum_path(enum_def, core_import);
38    let has_sanitized = enum_has_sanitized_fields(enum_def);
39    let string_methods_content = crate::template_env::render(
40        "generators/enums/enum_string_methods.jinja",
41        minijinja::context! {
42            name => name,
43            value_expr => "&self.inner",
44        },
45    );
46
47    let mut variant_accessors = String::new();
48    write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
49
50    let mut serde_tag_content = String::new();
51    if let Some(tag_field) = &enum_def.serde_tag {
52        write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
53    }
54
55    crate::template_env::render(
56        "generators/enums/pyo3_data_enum.jinja",
57        minijinja::context! {
58            name => name,
59            core_path => core_path,
60            has_sanitized => has_sanitized,
61            string_methods_content => string_methods_content,
62            variant_accessors_content => variant_accessors,
63            serde_tag_content => serde_tag_content,
64        },
65    )
66}
67
68/// Convert a Rust PascalCase variant name to `UPPER_SNAKE_CASE` for PyO3 `#[pyo3(name = "...")]`.
69///
70/// Handles acronym-style names where 2+ leading uppercase characters are followed only by
71/// lowercase letters (e.g. `RDFa` → `RDFA` instead of `RD_FA`). For Python-keyword variants
72/// whose Rust identifier was appended with `_` (e.g. `None_`), the screaming form preserves
73/// the trailing underscore (`NONE_`) so `setattr`-based aliases in `options.py` continue to
74/// work correctly.
75fn to_pyo3_screaming(name: &str) -> String {
76    use heck::ToShoutySnakeCase;
77    let chars: Vec<char> = name.chars().collect();
78    let upper_prefix_len = chars.iter().take_while(|c| c.is_uppercase()).count();
79    // Acronym: 2+ leading uppercase chars with only lowercase (or empty) remainder
80    if upper_prefix_len >= 2 && chars[upper_prefix_len..].iter().all(|c| c.is_lowercase() || *c == '_') {
81        name.to_ascii_uppercase()
82    } else {
83        name.to_shouty_snake_case()
84    }
85}
86
87/// Generate an enum.
88pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
89    // All enums are generated as unit-variant-only in the binding layer.
90    // Data variants are flattened to unit variants; the From/Into conversions
91    // handle the lossy mapping (discarding / providing defaults for field data).
92    let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
93    // Binding enums always derive Default, Serialize, and Deserialize.
94    // Default: enables using unwrap_or_default() in constructors.
95    // Serialize/Deserialize: required for FFI/type conversion across binding boundaries.
96    derives.push("Default");
97    derives.push("serde::Serialize");
98    derives.push("serde::Deserialize");
99
100    // Detect PyO3 context so we can rename all variants via #[pyo3(name = "UPPER_SNAKE_CASE")].
101    // PEP 8 mandates UPPER_SNAKE_CASE for enum members; pyclass variants must carry this
102    // rename so Python callers see `BatchStatus.VALIDATING` rather than `BatchStatus.Validating`.
103    let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
104
105    // Determine which variant carries #[default].
106    // Prefer the variant that has is_default=true in the source (mirrors the Rust core's
107    // #[default] attribute); fall back to the first variant when none is explicitly marked.
108    let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
109
110    let variants: Vec<_> = enum_def
111        .variants
112        .iter()
113        .enumerate()
114        .map(|(idx, v)| {
115            // In pyo3 context every variant gets #[pyo3(name = "UPPER_SNAKE_CASE")] so the
116            // Python-exposed name is PEP 8-compliant. For Python-keyword variants the
117            // Rust identifier is already escaped (None → None_) so we produce "NONE_" as
118            // the screaming form of that escaped name — callers use BatchStatus.NONE.
119            let pyo3_name = if is_pyo3 {
120                to_pyo3_screaming(&v.name)
121            } else {
122                String::new()
123            };
124            minijinja::context! {
125                name => v.name.clone(),
126                idx => idx,
127                is_default => idx == default_idx,
128                has_pyo3_rename => is_pyo3,
129                pyo3_name => pyo3_name,
130                serde_rename => v.serde_rename.clone().unwrap_or_default(),
131            }
132        })
133        .collect();
134
135    let string_methods = if is_pyo3 {
136        crate::template_env::render(
137            "generators/enums/enum_string_methods.jinja",
138            minijinja::context! {
139                name => enum_def.name,
140                value_expr => "self",
141            },
142        )
143    } else {
144        String::new()
145    };
146
147    crate::template_env::render(
148        "generators/enums/enum_definition.jinja",
149        minijinja::context! {
150            enum_name => enum_def.name,
151            derives => derives.join(", "),
152            serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
153            enum_attrs => cfg.enum_attrs.to_vec(),
154            variants => variants,
155            is_pyo3 => is_pyo3,
156            string_methods => string_methods,
157        },
158    )
159}
160
161/// Rust keywords that cannot be used as bare identifiers in function names.
162const RUST_KEYWORDS: &[&str] = &[
163    "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
164    "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
165    "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
166    "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
167];
168
169/// Generate variant accessor properties for a data enum.
170/// For single-tuple variants with a Named inner type, returns the typed binding struct directly.
171/// For all other variants, returns the variant data as a Python dict, or None if not active.
172pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
173    use alef_core::ir::TypeRef;
174
175    for variant in &enum_def.variants {
176        let variant_name_lower = crate::naming::pascal_to_snake(&variant.name);
177        let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
178            format!("r#{}", variant_name_lower)
179        } else {
180            variant_name_lower.clone()
181        };
182
183        if variant.fields.len() == 1 {
184            let field = &variant.fields[0];
185            let is_tuple_field = field
186                .name
187                .strip_prefix('_')
188                .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
189            if is_tuple_field {
190                if let TypeRef::Named(inner_type_name) = &field.ty {
191                    let variant_pascal = &variant.name;
192                    let clone_expr = if field.is_boxed {
193                        "(**data).clone().into()".to_string()
194                    } else {
195                        "data.clone().into()".to_string()
196                    };
197                    out.push('\n');
198                    out.push_str("    #[getter]\n");
199                    out.push_str(&crate::template_env::render(
200                        "generators/enums/getter_accessor.jinja",
201                        minijinja::context! {
202                            fn_name => &fn_name,
203                            inner_type_name => inner_type_name,
204                        },
205                    ));
206                    out.push_str("        match &self.inner {\n");
207                    out.push_str(&crate::template_env::render(
208                        "generators/enums/match_variant.jinja",
209                        minijinja::context! {
210                            core_path => &core_path,
211                            variant_pascal => variant_pascal,
212                            clone_expr => &clone_expr,
213                        },
214                    ));
215                    out.push_str("            _ => None,\n");
216                    out.push_str("        }\n");
217                    out.push_str("    }\n");
218                    continue;
219                }
220            }
221        }
222
223        out.push('\n');
224        out.push_str("    #[getter]\n");
225        out.push_str(&crate::template_env::render(
226            "generators/enums/py_dict_getter.jinja",
227            minijinja::context! {
228                fn_name => &fn_name,
229            },
230        ));
231        out.push_str("        let json = serde_json::to_value(&self.inner)\n");
232        out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
233        let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
234        out.push_str(&crate::template_env::render(
235            "generators/enums/tag_field_check.jinja",
236            minijinja::context! {
237                tag_field => tag_field,
238            },
239        ));
240        out.push_str("        let tag_value = json.get(tag_field)\n");
241        out.push_str("            .and_then(|v| v.as_str())\n");
242        out.push_str("            .unwrap_or(\"\");\n");
243        out.push_str(&crate::template_env::render(
244            "generators/enums/variant_tag_match.jinja",
245            minijinja::context! {
246                variant_name_lower => &variant_name_lower,
247            },
248        ));
249        out.push_str("            return Ok(None);\n");
250        out.push_str("        }\n");
251        out.push_str("        let json_str = json.to_string();\n");
252        out.push_str("        let json_mod = py.import(\"json\")?;\n");
253        out.push_str("        let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
254        out.push_str("        Ok(Some(py_dict.unbind()))\n");
255        out.push_str("    }\n");
256    }
257}
258
259pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
260    let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
261        format!("r#{tag_field}")
262    } else {
263        tag_field.to_string()
264    };
265    out.push('\n');
266    out.push_str("    #[getter]\n");
267    out.push_str(&crate::template_env::render(
268        "generators/enums/tag_getter_header.jinja",
269        minijinja::context! {
270            fn_name => &fn_name,
271        },
272    ));
273    out.push_str("        let json = serde_json::to_value(&self.inner)\n");
274    out.push_str("            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
275    out.push_str(&crate::template_env::render(
276        "generators/enums/json_get_field.jinja",
277        minijinja::context! {
278            tag_field => tag_field,
279        },
280    ));
281    out.push_str("            .and_then(|v| v.as_str())\n");
282    out.push_str("            .map(String::from)\n");
283    out.push_str(&crate::template_env::render(
284        "generators/enums/json_get_error.jinja",
285        minijinja::context! {
286            tag_field => tag_field,
287        },
288    ));
289    out.push_str("    }\n");
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::generators::AsyncPattern;
296    use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
297
298    fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
299        EnumVariant {
300            name: name.to_string(),
301            fields,
302            doc: String::new(),
303            is_default: false,
304            serde_rename: None,
305            is_tuple: false,
306        }
307    }
308
309    fn field(name: &str) -> FieldDef {
310        FieldDef {
311            name: name.to_string(),
312            ty: TypeRef::String,
313            optional: false,
314            default: None,
315            doc: String::new(),
316            sanitized: false,
317            is_boxed: false,
318            type_rust_path: None,
319            cfg: None,
320            typed_default: None,
321            core_wrapper: CoreWrapper::None,
322            vec_inner_core_wrapper: CoreWrapper::None,
323            newtype_wrapper: None,
324            serde_rename: None,
325            serde_flatten: false,
326            binding_excluded: false,
327            binding_exclusion_reason: None,
328        }
329    }
330
331    fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
332        EnumDef {
333            name: name.to_string(),
334            rust_path: format!("crate::{name}"),
335            original_rust_path: String::new(),
336            variants,
337            doc: String::new(),
338            cfg: None,
339            is_copy: false,
340            has_serde: true,
341            serde_tag: None,
342            serde_untagged: false,
343            serde_rename_all: None,
344            binding_excluded: false,
345            binding_exclusion_reason: None,
346        }
347    }
348
349    #[test]
350    fn gen_pyo3_data_enum_emits_string_methods() {
351        let generated = gen_pyo3_data_enum(
352            &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
353            "core",
354        );
355
356        assert!(
357            generated.contains("fn __str__(&self) -> PyResult<String>"),
358            "{generated}"
359        );
360        assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
361        assert!(
362            generated.contains("fn __repr__(&self) -> PyResult<String>"),
363            "{generated}"
364        );
365    }
366
367    #[test]
368    fn gen_pyo3_unit_enum_emits_string_methods() {
369        let cfg = RustBindingConfig {
370            struct_attrs: &[],
371            field_attrs: &[],
372            struct_derives: &[],
373            method_block_attr: None,
374            constructor_attr: "",
375            static_attr: None,
376            function_attr: "",
377            enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
378            enum_derives: &["Clone", "PartialEq"],
379            needs_signature: false,
380            signature_prefix: "",
381            signature_suffix: "",
382            core_import: "core",
383            async_pattern: AsyncPattern::None,
384            has_serde: true,
385            type_name_prefix: "",
386            option_duration_on_defaults: false,
387            opaque_type_names: &[],
388            skip_impl_constructor: false,
389            cast_uints_to_i32: false,
390            cast_large_ints_to_f64: false,
391            named_non_opaque_params_by_ref: false,
392            lossy_skip_types: &[],
393            serializable_opaque_type_names: &[],
394            never_skip_cfg_field_names: &[],
395        };
396        let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
397
398        assert!(
399            generated.contains("fn __str__(&self) -> PyResult<String>"),
400            "{generated}"
401        );
402        assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
403    }
404}