1use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use std::collections::HashSet;
6use syn::{parse_quote, DeriveInput};
7use synstructure::decl_derive;
8
9decl_derive!([Fold] => derive_fold);
10decl_derive!([Visit] => derive_visit);
11decl_derive!([Operation, attributes(core_op, symbol)] => derive_operation_all);
12
13fn smt_ir_crate_path() -> syn::Path {
15 match proc_macro_crate::crate_name("aws-smt-ir").expect("must depend on aws-smt-ir") {
16 proc_macro_crate::FoundCrate::Itself => parse_quote!(aws_smt_ir),
17 proc_macro_crate::FoundCrate::Name(name) => {
18 let name = syn::Ident::new(&name, proc_macro2::Span::call_site());
19 parse_quote!(#name)
20 }
21 }
22}
23
24fn has_core_op_attr(input: &DeriveInput) -> bool {
26 input.attrs.iter().any(|a| a.path().is_ident("core_op"))
27}
28
29fn variant_symbol(variant: &synstructure::VariantInfo) -> syn::LitStr {
32 let ast = variant.ast();
33 (ast.attrs.iter())
34 .find(|attr| attr.path().is_ident("symbol"))
35 .and_then(|attr| attr.parse_args().ok())
36 .unwrap_or_else(|| {
37 let name = ast.ident.to_string().to_lowercase();
38 parse_quote!(#name)
39 })
40}
41
42fn derive_operation_all(mut s: synstructure::Structure) -> TokenStream {
45 s.add_bounds(synstructure::AddBounds::None)
46 .bind_with(|_| synstructure::BindStyle::Move);
47 let debug = derive_fmt_any(&s, parse_quote!(std::fmt::Debug));
48 let display = derive_fmt_any(&s, parse_quote!(std::fmt::Display));
49 let from = derive_from(&s);
50 let operation = derive_operation(&s);
51 let iterate = derive_iterate(&s);
52 quote! {
53 #debug
54 #display
55 #from
56 #operation
57 #iterate
58 }
59}
60
61fn index_array(ty: &syn::Type) -> Option<&syn::Expr> {
62 match ty {
63 syn::Type::Array(syn::TypeArray { len, .. }) => Some(len),
64 _ => None,
65 }
66}
67
68fn derive_operation(s: &synstructure::Structure) -> TokenStream {
70 let smt_ir = smt_ir_crate_path();
71
72 #[allow(non_snake_case)]
73 let (Term, Logic, Operation, Parse, NumArgs, InvalidOp, QualIdentifier, IIndex, TryFrom, Vec) = (
74 quote!(#smt_ir::Term),
75 quote!(#smt_ir::Logic),
76 quote!(#smt_ir::term::Operation),
77 quote!(#smt_ir::term::args::Parse),
78 quote!(#smt_ir::term::args::NumArgs),
79 quote!(#smt_ir::term::InvalidOp),
80 quote!(#smt_ir::QualIdentifier),
81 quote!(#smt_ir::IIndex),
82 quote!(std::convert::TryFrom),
83 quote!(std::vec::Vec),
84 );
85
86 let mut bindings = vec![];
87
88 let parse_match_arms: Vec<_> = (s.variants().iter())
89 .enumerate()
90 .map(|(idx, variant)| {
91 let symbol = variant_symbol(variant);
92 let mut num_indices = None;
93 let mut min_args = vec![];
94 let mut max_args = vec![];
95
96 let constructed = variant.construct(|field, _| {
98 let ty = &field.ty;
99
100 if let Some(len) = index_array(ty) {
102 num_indices = Some(len.clone());
103 quote! {{
104 let indices: std::vec::Vec<_> = func.indices().iter().map(#IIndex::from).collect();
105 #TryFrom::try_from(indices).unwrap()
106 }}
107 } else {
108 min_args.push(quote!(<#ty as #NumArgs>::MIN_ARGS));
109 max_args.push(quote!(<#ty as #NumArgs>::MAX_ARGS));
110 quote!(#Parse::from_iter(&mut iter).unwrap())
111 }
112 });
113
114 let min_args = quote!((0 #(+ #min_args)*));
115 let max_args = quote!((0 #(+ #max_args)*));
116 let num_indices = num_indices.unwrap_or_else(|| parse_quote!(0));
117 let min_args_ident = syn::Ident::new(&format!("MIN_ARGS_{}", idx), Span::call_site());
118 let max_args_ident = syn::Ident::new(&format!("MAX_ARGS_{}", idx), Span::call_site());
119 let num_indices_ident = syn::Ident::new(&format!("INDICES_{}", idx), Span::call_site());
120 bindings.push(quote!(let #min_args_ident = #min_args;));
121 bindings.push(quote!(let #max_args_ident = #max_args;));
122 bindings.push(quote!(let #num_indices_ident = #num_indices;));
123
124 quote! {
127 (#symbol, num_args) if func.indices().len() == #num_indices_ident && (#min_args_ident..=#max_args_ident).contains(&num_args) => {
128 let mut iter = args.into_iter();
129 #constructed
130 }
131 }
132 })
133 .collect();
134
135 let parse_fn = quote! {
136 fn parse(func: #QualIdentifier, args: #Vec<#Term<L>>) -> std::result::Result<Self, #InvalidOp<L>> {
137 #(#bindings)*
138 #[deny(unreachable_patterns)]
139 Ok(match (func.sym_str(), args.len()) {
140 #(#parse_match_arms)*
141 _ => return Err(#InvalidOp { func, args })
142 })
143 }
144 };
145
146 let func_match_arms = s.each_variant(|variant| {
147 let symbol = variant_symbol(variant);
148 quote!(#symbol.into())
149 });
150
151 let func_fn = quote! {
152 fn func(&self) -> #smt_ir::ISymbol {
153 match self {
154 #func_match_arms
155 }
156 }
157 };
158
159 let mut where_clause = None;
160 s.add_trait_bounds(
161 &parse_quote!(#Parse<L>),
162 &mut where_clause,
163 synstructure::AddBounds::Fields,
164 );
165 if has_core_op_attr(s.ast()) {
166 s.gen_impl(quote! {
167 gen impl<L: #Logic> #Operation<L> for @Self
168 #where_clause,
169 <L as #Logic>::Op: #Operation<L>,
170 {
171 #parse_fn
172 #func_fn
173 }
174 })
175 } else {
176 s.gen_impl(quote! {
177 gen impl<L: #Logic> #Operation<L> for @Self #where_clause {
178 #parse_fn
179 #func_fn
180 }
181 })
182 }
183}
184
185fn bound_argument_fields(
186 s: &synstructure::Structure,
187 clause: &mut syn::WhereClause,
188 mut bound: impl FnMut(&syn::Type) -> syn::WherePredicate,
189) {
190 let mut seen = HashSet::new();
191
192 for variant in s.variants() {
193 for binding in variant.bindings() {
194 let ty = &binding.ast().ty;
195 if seen.insert(ty) && index_array(ty).is_none() {
196 clause.predicates.push(bound(ty));
197 }
198 }
199 }
200}
201
202fn derive_iterate(s: &synstructure::Structure) -> TokenStream {
204 let smt_ir = smt_ir_crate_path();
205
206 #[allow(non_snake_case)]
207 let (Term, Logic, Iterate, Args) = (
208 quote!(#smt_ir::Term),
209 quote!(#smt_ir::Logic),
210 quote!(#smt_ir::term::args::Iterate),
211 quote!(#smt_ir::term::args::Arguments),
212 );
213
214 fn argument_iter_branches(
215 s: &synstructure::Structure,
216 mut iterate: impl FnMut(&synstructure::BindingInfo) -> TokenStream,
217 ) -> TokenStream {
218 s.each_variant(|v| {
219 let mut bindings = (v.bindings().iter())
220 .skip_while(|field| index_array(&field.ast().ty).is_some())
221 .map(&mut iterate);
222 let mut iter = bindings
223 .next()
224 .unwrap_or_else(|| quote!(std::iter::empty()));
225 for new in bindings {
226 iter = quote!(#iter.chain(#new))
227 }
228 quote!(std::boxed::Box::new(#iter))
230 })
231 }
232
233 let mut where_clause = syn::WhereClause {
234 where_token: Default::default(),
235 predicates: Default::default(),
236 };
237
238 bound_argument_fields(
239 s,
240 &mut where_clause,
241 |ty| parse_quote!(#ty: #Iterate<'a, L>),
242 );
243
244 let args_branches = argument_iter_branches(s, |field| quote!(#Iterate::<L>::terms(#field)));
245 let into_args_branches =
246 argument_iter_branches(s, |field| quote!(#Iterate::<L>::into_terms(#field)));
247
248 s.gen_impl(quote! {
249 gen impl<'a, L: #Logic> #Iterate<'a, L> for @Self
250 #where_clause
251 {
252 type Terms = std::boxed::Box<dyn std::iter::Iterator<Item = &'a #Term<L>> + 'a>;
253 type IntoTerms = std::boxed::Box<dyn std::iter::Iterator<Item = #Term<L>> + 'a>;
254
255 fn terms(&'a self) -> Self::Terms {
256 match self {
257 #args_branches
258 }
259 }
260
261 fn into_terms(self) -> Self::IntoTerms {
262 match self {
263 #into_args_branches
264 }
265 }
266 }
267
268 gen impl<'a, L: #Logic> #Args<'a, L> for @Self #where_clause {}
269 })
270}
271
272fn derive_fmt_any(s: &synstructure::Structure, trait_path: syn::Path) -> TokenStream {
275 let smt_ir = smt_ir_crate_path();
276
277 #[allow(non_snake_case)]
278 let Format = quote!(#smt_ir::term::args::Format);
279
280 let fmt_body = s.each_variant(|variant| {
281 let symbol = variant_symbol(variant);
282 let bindings = variant.bindings();
283 if bindings.is_empty() {
284 quote!(std::write!(f, #symbol))
285 } else {
286 let mut fmt_indices = None;
287 let fmt_fields: Vec<_> = bindings
288 .iter()
289 .filter_map(|field| {
290 if index_array(&field.ast().ty).is_some() {
291 fmt_indices = Some(quote! {
292 for index in #field {
293 std::write!(f, " {}", index)?;
294 }
295 });
296 None
297 } else {
298 Some(quote! {
299 std::write!(f, " ")?;
300 #Format::fmt(#field, f, #trait_path::fmt)
301 })
302 }
303 })
304 .collect();
305 let fmt_func = if let Some(fmt_indices) = fmt_indices {
306 quote! {
307 std::write!(f, "(_ {}", #symbol)?;
308 #fmt_indices
309 std::write!(f, ")")
310 }
311 } else {
312 quote!(std::write!(f, #symbol))
313 };
314 quote! {
315 std::write!(f, "(")?;
316 #fmt_func?;
317 #(#fmt_fields?;)*
318 std::write!(f, ")")
319 }
320 }
321 });
322
323 let mut where_clause = None;
324 s.add_trait_bounds(
325 &parse_quote!(std::fmt::Debug),
326 &mut where_clause,
327 synstructure::AddBounds::Generics,
328 );
329 s.add_trait_bounds(
330 &parse_quote!(std::fmt::Display),
331 &mut where_clause,
332 synstructure::AddBounds::Generics,
333 );
334 s.add_trait_bounds(
335 &parse_quote!(#Format),
336 &mut where_clause,
337 synstructure::AddBounds::Generics,
338 );
339 s.gen_impl(quote! {
340 extern crate std;
341 gen impl #trait_path for @Self #where_clause {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 match self {
344 #fmt_body
345 }
346 }
347 }
348 })
349}
350
351fn derive_fold(mut s: synstructure::Structure) -> TokenStream {
353 let smt_ir = smt_ir_crate_path();
354 let name = &s.ast().ident; #[allow(non_snake_case)]
357 let (Logic, Fold, SuperFold, Folder) = (
358 quote!(#smt_ir::Logic),
359 quote!(#smt_ir::fold::Fold),
360 quote!(#smt_ir::fold::SuperFold),
361 quote!(#smt_ir::fold::Folder),
362 );
363
364 s.add_bounds(synstructure::AddBounds::None)
365 .bind_with(|_| synstructure::BindStyle::Move);
366
367 let impl_fold = s.gen_impl(quote! {
368 extern crate std;
369 gen impl<L: #Logic<Op = Self>, Out> #Fold<L, Out> for @Self {
370 type Output = Out;
371
372 fn fold_with<F, M>(
373 self,
374 folder: &mut F,
375 ) -> std::result::Result<Self::Output, F::Error>
376 where
377 F: #Folder<L, M, Output = Out>,
378 {
379 folder.fold_theory_op(self.into())
380 }
381 }
382 });
383
384 let impl_super_fold = {
385 let mut where_clause = None;
387 s.add_trait_bounds(
388 &parse_quote!(#Fold<L, Out>),
389 &mut where_clause,
390 synstructure::AddBounds::Generics,
391 );
392
393 let body = s.each_variant(|vi| {
395 vi.construct(|_, idx| {
396 let field = &vi.bindings()[idx];
397 quote!(#Fold::fold_with(#field, folder)?)
398 })
399 });
400
401 let out_params = s
403 .referenced_ty_params()
404 .into_iter()
405 .map(|ty| quote!(<#ty as #Fold<L, Out>>::Output));
406
407 s.gen_impl(quote! {
408 extern crate std;
409 gen impl<L: #Logic, Out> #SuperFold<L, Out> for @Self #where_clause {
410 type Output = #name<#(#out_params),*>;
411
412 fn super_fold_with<F, M>(
413 self,
414 folder: &mut F,
415 ) -> std::result::Result<Self::Output, F::Error>
416 where
417 F: #Folder<L, M, Output = Out>,
418 {
419 Ok(match self { #body })
420 }
421 }
422 })
423 };
424
425 quote! {
426 #impl_fold
427 #impl_super_fold
428 }
429}
430
431fn derive_from(s: &synstructure::Structure) -> TokenStream {
433 let smt_ir = smt_ir_crate_path();
434 let name = &s.ast().ident;
435 let params = s.referenced_ty_params();
436 let ty = quote!(#name<#(#params),*>);
437
438 #[allow(non_snake_case)]
439 let (From, Into, Logic, IOp, Term) = (
440 quote!(std::convert::From),
441 quote!(std::convert::Into),
442 quote!(#smt_ir::Logic),
443 quote!(#smt_ir::IOp),
444 quote!(#smt_ir::Term),
445 );
446
447 if has_core_op_attr(s.ast()) {
448 quote! {
449 }
455 } else {
456 quote! {
457 impl<#(#params,)* L: #Logic> #From<#ty> for #IOp<L>
458 where
459 #ty: #Into<L::Op>,
460 {
461 fn from(op: #ty) -> Self {
462 #IOp::new(op.into())
463 }
464 }
465 impl<#(#params,)* L: #Logic> #From<#ty> for #Term<L>
466 where
467 #ty: #Into<L::Op>,
468 {
469 fn from(op: #ty) -> Self {
470 let op: L::Op = op.into();
471 Self::OtherOp(op.into())
472 }
473 }
474 }
475 }
476}
477
478fn derive_visit(mut s: synstructure::Structure) -> TokenStream {
480 let smt_ir = smt_ir_crate_path();
481
482 #[allow(non_snake_case)]
483 let (Logic, Visit, SuperVisit, Visitor, ControlFlow) = (
484 quote!(#smt_ir::Logic),
485 quote!(#smt_ir::visit::Visit),
486 quote!(#smt_ir::visit::SuperVisit),
487 quote!(#smt_ir::visit::Visitor),
488 quote!(#smt_ir::visit::ControlFlow),
489 );
490
491 s.add_bounds(synstructure::AddBounds::None);
492
493 let impl_super_visit = {
494 let mut where_clause = None;
496 s.add_trait_bounds(
497 &parse_quote!(#Visit<L>),
498 &mut where_clause,
499 synstructure::AddBounds::Fields,
500 );
501
502 let body = s.each(|field| quote!(#smt_ir::try_break!(#Visit::visit_with(#field, visitor))));
504
505 s.gen_impl(quote! {
506 gen impl<L: #Logic> #SuperVisit<L> for @Self #where_clause {
507 fn super_visit_with<V: #Visitor<L>>(
508 &self,
509 visitor: &mut V,
510 ) -> #ControlFlow<V::BreakTy> {
511 match self { #body }
512 #ControlFlow::Continue(())
513 }
514 }
515 })
516 };
517
518 quote! {
519 #impl_super_visit
520 }
521}