1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
21 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
22}
23
24pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
37 let name = &enum_def.name;
38 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
39 let has_sanitized = enum_has_sanitized_fields(enum_def);
40 let mut out = String::with_capacity(512);
41
42 writeln!(out, "#[derive(Clone)]").ok();
43 writeln!(out, "#[pyclass(frozen)]").ok();
44 writeln!(out, "pub struct {name} {{").ok();
45 writeln!(out, " pub(crate) inner: {core_path},").ok();
46 writeln!(out, "}}").ok();
47 writeln!(out).ok();
48
49 writeln!(out, "#[pymethods]").ok();
50 writeln!(out, "impl {name} {{").ok();
51 if has_sanitized {
52 writeln!(out, "}}").ok();
56 } else {
57 writeln!(out, " #[new]").ok();
58 writeln!(
59 out,
60 " fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
61 )
62 .ok();
63 writeln!(out, " let json_mod = py.import(\"json\")?;").ok();
64 writeln!(
65 out,
66 " let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
67 )
68 .ok();
69 writeln!(out, " let inner: {core_path} = serde_json::from_str(&json_str)").ok();
70 writeln!(
71 out,
72 " .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
73 )
74 .ok();
75 writeln!(out, " Ok(Self {{ inner }})").ok();
76 writeln!(out, " }}").ok();
77 writeln!(out, "}}").ok();
78 }
79 writeln!(out).ok();
80
81 writeln!(out, "impl From<{name}> for {core_path} {{").ok();
83 writeln!(out, " fn from(val: {name}) -> Self {{ val.inner }}").ok();
84 writeln!(out, "}}").ok();
85 writeln!(out).ok();
86
87 writeln!(out, "impl From<{core_path}> for {name} {{").ok();
89 writeln!(out, " fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
90 writeln!(out, "}}").ok();
91 writeln!(out).ok();
92
93 writeln!(out, "impl serde::Serialize for {name} {{").ok();
96 writeln!(
97 out,
98 " fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
99 )
100 .ok();
101 writeln!(out, " self.inner.serialize(serializer)").ok();
102 writeln!(out, " }}").ok();
103 writeln!(out, "}}").ok();
104 writeln!(out).ok();
105
106 writeln!(out, "impl Default for {name} {{").ok();
109 writeln!(
110 out,
111 " fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
112 )
113 .ok();
114 writeln!(out, "}}").ok();
115 writeln!(out).ok();
116
117 writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
120 writeln!(
121 out,
122 " fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
123 )
124 .ok();
125 writeln!(out, " let inner = {core_path}::deserialize(deserializer)?;").ok();
126 writeln!(out, " Ok(Self {{ inner }})").ok();
127 writeln!(out, " }}").ok();
128 writeln!(out, "}}").ok();
129
130 out
131}
132
133const PYTHON_KEYWORDS: &[&str] = &[
137 "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
138 "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
139 "del", "global", "nonlocal", "lambda", "assert", "type",
140];
141
142pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
144 let mut out = String::with_capacity(512);
148 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
149 derives.push("Default");
153 derives.push("serde::Serialize");
154 derives.push("serde::Deserialize");
155 if !derives.is_empty() {
156 writeln!(out, "#[derive({})]", derives.join(", ")).ok();
157 }
158 for attr in cfg.enum_attrs {
159 writeln!(out, "#[{attr}]").ok();
160 }
161 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
164 writeln!(out, "pub enum {} {{", enum_def.name).ok();
165 for (idx, variant) in enum_def.variants.iter().enumerate() {
166 if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
167 writeln!(out, " #[pyo3(name = \"{}_\")]", variant.name).ok();
168 }
169 if idx == 0 {
171 writeln!(out, " #[default]").ok();
172 }
173 writeln!(out, " {} = {idx},", variant.name).ok();
174 }
175 writeln!(out, "}}").ok();
176
177 out
178}