slotted_egraphs_derive/
lib.rs

1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::*;
5
6// We allow the user to use tuples, Slot, Bind<_>, AppliedId and "user-defined types" in their enum variants.
7// user-defined types will be understood as slot-independent constants, and ignored by the system.
8
9#[proc_macro]
10pub fn define_language(input: TokenStream1) -> TokenStream1 {
11    let mut ie: ItemEnum = parse(input).unwrap();
12
13    let name = ie.ident.clone();
14    let str_names: Vec<Option<Expr>> = ie
15        .variants
16        .iter_mut()
17        .map(|x| x.discriminant.take().map(|(_, e)| e))
18        .collect();
19
20    let all_slot_occurrences_mut_arms: Vec<TokenStream2> = ie
21        .variants
22        .iter()
23        .map(|x| produce_all_slot_occurrences_mut(&name, x))
24        .collect();
25    let public_slot_occurrences_mut_arms: Vec<TokenStream2> = ie
26        .variants
27        .iter()
28        .map(|x| produce_public_slot_occurrences_mut(&name, x))
29        .collect();
30    let applied_id_occurrences_mut_arms: Vec<TokenStream2> = ie
31        .variants
32        .iter()
33        .map(|x| produce_applied_id_occurrences_mut(&name, x))
34        .collect();
35
36    let all_slot_occurrences_arms: Vec<TokenStream2> = ie
37        .variants
38        .iter()
39        .map(|x| produce_all_slot_occurrences(&name, x))
40        .collect();
41    let public_slot_occurrences_arms: Vec<TokenStream2> = ie
42        .variants
43        .iter()
44        .map(|x| produce_public_slot_occurrences(&name, x))
45        .collect();
46    let applied_id_occurrences_arms: Vec<TokenStream2> = ie
47        .variants
48        .iter()
49        .map(|x| produce_applied_id_occurrences(&name, x))
50        .collect();
51
52    let to_syntax_arms: Vec<TokenStream2> = ie
53        .variants
54        .iter()
55        .zip(&str_names)
56        .map(|(x, n)| produce_to_syntax(&name, &n, x))
57        .collect();
58    let from_syntax_arms1: Vec<TokenStream2> = ie
59        .variants
60        .iter()
61        .zip(&str_names)
62        .filter_map(|(x, n)| produce_from_syntax1(&name, &n, x))
63        .collect();
64    let from_syntax_arms2: Vec<TokenStream2> = ie
65        .variants
66        .iter()
67        .zip(&str_names)
68        .filter_map(|(x, n)| produce_from_syntax2(&name, &n, x))
69        .collect();
70
71    let slots_arms: Vec<TokenStream2> = ie
72        .variants
73        .iter()
74        .map(|x| produce_slots(&name, x))
75        .collect();
76    let weak_shape_inplace_arms: Vec<TokenStream2> = ie
77        .variants
78        .iter()
79        .map(|x| produce_weak_shape_inplace(&name, x))
80        .collect();
81
82    quote! {
83        #[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
84        #ie
85
86        impl Language for #name {
87            // mut:
88            fn all_slot_occurrences_mut(&mut self) -> Vec<&mut Slot> {
89                match self {
90                    #(#all_slot_occurrences_mut_arms),*
91                }
92            }
93
94            fn public_slot_occurrences_mut(&mut self) -> Vec<&mut Slot> {
95                match self {
96                    #(#public_slot_occurrences_mut_arms),*
97                }
98            }
99
100            fn applied_id_occurrences_mut(&mut self) -> Vec<&mut AppliedId> {
101                match self {
102                    #(#applied_id_occurrences_mut_arms),*
103                }
104            }
105
106
107            // immut:
108            fn all_slot_occurrences(&self) -> Vec<Slot> {
109                match self {
110                    #(#all_slot_occurrences_arms),*
111                }
112            }
113
114            fn public_slot_occurrences(&self) -> Vec<Slot> {
115                match self {
116                    #(#public_slot_occurrences_arms),*
117                }
118            }
119
120            fn applied_id_occurrences(&self) -> Vec<&AppliedId> {
121                match self {
122                    #(#applied_id_occurrences_arms),*
123                }
124            }
125
126            // syntax:
127            fn to_syntax(&self) -> Vec<SyntaxElem> {
128                match self {
129                    #(#to_syntax_arms),*
130                }
131            }
132
133            fn from_syntax(elems: &[SyntaxElem]) -> Option<Self> {
134                let SyntaxElem::String(op) = elems.get(0)? else { return None };
135                match &**op {
136                    #(#from_syntax_arms1),*
137                    _ => {
138                        #(#from_syntax_arms2)*
139
140                        None
141                    },
142                }
143            }
144
145            fn slots(&self) -> slotted_egraphs::SmallHashSet<Slot> {
146                match self {
147                    #(#slots_arms),*
148                }
149            }
150
151            fn weak_shape_inplace(&mut self) -> slotted_egraphs::SlotMap {
152                let m = &mut (slotted_egraphs::SlotMap::new(), 0);
153                match self {
154                    #(#weak_shape_inplace_arms),*
155                }
156
157                m.0.inverse()
158            }
159        }
160    }
161    .to_token_stream()
162    .into()
163}
164
165fn produce_all_slot_occurrences_mut(name: &Ident, v: &Variant) -> TokenStream2 {
166    let variant_name = &v.ident;
167    let n = v.fields.len();
168    let fields: Vec<Ident> = (0..n)
169        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
170        .collect();
171    quote! {
172        #name::#variant_name(#(#fields),*) => {
173            let out = std::iter::empty();
174            #(
175                let out = out.chain(#fields .all_slot_occurrences_iter_mut());
176            )*
177            out.collect()
178        }
179    }
180}
181
182fn produce_public_slot_occurrences_mut(name: &Ident, v: &Variant) -> TokenStream2 {
183    let variant_name = &v.ident;
184    let n = v.fields.len();
185    let fields: Vec<Ident> = (0..n)
186        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
187        .collect();
188    quote! {
189        #name::#variant_name(#(#fields),*) => {
190            let out = std::iter::empty();
191            #(
192                let out = out.chain(#fields .public_slot_occurrences_iter_mut());
193            )*
194            out.collect()
195        }
196    }
197}
198
199fn produce_applied_id_occurrences_mut(name: &Ident, v: &Variant) -> TokenStream2 {
200    let variant_name = &v.ident;
201    let n = v.fields.len();
202    let fields: Vec<Ident> = (0..n)
203        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
204        .collect();
205    quote! {
206        #name::#variant_name(#(#fields),*) => {
207            let out = std::iter::empty();
208            #(
209                let out = out.chain(#fields .applied_id_occurrences_iter_mut());
210            )*
211            out.collect()
212        }
213    }
214}
215
216// immut:
217fn produce_all_slot_occurrences(name: &Ident, v: &Variant) -> TokenStream2 {
218    let variant_name = &v.ident;
219    let n = v.fields.len();
220    let fields: Vec<Ident> = (0..n)
221        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
222        .collect();
223    quote! {
224        #name::#variant_name(#(#fields),*) => {
225            let out = std::iter::empty();
226            #(
227                let out = out.chain(#fields .all_slot_occurrences_iter().copied());
228            )*
229            out.collect()
230        }
231    }
232}
233
234fn produce_public_slot_occurrences(name: &Ident, v: &Variant) -> TokenStream2 {
235    let variant_name = &v.ident;
236    let n = v.fields.len();
237    let fields: Vec<Ident> = (0..n)
238        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
239        .collect();
240    quote! {
241        #name::#variant_name(#(#fields),*) => {
242            let out = std::iter::empty();
243            #(
244                let out = out.chain(#fields .public_slot_occurrences_iter().copied());
245            )*
246            out.collect()
247        }
248    }
249}
250
251fn produce_applied_id_occurrences(name: &Ident, v: &Variant) -> TokenStream2 {
252    let variant_name = &v.ident;
253    let n = v.fields.len();
254    let fields: Vec<Ident> = (0..n)
255        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
256        .collect();
257    quote! {
258        #name::#variant_name(#(#fields),*) => {
259            let out = std::iter::empty();
260            #(
261                let out = out.chain(#fields .applied_id_occurrences_iter());
262            )*
263            out.collect()
264        }
265    }
266}
267
268// syntax:
269fn produce_to_syntax(name: &Ident, e: &Option<Expr>, v: &Variant) -> TokenStream2 {
270    let variant_name = &v.ident;
271
272    if e.is_none() {
273        return quote! {
274            #name::#variant_name(a0) => {
275                a0.to_syntax()
276            }
277        };
278    }
279
280    let e = e.as_ref().unwrap();
281    let n = v.fields.len();
282    let fields: Vec<Ident> = (0..n)
283        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
284        .collect();
285    quote! {
286        #name::#variant_name(#(#fields),*) => {
287            let mut out: Vec<SyntaxElem> = vec![SyntaxElem::String(String::from(#e))];
288            #(
289                out.extend(#fields.to_syntax());
290            )*
291            out
292        }
293    }
294}
295
296fn produce_from_syntax1(name: &Ident, e: &Option<Expr>, v: &Variant) -> Option<TokenStream2> {
297    let variant_name = &v.ident;
298
299    let e = e.as_ref()?;
300    let n = v.fields.len();
301    let fields: Vec<Ident> = (0..n)
302        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
303        .collect();
304
305    let types: Vec<Type> = v.fields.iter().map(|x| x.ty.clone()).collect();
306
307    Some(quote! {
308        #e => {
309            let mut children = &elems[1..];
310            let mut rest = children;
311            #(
312                let #fields = (0..=children.len()).filter_map(|n| {
313                    let a = &children[..n];
314                    rest = &children[n..];
315
316                    <#types>::from_syntax(a)
317                }).next()?;
318                children = rest;
319            )*
320            Some(#name::#variant_name(#(#fields),*))
321        }
322    })
323}
324
325fn produce_from_syntax2(name: &Ident, e: &Option<Expr>, v: &Variant) -> Option<TokenStream2> {
326    if e.is_some() {
327        return None;
328    }
329    let variant_name = &v.ident;
330
331    let ty = v.fields.iter().map(|x| x.ty.clone()).next().unwrap();
332    Some(quote! {
333        if let Some(a) = <#ty>::from_syntax(elems) {
334            return Some(#name::#variant_name(a));
335        }
336    })
337}
338
339fn produce_slots(name: &Ident, v: &Variant) -> TokenStream2 {
340    let variant_name = &v.ident;
341    let n = v.fields.len();
342    let fields: Vec<Ident> = (0..n)
343        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
344        .collect();
345    quote! {
346        #name::#variant_name(#(#fields),*) => {
347            let out = std::iter::empty();
348            #(
349                let out = out.chain(#fields .public_slot_occurrences_iter().copied());
350            )*
351            out.collect()
352        }
353    }
354}
355
356fn produce_weak_shape_inplace(name: &Ident, v: &Variant) -> TokenStream2 {
357    let variant_name = &v.ident;
358    let n = v.fields.len();
359    let fields: Vec<Ident> = (0..n)
360        .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
361        .collect();
362    quote! {
363        #name::#variant_name(#(#fields),*) => {
364            #(
365                #fields .weak_shape_impl(m);
366            )*
367        }
368    }
369}