1use crate::generators::RustBindingConfig;
2use alef_core::ir::EnumDef;
3use std::fmt::Write;
4
5pub fn enum_has_data_variants(enum_def: &EnumDef) -> bool {
8 enum_def.variants.iter().any(|v| !v.fields.is_empty())
9}
10
11fn enum_has_sanitized_fields(enum_def: &EnumDef) -> bool {
18 enum_def.variants.iter().any(|v| v.fields.iter().any(|f| f.sanitized))
19}
20
21pub fn gen_pyo3_data_enum(enum_def: &EnumDef, core_import: &str) -> String {
33 let name = &enum_def.name;
34 let core_path = crate::conversions::core_enum_path(enum_def, core_import);
35 let has_sanitized = enum_has_sanitized_fields(enum_def);
36 let mut out = String::with_capacity(512);
37
38 writeln!(out, "#[derive(Clone)]").ok();
39 writeln!(out, "#[pyclass(frozen)]").ok();
40 writeln!(out, "pub struct {name} {{").ok();
41 writeln!(out, " pub(crate) inner: {core_path},").ok();
42 writeln!(out, "}}").ok();
43 writeln!(out).ok();
44
45 writeln!(out, "#[pymethods]").ok();
46 writeln!(out, "impl {name} {{").ok();
47 if has_sanitized {
48 writeln!(out, "}}").ok();
52 } else {
53 writeln!(out, " #[new]").ok();
54 writeln!(
55 out,
56 " fn new(py: Python<'_>, value: &Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {{"
57 )
58 .ok();
59 writeln!(out, " let json_mod = py.import(\"json\")?;").ok();
60 writeln!(
61 out,
62 " let json_str: String = json_mod.call_method1(\"dumps\", (value,))?.extract()?;"
63 )
64 .ok();
65 writeln!(out, " let inner: {core_path} = serde_json::from_str(&json_str)").ok();
66 writeln!(
67 out,
68 " .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!(\"Invalid {name}: {{e}}\")))?;"
69 )
70 .ok();
71 writeln!(out, " Ok(Self {{ inner }})").ok();
72 writeln!(out, " }}").ok();
73 writeln!(out, "}}").ok();
74 }
75 writeln!(out).ok();
76
77 writeln!(out, "impl From<{name}> for {core_path} {{").ok();
79 writeln!(out, " fn from(val: {name}) -> Self {{ val.inner }}").ok();
80 writeln!(out, "}}").ok();
81 writeln!(out).ok();
82
83 writeln!(out, "impl From<{core_path}> for {name} {{").ok();
85 writeln!(out, " fn from(val: {core_path}) -> Self {{ Self {{ inner: val }} }}").ok();
86 writeln!(out, "}}").ok();
87
88 if !has_sanitized {
89 writeln!(out).ok();
90
91 writeln!(out, "impl serde::Serialize for {name} {{").ok();
93 writeln!(
94 out,
95 " fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {{"
96 )
97 .ok();
98 writeln!(out, " self.inner.serialize(serializer)").ok();
99 writeln!(out, " }}").ok();
100 writeln!(out, "}}").ok();
101 writeln!(out).ok();
102
103 writeln!(out, "impl Default for {name} {{").ok();
105 writeln!(
106 out,
107 " fn default() -> Self {{ Self {{ inner: Default::default() }} }}"
108 )
109 .ok();
110 writeln!(out, "}}").ok();
111 writeln!(out).ok();
112
113 writeln!(out, "impl<'de> serde::Deserialize<'de> for {name} {{").ok();
115 writeln!(
116 out,
117 " fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {{"
118 )
119 .ok();
120 writeln!(out, " let inner = {core_path}::deserialize(deserializer)?;").ok();
121 writeln!(out, " Ok(Self {{ inner }})").ok();
122 writeln!(out, " }}").ok();
123 writeln!(out, "}}").ok();
124 }
125
126 out
127}
128
129const PYTHON_KEYWORDS: &[&str] = &[
133 "None", "True", "False", "from", "import", "class", "def", "return", "yield", "pass", "break", "continue", "and",
134 "or", "not", "is", "in", "if", "else", "elif", "for", "while", "with", "as", "try", "except", "finally", "raise",
135 "del", "global", "nonlocal", "lambda", "assert", "type",
136];
137
138pub fn gen_enum(enum_def: &EnumDef, cfg: &RustBindingConfig) -> String {
140 let mut out = String::with_capacity(512);
144 let mut derives: Vec<&str> = cfg.enum_derives.to_vec();
145 derives.push("Default");
149 derives.push("serde::Serialize");
150 derives.push("serde::Deserialize");
151 if !derives.is_empty() {
152 writeln!(out, "#[derive({})]", derives.join(", ")).ok();
153 }
154 for attr in cfg.enum_attrs {
155 writeln!(out, "#[{attr}]").ok();
156 }
157 let is_pyo3 = cfg.enum_attrs.iter().any(|a| a.contains("pyclass"));
160 writeln!(out, "pub enum {} {{", enum_def.name).ok();
161 for (idx, variant) in enum_def.variants.iter().enumerate() {
162 if is_pyo3 && PYTHON_KEYWORDS.contains(&variant.name.as_str()) {
163 writeln!(out, " #[pyo3(name = \"{}_\")]", variant.name).ok();
164 }
165 if idx == 0 {
167 writeln!(out, " #[default]").ok();
168 }
169 writeln!(out, " {} = {idx},", variant.name).ok();
170 }
171 writeln!(out, "}}").ok();
172
173 out
174}