1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use alef_core::keywords::PYTHON_KEYWORDS;
4use std::fmt::Write;
5
6pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
9 enum_def.variants.iter().any(|v| !v.fields.is_empty())
10}
11
12fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
22 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
23}
24
25pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
38 let name = &enum_def.name;
39 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
40 let has_sanitized = enum_has_sanitized_fields(enum_def);
41 let mut out = String::with_capacity(512);
42
43 writeln!(out, "#[derive(Clone)]").ok();
44 writeln!(out, "#[pyclass(frozen)]").ok();
45 writeln!(out, "pub struct {name} {{").ok();
46 writeln!(out, " pub(crate) inner: {core_path},").ok();
47 writeln!(out, "}}").ok();
48 writeln!(out).ok();
49
50 writeln!(out, "#[pymethods]").ok();
51 writeln!(out, "impl {name} {{").ok();
52 if has_sanitized {
53 write_pyo3_enum_string_methods(&mut out, name, "&self.inner");
57 write_pyo3_variant_accessors(&mut out, enum_def, &core_path);
58 if let Some(tag_field) = &enum_def.serde_tag {
59 write_pyo3_serde_tag_getter(&mut out, tag_field);
60 }
61 writeln!(out, "}}").ok();
62 } else {
63 writeln!(out, " #[new]").ok();
64 writeln!(
65 out,
66 " fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyAny>) -> PyResult<Self> {{"
67 )
68 .ok();
69 writeln!(
70 out,
71 " // Accept either a Python dict (full tagged-union shape) or a string"
72 )
73 .ok();
74 writeln!(
75 out,
76 " // (the unit variant name). Strings are wrapped in `\"...\"` so serde_json"
77 )
78 .ok();
79 writeln!(
80 out,
81 " // can deserialize into a unit-variant of the tagged enum."
82 )
83 .ok();
84 writeln!(
85 out,
86 " let json_str: String = if let Ok(s) = value.extract::<String>() {{"
87 )
88 .ok();
89 writeln!(
90 out,
91 " serde_json::to_string(&s).map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?"
92 )
93 .ok();
94 writeln!(out, " }} else {{").ok();
95 writeln!(out, " let json_mod = py.import(\"json\")?;").ok();
96 writeln!(
97 out,
98 " json_mod.call_method1(\"dumps\", (value,))?.extract()?"
99 )
100 .ok();
101 writeln!(out, " }};").ok();
102 writeln!(out, " let inner: {core_path} = serde_json::from_str(&json_str)").ok();
103 writeln!(
104 out,
105 " .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
106 )
107 .ok();
108 writeln!(out, " Ok(Self {{ inner }})").ok();
109 writeln!(out, " }}").ok();
110 write_pyo3_enum_string_methods(&mut out, name, "&self.inner");
111 write_pyo3_variant_accessors(&mut out, enum_def, &core_path);
112 if let Some(tag_field) = &enum_def.serde_tag {
113 write_pyo3_serde_tag_getter(&mut out, tag_field);
114 }
115 writeln!(out, "}}").ok();
116 }
117 writeln!(out).ok();
118
119 writeln!(out, "impl From<{name}> for {core_path} {{").ok();
121 writeln!(out, " fn from(val: {name}) -> Self {{ val.inner }}").ok();
122 writeln!(out, "}}").ok();
123 writeln!(out).ok();
124
125 writeln!(out, "impl From<{core_path}> for {name} {{").ok();
127 writeln!(out, " fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
128 writeln!(out, "}}").ok();
129 writeln!(out).ok();
130
131 writeln!(out, "impl serde::Serialize for {name} {{").ok();
134 writeln!(
135 out,
136 " fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
137 )
138 .ok();
139 writeln!(out, " self.inner.serialize(serializer)").ok();
140 writeln!(out, " }}").ok();
141 writeln!(out, "}}").ok();
142 writeln!(out).ok();
143
144 writeln!(out, "impl Default for {name} {{").ok();
147 writeln!(
148 out,
149 " fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
150 )
151 .ok();
152 writeln!(out, "}}").ok();
153 writeln!(out).ok();
154
155 writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
158 writeln!(
159 out,
160 " fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
161 )
162 .ok();
163 writeln!(out, " let inner = {core_path}::deserialize(deserializer)?;").ok();
164 writeln!(out, " Ok(Self {{ inner }})").ok();
165 writeln!(out, " }}").ok();
166 writeln!(out, "}}").ok();
167
168 out
169}
170
171pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
173 let mut out = String::with_capacity(512);
177 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
178 derives.push("Default");
182 derives.push("serde::Serialize");
183 derives.push("serde::Deserialize");
184 if !derives.is_empty() {
185 writeln!(out, "#[derive({})]", derives.join(", ")).ok();
186 }
187 if let Some(rename_all) = &enum_def.serde_rename_all {
188 writeln!(out, "#[serde(rename_all = \"{rename_all}\")]").ok();
189 }
190 for attr in cfg.enum_attrs {
191 writeln!(out, "#[{attr}]").ok();
192 }
193 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
196 writeln!(out, "pub enum {} {{", enum_def.name).ok();
197 let default_idx = enum_def.variants.iter().position(|v| v.is_default).unwrap_or(0);
201 for (idx, variant) in enum_def.variants.iter().enumerate() {
202 if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
203 writeln!(out, " #[pyo3(name = \"{}_\")]", variant.name).ok();
204 }
205 if idx == default_idx {
207 writeln!(out, " #[default]").ok();
208 }
209 writeln!(out, " {} = {idx},", variant.name).ok();
210 }
211 writeln!(out, "}}").ok();
212 if is_pyo3 {
213 writeln!(out).ok();
214 writeln!(out, "#[pymethods]").ok();
215 writeln!(out, "impl {} {{", enum_def.name).ok();
216 write_pyo3_enum_string_methods(&mut out, &enum_def.name, "self");
217 writeln!(out, "}}").ok();
218 }
219
220 out
221}
222
223const RUST_KEYWORDS: &[&str] = &[
225 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate", "do", "dyn", "else",
226 "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop", "macro", "match", "mod",
227 "move", "mut", "override", "priv", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
228 "true", "try", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
229];
230
231fn write_pyo3_variant_accessors(out: &mut String, enum_def: &EnumDef, core_path: &str) {
235 use alef_core::ir::TypeRef;
236 use heck::ToSnakeCase;
237
238 for variant in &enum_def.variants {
239 let variant_name_lower = variant.name.to_snake_case();
240 let fn_name = if RUST_KEYWORDS.contains(&variant_name_lower.as_str()) {
242 format!("r#{}", variant_name_lower)
243 } else {
244 variant_name_lower.clone()
245 };
246
247 if variant.fields.len() == 1 {
249 let field = &variant.fields[0];
250 let is_tuple_field = field
251 .name
252 .strip_prefix('_')
253 .is_some_and(|s| s.chars().all(|c| c.is_ascii_digit()));
254 if is_tuple_field {
255 if let TypeRef::Named(inner_type_name) = &field.ty {
256 let variant_pascal = &variant.name;
257 writeln!(out).ok();
258 writeln!(out, " #[getter]").ok();
259 writeln!(out, " fn {fn_name}(&self) -> Option<{inner_type_name}> {{").ok();
260 writeln!(out, " match &self.inner {{").ok();
261 let clone_expr = if field.is_boxed {
265 "(**data).clone().into()".to_string()
266 } else {
267 "data.clone().into()".to_string()
268 };
269 writeln!(
270 out,
271 " {core_path}::{variant_pascal}(data) => Some({clone_expr}),"
272 )
273 .ok();
274 writeln!(out, " _ => None,").ok();
275 writeln!(out, " }}").ok();
276 writeln!(out, " }}").ok();
277 continue;
278 }
279 }
280 }
281
282 writeln!(out).ok();
284 writeln!(out, " #[getter]").ok();
285 writeln!(
286 out,
287 " fn {fn_name}(&self, py: Python<'_>) -> PyResult<Option<pyo3::Py<pyo3::types::PyDict>>> {{"
288 )
289 .ok();
290 writeln!(out, " // Serialize to JSON first").ok();
291 writeln!(out, " let json = serde_json::to_value(&self.inner)").ok();
292 writeln!(
293 out,
294 " .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;"
295 )
296 .ok();
297 writeln!(out, " // Check the tag field to see if this variant is active").ok();
298 writeln!(
299 out,
300 " let tag_field = \"{}\";",
301 enum_def.serde_tag.as_ref().unwrap_or(&"tag".to_string())
302 )
303 .ok();
304 writeln!(out, " let tag_value = json.get(tag_field)").ok();
305 writeln!(out, " .and_then(|v| v.as_str())").ok();
306 writeln!(out, " .unwrap_or(\"\");").ok();
307 writeln!(out, " if tag_value != \"{}\" {{", variant_name_lower).ok();
308 writeln!(out, " return Ok(None);").ok();
309 writeln!(out, " }}").ok();
310 writeln!(out, " // Create a Python dict from the JSON").ok();
311 writeln!(out, " let json_str = json.to_string();").ok();
312 writeln!(out, " let json_mod = py.import(\"json\")?;").ok();
313 writeln!(
314 out,
315 " let py_dict = json_mod.call_method1(\"loads\", (&json_str,))?.downcast_into::<pyo3::types::PyDict>()?;"
316 )
317 .ok();
318 writeln!(out, " Ok(Some(py_dict.unbind()))").ok();
319 writeln!(out, " }}").ok();
320 }
321}
322
323fn write_pyo3_serde_tag_getter(out: &mut String, tag_field: &str) {
324 let fn_name = if RUST_KEYWORDS.contains(&tag_field) {
327 format!("r#{tag_field}")
328 } else {
329 tag_field.to_string()
330 };
331 writeln!(out).ok();
332 writeln!(out, " #[getter]").ok();
333 writeln!(out, " fn {fn_name}(&self) -> pyo3::PyResult<String> {{").ok();
334 writeln!(out, " let json = serde_json::to_value(&self.inner)").ok();
335 writeln!(
336 out,
337 " .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;"
338 )
339 .ok();
340 writeln!(out, " json.get(\"{tag_field}\")").ok();
341 writeln!(out, " .and_then(|v| v.as_str())").ok();
342 writeln!(out, " .map(String::from)").ok();
343 writeln!(
344 out,
345 " .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err(\"{tag_field} not found in serialized enum\"))"
346 )
347 .ok();
348 writeln!(out, " }}").ok();
349}
350
351fn write_pyo3_enum_string_methods(out: &mut String, name: &str, value_expr: &str) {
352 writeln!(out).ok();
353 writeln!(out, " fn __str__(&self) -> PyResult<String> {{").ok();
354 writeln!(
355 out,
356 " serde_json::to_value({value_expr})\n .map(|value| match value {{\n serde_json::Value::String(value) => value,\n other => other.to_string(),\n }})\n .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(\"Failed to serialize {name}: {{e}}\")))"
357 )
358 .ok();
359 writeln!(out, " }}").ok();
360 writeln!(out).ok();
361 writeln!(out, " fn __repr__(&self) -> PyResult<String> {{").ok();
362 writeln!(out, " self.__str__()").ok();
363 writeln!(out, " }}").ok();
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::generators::AsyncPattern;
370 use alef_core::ir::{CoreWrapper, EnumVariant, FieldDef, TypeRef};
371
372 fn variant(name: &str, fields: Vec<FieldDef>) -> EnumVariant {
373 EnumVariant {
374 name: name.to_string(),
375 fields,
376 doc: String::new(),
377 is_default: false,
378 serde_rename: None,
379 is_tuple: false,
380 }
381 }
382
383 fn field(name: &str) -> FieldDef {
384 FieldDef {
385 name: name.to_string(),
386 ty: TypeRef::String,
387 optional: false,
388 default: None,
389 doc: String::new(),
390 sanitized: false,
391 is_boxed: false,
392 type_rust_path: None,
393 cfg: None,
394 typed_default: None,
395 core_wrapper: CoreWrapper::None,
396 vec_inner_core_wrapper: CoreWrapper::None,
397 newtype_wrapper: None,
398 }
399 }
400
401 fn enum_def(name: &str, variants: Vec<EnumVariant>) -> EnumDef {
402 EnumDef {
403 name: name.to_string(),
404 rust_path: format!("crate::{name}"),
405 original_rust_path: String::new(),
406 variants,
407 doc: String::new(),
408 cfg: None,
409 is_copy: false,
410 has_serde: true,
411 serde_tag: None,
412 serde_rename_all: None,
413 }
414 }
415
416 #[test]
417 fn gen_pyo3_data_enum_emits_string_methods() {
418 let generated = gen_pyo3_data_enum(
419 &enum_def("StructureKind", vec![variant("Other", vec![field("value")])]),
420 "core",
421 );
422
423 assert!(
424 generated.contains("fn __str__(&self) -> PyResult<String>"),
425 "{generated}"
426 );
427 assert!(generated.contains("serde_json::to_value(&self.inner)"), "{generated}");
428 assert!(
429 generated.contains("fn __repr__(&self) -> PyResult<String>"),
430 "{generated}"
431 );
432 }
433
434 #[test]
435 fn gen_pyo3_unit_enum_emits_string_methods() {
436 let cfg = RustBindingConfig {
437 struct_attrs: &[],
438 field_attrs: &[],
439 struct_derives: &[],
440 method_block_attr: None,
441 constructor_attr: "",
442 static_attr: None,
443 function_attr: "",
444 enum_attrs: &["pyclass(eq, eq_int, from_py_object)"],
445 enum_derives: &["Clone", "PartialEq"],
446 needs_signature: false,
447 signature_prefix: "",
448 signature_suffix: "",
449 core_import: "core",
450 async_pattern: AsyncPattern::None,
451 has_serde: true,
452 type_name_prefix: "",
453 option_duration_on_defaults: false,
454 opaque_type_names: &[],
455 skip_impl_constructor: false,
456 cast_uints_to_i32: false,
457 cast_large_ints_to_f64: false,
458 named_non_opaque_params_by_ref: false,
459 lossy_skip_types: &[],
460 };
461 let generated = gen_enum(&enum_def("StructureKind", vec![variant("Function", Vec::new())]), &cfg);
462
463 assert!(
464 generated.contains("fn __str__(&self) -> PyResult<String>"),
465 "{generated}"
466 );
467 assert!(generated.contains("serde_json::to_value(self)"), "{generated}");
468 }
469}