1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37 let name = &enum_def.name;
38 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39 let has_sanitized = enum_has_sanitized_fields(enum_def);
40 let string_methods_content = crate::template_env::render(
41 "generators/enums/enum_string_methods.jinja",
42 minijinja::context! {
43 name => name,
44 value_expr => "&self.inner",
45 },
46 );
47
48 let mut variant_accessors = String::new();
49 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
50
51 let mut serde_tag_content = String::new();
52 if let Some(tag_field) = &enum_def.serde_tag {
53 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
54 }
55
56 crate::template_env::render(
57 "generators/enums/pyo3_data_enum.jinja",
58 minijinja::context! {
59 name => name,
60 core_path => core_path,
61 has_sanitized => has_sanitized,
62 string_methods_content => string_methods_content,
63 variant_accessors_content => variant_accessors,
64 serde_tag_content => serde_tag_content,
65 },
66 )
67}
68
69pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
71 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
75 derives.push("Default");
79 derives.push("serde::Serialize");
80 derives.push("serde::Deserialize");
81
82 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
85
86 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
90
91 let variants: Vec<_> = enum_def
92 .variants
93 .iter()
94 .enumerate()
95 .map(|(idx, v)| {
96 minijinja::context! {
97 name => v.name.clone(),
98 idx => idx,
99 is_default => idx == default_idx,
100 has_pyo3_rename => is_pyo3 && PYTHON_KEYWORDS.contains(&v.name.as_str()),
101 serde_rename => v.serde_rename.clone().unwrap_or_default(),
102 }
103 })
104 .collect();
105
106 let string_methods = if is_pyo3 {
107 crate::template_env::render(
108 "generators/enums/enum_string_methods.jinja",
109 minijinja::context! {
110 name => enum_def.name,
111 value_expr => "self",
112 },
113 )
114 } else {
115 String::new()
116 };
117
118 crate::template_env::render(
119 "generators/enums/enum_definition.jinja",
120 minijinja::context! {
121 enum_name => enum_def.name,
122 derives => derives.join(", "),
123 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
124 enum_attrs => cfg.enum_attrs.to_vec(),
125 variants => variants,
126 is_pyo3 => is_pyo3,
127 string_methods => string_methods,
128 },
129 )
130}
131
132const RUST_KEYWORDS: &[&str] = &[
134 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
135 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
136 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
137 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
138];
139
140pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
144 use alef_core::ir::TypeRef;
145 use heck::ToSnakeCase;
146
147 for variant in &enum_def.variants {
148 let variant_name_lower = variant.name.to_snake_case();
149 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
150 format!("r#{}", variant_name_lower)
151 } else {
152 variant_name_lower.clone()
153 };
154
155 if variant.fields.len() == 1 {
156 let field = &variant.fields[0];
157 let is_tuple_field = field
158 .name
159 .strip_prefix('_')
160 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
161 if is_tuple_field {
162 if let TypeRef::Named(inner_type_name) = &field.ty {
163 let variant_pascal = &variant.name;
164 let clone_expr = if field.is_boxed {
165 "(**data).clone().into()".to_string()
166 } else {
167 "data.clone().into()".to_string()
168 };
169 out.push('\n');
170 out.push_str(" #[getter]\n");
171 out.push_str(&crate::template_env::render(
172 "generators/enums/getter_accessor.jinja",
173 minijinja::context! {
174 fn_name => &fn_name,
175 inner_type_name => inner_type_name,
176 },
177 ));
178 out.push_str(" match &self.inner {\n");
179 out.push_str(&crate::template_env::render(
180 "generators/enums/match_variant.jinja",
181 minijinja::context! {
182 core_path => &core_path,
183 variant_pascal => variant_pascal,
184 clone_expr => &clone_expr,
185 },
186 ));
187 out.push_str(" _ => None,\n");
188 out.push_str(" }\n");
189 out.push_str(" }\n");
190 continue;
191 }
192 }
193 }
194
195 out.push('\n');
196 out.push_str(" #[getter]\n");
197 out.push_str(&crate::template_env::render(
198 "generators/enums/py_dict_getter.jinja",
199 minijinja::context! {
200 fn_name => &fn_name,
201 },
202 ));
203 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
204 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
205 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
206 out.push_str(&crate::template_env::render(
207 "generators/enums/tag_field_check.jinja",
208 minijinja::context! {
209 tag_field => tag_field,
210 },
211 ));
212 out.push_str(" let tag_value = json.get(tag_field)\n");
213 out.push_str(" .and_then(|v| v.as_str())\n");
214 out.push_str(" .unwrap_or(\"\");\n");
215 out.push_str(&crate::template_env::render(
216 "generators/enums/variant_tag_match.jinja",
217 minijinja::context! {
218 variant_name_lower => &variant_name_lower,
219 },
220 ));
221 out.push_str(" return Ok(None);\n");
222 out.push_str(" }\n");
223 out.push_str(" let json_str = json.to_string();\n");
224 out.push_str(" let json_mod = py.import(\"json\")?;\n");
225 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
226 out.push_str(" Ok(Some(py_dict.unbind()))\n");
227 out.push_str(" }\n");
228 }
229}
230
231pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
232 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
233 format!("r#{tag_field}")
234 } else {
235 tag_field.to_string()
236 };
237 out.push('\n');
238 out.push_str(" #[getter]\n");
239 out.push_str(&crate::template_env::render(
240 "generators/enums/tag_getter_header.jinja",
241 minijinja::context! {
242 fn_name => &fn_name,
243 },
244 ));
245 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
246 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
247 out.push_str(&crate::template_env::render(
248 "generators/enums/json_get_field.jinja",
249 minijinja::context! {
250 tag_field => tag_field,
251 },
252 ));
253 out.push_str(" .and_then(|v| v.as_str())\n");
254 out.push_str(" .map(String::from)\n");
255 out.push_str(&crate::template_env::render(
256 "generators/enums/json_get_error.jinja",
257 minijinja::context! {
258 tag_field => tag_field,
259 },
260 ));
261 out.push_str(" }\n");
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::generators::AsyncPattern;
268 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
269
270 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
271 EnumVariant {
272 name: name.to_string(),
273 fields,
274 doc: String::new(),
275 is_default: false,
276 serde_rename: None,
277 is_tuple: false,
278 }
279 }
280
281 fn field(name: &str) -> FieldDef {
282 FieldDef {
283 name: name.to_string(),
284 ty: TypeRef::String,
285 optional: false,
286 default: None,
287 doc: String::new(),
288 sanitized: false,
289 is_boxed: false,
290 type_rust_path: None,
291 cfg: None,
292 typed_default: None,
293 core_wrapper: CoreWrapper::None,
294 vec_inner_core_wrapper: CoreWrapper::None,
295 newtype_wrapper: None,
296 serde_rename: None,
297 }
298 }
299
300 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
301 EnumDef {
302 name: name.to_string(),
303 rust_path: format!("crate::{name}"),
304 original_rust_path: String::new(),
305 variants,
306 doc: String::new(),
307 cfg: None,
308 is_copy: false,
309 has_serde: true,
310 serde_tag: None,
311 serde_rename_all: None,
312 }
313 }
314
315 #[test]
316 fn gen_pyo3_data_enum_emits_string_methods() {
317 let generated = gen_pyo3_data_enum(
318 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
319 "core",
320 );
321
322 assert!(
323 generated.contains("fn __str__(&self) -> PyResult<String>"),
324 "{generated}"
325 );
326 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
327 assert!(
328 generated.contains("fn __repr__(&self) -> PyResult<String>"),
329 "{generated}"
330 );
331 }
332
333 #[test]
334 fn gen_pyo3_unit_enum_emits_string_methods() {
335 let cfg = RustBindingConfig {
336 struct_attrs: &[],
337 field_attrs: &[],
338 struct_derives: &[],
339 method_block_attr: None,
340 constructor_attr: "",
341 static_attr: None,
342 function_attr: "",
343 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
344 enum_derives: &["Clone", "PartialEq"],
345 needs_signature: false,
346 signature_prefix: "",
347 signature_suffix: "",
348 core_import: "core",
349 async_pattern: AsyncPattern::None,
350 has_serde: true,
351 type_name_prefix: "",
352 option_duration_on_defaults: false,
353 opaque_type_names: &[],
354 skip_impl_constructor: false,
355 cast_uints_to_i32: false,
356 cast_large_ints_to_f64: false,
357 named_non_opaque_params_by_ref: false,
358 lossy_skip_types: &[],
359 };
360 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
361
362 assert!(
363 generated.contains("fn __str__(&self) -> PyResult<String>"),
364 "{generated}"
365 );
366 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
367 }
368}