1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37 let name = &enum_def.name;
38 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39 let has_sanitized = enum_has_sanitized_fields(enum_def);
40 let string_methods_content = crate::template_env::render(
41 "generators/enums/enum_string_methods.jinja",
42 minijinja::context! {
43 name => name,
44 value_expr => "&self.inner",
45 },
46 );
47
48 let mut variant_accessors = String::new();
49 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
50
51 let mut serde_tag_content = String::new();
52 if let Some(tag_field) = &enum_def.serde_tag {
53 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
54 }
55
56 crate::template_env::render(
57 "generators/enums/pyo3_data_enum.jinja",
58 minijinja::context! {
59 name => name,
60 core_path => core_path,
61 has_sanitized => has_sanitized,
62 string_methods_content => string_methods_content,
63 variant_accessors_content => variant_accessors,
64 serde_tag_content => serde_tag_content,
65 },
66 )
67}
68
69pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
71 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
75 derives.push("Default");
79 derives.push("serde::Serialize");
80 derives.push("serde::Deserialize");
81
82 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
85
86 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
90
91 let variants: Vec<_> = enum_def
92 .variants
93 .iter()
94 .enumerate()
95 .map(|(idx, v)| {
96 minijinja::context! {
97 name => v.name.clone(),
98 idx => idx,
99 is_default => idx == default_idx,
100 has_pyo3_rename => is_pyo3 && PYTHON_KEYWORDS.contains(&v.name.as_str()),
101 serde_rename => v.serde_rename.clone().unwrap_or_default(),
102 }
103 })
104 .collect();
105
106 let string_methods = if is_pyo3 {
107 crate::template_env::render(
108 "generators/enums/enum_string_methods.jinja",
109 minijinja::context! {
110 name => enum_def.name,
111 value_expr => "self",
112 },
113 )
114 } else {
115 String::new()
116 };
117
118 crate::template_env::render(
119 "generators/enums/enum_definition.jinja",
120 minijinja::context! {
121 enum_name => enum_def.name,
122 derives => derives.join(", "),
123 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
124 enum_attrs => cfg.enum_attrs.to_vec(),
125 variants => variants,
126 is_pyo3 => is_pyo3,
127 string_methods => string_methods,
128 },
129 )
130}
131
132const RUST_KEYWORDS: &[&str] = &[
134 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
135 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
136 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
137 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
138];
139
140pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
144 use alef_core::ir::TypeRef;
145 use heck::ToSnakeCase;
146
147 for variant in &enum_def.variants {
148 let variant_name_lower = variant.name.to_snake_case();
149 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
150 format!("r#{}", variant_name_lower)
151 } else {
152 variant_name_lower.clone()
153 };
154
155 if variant.fields.len() == 1 {
156 let field = &variant.fields[0];
157 let is_tuple_field = field
158 .name
159 .strip_prefix('_')
160 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
161 if is_tuple_field {
162 if let TypeRef::Named(inner_type_name) = &field.ty {
163 let variant_pascal = &variant.name;
164 let clone_expr = if field.is_boxed {
165 "(**data).clone().into()".to_string()
166 } else {
167 "data.clone().into()".to_string()
168 };
169 out.push('\n');
170 out.push_str(" #[getter]\n");
171 out.push_str(&crate::template_env::render(
172 "generators/enums/getter_accessor.jinja",
173 minijinja::context! {
174 fn_name => &fn_name,
175 inner_type_name => inner_type_name,
176 },
177 ));
178 out.push_str(" match &self.inner {\n");
179 out.push_str(&crate::template_env::render(
180 "generators/enums/match_variant.jinja",
181 minijinja::context! {
182 core_path => &core_path,
183 variant_pascal => variant_pascal,
184 clone_expr => &clone_expr,
185 },
186 ));
187 out.push_str(" _ => None,\n");
188 out.push_str(" }\n");
189 out.push_str(" }\n");
190 continue;
191 }
192 }
193 }
194
195 out.push('\n');
196 out.push_str(" #[getter]\n");
197 out.push_str(&crate::template_env::render(
198 "generators/enums/py_dict_getter.jinja",
199 minijinja::context! {
200 fn_name => &fn_name,
201 },
202 ));
203 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
204 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
205 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
206 out.push_str(&crate::template_env::render(
207 "generators/enums/tag_field_check.jinja",
208 minijinja::context! {
209 tag_field => tag_field,
210 },
211 ));
212 out.push_str(" let tag_value = json.get(tag_field)\n");
213 out.push_str(" .and_then(|v| v.as_str())\n");
214 out.push_str(" .unwrap_or(\"\");\n");
215 out.push_str(&crate::template_env::render(
216 "generators/enums/variant_tag_match.jinja",
217 minijinja::context! {
218 variant_name_lower => &variant_name_lower,
219 },
220 ));
221 out.push_str(" return Ok(None);\n");
222 out.push_str(" }\n");
223 out.push_str(" let json_str = json.to_string();\n");
224 out.push_str(" let json_mod = py.import(\"json\")?;\n");
225 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
226 out.push_str(" Ok(Some(py_dict.unbind()))\n");
227 out.push_str(" }\n");
228 }
229}
230
231pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
232 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
233 format!("r#{tag_field}")
234 } else {
235 tag_field.to_string()
236 };
237 out.push('\n');
238 out.push_str(" #[getter]\n");
239 out.push_str(&crate::template_env::render(
240 "generators/enums/tag_getter_header.jinja",
241 minijinja::context! {
242 fn_name => &fn_name,
243 },
244 ));
245 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
246 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
247 out.push_str(&crate::template_env::render(
248 "generators/enums/json_get_field.jinja",
249 minijinja::context! {
250 tag_field => tag_field,
251 },
252 ));
253 out.push_str(" .and_then(|v| v.as_str())\n");
254 out.push_str(" .map(String::from)\n");
255 out.push_str(&crate::template_env::render(
256 "generators/enums/json_get_error.jinja",
257 minijinja::context! {
258 tag_field => tag_field,
259 },
260 ));
261 out.push_str(" }\n");
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::generators::AsyncPattern;
268 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
269
270 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
271 EnumVariant {
272 name: name.to_string(),
273 fields,
274 doc: String::new(),
275 is_default: false,
276 serde_rename: None,
277 is_tuple: false,
278 }
279 }
280
281 fn field(name: &str) -> FieldDef {
282 FieldDef {
283 name: name.to_string(),
284 ty: TypeRef::String,
285 optional: false,
286 default: None,
287 doc: String::new(),
288 sanitized: false,
289 is_boxed: false,
290 type_rust_path: None,
291 cfg: None,
292 typed_default: None,
293 core_wrapper: CoreWrapper::None,
294 vec_inner_core_wrapper: CoreWrapper::None,
295 newtype_wrapper: None,
296 serde_rename: None,
297 serde_flatten: false,
298 binding_excluded: false,
299 binding_exclusion_reason: None,
300 }
301 }
302
303 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
304 EnumDef {
305 name: name.to_string(),
306 rust_path: format!("crate::{name}"),
307 original_rust_path: String::new(),
308 variants,
309 doc: String::new(),
310 cfg: None,
311 is_copy: false,
312 has_serde: true,
313 serde_tag: None,
314 serde_untagged: false,
315 serde_rename_all: None,
316 binding_excluded: false,
317 binding_exclusion_reason: None,
318 }
319 }
320
321 #[test]
322 fn gen_pyo3_data_enum_emits_string_methods() {
323 let generated = gen_pyo3_data_enum(
324 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
325 "core",
326 );
327
328 assert!(
329 generated.contains("fn __str__(&self) -> PyResult<String>"),
330 "{generated}"
331 );
332 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
333 assert!(
334 generated.contains("fn __repr__(&self) -> PyResult<String>"),
335 "{generated}"
336 );
337 }
338
339 #[test]
340 fn gen_pyo3_unit_enum_emits_string_methods() {
341 let cfg = RustBindingConfig {
342 struct_attrs: &[],
343 field_attrs: &[],
344 struct_derives: &[],
345 method_block_attr: None,
346 constructor_attr: "",
347 static_attr: None,
348 function_attr: "",
349 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
350 enum_derives: &["Clone", "PartialEq"],
351 needs_signature: false,
352 signature_prefix: "",
353 signature_suffix: "",
354 core_import: "core",
355 async_pattern: AsyncPattern::None,
356 has_serde: true,
357 type_name_prefix: "",
358 option_duration_on_defaults: false,
359 opaque_type_names: &[],
360 skip_impl_constructor: false,
361 cast_uints_to_i32: false,
362 cast_large_ints_to_f64: false,
363 named_non_opaque_params_by_ref: false,
364 lossy_skip_types: &[],
365 serializable_opaque_type_names: &[],
366 never_skip_cfg_field_names: &[],
367 };
368 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
369
370 assert!(
371 generated.contains("fn __str__(&self) -> PyResult<String>"),
372 "{generated}"
373 );
374 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
375 }
376}