1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3
4pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
7 enum_def.variants.iter().any(|v| !v.fields.is_empty())
8}
9
10fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
20 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
21}
22
23pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
36 let name = &enum_def.name;
37 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
38 let has_sanitized = enum_has_sanitized_fields(enum_def);
39 let string_methods_content = crate::template_env::render(
40 "generators/enums/enum_string_methods.jinja",
41 minijinja::context! {
42 name => name,
43 value_expr => "&self.inner",
44 },
45 );
46
47 let mut variant_accessors = String::new();
48 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
49
50 let mut serde_tag_content = String::new();
51 if let Some(tag_field) = &enum_def.serde_tag {
52 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
53 }
54
55 crate::template_env::render(
56 "generators/enums/pyo3_data_enum.jinja",
57 minijinja::context! {
58 name => name,
59 core_path => core_path,
60 has_sanitized => has_sanitized,
61 string_methods_content => string_methods_content,
62 variant_accessors_content => variant_accessors,
63 serde_tag_content => serde_tag_content,
64 },
65 )
66}
67
68fn to_pyo3_screaming(name: &str) -> String {
76 use heck::ToShoutySnakeCase;
77 let chars: Vec<char> = name.chars().collect();
78 let upper_prefix_len = chars.iter().take_while(|c| c.is_uppercase()).count();
79 if upper_prefix_len >= 2 && chars[upper_prefix_len..].iter().all(|c| c.is_lowercase() || *c == '_') {
81 name.to_ascii_uppercase()
82 } else {
83 name.to_shouty_snake_case()
84 }
85}
86
87pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
89 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
93 derives.push("Default");
97 derives.push("serde::Serialize");
98 derives.push("serde::Deserialize");
99
100 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
104
105 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
109
110 let variants: Vec<_> = enum_def
111 .variants
112 .iter()
113 .enumerate()
114 .map(|(idx, v)| {
115 let pyo3_name = if is_pyo3 {
120 to_pyo3_screaming(&v.name)
121 } else {
122 String::new()
123 };
124 minijinja::context! {
125 name => v.name.clone(),
126 idx => idx,
127 is_default => idx == default_idx,
128 has_pyo3_rename => is_pyo3,
129 pyo3_name => pyo3_name,
130 serde_rename => v.serde_rename.clone().unwrap_or_default(),
131 }
132 })
133 .collect();
134
135 let string_methods = if is_pyo3 {
136 crate::template_env::render(
137 "generators/enums/enum_string_methods.jinja",
138 minijinja::context! {
139 name => enum_def.name,
140 value_expr => "self",
141 },
142 )
143 } else {
144 String::new()
145 };
146
147 crate::template_env::render(
148 "generators/enums/enum_definition.jinja",
149 minijinja::context! {
150 enum_name => enum_def.name,
151 derives => derives.join(", "),
152 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
153 enum_attrs => cfg.enum_attrs.to_vec(),
154 variants => variants,
155 is_pyo3 => is_pyo3,
156 string_methods => string_methods,
157 },
158 )
159}
160
161const RUST_KEYWORDS: &[&str] = &[
163 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
164 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
165 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
166 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
167];
168
169pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
173 use alef_core::ir::TypeRef;
174
175 for variant in &enum_def.variants {
176 let variant_name_lower = crate::naming::pascal_to_snake(&variant.name);
177 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
178 format!("r#{}", variant_name_lower)
179 } else {
180 variant_name_lower.clone()
181 };
182
183 if variant.fields.len() == 1 {
184 let field = &variant.fields[0];
185 let is_tuple_field = field
186 .name
187 .strip_prefix('_')
188 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
189 if is_tuple_field {
190 if let TypeRef::Named(inner_type_name) = &field.ty {
191 let variant_pascal = &variant.name;
192 let clone_expr = if field.is_boxed {
193 "(**data).clone().into()".to_string()
194 } else {
195 "data.clone().into()".to_string()
196 };
197 out.push('\n');
198 out.push_str(" #[getter]\n");
199 out.push_str(&crate::template_env::render(
200 "generators/enums/getter_accessor.jinja",
201 minijinja::context! {
202 fn_name => &fn_name,
203 inner_type_name => inner_type_name,
204 },
205 ));
206 out.push_str(" match &self.inner {\n");
207 out.push_str(&crate::template_env::render(
208 "generators/enums/match_variant.jinja",
209 minijinja::context! {
210 core_path => &core_path,
211 variant_pascal => variant_pascal,
212 clone_expr => &clone_expr,
213 },
214 ));
215 out.push_str(" _ => None,\n");
216 out.push_str(" }\n");
217 out.push_str(" }\n");
218 continue;
219 }
220 }
221 }
222
223 out.push('\n');
224 out.push_str(" #[getter]\n");
225 out.push_str(&crate::template_env::render(
226 "generators/enums/py_dict_getter.jinja",
227 minijinja::context! {
228 fn_name => &fn_name,
229 },
230 ));
231 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
232 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
233 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
234 out.push_str(&crate::template_env::render(
235 "generators/enums/tag_field_check.jinja",
236 minijinja::context! {
237 tag_field => tag_field,
238 },
239 ));
240 out.push_str(" let tag_value = json.get(tag_field)\n");
241 out.push_str(" .and_then(|v| v.as_str())\n");
242 out.push_str(" .unwrap_or(\"\");\n");
243 out.push_str(&crate::template_env::render(
244 "generators/enums/variant_tag_match.jinja",
245 minijinja::context! {
246 variant_name_lower => &variant_name_lower,
247 },
248 ));
249 out.push_str(" return Ok(None);\n");
250 out.push_str(" }\n");
251 out.push_str(" let json_str = json.to_string();\n");
252 out.push_str(" let json_mod = py.import(\"json\")?;\n");
253 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
254 out.push_str(" Ok(Some(py_dict.unbind()))\n");
255 out.push_str(" }\n");
256 }
257}
258
259pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
260 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
261 format!("r#{tag_field}")
262 } else {
263 tag_field.to_string()
264 };
265 out.push('\n');
266 out.push_str(" #[getter]\n");
267 out.push_str(&crate::template_env::render(
268 "generators/enums/tag_getter_header.jinja",
269 minijinja::context! {
270 fn_name => &fn_name,
271 },
272 ));
273 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
274 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
275 out.push_str(&crate::template_env::render(
276 "generators/enums/json_get_field.jinja",
277 minijinja::context! {
278 tag_field => tag_field,
279 },
280 ));
281 out.push_str(" .and_then(|v| v.as_str())\n");
282 out.push_str(" .map(String::from)\n");
283 out.push_str(&crate::template_env::render(
284 "generators/enums/json_get_error.jinja",
285 minijinja::context! {
286 tag_field => tag_field,
287 },
288 ));
289 out.push_str(" }\n");
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::generators::AsyncPattern;
296 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
297
298 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
299 EnumVariant {
300 name: name.to_string(),
301 fields,
302 doc: String::new(),
303 is_default: false,
304 serde_rename: None,
305 is_tuple: false,
306 }
307 }
308
309 fn field(name: &str) -> FieldDef {
310 FieldDef {
311 name: name.to_string(),
312 ty: TypeRef::String,
313 optional: false,
314 default: None,
315 doc: String::new(),
316 sanitized: false,
317 is_boxed: false,
318 type_rust_path: None,
319 cfg: None,
320 typed_default: None,
321 core_wrapper: CoreWrapper::None,
322 vec_inner_core_wrapper: CoreWrapper::None,
323 newtype_wrapper: None,
324 serde_rename: None,
325 serde_flatten: false,
326 binding_excluded: false,
327 binding_exclusion_reason: None,
328 original_type: None,
329 }
330 }
331
332 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
333 EnumDef {
334 name: name.to_string(),
335 rust_path: format!("crate::{name}"),
336 original_rust_path: String::new(),
337 variants,
338 doc: String::new(),
339 cfg: None,
340 is_copy: false,
341 has_serde: true,
342 serde_tag: None,
343 serde_untagged: false,
344 serde_rename_all: None,
345 binding_excluded: false,
346 binding_exclusion_reason: None,
347 }
348 }
349
350 #[test]
351 fn gen_pyo3_data_enum_emits_string_methods() {
352 let generated = gen_pyo3_data_enum(
353 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
354 "core",
355 );
356
357 assert!(
358 generated.contains("fn __str__(&self) -> PyResult<String>"),
359 "{generated}"
360 );
361 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
362 assert!(
363 generated.contains("fn __repr__(&self) -> PyResult<String>"),
364 "{generated}"
365 );
366 }
367
368 #[test]
369 fn gen_pyo3_unit_enum_emits_string_methods() {
370 let cfg = RustBindingConfig {
371 struct_attrs: &[],
372 field_attrs: &[],
373 struct_derives: &[],
374 method_block_attr: None,
375 constructor_attr: "",
376 static_attr: None,
377 function_attr: "",
378 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
379 enum_derives: &["Clone", "PartialEq"],
380 needs_signature: false,
381 signature_prefix: "",
382 signature_suffix: "",
383 core_import: "core",
384 async_pattern: AsyncPattern::None,
385 has_serde: true,
386 type_name_prefix: "",
387 option_duration_on_defaults: false,
388 opaque_type_names: &[],
389 skip_impl_constructor: false,
390 cast_uints_to_i32: false,
391 cast_large_ints_to_f64: false,
392 named_non_opaque_params_by_ref: false,
393 lossy_skip_types: &[],
394 serializable_opaque_type_names: &[],
395 never_skip_cfg_field_names: &[],
396 };
397 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
398
399 assert!(
400 generated.contains("fn __str__(&self) -> PyResult<String>"),
401 "{generated}"
402 );
403 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
404 }
405}