aws_smt_ir_derive/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use std::collections::HashSet;
6use syn::{parse_quote, DeriveInput};
7use synstructure::decl_derive;
8
9decl_derive!([Fold] => derive_fold);
10decl_derive!([Visit] => derive_visit);
11decl_derive!([Operation, attributes(core_op, symbol)] => derive_operation_all);
12
13/// Gets the path to the `aws-smt-ir` crate, returning `crate` in case it is the current crate.
14fn smt_ir_crate_path() -> syn::Path {
15    match proc_macro_crate::crate_name("aws-smt-ir").expect("must depend on aws-smt-ir") {
16        proc_macro_crate::FoundCrate::Itself => parse_quote!(aws_smt_ir),
17        proc_macro_crate::FoundCrate::Name(name) => {
18            let name = syn::Ident::new(&name, proc_macro2::Span::call_site());
19            parse_quote!(#name)
20        }
21    }
22}
23
24/// Determines whether `input` is annotated with the `#[core_op]` attribute.
25fn has_core_op_attr(input: &DeriveInput) -> bool {
26    input.attrs.iter().any(|a| a.path().is_ident("core_op"))
27}
28
29// Determine a variant's symbol (its representation in SMT-LIB) -- either the lowercased name of
30// the variant or the contents of an attached `symbol` attribute, if present e.g. `#[symbol("=>")]`.
31fn variant_symbol(variant: &synstructure::VariantInfo) -> syn::LitStr {
32    let ast = variant.ast();
33    (ast.attrs.iter())
34        .find(|attr| attr.path().is_ident("symbol"))
35        .and_then(|attr| attr.parse_args().ok())
36        .unwrap_or_else(|| {
37            let name = ast.ident.to_string().to_lowercase();
38            parse_quote!(#name)
39        })
40}
41
42/// Derives `Debug`, `Display`, and `Operation` for an operator enum e.g. `Foo<Term>`
43/// and `From` for `Op`.
44fn derive_operation_all(mut s: synstructure::Structure) -> TokenStream {
45    s.add_bounds(synstructure::AddBounds::None)
46        .bind_with(|_| synstructure::BindStyle::Move);
47    let debug = derive_fmt_any(&s, parse_quote!(std::fmt::Debug));
48    let display = derive_fmt_any(&s, parse_quote!(std::fmt::Display));
49    let from = derive_from(&s);
50    let operation = derive_operation(&s);
51    let iterate = derive_iterate(&s);
52    quote! {
53        #debug
54        #display
55        #from
56        #operation
57        #iterate
58    }
59}
60
61fn index_array(ty: &syn::Type) -> Option<&syn::Expr> {
62    match ty {
63        syn::Type::Array(syn::TypeArray { len, .. }) => Some(len),
64        _ => None,
65    }
66}
67
68/// Derives an `Operation` implementation for an operator.
69fn derive_operation(s: &synstructure::Structure) -> TokenStream {
70    let smt_ir = smt_ir_crate_path();
71
72    #[allow(non_snake_case)]
73    let (Term, Logic, Operation, Parse, NumArgs, InvalidOp, QualIdentifier, IIndex, TryFrom, Vec) = (
74        quote!(#smt_ir::Term),
75        quote!(#smt_ir::Logic),
76        quote!(#smt_ir::term::Operation),
77        quote!(#smt_ir::term::args::Parse),
78        quote!(#smt_ir::term::args::NumArgs),
79        quote!(#smt_ir::term::InvalidOp),
80        quote!(#smt_ir::QualIdentifier),
81        quote!(#smt_ir::IIndex),
82        quote!(std::convert::TryFrom),
83        quote!(std::vec::Vec),
84    );
85
86    let mut bindings = vec![];
87
88    let parse_match_arms: Vec<_> = (s.variants().iter())
89        .enumerate()
90        .map(|(idx, variant)| {
91            let symbol = variant_symbol(variant);
92            let mut num_indices = None;
93            let mut min_args = vec![];
94            let mut max_args = vec![];
95
96            // For each field in the variant, try to parse it from the iterator of arguments
97            let constructed = variant.construct(|field, _| {
98                let ty = &field.ty;
99
100                // Check for array of indices
101                if let Some(len) = index_array(ty) {
102                    num_indices = Some(len.clone());
103                    quote! {{
104                        let indices: std::vec::Vec<_> = func.indices().iter().map(#IIndex::from).collect();
105                        #TryFrom::try_from(indices).unwrap()
106                    }}
107                } else {
108                    min_args.push(quote!(<#ty as #NumArgs>::MIN_ARGS));
109                    max_args.push(quote!(<#ty as #NumArgs>::MAX_ARGS));
110                    quote!(#Parse::from_iter(&mut iter).unwrap())
111                }
112            });
113
114            let min_args = quote!((0 #(+ #min_args)*));
115            let max_args = quote!((0 #(+ #max_args)*));
116            let num_indices = num_indices.unwrap_or_else(|| parse_quote!(0));
117            let min_args_ident = syn::Ident::new(&format!("MIN_ARGS_{}", idx), Span::call_site());
118            let max_args_ident = syn::Ident::new(&format!("MAX_ARGS_{}", idx), Span::call_site());
119            let num_indices_ident = syn::Ident::new(&format!("INDICES_{}", idx), Span::call_site());
120            bindings.push(quote!(let #min_args_ident = #min_args;));
121            bindings.push(quote!(let #max_args_ident = #max_args;));
122            bindings.push(quote!(let #num_indices_ident = #num_indices;));
123
124            // Construct a match arm e.g. `"and" => { ... }` where `...` constructs the variant from
125            // a slice of arguments.
126            quote! {
127                (#symbol, num_args) if func.indices().len() == #num_indices_ident && (#min_args_ident..=#max_args_ident).contains(&num_args) => {
128                    let mut iter = args.into_iter();
129                    #constructed
130                }
131            }
132        })
133        .collect();
134
135    let parse_fn = quote! {
136        fn parse(func: #QualIdentifier, args: #Vec<#Term<L>>) -> std::result::Result<Self, #InvalidOp<L>> {
137            #(#bindings)*
138            #[deny(unreachable_patterns)]
139            Ok(match (func.sym_str(), args.len()) {
140                #(#parse_match_arms)*
141                _ => return Err(#InvalidOp { func, args })
142            })
143        }
144    };
145
146    let func_match_arms = s.each_variant(|variant| {
147        let symbol = variant_symbol(variant);
148        quote!(#symbol.into())
149    });
150
151    let func_fn = quote! {
152        fn func(&self) -> #smt_ir::ISymbol {
153            match self {
154                #func_match_arms
155            }
156        }
157    };
158
159    let mut where_clause = None;
160    s.add_trait_bounds(
161        &parse_quote!(#Parse<L>),
162        &mut where_clause,
163        synstructure::AddBounds::Fields,
164    );
165    if has_core_op_attr(s.ast()) {
166        s.gen_impl(quote! {
167            gen impl<L: #Logic> #Operation<L> for @Self
168            #where_clause,
169                <L as #Logic>::Op: #Operation<L>,
170            {
171                #parse_fn
172                #func_fn
173            }
174        })
175    } else {
176        s.gen_impl(quote! {
177            gen impl<L: #Logic> #Operation<L> for @Self #where_clause {
178                #parse_fn
179                #func_fn
180            }
181        })
182    }
183}
184
185fn bound_argument_fields(
186    s: &synstructure::Structure,
187    clause: &mut syn::WhereClause,
188    mut bound: impl FnMut(&syn::Type) -> syn::WherePredicate,
189) {
190    let mut seen = HashSet::new();
191
192    for variant in s.variants() {
193        for binding in variant.bindings() {
194            let ty = &binding.ast().ty;
195            if seen.insert(ty) && index_array(ty).is_none() {
196                clause.predicates.push(bound(ty));
197            }
198        }
199    }
200}
201
202/// Derives an `Iterate` implementation for an operator.
203fn derive_iterate(s: &synstructure::Structure) -> TokenStream {
204    let smt_ir = smt_ir_crate_path();
205
206    #[allow(non_snake_case)]
207    let (Term, Logic, Iterate, Args) = (
208        quote!(#smt_ir::Term),
209        quote!(#smt_ir::Logic),
210        quote!(#smt_ir::term::args::Iterate),
211        quote!(#smt_ir::term::args::Arguments),
212    );
213
214    fn argument_iter_branches(
215        s: &synstructure::Structure,
216        mut iterate: impl FnMut(&synstructure::BindingInfo) -> TokenStream,
217    ) -> TokenStream {
218        s.each_variant(|v| {
219            let mut bindings = (v.bindings().iter())
220                .skip_while(|field| index_array(&field.ast().ty).is_some())
221                .map(&mut iterate);
222            let mut iter = bindings
223                .next()
224                .unwrap_or_else(|| quote!(std::iter::empty()));
225            for new in bindings {
226                iter = quote!(#iter.chain(#new))
227            }
228            // TODO: instead of boxing, could also make an enum -- might be worth it
229            quote!(std::boxed::Box::new(#iter))
230        })
231    }
232
233    let mut where_clause = syn::WhereClause {
234        where_token: Default::default(),
235        predicates: Default::default(),
236    };
237
238    bound_argument_fields(
239        s,
240        &mut where_clause,
241        |ty| parse_quote!(#ty: #Iterate<'a, L>),
242    );
243
244    let args_branches = argument_iter_branches(s, |field| quote!(#Iterate::<L>::terms(#field)));
245    let into_args_branches =
246        argument_iter_branches(s, |field| quote!(#Iterate::<L>::into_terms(#field)));
247
248    s.gen_impl(quote! {
249        gen impl<'a, L: #Logic> #Iterate<'a, L> for @Self
250        #where_clause
251        {
252            type Terms = std::boxed::Box<dyn std::iter::Iterator<Item = &'a #Term<L>> + 'a>;
253            type IntoTerms = std::boxed::Box<dyn std::iter::Iterator<Item = #Term<L>> + 'a>;
254
255            fn terms(&'a self) -> Self::Terms {
256                match self {
257                    #args_branches
258                }
259            }
260
261            fn into_terms(self) -> Self::IntoTerms {
262                match self {
263                    #into_args_branches
264                }
265            }
266        }
267
268        gen impl<'a, L: #Logic> #Args<'a, L> for @Self #where_clause {}
269    })
270}
271
272/// Derives a `Debug` or `Display` implementation for an operator enum depending on the trait path
273/// passed as `trait_path`.
274fn derive_fmt_any(s: &synstructure::Structure, trait_path: syn::Path) -> TokenStream {
275    let smt_ir = smt_ir_crate_path();
276
277    #[allow(non_snake_case)]
278    let Format = quote!(#smt_ir::term::args::Format);
279
280    let fmt_body = s.each_variant(|variant| {
281        let symbol = variant_symbol(variant);
282        let bindings = variant.bindings();
283        if bindings.is_empty() {
284            quote!(std::write!(f, #symbol))
285        } else {
286            let mut fmt_indices = None;
287            let fmt_fields: Vec<_> = bindings
288                .iter()
289                .filter_map(|field| {
290                    if index_array(&field.ast().ty).is_some() {
291                        fmt_indices = Some(quote! {
292                            for index in #field {
293                                std::write!(f, " {}", index)?;
294                            }
295                        });
296                        None
297                    } else {
298                        Some(quote! {
299                            std::write!(f, " ")?;
300                            #Format::fmt(#field, f, #trait_path::fmt)
301                        })
302                    }
303                })
304                .collect();
305            let fmt_func = if let Some(fmt_indices) = fmt_indices {
306                quote! {
307                    std::write!(f, "(_ {}", #symbol)?;
308                    #fmt_indices
309                    std::write!(f, ")")
310                }
311            } else {
312                quote!(std::write!(f, #symbol))
313            };
314            quote! {
315                std::write!(f, "(")?;
316                #fmt_func?;
317                #(#fmt_fields?;)*
318                std::write!(f, ")")
319            }
320        }
321    });
322
323    let mut where_clause = None;
324    s.add_trait_bounds(
325        &parse_quote!(std::fmt::Debug),
326        &mut where_clause,
327        synstructure::AddBounds::Generics,
328    );
329    s.add_trait_bounds(
330        &parse_quote!(std::fmt::Display),
331        &mut where_clause,
332        synstructure::AddBounds::Generics,
333    );
334    s.add_trait_bounds(
335        &parse_quote!(#Format),
336        &mut where_clause,
337        synstructure::AddBounds::Generics,
338    );
339    s.gen_impl(quote! {
340        extern crate std;
341        gen impl #trait_path for @Self #where_clause {
342            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343                match self {
344                    #fmt_body
345                }
346            }
347        }
348    })
349}
350
351/// Derives `Fold` and `SuperFold` for an operator enum e.g. `Foo<Term>`.
352fn derive_fold(mut s: synstructure::Structure) -> TokenStream {
353    let smt_ir = smt_ir_crate_path();
354    let name = &s.ast().ident; // E.g. `Foo`
355
356    #[allow(non_snake_case)]
357    let (Logic, Fold, SuperFold, Folder) = (
358        quote!(#smt_ir::Logic),
359        quote!(#smt_ir::fold::Fold),
360        quote!(#smt_ir::fold::SuperFold),
361        quote!(#smt_ir::fold::Folder),
362    );
363
364    s.add_bounds(synstructure::AddBounds::None)
365        .bind_with(|_| synstructure::BindStyle::Move);
366
367    let impl_fold = s.gen_impl(quote! {
368        extern crate std;
369        gen impl<L: #Logic<Op = Self>, Out> #Fold<L, Out> for @Self {
370            type Output = Out;
371
372            fn fold_with<F, M>(
373                self,
374                folder: &mut F,
375            ) -> std::result::Result<Self::Output, F::Error>
376            where
377                F: #Folder<L, M, Output = Out>,
378            {
379                folder.fold_theory_op(self.into())
380            }
381        }
382    });
383
384    let impl_super_fold = {
385        // Bound each generic parameter to implement `Fold`
386        let mut where_clause = None;
387        s.add_trait_bounds(
388            &parse_quote!(#Fold<L, Out>),
389            &mut where_clause,
390            synstructure::AddBounds::Generics,
391        );
392
393        // For each variant, construct a new version by folding each of the fields
394        let body = s.each_variant(|vi| {
395            vi.construct(|_, idx| {
396                let field = &vi.bindings()[idx];
397                quote!(#Fold::fold_with(#field, folder)?)
398            })
399        });
400
401        // If input type is `Foo<A, B>`, output `Foo<<A as Fold>::Output, <B as Fold>::Output>`
402        let out_params = s
403            .referenced_ty_params()
404            .into_iter()
405            .map(|ty| quote!(<#ty as #Fold<L, Out>>::Output));
406
407        s.gen_impl(quote! {
408            extern crate std;
409            gen impl<L: #Logic, Out> #SuperFold<L, Out> for @Self #where_clause {
410                type Output = #name<#(#out_params),*>;
411
412                fn super_fold_with<F, M>(
413                    self,
414                    folder: &mut F,
415                ) -> std::result::Result<Self::Output, F::Error>
416                where
417                    F: #Folder<L, M, Output = Out>,
418                {
419                    Ok(match self { #body })
420                }
421            }
422        })
423    };
424
425    quote! {
426        #impl_fold
427        #impl_super_fold
428    }
429}
430
431/// For a given operator `O`, derives `From<O>` for `Op`.
432fn derive_from(s: &synstructure::Structure) -> TokenStream {
433    let smt_ir = smt_ir_crate_path();
434    let name = &s.ast().ident;
435    let params = s.referenced_ty_params();
436    let ty = quote!(#name<#(#params),*>);
437
438    #[allow(non_snake_case)]
439    let (From, Into, Logic, IOp, Term) = (
440        quote!(std::convert::From),
441        quote!(std::convert::Into),
442        quote!(#smt_ir::Logic),
443        quote!(#smt_ir::IOp),
444        quote!(#smt_ir::Term),
445    );
446
447    if has_core_op_attr(s.ast()) {
448        quote! {
449            // impl<#(#params,)* L: #Logic> #From<#ty> for #Term<L> {
450            //     fn from(op: #ty) -> Self {
451            //         Self::CoreOp(op)
452            //     }
453            // }
454        }
455    } else {
456        quote! {
457            impl<#(#params,)* L: #Logic> #From<#ty> for #IOp<L>
458            where
459                #ty: #Into<L::Op>,
460            {
461                fn from(op: #ty) -> Self {
462                    #IOp::new(op.into())
463                }
464            }
465            impl<#(#params,)* L: #Logic> #From<#ty> for #Term<L>
466            where
467                #ty: #Into<L::Op>,
468            {
469                fn from(op: #ty) -> Self {
470                    let op: L::Op = op.into();
471                    Self::OtherOp(op.into())
472                }
473            }
474        }
475    }
476}
477
478/// Derives `Visit` and `SuperVisit` for an operator enum e.g. `Foo<Term>`.
479fn derive_visit(mut s: synstructure::Structure) -> TokenStream {
480    let smt_ir = smt_ir_crate_path();
481
482    #[allow(non_snake_case)]
483    let (Logic, Visit, SuperVisit, Visitor, ControlFlow) = (
484        quote!(#smt_ir::Logic),
485        quote!(#smt_ir::visit::Visit),
486        quote!(#smt_ir::visit::SuperVisit),
487        quote!(#smt_ir::visit::Visitor),
488        quote!(#smt_ir::visit::ControlFlow),
489    );
490
491    s.add_bounds(synstructure::AddBounds::None);
492
493    let impl_super_visit = {
494        // Bound each field to implement `Visit`
495        let mut where_clause = None;
496        s.add_trait_bounds(
497            &parse_quote!(#Visit<L>),
498            &mut where_clause,
499            synstructure::AddBounds::Fields,
500        );
501
502        // For each variant, visit each of the fields
503        let body = s.each(|field| quote!(#smt_ir::try_break!(#Visit::visit_with(#field, visitor))));
504
505        s.gen_impl(quote! {
506            gen impl<L: #Logic> #SuperVisit<L> for @Self #where_clause {
507                fn super_visit_with<V: #Visitor<L>>(
508                    &self,
509                    visitor: &mut V,
510                ) -> #ControlFlow<V::BreakTy> {
511                    match self { #body }
512                    #ControlFlow::Continue(())
513                }
514            }
515        })
516    };
517
518    quote! {
519        #impl_super_visit
520    }
521}