1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3
4pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
7 enum_def.variants.iter().any(|v| !v.fields.is_empty())
8}
9
10fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
20 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
21}
22
23pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
36 let name = &enum_def.name;
37 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
38 let has_sanitized = enum_has_sanitized_fields(enum_def);
39 let string_methods_content = crate::template_env::render(
40 "generators/enums/enum_string_methods.jinja",
41 minijinja::context! {
42 name => name,
43 value_expr => "&self.inner",
44 },
45 );
46
47 let mut variant_accessors = String::new();
48 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
49
50 let mut serde_tag_content = String::new();
51 if let Some(tag_field) = &enum_def.serde_tag {
52 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
53 }
54
55 crate::template_env::render(
56 "generators/enums/pyo3_data_enum.jinja",
57 minijinja::context! {
58 name => name,
59 core_path => core_path,
60 has_sanitized => has_sanitized,
61 string_methods_content => string_methods_content,
62 variant_accessors_content => variant_accessors,
63 serde_tag_content => serde_tag_content,
64 },
65 )
66}
67
68fn to_pyo3_screaming(name: &str) -> String {
76 use heck::ToShoutySnakeCase;
77 let chars: Vec<char> = name.chars().collect();
78 let upper_prefix_len = chars.iter().take_while(|c| c.is_uppercase()).count();
79 if upper_prefix_len >= 2 && chars[upper_prefix_len..].iter().all(|c| c.is_lowercase() || *c == '_') {
81 name.to_ascii_uppercase()
82 } else {
83 name.to_shouty_snake_case()
84 }
85}
86
87pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
89 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
93 derives.push("Default");
97 derives.push("serde::Serialize");
98 derives.push("serde::Deserialize");
99
100 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
104
105 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
109
110 let variants: Vec<_> = enum_def
111 .variants
112 .iter()
113 .enumerate()
114 .map(|(idx, v)| {
115 let pyo3_name = if is_pyo3 {
120 to_pyo3_screaming(&v.name)
121 } else {
122 String::new()
123 };
124 minijinja::context! {
125 name => v.name.clone(),
126 idx => idx,
127 is_default => idx == default_idx,
128 has_pyo3_rename => is_pyo3,
129 pyo3_name => pyo3_name,
130 serde_rename => v.serde_rename.clone().unwrap_or_default(),
131 }
132 })
133 .collect();
134
135 let string_methods = if is_pyo3 {
136 crate::template_env::render(
137 "generators/enums/enum_string_methods.jinja",
138 minijinja::context! {
139 name => enum_def.name,
140 value_expr => "self",
141 },
142 )
143 } else {
144 String::new()
145 };
146
147 crate::template_env::render(
148 "generators/enums/enum_definition.jinja",
149 minijinja::context! {
150 enum_name => enum_def.name,
151 derives => derives.join(", "),
152 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
153 enum_attrs => cfg.enum_attrs.to_vec(),
154 variants => variants,
155 is_pyo3 => is_pyo3,
156 string_methods => string_methods,
157 },
158 )
159}
160
161const RUST_KEYWORDS: &[&str] = &[
163 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
164 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
165 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
166 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
167];
168
169pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
173 use alef_core::ir::TypeRef;
174
175 for variant in &enum_def.variants {
176 let variant_name_lower = crate::naming::pascal_to_snake(&variant.name);
177 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
178 format!("r#{}", variant_name_lower)
179 } else {
180 variant_name_lower.clone()
181 };
182
183 if variant.fields.len() == 1 {
184 let field = &variant.fields[0];
185 let is_tuple_field = field
186 .name
187 .strip_prefix('_')
188 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
189 if is_tuple_field {
190 if let TypeRef::Named(inner_type_name) = &field.ty {
191 let variant_pascal = &variant.name;
192 let clone_expr = if field.is_boxed {
193 "(**data).clone().into()".to_string()
194 } else {
195 "data.clone().into()".to_string()
196 };
197 out.push('\n');
198 out.push_str(" #[getter]\n");
199 out.push_str(&crate::template_env::render(
200 "generators/enums/getter_accessor.jinja",
201 minijinja::context! {
202 fn_name => &fn_name,
203 inner_type_name => inner_type_name,
204 },
205 ));
206 out.push_str(" match &self.inner {\n");
207 out.push_str(&crate::template_env::render(
208 "generators/enums/match_variant.jinja",
209 minijinja::context! {
210 core_path => &core_path,
211 variant_pascal => variant_pascal,
212 clone_expr => &clone_expr,
213 },
214 ));
215 out.push_str(" _ => None,\n");
216 out.push_str(" }\n");
217 out.push_str(" }\n");
218 continue;
219 }
220 }
221 }
222
223 out.push('\n');
224 out.push_str(" #[getter]\n");
225 out.push_str(&crate::template_env::render(
226 "generators/enums/py_dict_getter.jinja",
227 minijinja::context! {
228 fn_name => &fn_name,
229 },
230 ));
231 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
232 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
233 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
234 out.push_str(&crate::template_env::render(
235 "generators/enums/tag_field_check.jinja",
236 minijinja::context! {
237 tag_field => tag_field,
238 },
239 ));
240 out.push_str(" let tag_value = json.get(tag_field)\n");
241 out.push_str(" .and_then(|v| v.as_str())\n");
242 out.push_str(" .unwrap_or(\"\");\n");
243 out.push_str(&crate::template_env::render(
244 "generators/enums/variant_tag_match.jinja",
245 minijinja::context! {
246 variant_name_lower => &variant_name_lower,
247 },
248 ));
249 out.push_str(" return Ok(None);\n");
250 out.push_str(" }\n");
251 out.push_str(" let json_str = json.to_string();\n");
252 out.push_str(" let json_mod = py.import(\"json\")?;\n");
253 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
254 out.push_str(" Ok(Some(py_dict.unbind()))\n");
255 out.push_str(" }\n");
256 }
257}
258
259pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
260 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
261 format!("r#{tag_field}")
262 } else {
263 tag_field.to_string()
264 };
265 out.push('\n');
266 out.push_str(" #[getter]\n");
267 out.push_str(&crate::template_env::render(
268 "generators/enums/tag_getter_header.jinja",
269 minijinja::context! {
270 fn_name => &fn_name,
271 },
272 ));
273 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
274 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
275 out.push_str(&crate::template_env::render(
276 "generators/enums/json_get_field.jinja",
277 minijinja::context! {
278 tag_field => tag_field,
279 },
280 ));
281 out.push_str(" .and_then(|v| v.as_str())\n");
282 out.push_str(" .map(String::from)\n");
283 out.push_str(&crate::template_env::render(
284 "generators/enums/json_get_error.jinja",
285 minijinja::context! {
286 tag_field => tag_field,
287 },
288 ));
289 out.push_str(" }\n");
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::generators::AsyncPattern;
296 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
297
298 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
299 EnumVariant {
300 name: name.to_string(),
301 fields,
302 doc: String::new(),
303 is_default: false,
304 serde_rename: None,
305 is_tuple: false,
306 }
307 }
308
309 fn field(name: &str) -> FieldDef {
310 FieldDef {
311 name: name.to_string(),
312 ty: TypeRef::String,
313 optional: false,
314 default: None,
315 doc: String::new(),
316 sanitized: false,
317 is_boxed: false,
318 type_rust_path: None,
319 cfg: None,
320 typed_default: None,
321 core_wrapper: CoreWrapper::None,
322 vec_inner_core_wrapper: CoreWrapper::None,
323 newtype_wrapper: None,
324 serde_rename: None,
325 serde_flatten: false,
326 binding_excluded: false,
327 binding_exclusion_reason: None,
328 }
329 }
330
331 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
332 EnumDef {
333 name: name.to_string(),
334 rust_path: format!("crate::{name}"),
335 original_rust_path: String::new(),
336 variants,
337 doc: String::new(),
338 cfg: None,
339 is_copy: false,
340 has_serde: true,
341 serde_tag: None,
342 serde_untagged: false,
343 serde_rename_all: None,
344 binding_excluded: false,
345 binding_exclusion_reason: None,
346 }
347 }
348
349 #[test]
350 fn gen_pyo3_data_enum_emits_string_methods() {
351 let generated = gen_pyo3_data_enum(
352 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
353 "core",
354 );
355
356 assert!(
357 generated.contains("fn __str__(&self) -> PyResult<String>"),
358 "{generated}"
359 );
360 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
361 assert!(
362 generated.contains("fn __repr__(&self) -> PyResult<String>"),
363 "{generated}"
364 );
365 }
366
367 #[test]
368 fn gen_pyo3_unit_enum_emits_string_methods() {
369 let cfg = RustBindingConfig {
370 struct_attrs: &[],
371 field_attrs: &[],
372 struct_derives: &[],
373 method_block_attr: None,
374 constructor_attr: "",
375 static_attr: None,
376 function_attr: "",
377 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
378 enum_derives: &["Clone", "PartialEq"],
379 needs_signature: false,
380 signature_prefix: "",
381 signature_suffix: "",
382 core_import: "core",
383 async_pattern: AsyncPattern::None,
384 has_serde: true,
385 type_name_prefix: "",
386 option_duration_on_defaults: false,
387 opaque_type_names: &[],
388 skip_impl_constructor: false,
389 cast_uints_to_i32: false,
390 cast_large_ints_to_f64: false,
391 named_non_opaque_params_by_ref: false,
392 lossy_skip_types: &[],
393 serializable_opaque_type_names: &[],
394 never_skip_cfg_field_names: &[],
395 };
396 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
397
398 assert!(
399 generated.contains("fn __str__(&self) -> PyResult<String>"),
400 "{generated}"
401 );
402 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
403 }
404}