Skip to main content

aspire_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5/// Convert a PascalCase identifier to snake_case.
6fn to_snake_case(s: &str) -> String {
7    let mut result = String::new();
8    for (i, ch) in s.chars().enumerate() {
9        if ch.is_uppercase() {
10            if i > 0 {
11                result.push('_');
12            }
13            result.push(ch.to_lowercase().next().unwrap());
14        } else {
15            result.push(ch);
16        }
17    }
18    result
19}
20
21#[proc_macro_derive(Symbolic)]
22pub fn derive_symbolic(input: TokenStream) -> TokenStream {
23    let input = parse_macro_input!(input as DeriveInput);
24    let name = &input.ident;
25
26    let expanded = match &input.data {
27        Data::Struct(data) => derive_struct(name, &data.fields),
28        Data::Enum(data) => derive_enum(name, data),
29        Data::Union(_) => {
30            return syn::Error::new_spanned(name, "Symbolic cannot be derived for unions")
31                .to_compile_error()
32                .into();
33        }
34    };
35
36    expanded.into()
37}
38
39fn derive_struct(name: &syn::Ident, fields: &Fields) -> proc_macro2::TokenStream {
40    let func_name = to_snake_case(&name.to_string());
41
42    let field_count = match fields {
43        Fields::Unit => 0,
44        Fields::Unnamed(f) => f.unnamed.len(),
45        Fields::Named(f) => f.named.len(),
46    };
47
48    let symbolic_impl = match fields {
49        Fields::Unit => {
50            quote! {
51                impl Symbolic for #name {
52                    fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
53                        if sym.symbol_type() != aspire::SymbolType::Function { return None; }
54                        if sym.is_positive() != Some(true) { return None; }
55                        if sym.name()? != #func_name { return None; }
56                        let args = sym.arguments()?;
57                        if !args.is_empty() { return None; }
58                        Some(#name)
59                    }
60                    fn to_symbol(&self) -> aspire::Symbol {
61                        aspire::Symbol::id(#func_name, true).unwrap()
62                    }
63                }
64            }
65        }
66        Fields::Unnamed(fields) => {
67            let field_indices: Vec<syn::Index> =
68                (0..fields.unnamed.len()).map(syn::Index::from).collect();
69            let field_vars: Vec<syn::Ident> = (0..fields.unnamed.len())
70                .map(|i| syn::Ident::new(&format!("f{i}"), proc_macro2::Span::call_site()))
71                .collect();
72
73            quote! {
74                impl Symbolic for #name {
75                    fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
76                        if sym.symbol_type() != aspire::SymbolType::Function { return None; }
77                        if sym.is_positive() != Some(true) { return None; }
78                        if sym.name()? != #func_name { return None; }
79                        let args = sym.arguments()?;
80                        if args.len() != #field_count { return None; }
81                        Some(#name(
82                            #(Symbolic::from_symbol(args[#field_indices])?,)*
83                        ))
84                    }
85                    fn to_symbol(&self) -> aspire::Symbol {
86                        let #name(#(#field_vars),*) = self;
87                        aspire::Symbol::function(#func_name, &[
88                            #(#field_vars.to_symbol(),)*
89                        ], true).unwrap()
90                    }
91                }
92            }
93        }
94        Fields::Named(fields) => {
95            let field_names: Vec<&syn::Ident> = fields
96                .named
97                .iter()
98                .map(|f| f.ident.as_ref().unwrap())
99                .collect();
100            let field_indices: Vec<syn::Index> =
101                (0..fields.named.len()).map(syn::Index::from).collect();
102
103            quote! {
104                impl Symbolic for #name {
105                    fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
106                        if sym.symbol_type() != aspire::SymbolType::Function { return None; }
107                        if sym.is_positive() != Some(true) { return None; }
108                        if sym.name()? != #func_name { return None; }
109                        let args = sym.arguments()?;
110                        if args.len() != #field_count { return None; }
111                        Some(#name {
112                            #(#field_names: Symbolic::from_symbol(args[#field_indices])?,)*
113                        })
114                    }
115                    fn to_symbol(&self) -> aspire::Symbol {
116                        aspire::Symbol::function(#func_name, &[
117                            #(self.#field_names.to_symbol(),)*
118                        ], true).unwrap()
119                    }
120                }
121            }
122        }
123    };
124
125    quote! {
126        #symbolic_impl
127
128        impl aspire::SymbolicFun for #name {
129            fn signature() -> (&'static str, usize) {
130                (#func_name, #field_count)
131            }
132        }
133
134        impl std::fmt::Display for #name {
135            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136                std::fmt::Display::fmt(&self.to_symbol(), f)
137            }
138        }
139    }
140}
141
142fn derive_enum(name: &syn::Ident, data: &syn::DataEnum) -> proc_macro2::TokenStream {
143    let mut from_arms = Vec::new();
144    let mut to_arms = Vec::new();
145
146    for variant in &data.variants {
147        let variant_name = &variant.ident;
148        let func_name = to_snake_case(&variant_name.to_string());
149
150        match &variant.fields {
151            Fields::Unit => {
152                from_arms.push(quote! {
153                    (#func_name, 0) => Some(#name::#variant_name),
154                });
155                to_arms.push(quote! {
156                    #name::#variant_name => aspire::Symbol::id(#func_name, true).unwrap(),
157                });
158            }
159            Fields::Unnamed(fields) => {
160                let field_count = fields.unnamed.len();
161                let field_indices: Vec<syn::Index> =
162                    (0..field_count).map(syn::Index::from).collect();
163                let field_vars: Vec<syn::Ident> = (0..field_count)
164                    .map(|i| syn::Ident::new(&format!("f{i}"), proc_macro2::Span::call_site()))
165                    .collect();
166
167                from_arms.push(quote! {
168                    (#func_name, #field_count) => Some(#name::#variant_name(
169                        #(Symbolic::from_symbol(args[#field_indices])?,)*
170                    )),
171                });
172                to_arms.push(quote! {
173                    #name::#variant_name(#(#field_vars),*) => {
174                        aspire::Symbol::function(#func_name, &[
175                            #(#field_vars.to_symbol(),)*
176                        ], true).unwrap()
177                    }
178                });
179            }
180            Fields::Named(fields) => {
181                let field_count = fields.named.len();
182                let field_names: Vec<&syn::Ident> = fields
183                    .named
184                    .iter()
185                    .map(|f| f.ident.as_ref().unwrap())
186                    .collect();
187                let field_indices: Vec<syn::Index> =
188                    (0..field_count).map(syn::Index::from).collect();
189
190                from_arms.push(quote! {
191                    (#func_name, #field_count) => Some(#name::#variant_name {
192                        #(#field_names: Symbolic::from_symbol(args[#field_indices])?,)*
193                    }),
194                });
195                to_arms.push(quote! {
196                    #name::#variant_name { #(#field_names),* } => {
197                        aspire::Symbol::function(#func_name, &[
198                            #(#field_names.to_symbol(),)*
199                        ], true).unwrap()
200                    }
201                });
202            }
203        }
204    }
205
206    quote! {
207        impl Symbolic for #name {
208            fn from_symbol(sym: aspire::Symbol) -> Option<Self> {
209                if sym.symbol_type() != aspire::SymbolType::Function { return None; }
210                if sym.is_positive() != Some(true) { return None; }
211                let name = sym.name()?;
212                let args = sym.arguments()?;
213                match (name, args.len()) {
214                    #(#from_arms)*
215                    _ => None,
216                }
217            }
218            fn to_symbol(&self) -> aspire::Symbol {
219                match self {
220                    #(#to_arms)*
221                }
222            }
223        }
224
225        impl std::fmt::Display for #name {
226            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227                std::fmt::Display::fmt(&self.to_symbol(), f)
228            }
229        }
230    }
231}