walrus_macro/
lib.rs

1#![recursion_limit = "256"]
2
3extern crate proc_macro;
4
5use self::proc_macro::TokenStream;
6use heck::ToSnakeCase;
7use proc_macro2::Span;
8use quote::quote;
9use syn::ext::IdentExt;
10use syn::parse::{Parse, ParseStream};
11use syn::punctuated::Punctuated;
12use syn::DeriveInput;
13use syn::Error;
14use syn::{parse_macro_input, Ident, Result, Token};
15
16#[proc_macro_attribute]
17pub fn walrus_instr(_attr: TokenStream, input: TokenStream) -> TokenStream {
18    let input = parse_macro_input!(input as DeriveInput);
19
20    let variants = match get_enum_variants(&input) {
21        Ok(v) => v,
22        Err(e) => return e.to_compile_error().into(),
23    };
24
25    assert_eq!(input.ident.to_string(), "Instr");
26
27    let types = create_types(&input.attrs, &variants);
28    let visit = create_visit(&variants);
29    let builder = create_builder(&variants);
30
31    let expanded = quote! {
32        #types
33        #visit
34        #builder
35    };
36
37    TokenStream::from(expanded)
38}
39
40struct WalrusVariant {
41    syn: syn::Variant,
42    fields: Vec<WalrusFieldOpts>,
43    opts: WalrusVariantOpts,
44}
45
46#[derive(Default)]
47struct WalrusVariantOpts {
48    display_name: Option<syn::Ident>,
49    display_extra: Option<syn::Ident>,
50    skip_builder: bool,
51}
52
53#[derive(Default)]
54struct WalrusFieldOpts {
55    skip_visit: bool,
56}
57
58fn get_enum_variants(input: &DeriveInput) -> Result<Vec<WalrusVariant>> {
59    let en = match &input.data {
60        syn::Data::Enum(en) => en,
61        syn::Data::Struct(_) => {
62            panic!("can only put #[walrus_instr] on an enum; found it on a struct")
63        }
64        syn::Data::Union(_) => {
65            panic!("can only put #[walrus_instr] on an enum; found it on a union")
66        }
67    };
68    en.variants
69        .iter()
70        .cloned()
71        .map(|mut variant| {
72            Ok(WalrusVariant {
73                opts: syn::parse(walrus_attrs(&mut variant.attrs))?,
74                fields: variant
75                    .fields
76                    .iter_mut()
77                    .map(|field| syn::parse(walrus_attrs(&mut field.attrs)))
78                    .collect::<Result<_>>()?,
79                syn: variant,
80            })
81        })
82        .collect()
83}
84
85impl Parse for WalrusFieldOpts {
86    fn parse(input: ParseStream) -> Result<Self> {
87        enum Attr {
88            SkipVisit,
89        }
90
91        let attrs = Punctuated::<_, syn::token::Comma>::parse_terminated(input)?;
92        let mut ret = WalrusFieldOpts::default();
93        for attr in attrs {
94            match attr {
95                Attr::SkipVisit => ret.skip_visit = true,
96            }
97        }
98        return Ok(ret);
99
100        impl Parse for Attr {
101            fn parse(input: ParseStream) -> Result<Self> {
102                let attr: Ident = input.parse()?;
103                if attr == "skip_visit" {
104                    return Ok(Attr::SkipVisit);
105                }
106                Err(Error::new(attr.span(), "unexpected attribute"))
107            }
108        }
109    }
110}
111
112impl Parse for WalrusVariantOpts {
113    fn parse(input: ParseStream) -> Result<Self> {
114        enum Attr {
115            DisplayName(syn::Ident),
116            DisplayExtra(syn::Ident),
117            SkipBuilder,
118        }
119
120        let attrs = Punctuated::<_, syn::token::Comma>::parse_terminated(input)?;
121        let mut ret = WalrusVariantOpts::default();
122        for attr in attrs {
123            match attr {
124                Attr::DisplayName(ident) => ret.display_name = Some(ident),
125                Attr::DisplayExtra(ident) => ret.display_extra = Some(ident),
126                Attr::SkipBuilder => ret.skip_builder = true,
127            }
128        }
129        return Ok(ret);
130
131        impl Parse for Attr {
132            fn parse(input: ParseStream) -> Result<Self> {
133                let attr: Ident = input.parse()?;
134                if attr == "display_name" {
135                    input.parse::<Token![=]>()?;
136                    let name = input.call(Ident::parse_any)?;
137                    return Ok(Attr::DisplayName(name));
138                }
139                if attr == "display_extra" {
140                    input.parse::<Token![=]>()?;
141                    let name = input.call(Ident::parse_any)?;
142                    return Ok(Attr::DisplayExtra(name));
143                }
144                if attr == "skip_builder" {
145                    return Ok(Attr::SkipBuilder);
146                }
147                Err(Error::new(attr.span(), "unexpected attribute"))
148            }
149        }
150    }
151}
152
153fn walrus_attrs(attrs: &mut Vec<syn::Attribute>) -> TokenStream {
154    let mut ret = proc_macro2::TokenStream::new();
155    let ident = syn::Path::from(syn::Ident::new("walrus", Span::call_site()));
156    for i in (0..attrs.len()).rev() {
157        if attrs[i].path() != &ident {
158            continue;
159        }
160        let attr = attrs.remove(i);
161        let group = if let syn::Meta::List(syn::MetaList { tokens, .. }) = attr.meta {
162            tokens
163        } else {
164            panic!("#[walrus(...)] expected")
165        };
166        ret.extend(group);
167        ret.extend(quote! { , });
168    }
169    ret.into()
170}
171
172fn create_types(attrs: &[syn::Attribute], variants: &[WalrusVariant]) -> impl quote::ToTokens {
173    let types: Vec<_> = variants
174        .iter()
175        .map(|v| {
176            let name = &v.syn.ident;
177            let attrs = &v.syn.attrs;
178            let fields = v.syn.fields.iter().map(|f| {
179                let name = &f.ident;
180                let attrs = &f.attrs;
181                let ty = &f.ty;
182                quote! {
183                    #( #attrs )*
184                    pub #name : #ty,
185                }
186            });
187            quote! {
188                #( #attrs )*
189                #[derive(Clone, Debug)]
190                pub struct #name {
191                    #( #fields )*
192                }
193
194                impl From<#name> for Instr {
195                    #[inline]
196                    fn from(x: #name) -> Instr {
197                        Instr::#name(x)
198                    }
199                }
200            }
201        })
202        .collect();
203
204    let methods: Vec<_> = variants
205        .iter()
206        .map(|v| {
207            let name = &v.syn.ident;
208            let snake_name = name.to_string().to_snake_case();
209
210            let is_name = format!("is_{}", snake_name);
211            let is_name = syn::Ident::new(&is_name, Span::call_site());
212
213            let ref_name = format!("{}_ref", snake_name);
214            let ref_name = syn::Ident::new(&ref_name, Span::call_site());
215
216            let mut_name = format!("{}_mut", snake_name);
217            let mut_name = syn::Ident::new(&mut_name, Span::call_site());
218
219            let unwrap_name = format!("unwrap_{}", snake_name);
220            let unwrap_name = syn::Ident::new(&unwrap_name, Span::call_site());
221
222            let unwrap_mut_name = format!("unwrap_{}_mut", snake_name);
223            let unwrap_mut_name = syn::Ident::new(&unwrap_mut_name, Span::call_site());
224
225            let ref_name_doc = format!(
226                "
227                If this instruction is a `{}`, get a shared reference to it.
228
229                Returns `None` otherwise.
230            ",
231                name
232            );
233
234            let mut_name_doc = format!(
235                "
236                If this instruction is a `{}`, get an exclusive reference to it.
237
238                Returns `None` otherwise.
239            ",
240                name
241            );
242
243            let is_name_doc = format!("Is this instruction a `{}`?", name);
244
245            let unwrap_name_doc = format!(
246                "
247                Get a shared reference to the underlying `{}`.
248
249                Panics if this instruction is not a `{}`.
250            ",
251                name, name
252            );
253
254            let unwrap_mut_name_doc = format!(
255                "
256                Get an exclusive reference to the underlying `{}`.
257
258                Panics if this instruction is not a `{}`.
259            ",
260                name, name
261            );
262
263            quote! {
264                #[doc=#ref_name_doc]
265                #[inline]
266                fn #ref_name(&self) -> Option<&#name> {
267                    if let Instr::#name(ref x) = *self {
268                        Some(x)
269                    } else {
270                        None
271                    }
272                }
273
274                #[doc=#mut_name_doc]
275                #[inline]
276                pub fn #mut_name(&mut self) -> Option<&mut #name> {
277                    if let Instr::#name(ref mut x) = *self {
278                        Some(x)
279                    } else {
280                        None
281                    }
282                }
283
284                #[doc=#is_name_doc]
285                #[inline]
286                pub fn #is_name(&self) -> bool {
287                    self.#ref_name().is_some()
288                }
289
290                #[doc=#unwrap_name_doc]
291                #[inline]
292                pub fn #unwrap_name(&self) -> &#name {
293                    self.#ref_name().unwrap()
294                }
295
296                #[doc=#unwrap_mut_name_doc]
297                #[inline]
298                pub fn #unwrap_mut_name(&mut self) -> &mut #name {
299                    self.#mut_name().unwrap()
300                }
301            }
302        })
303        .collect();
304
305    let variants: Vec<_> = variants
306        .iter()
307        .map(|v| {
308            let name = &v.syn.ident;
309            let attrs = &v.syn.attrs;
310            quote! {
311                #( #attrs )*
312                #name(#name)
313            }
314        })
315        .collect();
316
317    quote! {
318        #( #types )*
319
320        #( #attrs )*
321        pub enum Instr {
322            #(#variants),*
323        }
324
325        impl Instr {
326            #( #methods )*
327        }
328    }
329}
330
331fn visit_fields(
332    variant: &WalrusVariant,
333    allow_skip: bool,
334) -> impl Iterator<Item = (syn::Ident, proc_macro2::TokenStream, bool)> + '_ {
335    return variant
336        .syn
337        .fields
338        .iter()
339        .zip(&variant.fields)
340        .enumerate()
341        .filter(move |(_, (_, info))| !allow_skip || !info.skip_visit)
342        .map(move |(i, (field, _info))| {
343            let field_name = match &field.ident {
344                Some(name) => quote! { #name },
345                None => quote! { #i },
346            };
347            let (ty_name, list) = extract_name_and_if_list(&field.ty);
348            let mut method_name = "visit_".to_string();
349            method_name.push_str(&ty_name.to_string().to_snake_case());
350            let method_name = syn::Ident::new(&method_name, Span::call_site());
351            (method_name, field_name, list)
352        });
353
354    fn extract_name_and_if_list(ty: &syn::Type) -> (&syn::Ident, bool) {
355        let path = match ty {
356            syn::Type::Path(p) => &p.path,
357            _ => panic!("field types must be paths"),
358        };
359        let segment = path.segments.last().unwrap();
360        let args = match &segment.arguments {
361            syn::PathArguments::None => return (&segment.ident, false),
362            syn::PathArguments::AngleBracketed(a) => &a.args,
363            _ => panic!("invalid path in #[walrus_instr]"),
364        };
365        let mut ty = match args.first().unwrap() {
366            syn::GenericArgument::Type(ty) => ty,
367            _ => panic!("invalid path in #[walrus_instr]"),
368        };
369        if let syn::Type::Slice(t) = ty {
370            ty = &t.elem;
371        }
372        match ty {
373            syn::Type::Path(p) => {
374                let segment = p.path.segments.last().unwrap();
375                (&segment.ident, true)
376            }
377            _ => panic!("invalid path in #[walrus_instr]"),
378        }
379    }
380}
381
382fn create_visit(variants: &[WalrusVariant]) -> impl quote::ToTokens {
383    let mut visit_impls = Vec::new();
384    let mut visitor_trait_methods = Vec::new();
385    let mut visitor_mut_trait_methods = Vec::new();
386    let mut visit_impl = Vec::new();
387    let mut visit_mut_impl = Vec::new();
388
389    for variant in variants {
390        let name = &variant.syn.ident;
391
392        let mut method_name = "visit_".to_string();
393        method_name.push_str(&name.to_string().to_snake_case());
394        let method_name = syn::Ident::new(&method_name, Span::call_site());
395        let method_name_mut = syn::Ident::new(&format!("{}_mut", method_name), Span::call_site());
396
397        let recurse_fields = visit_fields(variant, true).map(|(method_name, field_name, list)| {
398            if list {
399                quote! {
400                    for item in self.#field_name.iter() {
401                        visitor.#method_name(item);
402                    }
403                }
404            } else {
405                quote! {
406                    visitor.#method_name(&self.#field_name);
407                }
408            }
409        });
410        let recurse_fields_mut =
411            visit_fields(variant, true).map(|(method_name, field_name, list)| {
412                let name = format!("{}_mut", method_name);
413                let method_name = syn::Ident::new(&name, Span::call_site());
414                if list {
415                    quote! {
416                        for item in self.#field_name.iter_mut() {
417                            visitor.#method_name(item);
418                        }
419                    }
420                } else {
421                    quote! {
422                        visitor.#method_name(&mut self.#field_name);
423                    }
424                }
425            });
426
427        visit_impls.push(quote! {
428            impl<'instr> Visit<'instr> for #name {
429                #[inline]
430                fn visit<V: Visitor<'instr>>(&self, visitor: &mut V) {
431                    #(#recurse_fields);*
432                }
433            }
434            impl VisitMut for #name {
435                #[inline]
436                fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) {
437                    #(#recurse_fields_mut);*
438                }
439            }
440        });
441
442        let doc = format!("Visit `{}`.", name);
443        visitor_trait_methods.push(quote! {
444            #[doc=#doc]
445            #[inline]
446            fn #method_name(&mut self, instr: &#name) {
447                // ...
448            }
449        });
450        visitor_mut_trait_methods.push(quote! {
451            #[doc=#doc]
452            #[inline]
453            fn #method_name_mut(&mut self, instr: &mut #name) {
454                instr.visit_mut(self);
455            }
456        });
457
458        let mut method_name = "visit_".to_string();
459        method_name.push_str(&name.to_string().to_snake_case());
460        let method_name = syn::Ident::new(&method_name, Span::call_site());
461        visit_impl.push(quote! {
462            Instr::#name(e) => {
463                visitor.#method_name(e);
464                e.visit(visitor);
465            }
466        });
467        visit_mut_impl.push(quote! {
468            Instr::#name(e) => {
469                visitor.#method_name_mut(e);
470                e.visit_mut(visitor);
471            }
472        });
473    }
474
475    quote! {
476        /// A visitor is a set of callbacks that are called when a traversal
477        /// (such as `dfs_in_order`) is walking an instruction tree.
478        ///
479        /// ## Recursion
480        ///
481        /// Do *not* recursively get nested `InstrSeq`s for any `InstrSeqId` you
482        /// visit! You *will* blow the stack when processing large Wasm
483        /// files. `Visitor`s are _just_ heterogenously-typed callbacks, _not_
484        /// traversals themselves!
485        ///
486        /// Instead, use `walrus::ir::dfs_in_order` and other traversal drivers
487        /// that will walk the tree in a non-recursive, iterative fashion
488        ///
489        /// # Provided Methods
490        ///
491        /// Every `Visitor` trait method has a default, provided implementation
492        /// that does nothing.
493        pub trait Visitor<'instr>: Sized {
494            /// Called before the traversal will start visiting each of the
495            /// instructions an instruction sequence.
496            ///
497            /// The order in which instruction sequences are visited is defined
498            /// by the traversal function, e.g. `walrus::ir::dfs_in_order`.
499            #[inline]
500            fn start_instr_seq(&mut self, instr_seq: &'instr InstrSeq) {
501                // ...
502            }
503
504            /// Called after the traversal finishes visiting each of the
505            /// instructions in an instruction sequence.
506            #[inline]
507            fn end_instr_seq(&mut self, instr_seq: &'instr InstrSeq) {
508                // ...
509            }
510
511            /// Visit `Instr`.
512            #[inline]
513            fn visit_instr(&mut self, instr: &'instr Instr, instr_loc: &'instr InstrLocId) {
514                // ...
515            }
516
517            /// Visit `InstrSeqId`.
518            #[inline]
519            fn visit_instr_seq_id(&mut self, instr_seq_id: &InstrSeqId) {
520                // ...
521            }
522
523            /// Visit `LocalId`.
524            #[inline]
525            fn visit_local_id(&mut self, local: &crate::LocalId) {
526                // ...
527            }
528
529            /// Visit `MemoryId`.
530            #[inline]
531            fn visit_memory_id(&mut self, memory: &crate::MemoryId) {
532                // ...
533            }
534
535            /// Visit `TableId`.
536            #[inline]
537            fn visit_table_id(&mut self, table: &crate::TableId) {
538                // ...
539            }
540
541            /// Visit `GlobalId`.
542            #[inline]
543            fn visit_global_id(&mut self, global: &crate::GlobalId) {
544                // ...
545            }
546
547            /// Visit `FunctionId`.
548            #[inline]
549            fn visit_function_id(&mut self, function: &crate::FunctionId) {
550                // ...
551            }
552
553            /// Visit `DataId`.
554            #[inline]
555            fn visit_data_id(&mut self, function: &crate::DataId) {
556                // ...
557            }
558
559            /// Visit `TypeId`
560            #[inline]
561            fn visit_type_id(&mut self, ty: &crate::TypeId) {
562                // ...
563            }
564
565            /// Visit `ElementId`
566            #[inline]
567            fn visit_element_id(&mut self, elem: &crate::ElementId) {
568                // ...
569            }
570
571            /// Visit `TagId`
572            #[inline]
573            fn visit_tag_id(&mut self, tag: &crate::TagId) {
574                // ...
575            }
576
577            /// Visit `Value`.
578            #[inline]
579            fn visit_value(&mut self, value: &crate::ir::Value) {
580                // ...
581            }
582
583            #( #visitor_trait_methods )*
584        }
585
586        /// A mutable version of `Visitor`.
587        ///
588        /// See `Visitor`'s documentation for details.
589        pub trait VisitorMut: Sized {
590            /// Called before the traversal will start visiting each of the
591            /// instructions an instruction sequence.
592            ///
593            /// The order in which instruction sequences are visited is defined
594            /// by the traversal function, e.g. `walrus::ir::dfs_pre_order_mut`.
595            #[inline]
596            fn start_instr_seq_mut(&mut self, instr_seq: &mut InstrSeq) {
597                // ...
598            }
599
600            /// Called after the traversal finishes visiting each of the
601            /// instructions in an instruction sequence.
602            #[inline]
603            fn end_instr_seq_mut(&mut self, instr_seq: &mut InstrSeq) {
604                // ...
605            }
606
607            /// Visit `Instr`.
608            #[inline]
609            fn visit_instr_mut(&mut self, instr: &mut Instr, instr_loc: &mut InstrLocId) {
610                // ...
611            }
612
613            /// Visit `InstrSeqId`.
614            #[inline]
615            fn visit_instr_seq_id_mut(&mut self, instr_seq_id: &mut InstrSeqId) {
616                // ...
617            }
618
619            /// Visit `Local`.
620            #[inline]
621            fn visit_local_id_mut(&mut self, local: &mut crate::LocalId) {
622                // ...
623            }
624
625            /// Visit `Memory`.
626            #[inline]
627            fn visit_memory_id_mut(&mut self, memory: &mut crate::MemoryId) {
628                // ...
629            }
630
631            /// Visit `Table`.
632            #[inline]
633            fn visit_table_id_mut(&mut self, table: &mut crate::TableId) {
634                // ...
635            }
636
637            /// Visit `GlobalId`.
638            #[inline]
639            fn visit_global_id_mut(&mut self, global: &mut crate::GlobalId) {
640                // ...
641            }
642
643            /// Visit `FunctionId`.
644            #[inline]
645            fn visit_function_id_mut(&mut self, function: &mut crate::FunctionId) {
646                // ...
647            }
648
649            /// Visit `DataId`.
650            #[inline]
651            fn visit_data_id_mut(&mut self, function: &mut crate::DataId) {
652                // ...
653            }
654
655            /// Visit `TypeId`
656            #[inline]
657            fn visit_type_id_mut(&mut self, ty: &mut crate::TypeId) {
658                // ...
659            }
660
661            /// Visit `ElementId`
662            #[inline]
663            fn visit_element_id_mut(&mut self, elem: &mut crate::ElementId) {
664                // ...
665            }
666
667            /// Visit `TagId`
668            #[inline]
669            fn visit_tag_id_mut(&mut self, tag: &mut crate::TagId) {
670                // ...
671            }
672
673            /// Visit `Value`.
674            #[inline]
675            fn visit_value_mut(&mut self, value: &mut crate::ir::Value) {
676                // ...
677            }
678
679            #( #visitor_mut_trait_methods )*
680        }
681
682        impl<'instr> Visit<'instr> for Instr {
683            #[inline]
684            fn visit<V>(&self, visitor: &mut V) where V: Visitor<'instr> {
685                match self {
686                    #( #visit_impl )*
687                }
688            }
689        }
690
691        impl VisitMut for Instr {
692            #[inline]
693            fn visit_mut<V>(&mut self, visitor: &mut V) where V: VisitorMut {
694                match self {
695                    #( #visit_mut_impl )*
696                }
697            }
698        }
699
700        #( #visit_impls )*
701    }
702}
703
704fn create_builder(variants: &[WalrusVariant]) -> impl quote::ToTokens {
705    let mut builder_methods = Vec::new();
706    for variant in variants {
707        if variant.opts.skip_builder {
708            continue;
709        }
710
711        let name = &variant.syn.ident;
712
713        let mut method_name = name.to_string().to_snake_case();
714
715        let mut method_name_at = method_name.clone();
716        method_name_at.push_str("_at");
717        let method_name_at = syn::Ident::new(&method_name_at, Span::call_site());
718
719        if method_name == "return" || method_name == "const" {
720            method_name.push('_');
721        } else if method_name == "block" {
722            continue;
723        }
724        let method_name = syn::Ident::new(&method_name, Span::call_site());
725
726        let mut args = Vec::new();
727        let mut arg_names = Vec::new();
728
729        for field in variant.syn.fields.iter() {
730            let name = field.ident.as_ref().expect("can't have unnamed fields");
731            arg_names.push(name);
732            let ty = &field.ty;
733            args.push(quote! { #name: #ty });
734        }
735
736        let doc = format!(
737            "Push a new `{}` instruction onto this builder's block.",
738            name
739        );
740        let at_doc = format!(
741            "Splice a new `{}` instruction into this builder's block at the given index.\n\n\
742             # Panics\n\n\
743             Panics if `position > self.instrs.len()`.",
744            name
745        );
746
747        let arg_names = &arg_names;
748        let args = &args;
749
750        builder_methods.push(quote! {
751            #[inline]
752            #[doc=#doc]
753            pub fn #method_name(&mut self, #(#args),*) -> &mut Self {
754                self.instr(#name { #(#arg_names),* })
755            }
756
757            #[inline]
758            #[doc=#at_doc]
759            pub fn #method_name_at(&mut self, position: usize, #(#args),*) -> &mut Self {
760                self.instr_at(position, #name { #(#arg_names),* })
761            }
762        });
763    }
764    quote! {
765        #[allow(missing_docs)]
766        impl crate::InstrSeqBuilder<'_> {
767            #(#builder_methods)*
768        }
769    }
770}