1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
17 let name = &enum_def.name;
18 let core_path = format!("{core_import}::{name}");
19 let mut out = String::with_capacity(512);
20
21 writeln!(out, "#[derive(Clone)]").ok();
22 writeln!(out, "#[pyclass(frozen)]").ok();
23 writeln!(out, "pub struct {name} {{").ok();
24 writeln!(out, " pub(crate) inner: {core_path},").ok();
25 writeln!(out, "}}").ok();
26 writeln!(out).ok();
27
28 writeln!(out, "#[pymethods]").ok();
29 writeln!(out, "impl {name} {{").ok();
30 writeln!(out, " #[new]").ok();
31 writeln!(
32 out,
33 " fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
34 )
35 .ok();
36 writeln!(out, " let json_mod = py.import(\"json\")?;").ok();
37 writeln!(
38 out,
39 " let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
40 )
41 .ok();
42 writeln!(out, " let inner: {core_path} = serde_json::from_str(&json_str)").ok();
43 writeln!(
44 out,
45 " .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
46 )
47 .ok();
48 writeln!(out, " Ok(Self {{ inner }})").ok();
49 writeln!(out, " }}").ok();
50 writeln!(out, "}}").ok();
51 writeln!(out).ok();
52
53 writeln!(out, "impl From<{name}> for {core_path} {{").ok();
55 writeln!(out, " fn from(val: {name}) -> Self {{ val.inner }}").ok();
56 writeln!(out, "}}").ok();
57 writeln!(out).ok();
58
59 writeln!(out, "impl From<{core_path}> for {name} {{").ok();
61 writeln!(out, " fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
62 writeln!(out, "}}").ok();
63 writeln!(out).ok();
64
65 writeln!(out, "impl serde::Serialize for {name} {{").ok();
67 writeln!(
68 out,
69 " fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
70 )
71 .ok();
72 writeln!(out, " self.inner.serialize(serializer)").ok();
73 writeln!(out, " }}").ok();
74 writeln!(out, "}}").ok();
75 writeln!(out).ok();
76
77 writeln!(out, "impl Default for {name} {{").ok();
79 writeln!(
80 out,
81 " fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
82 )
83 .ok();
84 writeln!(out, "}}").ok();
85
86 out
87}
88
89const PYTHON_KEYWORDS: &[&str] = &[
93 "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
94 "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
95 "del", "global", "nonlocal", "lambda", "assert", "type",
96];
97
98pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
100 let mut out = String::with_capacity(512);
104 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
105 if cfg.has_serde {
106 derives.push("serde::Serialize");
107 }
108 if !derives.is_empty() {
109 writeln!(out, "#[derive({})]", derives.join(", ")).ok();
110 }
111 for attr in cfg.enum_attrs {
112 writeln!(out, "#[{attr}]").ok();
113 }
114 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
117 writeln!(out, "pub enum {} {{", enum_def.name).ok();
118 for (idx, variant) in enum_def.variants.iter().enumerate() {
119 if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
120 writeln!(out, " #[pyo3(name = \"{}_\")]", variant.name).ok();
121 }
122 writeln!(out, " {} = {idx},", variant.name).ok();
123 }
124 writeln!(out, "}}").ok();
125
126 if let Some(first) = enum_def.variants.first() {
129 writeln!(out).ok();
130 writeln!(out, "#[allow(clippy::derivable_impls)]").ok();
131 writeln!(out, "impl Default for {} {{", enum_def.name).ok();
132 writeln!(out, " fn default() -> Self {{ Self::{} }}", first.name).ok();
133 writeln!(out, "}}").ok();
134 }
135
136 out
137}