1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37 let name = &enum_def.name;
38 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39 let has_sanitized = enum_has_sanitized_fields(enum_def);
40 let string_methods_content = crate::template_env::render(
41 "generators/enums/enum_string_methods.jinja",
42 minijinja::context! {
43 name => name,
44 value_expr => "&self.inner",
45 },
46 );
47
48 let mut variant_accessors = String::new();
49 write_pyo3_variant_accessors(&mut variant_accessors, enum_def, &core_path);
50
51 let mut serde_tag_content = String::new();
52 if let Some(tag_field) = &enum_def.serde_tag {
53 write_pyo3_serde_tag_getter(&mut serde_tag_content, tag_field);
54 }
55
56 crate::template_env::render(
57 "generators/enums/pyo3_data_enum.jinja",
58 minijinja::context! {
59 name => name,
60 core_path => core_path,
61 has_sanitized => has_sanitized,
62 string_methods_content => string_methods_content,
63 variant_accessors_content => variant_accessors,
64 serde_tag_content => serde_tag_content,
65 },
66 )
67}
68
69pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
71 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
75 derives.push("Default");
79 derives.push("serde::Serialize");
80 derives.push("serde::Deserialize");
81
82 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
85
86 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
90
91 let variants: Vec<_> = enum_def
92 .variants
93 .iter()
94 .enumerate()
95 .map(|(idx, v)| {
96 minijinja::context! {
97 name => v.name.clone(),
98 idx => idx,
99 is_default => idx == default_idx,
100 has_pyo3_rename => is_pyo3 && PYTHON_KEYWORDS.contains(&v.name.as_str()),
101 serde_rename => v.serde_rename.clone().unwrap_or_default(),
102 }
103 })
104 .collect();
105
106 let string_methods = if is_pyo3 {
107 crate::template_env::render(
108 "generators/enums/enum_string_methods.jinja",
109 minijinja::context! {
110 name => enum_def.name,
111 value_expr => "self",
112 },
113 )
114 } else {
115 String::new()
116 };
117
118 crate::template_env::render(
119 "generators/enums/enum_definition.jinja",
120 minijinja::context! {
121 enum_name => enum_def.name,
122 derives => derives.join(", "),
123 serde_rename_all => enum_def.serde_rename_all.as_deref().unwrap_or(""),
124 enum_attrs => cfg.enum_attrs.to_vec(),
125 variants => variants,
126 is_pyo3 => is_pyo3,
127 string_methods => string_methods,
128 },
129 )
130}
131
132const RUST_KEYWORDS: &[&str] = &[
134 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
135 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
136 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
137 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
138];
139
140pub(crate) fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
144 use alef_core::ir::TypeRef;
145 use heck::ToSnakeCase;
146
147 for variant in &enum_def.variants {
148 let variant_name_lower = variant.name.to_snake_case();
149 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
150 format!("r#{}", variant_name_lower)
151 } else {
152 variant_name_lower.clone()
153 };
154
155 if variant.fields.len() == 1 {
156 let field = &variant.fields[0];
157 let is_tuple_field = field
158 .name
159 .strip_prefix('_')
160 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
161 if is_tuple_field {
162 if let TypeRef::Named(inner_type_name) = &field.ty {
163 let variant_pascal = &variant.name;
164 let clone_expr = if field.is_boxed {
165 "(**data).clone().into()".to_string()
166 } else {
167 "data.clone().into()".to_string()
168 };
169 out.push('\n');
170 out.push_str(" #[getter]\n");
171 out.push_str(&crate::template_env::render(
172 "generators/enums/getter_accessor.jinja",
173 minijinja::context! {
174 fn_name => &fn_name,
175 inner_type_name => inner_type_name,
176 },
177 ));
178 out.push_str(" match &self.inner {\n");
179 out.push_str(&crate::template_env::render(
180 "generators/enums/match_variant.jinja",
181 minijinja::context! {
182 core_path => &core_path,
183 variant_pascal => variant_pascal,
184 clone_expr => &clone_expr,
185 },
186 ));
187 out.push_str(" _ => None,\n");
188 out.push_str(" }\n");
189 out.push_str(" }\n");
190 continue;
191 }
192 }
193 }
194
195 out.push('\n');
196 out.push_str(" #[getter]\n");
197 out.push_str(&crate::template_env::render(
198 "generators/enums/py_dict_getter.jinja",
199 minijinja::context! {
200 fn_name => &fn_name,
201 },
202 ));
203 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
204 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
205 let tag_field = enum_def.serde_tag.as_deref().unwrap_or("tag");
206 out.push_str(&crate::template_env::render(
207 "generators/enums/tag_field_check.jinja",
208 minijinja::context! {
209 tag_field => tag_field,
210 },
211 ));
212 out.push_str(" let tag_value = json.get(tag_field)\n");
213 out.push_str(" .and_then(|v| v.as_str())\n");
214 out.push_str(" .unwrap_or(\"\");\n");
215 out.push_str(&crate::template_env::render(
216 "generators/enums/variant_tag_match.jinja",
217 minijinja::context! {
218 variant_name_lower => &variant_name_lower,
219 },
220 ));
221 out.push_str(" return Ok(None);\n");
222 out.push_str(" }\n");
223 out.push_str(" let json_str = json.to_string();\n");
224 out.push_str(" let json_mod = py.import(\"json\")?;\n");
225 out.push_str(" let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;\n");
226 out.push_str(" Ok(Some(py_dict.unbind()))\n");
227 out.push_str(" }\n");
228 }
229}
230
231pub(crate) fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
232 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
233 format!("r#{tag_field}")
234 } else {
235 tag_field.to_string()
236 };
237 out.push('\n');
238 out.push_str(" #[getter]\n");
239 out.push_str(&crate::template_env::render(
240 "generators/enums/tag_getter_header.jinja",
241 minijinja::context! {
242 fn_name => &fn_name,
243 },
244 ));
245 out.push_str(" let json = serde_json::to_value(&self.inner)\n");
246 out.push_str(" .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;\n");
247 out.push_str(&crate::template_env::render(
248 "generators/enums/json_get_field.jinja",
249 minijinja::context! {
250 tag_field => tag_field,
251 },
252 ));
253 out.push_str(" .and_then(|v| v.as_str())\n");
254 out.push_str(" .map(String::from)\n");
255 out.push_str(&crate::template_env::render(
256 "generators/enums/json_get_error.jinja",
257 minijinja::context! {
258 tag_field => tag_field,
259 },
260 ));
261 out.push_str(" }\n");
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::generators::AsyncPattern;
268 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
269
270 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
271 EnumVariant {
272 name: name.to_string(),
273 fields,
274 doc: String::new(),
275 is_default: false,
276 serde_rename: None,
277 is_tuple: false,
278 }
279 }
280
281 fn field(name: &str) -> FieldDef {
282 FieldDef {
283 name: name.to_string(),
284 ty: TypeRef::String,
285 optional: false,
286 default: None,
287 doc: String::new(),
288 sanitized: false,
289 is_boxed: false,
290 type_rust_path: None,
291 cfg: None,
292 typed_default: None,
293 core_wrapper: CoreWrapper::None,
294 vec_inner_core_wrapper: CoreWrapper::None,
295 newtype_wrapper: None,
296 serde_rename: None,
297 serde_flatten: false,
298 }
299 }
300
301 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
302 EnumDef {
303 name: name.to_string(),
304 rust_path: format!("crate::{name}"),
305 original_rust_path: String::new(),
306 variants,
307 doc: String::new(),
308 cfg: None,
309 is_copy: false,
310 has_serde: true,
311 serde_tag: None,
312 serde_untagged: false,
313 serde_rename_all: None,
314 }
315 }
316
317 #[test]
318 fn gen_pyo3_data_enum_emits_string_methods() {
319 let generated = gen_pyo3_data_enum(
320 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
321 "core",
322 );
323
324 assert!(
325 generated.contains("fn __str__(&self) -> PyResult<String>"),
326 "{generated}"
327 );
328 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
329 assert!(
330 generated.contains("fn __repr__(&self) -> PyResult<String>"),
331 "{generated}"
332 );
333 }
334
335 #[test]
336 fn gen_pyo3_unit_enum_emits_string_methods() {
337 let cfg = RustBindingConfig {
338 struct_attrs: &[],
339 field_attrs: &[],
340 struct_derives: &[],
341 method_block_attr: None,
342 constructor_attr: "",
343 static_attr: None,
344 function_attr: "",
345 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
346 enum_derives: &["Clone", "PartialEq"],
347 needs_signature: false,
348 signature_prefix: "",
349 signature_suffix: "",
350 core_import: "core",
351 async_pattern: AsyncPattern::None,
352 has_serde: true,
353 type_name_prefix: "",
354 option_duration_on_defaults: false,
355 opaque_type_names: &[],
356 skip_impl_constructor: false,
357 cast_uints_to_i32: false,
358 cast_large_ints_to_f64: false,
359 named_non_opaque_params_by_ref: false,
360 lossy_skip_types: &[],
361 serializable_opaque_type_names: &[],
362 };
363 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
364
365 assert!(
366 generated.contains("fn __str__(&self) -> PyResult<String>"),
367 "{generated}"
368 );
369 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
370 }
371}