1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
18 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
19}
20
21pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
33 let name = &enum_def.name;
34 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
35 let has_sanitized = enum_has_sanitized_fields(enum_def);
36 let mut out = String::with_capacity(512);
37
38 writeln!(out, "#[derive(Clone)]").ok();
39 writeln!(out, "#[pyclass(frozen)]").ok();
40 writeln!(out, "pub struct {name} {{").ok();
41 writeln!(out, " pub(crate) inner: {core_path},").ok();
42 writeln!(out, "}}").ok();
43 writeln!(out).ok();
44
45 writeln!(out, "#[pymethods]").ok();
46 writeln!(out, "impl {name} {{").ok();
47 writeln!(out, " #[new]").ok();
48 if has_sanitized {
49 writeln!(
52 out,
53 " fn new(_py: Python<'_>, _value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
54 )
55 .ok();
56 writeln!(
57 out,
58 " Err(pyo3::exceptions::PyNotImplementedError::new_err(\"{name} cannot be constructed from Python: the core type contains non-serializable variants\"))"
59 )
60 .ok();
61 } else {
62 writeln!(
63 out,
64 " fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
65 )
66 .ok();
67 writeln!(out, " let json_mod = py.import(\"json\")?;").ok();
68 writeln!(
69 out,
70 " let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
71 )
72 .ok();
73 writeln!(out, " let inner: {core_path} = serde_json::from_str(&json_str)").ok();
74 writeln!(
75 out,
76 " .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
77 )
78 .ok();
79 writeln!(out, " Ok(Self {{ inner }})").ok();
80 }
81 writeln!(out, " }}").ok();
82 writeln!(out, "}}").ok();
83 writeln!(out).ok();
84
85 writeln!(out, "impl From<{name}> for {core_path} {{").ok();
87 writeln!(out, " fn from(val: {name}) -> Self {{ val.inner }}").ok();
88 writeln!(out, "}}").ok();
89 writeln!(out).ok();
90
91 writeln!(out, "impl From<{core_path}> for {name} {{").ok();
93 writeln!(out, " fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
94 writeln!(out, "}}").ok();
95
96 if !has_sanitized {
97 writeln!(out).ok();
98
99 writeln!(out, "impl serde::Serialize for {name} {{").ok();
101 writeln!(
102 out,
103 " fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
104 )
105 .ok();
106 writeln!(out, " self.inner.serialize(serializer)").ok();
107 writeln!(out, " }}").ok();
108 writeln!(out, "}}").ok();
109 writeln!(out).ok();
110
111 writeln!(out, "impl Default for {name} {{").ok();
113 writeln!(
114 out,
115 " fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
116 )
117 .ok();
118 writeln!(out, "}}").ok();
119 writeln!(out).ok();
120
121 writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
123 writeln!(
124 out,
125 " fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
126 )
127 .ok();
128 writeln!(out, " let inner = {core_path}::deserialize(deserializer)?;").ok();
129 writeln!(out, " Ok(Self {{ inner }})").ok();
130 writeln!(out, " }}").ok();
131 writeln!(out, "}}").ok();
132 }
133
134 out
135}
136
137const PYTHON_KEYWORDS: &[&str] = &[
141 "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
142 "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
143 "del", "global", "nonlocal", "lambda", "assert", "type",
144];
145
146pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
148 let mut out = String::with_capacity(512);
152 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
153 derives.push("Default");
157 derives.push("serde::Serialize");
158 derives.push("serde::Deserialize");
159 if !derives.is_empty() {
160 writeln!(out, "#[derive({})]", derives.join(", ")).ok();
161 }
162 for attr in cfg.enum_attrs {
163 writeln!(out, "#[{attr}]").ok();
164 }
165 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
168 writeln!(out, "pub enum {} {{", enum_def.name).ok();
169 for (idx, variant) in enum_def.variants.iter().enumerate() {
170 if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
171 writeln!(out, " #[pyo3(name = \"{}_\")]", variant.name).ok();
172 }
173 if idx == 0 {
175 writeln!(out, " #[default]").ok();
176 }
177 writeln!(out, " {} = {idx},", variant.name).ok();
178 }
179 writeln!(out, "}}").ok();
180
181 out
182}