1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3
4pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
7 enum_def.variants.iter().any(|v| !v.fields.is_empty())
8}
9
10fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
20 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
21}
22
23pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
36 let name = &enum_def.name;
37 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
38 let has_sanitized = enum_has_sanitized_fields(enum_def);
39 let string_methods_content = crate::template_env::render(
40 "generators/enums/enum_string_methods.jinja",
41 minijinja::context! {
42 name => name,
43 value_expr => "&self.inner",
44 },
45 );
46
47 let mut variant_accessors = String::new();
48 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
49
50 let mut serde_tag_content = String::new();
51 if let Some(tag_field) = &enum_def.serde_tag {
52 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
53 }
54
55 crate::template_env::render(
56 "generators/enums/pyo3_data_enum.jinja",
57 minijinja::context! {
58 name => name,
59 core_path => core_path,
60 has_sanitized => has_sanitized,
61 string_methods_content => string_methods_content,
62 variant_accessors_content => variant_accessors,
63 serde_tag_content => serde_tag_content,
64 },
65 )
66}
67
68fn to_pyo3_screaming(name: &str) -> String {
76 use heck::ToShoutySnakeCase;
77 let chars: Vec<char> = name.chars().collect();
78 let upper_prefix_len = chars.iter().take_while(|c| c.is_uppercase()).count();
79 if upper_prefix_len >= 2 && chars[upper_prefix_len..].iter().all(|c| c.is_lowercase() || *c == '_') {
81 name.to_ascii_uppercase()
82 } else {
83 name.to_shouty_snake_case()
84 }
85}
86
87pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
89 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
93 derives.push("Default");
97 derives.push("serde::Serialize");
98 derives.push("serde::Deserialize");
99
100 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
104
105 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
109
110 let variants: Vec<_> = enum_def
111 .variants
112 .iter()
113 .enumerate()
114 .map(|(idx, v)| {
115 let pyo3_name = if is_pyo3 {
120 to_pyo3_screaming(&v.name)
121 } else {
122 String::new()
123 };
124 minijinja::context! {
125 name => v.name.clone(),
126 idx => idx,
127 is_default => idx == default_idx,
128 has_pyo3_rename => is_pyo3,
129 pyo3_name => pyo3_name,
130 serde_rename => v.serde_rename.clone().unwrap_or_default(),
131 }
132 })
133 .collect();
134
135 let string_methods = if is_pyo3 {
136 crate::template_env::render(
137 "generators/enums/enum_string_methods.jinja",
138 minijinja::context! {
139 name => enum_def.name,
140 value_expr => "self",
141 },
142 )
143 } else {
144 String::new()
145 };
146
147 crate::template_env::render(
148 "generators/enums/enum_definition.jinja",
149 minijinja::context! {
150 enum_name => enum_def.name,
151 derives => derives.join(", "),
152 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
153 enum_attrs => cfg.enum_attrs.to_vec(),
154 variants => variants,
155 is_pyo3 => is_pyo3,
156 string_methods => string_methods,
157 },
158 )
159}
160
161const RUST_KEYWORDS: &[&str] = &[
163 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
164 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
165 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
166 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
167];
168
169pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
173 use alef_core::ir::TypeRef;
174
175
176 for variant in &enum_def.variants {
177 let variant_name_lower = crate::naming::pascal_to_snake(&variant.name);
178 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
179 format!("r#{}", variant_name_lower)
180 } else {
181 variant_name_lower.clone()
182 };
183
184 if variant.fields.len() == 1 {
185 let field = &variant.fields[0];
186 let is_tuple_field = field
187 .name
188 .strip_prefix('_')
189 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
190 if is_tuple_field {
191 if let TypeRef::Named(inner_type_name) = &field.ty {
192 let variant_pascal = &variant.name;
193 let clone_expr = if field.is_boxed {
194 "(**data).clone().into()".to_string()
195 } else {
196 "data.clone().into()".to_string()
197 };
198 out.push('\n');
199 out.push_str(" #[getter]\n");
200 out.push_str(&crate::template_env::render(
201 "generators/enums/getter_accessor.jinja",
202 minijinja::context! {
203 fn_name => &fn_name,
204 inner_type_name => inner_type_name,
205 },
206 ));
207 out.push_str(" match &self.inner {\n");
208 out.push_str(&crate::template_env::render(
209 "generators/enums/match_variant.jinja",
210 minijinja::context! {
211 core_path => &core_path,
212 variant_pascal => variant_pascal,
213 clone_expr => &clone_expr,
214 },
215 ));
216 out.push_str(" _ => None,\n");
217 out.push_str(" }\n");
218 out.push_str(" }\n");
219 continue;
220 }
221 }
222 }
223
224 out.push('\n');
225 out.push_str(" #[getter]\n");
226 out.push_str(&crate::template_env::render(
227 "generators/enums/py_dict_getter.jinja",
228 minijinja::context! {
229 fn_name => &fn_name,
230 },
231 ));
232 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
233 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
234 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
235 out.push_str(&crate::template_env::render(
236 "generators/enums/tag_field_check.jinja",
237 minijinja::context! {
238 tag_field => tag_field,
239 },
240 ));
241 out.push_str(" let tag_value = json.get(tag_field)\n");
242 out.push_str(" .and_then(|v| v.as_str())\n");
243 out.push_str(" .unwrap_or(\"\");\n");
244 out.push_str(&crate::template_env::render(
245 "generators/enums/variant_tag_match.jinja",
246 minijinja::context! {
247 variant_name_lower => &variant_name_lower,
248 },
249 ));
250 out.push_str(" return Ok(None);\n");
251 out.push_str(" }\n");
252 out.push_str(" let json_str = json.to_string();\n");
253 out.push_str(" let json_mod = py.import(\"json\")?;\n");
254 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
255 out.push_str(" Ok(Some(py_dict.unbind()))\n");
256 out.push_str(" }\n");
257 }
258}
259
260pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
261 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
262 format!("r#{tag_field}")
263 } else {
264 tag_field.to_string()
265 };
266 out.push('\n');
267 out.push_str(" #[getter]\n");
268 out.push_str(&crate::template_env::render(
269 "generators/enums/tag_getter_header.jinja",
270 minijinja::context! {
271 fn_name => &fn_name,
272 },
273 ));
274 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
275 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
276 out.push_str(&crate::template_env::render(
277 "generators/enums/json_get_field.jinja",
278 minijinja::context! {
279 tag_field => tag_field,
280 },
281 ));
282 out.push_str(" .and_then(|v| v.as_str())\n");
283 out.push_str(" .map(String::from)\n");
284 out.push_str(&crate::template_env::render(
285 "generators/enums/json_get_error.jinja",
286 minijinja::context! {
287 tag_field => tag_field,
288 },
289 ));
290 out.push_str(" }\n");
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::generators::AsyncPattern;
297 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
298
299 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
300 EnumVariant {
301 name: name.to_string(),
302 fields,
303 doc: String::new(),
304 is_default: false,
305 serde_rename: None,
306 is_tuple: false,
307 }
308 }
309
310 fn field(name: &str) -> FieldDef {
311 FieldDef {
312 name: name.to_string(),
313 ty: TypeRef::String,
314 optional: false,
315 default: None,
316 doc: String::new(),
317 sanitized: false,
318 is_boxed: false,
319 type_rust_path: None,
320 cfg: None,
321 typed_default: None,
322 core_wrapper: CoreWrapper::None,
323 vec_inner_core_wrapper: CoreWrapper::None,
324 newtype_wrapper: None,
325 serde_rename: None,
326 serde_flatten: false,
327 binding_excluded: false,
328 binding_exclusion_reason: None,
329 }
330 }
331
332 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
333 EnumDef {
334 name: name.to_string(),
335 rust_path: format!("crate::{name}"),
336 original_rust_path: String::new(),
337 variants,
338 doc: String::new(),
339 cfg: None,
340 is_copy: false,
341 has_serde: true,
342 serde_tag: None,
343 serde_untagged: false,
344 serde_rename_all: None,
345 binding_excluded: false,
346 binding_exclusion_reason: None,
347 }
348 }
349
350 #[test]
351 fn gen_pyo3_data_enum_emits_string_methods() {
352 let generated = gen_pyo3_data_enum(
353 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
354 "core",
355 );
356
357 assert!(
358 generated.contains("fn __str__(&self) -> PyResult<String>"),
359 "{generated}"
360 );
361 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
362 assert!(
363 generated.contains("fn __repr__(&self) -> PyResult<String>"),
364 "{generated}"
365 );
366 }
367
368 #[test]
369 fn gen_pyo3_unit_enum_emits_string_methods() {
370 let cfg = RustBindingConfig {
371 struct_attrs: &[],
372 field_attrs: &[],
373 struct_derives: &[],
374 method_block_attr: None,
375 constructor_attr: "",
376 static_attr: None,
377 function_attr: "",
378 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
379 enum_derives: &["Clone", "PartialEq"],
380 needs_signature: false,
381 signature_prefix: "",
382 signature_suffix: "",
383 core_import: "core",
384 async_pattern: AsyncPattern::None,
385 has_serde: true,
386 type_name_prefix: "",
387 option_duration_on_defaults: false,
388 opaque_type_names: &[],
389 skip_impl_constructor: false,
390 cast_uints_to_i32: false,
391 cast_large_ints_to_f64: false,
392 named_non_opaque_params_by_ref: false,
393 lossy_skip_types: &[],
394 serializable_opaque_type_names: &[],
395 never_skip_cfg_field_names: &[],
396 };
397 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
398
399 assert!(
400 generated.contains("fn __str__(&self) -> PyResult<String>"),
401 "{generated}"
402 );
403 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
404 }
405}