1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5fn to_snake_case(s: &str) -> String {
7 let mut result = String::new();
8 for (i, ch) in s.chars().enumerate() {
9 if ch.is_uppercase() {
10 if i > 0 {
11 result.push('_');
12 }
13 result.push(ch.to_lowercase().next().unwrap());
14 } else {
15 result.push(ch);
16 }
17 }
18 result
19}
20
21#[proc_macro_derive(Symbolic)]
22pub fn derive_symbolic(input: TokenStream) -> TokenStream {
23 let input = parse_macro_input!(input as DeriveInput);
24 let name = &input.ident;
25
26 let expanded = match &input.data {
27 Data::Struct(data) => derive_struct(name, &data.fields),
28 Data::Enum(data) => derive_enum(name, data),
29 Data::Union(_) => {
30 return syn::Error::new_spanned(name, "Symbolic cannot be derived for unions")
31 .to_compile_error()
32 .into();
33 }
34 };
35
36 expanded.into()
37}
38
39fn derive_struct(name: &syn::Ident, fields: &Fields) -> proc_macro2::TokenStream {
40 let func_name = to_snake_case(&name.to_string());
41
42 let field_count = match fields {
43 Fields::Unit => 0,
44 Fields::Unnamed(f) => f.unnamed.len(),
45 Fields::Named(f) => f.named.len(),
46 };
47
48 let symbolic_impl = match fields {
49 Fields::Unit => {
50 quote! {
51 impl Symbolic for #name {
52 fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
53 if sym.symbol_type() != aspire::SymbolType::Function { return None; }
54 if sym.is_positive() != Some(true) { return None; }
55 if sym.name()? != #func_name { return None; }
56 let args = sym.arguments()?;
57 if !args.is_empty() { return None; }
58 Some(#name)
59 }
60 fn to_symbol(&self) -> aspire::Symbol {
61 aspire::Symbol::id(#func_name, true).unwrap()
62 }
63 }
64 }
65 }
66 Fields::Unnamed(fields) => {
67 let field_indices: Vec<syn::Index> =
68 (0..fields.unnamed.len()).map(syn::Index::from).collect();
69 let field_vars: Vec<syn::Ident> = (0..fields.unnamed.len())
70 .map(|i| syn::Ident::new(&format!("f{i}"), proc_macro2::Span::call_site()))
71 .collect();
72
73 quote! {
74 impl Symbolic for #name {
75 fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
76 if sym.symbol_type() != aspire::SymbolType::Function { return None; }
77 if sym.is_positive() != Some(true) { return None; }
78 if sym.name()? != #func_name { return None; }
79 let args = sym.arguments()?;
80 if args.len() != #field_count { return None; }
81 Some(#name(
82 #(Symbolic::from_symbol(args[#field_indices])?,)*
83 ))
84 }
85 fn to_symbol(&self) -> aspire::Symbol {
86 let #name(#(#field_vars),*) = self;
87 aspire::Symbol::function(#func_name, &[
88 #(#field_vars.to_symbol(),)*
89 ], true).unwrap()
90 }
91 }
92 }
93 }
94 Fields::Named(fields) => {
95 let field_names: Vec<&syn::Ident> = fields
96 .named
97 .iter()
98 .map(|f| f.ident.as_ref().unwrap())
99 .collect();
100 let field_indices: Vec<syn::Index> =
101 (0..fields.named.len()).map(syn::Index::from).collect();
102
103 quote! {
104 impl Symbolic for #name {
105 fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
106 if sym.symbol_type() != aspire::SymbolType::Function { return None; }
107 if sym.is_positive() != Some(true) { return None; }
108 if sym.name()? != #func_name { return None; }
109 let args = sym.arguments()?;
110 if args.len() != #field_count { return None; }
111 Some(#name {
112 #(#field_names: Symbolic::from_symbol(args[#field_indices])?,)*
113 })
114 }
115 fn to_symbol(&self) -> aspire::Symbol {
116 aspire::Symbol::function(#func_name, &[
117 #(self.#field_names.to_symbol(),)*
118 ], true).unwrap()
119 }
120 }
121 }
122 }
123 };
124
125 quote! {
126 #symbolic_impl
127
128 impl aspire::SymbolicFun for #name {
129 fn signature() -> (&'static str, usize) {
130 (#func_name, #field_count)
131 }
132 }
133
134 impl std::fmt::Display for #name {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 std::fmt::Display::fmt(&self.to_symbol(), f)
137 }
138 }
139 }
140}
141
142fn derive_enum(name: &syn::Ident, data: &syn::DataEnum) -> proc_macro2::TokenStream {
143 let mut from_arms = Vec::new();
144 let mut to_arms = Vec::new();
145
146 for variant in &data.variants {
147 let variant_name = &variant.ident;
148 let func_name = to_snake_case(&variant_name.to_string());
149
150 match &variant.fields {
151 Fields::Unit => {
152 from_arms.push(quote! {
153 (#func_name, 0) => Some(#name::#variant_name),
154 });
155 to_arms.push(quote! {
156 #name::#variant_name => aspire::Symbol::id(#func_name, true).unwrap(),
157 });
158 }
159 Fields::Unnamed(fields) => {
160 let field_count = fields.unnamed.len();
161 let field_indices: Vec<syn::Index> =
162 (0..field_count).map(syn::Index::from).collect();
163 let field_vars: Vec<syn::Ident> = (0..field_count)
164 .map(|i| syn::Ident::new(&format!("f{i}"), proc_macro2::Span::call_site()))
165 .collect();
166
167 from_arms.push(quote! {
168 (#func_name, #field_count) => Some(#name::#variant_name(
169 #(Symbolic::from_symbol(args[#field_indices])?,)*
170 )),
171 });
172 to_arms.push(quote! {
173 #name::#variant_name(#(#field_vars),*) => {
174 aspire::Symbol::function(#func_name, &[
175 #(#field_vars.to_symbol(),)*
176 ], true).unwrap()
177 }
178 });
179 }
180 Fields::Named(fields) => {
181 let field_count = fields.named.len();
182 let field_names: Vec<&syn::Ident> = fields
183 .named
184 .iter()
185 .map(|f| f.ident.as_ref().unwrap())
186 .collect();
187 let field_indices: Vec<syn::Index> =
188 (0..field_count).map(syn::Index::from).collect();
189
190 from_arms.push(quote! {
191 (#func_name, #field_count) => Some(#name::#variant_name {
192 #(#field_names: Symbolic::from_symbol(args[#field_indices])?,)*
193 }),
194 });
195 to_arms.push(quote! {
196 #name::#variant_name { #(#field_names),* } => {
197 aspire::Symbol::function(#func_name, &[
198 #(#field_names.to_symbol(),)*
199 ], true).unwrap()
200 }
201 });
202 }
203 }
204 }
205
206 quote! {
207 impl Symbolic for #name {
208 fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
209 if sym.symbol_type() != aspire::SymbolType::Function { return None; }
210 if sym.is_positive() != Some(true) { return None; }
211 let name = sym.name()?;
212 let args = sym.arguments()?;
213 match (name, args.len()) {
214 #(#from_arms)*
215 _ => None,
216 }
217 }
218 fn to_symbol(&self) -> aspire::Symbol {
219 match self {
220 #(#to_arms)*
221 }
222 }
223 }
224
225 impl std::fmt::Display for #name {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 std::fmt::Display::fmt(&self.to_symbol(), f)
228 }
229 }
230 }
231}