1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37 let name = &enum_def.name;
38 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39 let has_sanitized = enum_has_sanitized_fields(enum_def);
40 let string_methods_content = crate::template_env::render(
41 "generators/enums/enum_string_methods.jinja",
42 minijinja::context! {
43 name => name,
44 value_expr => "&self.inner",
45 },
46 );
47
48 let mut variant_accessors = String::new();
49 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
50
51 let mut serde_tag_content = String::new();
52 if let Some(tag_field) = &enum_def.serde_tag {
53 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
54 }
55
56 crate::template_env::render(
57 "generators/enums/pyo3_data_enum.jinja",
58 minijinja::context! {
59 name => name,
60 core_path => core_path,
61 has_sanitized => has_sanitized,
62 string_methods_content => string_methods_content,
63 variant_accessors_content => variant_accessors,
64 serde_tag_content => serde_tag_content,
65 },
66 )
67}
68
69pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
71 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
75 derives.push("Default");
79 derives.push("serde::Serialize");
80 derives.push("serde::Deserialize");
81
82 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
85
86 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
90
91 let variants: Vec<_> = enum_def
92 .variants
93 .iter()
94 .enumerate()
95 .map(|(idx, v)| {
96 minijinja::context! {
97 name => v.name.clone(),
98 idx => idx,
99 is_default => idx == default_idx,
100 has_pyo3_rename => is_pyo3 && PYTHON_KEYWORDS.contains(&v.name.as_str()),
101 }
102 })
103 .collect();
104
105 let string_methods = if is_pyo3 {
106 crate::template_env::render(
107 "generators/enums/enum_string_methods.jinja",
108 minijinja::context! {
109 name => enum_def.name,
110 value_expr => "self",
111 },
112 )
113 } else {
114 String::new()
115 };
116
117 crate::template_env::render(
118 "generators/enums/enum_definition.jinja",
119 minijinja::context! {
120 enum_name => enum_def.name,
121 derives => derives.join(", "),
122 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
123 enum_attrs => cfg.enum_attrs.to_vec(),
124 variants => variants,
125 is_pyo3 => is_pyo3,
126 string_methods => string_methods,
127 },
128 )
129}
130
131const RUST_KEYWORDS: &[&str] = &[
133 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
134 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
135 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
136 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
137];
138
139pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
143 use alef_core::ir::TypeRef;
144 use heck::ToSnakeCase;
145
146 for variant in &enum_def.variants {
147 let variant_name_lower = variant.name.to_snake_case();
148 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
149 format!("r#{}", variant_name_lower)
150 } else {
151 variant_name_lower.clone()
152 };
153
154 if variant.fields.len() == 1 {
155 let field = &variant.fields[0];
156 let is_tuple_field = field
157 .name
158 .strip_prefix('_')
159 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
160 if is_tuple_field {
161 if let TypeRef::Named(inner_type_name) = &field.ty {
162 let variant_pascal = &variant.name;
163 let clone_expr = if field.is_boxed {
164 "(**data).clone().into()".to_string()
165 } else {
166 "data.clone().into()".to_string()
167 };
168 out.push('\n');
169 out.push_str(" #[getter]\n");
170 out.push_str(&crate::template_env::render(
171 "generators/enums/getter_accessor.jinja",
172 minijinja::context! {
173 fn_name => &fn_name,
174 inner_type_name => inner_type_name,
175 },
176 ));
177 out.push_str(" match &self.inner {\n");
178 out.push_str(&crate::template_env::render(
179 "generators/enums/match_variant.jinja",
180 minijinja::context! {
181 core_path => &core_path,
182 variant_pascal => variant_pascal,
183 clone_expr => &clone_expr,
184 },
185 ));
186 out.push_str(" _ => None,\n");
187 out.push_str(" }\n");
188 out.push_str(" }\n");
189 continue;
190 }
191 }
192 }
193
194 out.push('\n');
195 out.push_str(" #[getter]\n");
196 out.push_str(&crate::template_env::render(
197 "generators/enums/py_dict_getter.jinja",
198 minijinja::context! {
199 fn_name => &fn_name,
200 },
201 ));
202 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
203 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
204 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
205 out.push_str(&crate::template_env::render(
206 "generators/enums/tag_field_check.jinja",
207 minijinja::context! {
208 tag_field => tag_field,
209 },
210 ));
211 out.push_str(" let tag_value = json.get(tag_field)\n");
212 out.push_str(" .and_then(|v| v.as_str())\n");
213 out.push_str(" .unwrap_or(\"\");\n");
214 out.push_str(&crate::template_env::render(
215 "generators/enums/variant_tag_match.jinja",
216 minijinja::context! {
217 variant_name_lower => &variant_name_lower,
218 },
219 ));
220 out.push_str(" return Ok(None);\n");
221 out.push_str(" }\n");
222 out.push_str(" let json_str = json.to_string();\n");
223 out.push_str(" let json_mod = py.import(\"json\")?;\n");
224 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
225 out.push_str(" Ok(Some(py_dict.unbind()))\n");
226 out.push_str(" }\n");
227 }
228}
229
230pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
231 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
232 format!("r#{tag_field}")
233 } else {
234 tag_field.to_string()
235 };
236 out.push('\n');
237 out.push_str(" #[getter]\n");
238 out.push_str(&crate::template_env::render(
239 "generators/enums/tag_getter_header.jinja",
240 minijinja::context! {
241 fn_name => &fn_name,
242 },
243 ));
244 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
245 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
246 out.push_str(&crate::template_env::render(
247 "generators/enums/json_get_field.jinja",
248 minijinja::context! {
249 tag_field => tag_field,
250 },
251 ));
252 out.push_str(" .and_then(|v| v.as_str())\n");
253 out.push_str(" .map(String::from)\n");
254 out.push_str(&crate::template_env::render(
255 "generators/enums/json_get_error.jinja",
256 minijinja::context! {
257 tag_field => tag_field,
258 },
259 ));
260 out.push_str(" }\n");
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::generators::AsyncPattern;
267 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
268
269 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
270 EnumVariant {
271 name: name.to_string(),
272 fields,
273 doc: String::new(),
274 is_default: false,
275 serde_rename: None,
276 is_tuple: false,
277 }
278 }
279
280 fn field(name: &str) -> FieldDef {
281 FieldDef {
282 name: name.to_string(),
283 ty: TypeRef::String,
284 optional: false,
285 default: None,
286 doc: String::new(),
287 sanitized: false,
288 is_boxed: false,
289 type_rust_path: None,
290 cfg: None,
291 typed_default: None,
292 core_wrapper: CoreWrapper::None,
293 vec_inner_core_wrapper: CoreWrapper::None,
294 newtype_wrapper: None,
295 }
296 }
297
298 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
299 EnumDef {
300 name: name.to_string(),
301 rust_path: format!("crate::{name}"),
302 original_rust_path: String::new(),
303 variants,
304 doc: String::new(),
305 cfg: None,
306 is_copy: false,
307 has_serde: true,
308 serde_tag: None,
309 serde_rename_all: None,
310 }
311 }
312
313 #[test]
314 fn gen_pyo3_data_enum_emits_string_methods() {
315 let generated = gen_pyo3_data_enum(
316 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
317 "core",
318 );
319
320 assert!(
321 generated.contains("fn __str__(&self) -> PyResult<String>"),
322 "{generated}"
323 );
324 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
325 assert!(
326 generated.contains("fn __repr__(&self) -> PyResult<String>"),
327 "{generated}"
328 );
329 }
330
331 #[test]
332 fn gen_pyo3_unit_enum_emits_string_methods() {
333 let cfg = RustBindingConfig {
334 struct_attrs: &[],
335 field_attrs: &[],
336 struct_derives: &[],
337 method_block_attr: None,
338 constructor_attr: "",
339 static_attr: None,
340 function_attr: "",
341 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
342 enum_derives: &["Clone", "PartialEq"],
343 needs_signature: false,
344 signature_prefix: "",
345 signature_suffix: "",
346 core_import: "core",
347 async_pattern: AsyncPattern::None,
348 has_serde: true,
349 type_name_prefix: "",
350 option_duration_on_defaults: false,
351 opaque_type_names: &[],
352 skip_impl_constructor: false,
353 cast_uints_to_i32: false,
354 cast_large_ints_to_f64: false,
355 named_non_opaque_params_by_ref: false,
356 lossy_skip_types: &[],
357 };
358 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
359
360 assert!(
361 generated.contains("fn __str__(&self) -> PyResult<String>"),
362 "{generated}"
363 );
364 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
365 }
366}