1#![recursion_limit = "256"]
2
3extern crate proc_macro;
4
5use self::proc_macro::TokenStream;
6use heck::ToSnakeCase;
7use proc_macro2::Span;
8use quote::quote;
9use syn::ext::IdentExt;
10use syn::parse::{Parse, ParseStream};
11use syn::punctuated::Punctuated;
12use syn::DeriveInput;
13use syn::Error;
14use syn::{parse_macro_input, Ident, Result, Token};
15
16#[proc_macro_attribute]
17pub fn walrus_instr(_attr: TokenStream, input: TokenStream) -> TokenStream {
18 let input = parse_macro_input!(input as DeriveInput);
19
20 let variants = match get_enum_variants(&input) {
21 Ok(v) => v,
22 Err(e) => return e.to_compile_error().into(),
23 };
24
25 assert_eq!(input.ident.to_string(), "Instr");
26
27 let types = create_types(&input.attrs, &variants);
28 let visit = create_visit(&variants);
29 let builder = create_builder(&variants);
30
31 let expanded = quote! {
32 #types
33 #visit
34 #builder
35 };
36
37 TokenStream::from(expanded)
38}
39
40struct WalrusVariant {
41 syn: syn::Variant,
42 fields: Vec<WalrusFieldOpts>,
43 opts: WalrusVariantOpts,
44}
45
46#[derive(Default)]
47struct WalrusVariantOpts {
48 display_name: Option<syn::Ident>,
49 display_extra: Option<syn::Ident>,
50 skip_builder: bool,
51}
52
53#[derive(Default)]
54struct WalrusFieldOpts {
55 skip_visit: bool,
56}
57
58fn get_enum_variants(input: &DeriveInput) -> Result<Vec<WalrusVariant>> {
59 let en = match &input.data {
60 syn::Data::Enum(en) => en,
61 syn::Data::Struct(_) => {
62 panic!("can only put #[walrus_instr] on an enum; found it on a struct")
63 }
64 syn::Data::Union(_) => {
65 panic!("can only put #[walrus_instr] on an enum; found it on a union")
66 }
67 };
68 en.variants
69 .iter()
70 .cloned()
71 .map(|mut variant| {
72 Ok(WalrusVariant {
73 opts: syn::parse(walrus_attrs(&mut variant.attrs))?,
74 fields: variant
75 .fields
76 .iter_mut()
77 .map(|field| syn::parse(walrus_attrs(&mut field.attrs)))
78 .collect::<Result<_>>()?,
79 syn: variant,
80 })
81 })
82 .collect()
83}
84
85impl Parse for WalrusFieldOpts {
86 fn parse(input: ParseStream) -> Result<Self> {
87 enum Attr {
88 SkipVisit,
89 }
90
91 let attrs = Punctuated::<_, syn::token::Comma>::parse_terminated(input)?;
92 let mut ret = WalrusFieldOpts::default();
93 for attr in attrs {
94 match attr {
95 Attr::SkipVisit => ret.skip_visit = true,
96 }
97 }
98 return Ok(ret);
99
100 impl Parse for Attr {
101 fn parse(input: ParseStream) -> Result<Self> {
102 let attr: Ident = input.parse()?;
103 if attr == "skip_visit" {
104 return Ok(Attr::SkipVisit);
105 }
106 Err(Error::new(attr.span(), "unexpected attribute"))
107 }
108 }
109 }
110}
111
112impl Parse for WalrusVariantOpts {
113 fn parse(input: ParseStream) -> Result<Self> {
114 enum Attr {
115 DisplayName(syn::Ident),
116 DisplayExtra(syn::Ident),
117 SkipBuilder,
118 }
119
120 let attrs = Punctuated::<_, syn::token::Comma>::parse_terminated(input)?;
121 let mut ret = WalrusVariantOpts::default();
122 for attr in attrs {
123 match attr {
124 Attr::DisplayName(ident) => ret.display_name = Some(ident),
125 Attr::DisplayExtra(ident) => ret.display_extra = Some(ident),
126 Attr::SkipBuilder => ret.skip_builder = true,
127 }
128 }
129 return Ok(ret);
130
131 impl Parse for Attr {
132 fn parse(input: ParseStream) -> Result<Self> {
133 let attr: Ident = input.parse()?;
134 if attr == "display_name" {
135 input.parse::<Token![=]>()?;
136 let name = input.call(Ident::parse_any)?;
137 return Ok(Attr::DisplayName(name));
138 }
139 if attr == "display_extra" {
140 input.parse::<Token![=]>()?;
141 let name = input.call(Ident::parse_any)?;
142 return Ok(Attr::DisplayExtra(name));
143 }
144 if attr == "skip_builder" {
145 return Ok(Attr::SkipBuilder);
146 }
147 Err(Error::new(attr.span(), "unexpected attribute"))
148 }
149 }
150 }
151}
152
153fn walrus_attrs(attrs: &mut Vec<syn::Attribute>) -> TokenStream {
154 let mut ret = proc_macro2::TokenStream::new();
155 let ident = syn::Path::from(syn::Ident::new("walrus", Span::call_site()));
156 for i in (0..attrs.len()).rev() {
157 if attrs[i].path() != &ident {
158 continue;
159 }
160 let attr = attrs.remove(i);
161 let group = if let syn::Meta::List(syn::MetaList { tokens, .. }) = attr.meta {
162 tokens
163 } else {
164 panic!("#[walrus(...)] expected")
165 };
166 ret.extend(group);
167 ret.extend(quote! { , });
168 }
169 ret.into()
170}
171
172fn create_types(attrs: &[syn::Attribute], variants: &[WalrusVariant]) -> impl quote::ToTokens {
173 let types: Vec<_> = variants
174 .iter()
175 .map(|v| {
176 let name = &v.syn.ident;
177 let attrs = &v.syn.attrs;
178 let fields = v.syn.fields.iter().map(|f| {
179 let name = &f.ident;
180 let attrs = &f.attrs;
181 let ty = &f.ty;
182 quote! {
183 #( #attrs )*
184 pub #name : #ty,
185 }
186 });
187 quote! {
188 #( #attrs )*
189 #[derive(Clone, Debug)]
190 pub struct #name {
191 #( #fields )*
192 }
193
194 impl From<#name> for Instr {
195 #[inline]
196 fn from(x: #name) -> Instr {
197 Instr::#name(x)
198 }
199 }
200 }
201 })
202 .collect();
203
204 let methods: Vec<_> = variants
205 .iter()
206 .map(|v| {
207 let name = &v.syn.ident;
208 let snake_name = name.to_string().to_snake_case();
209
210 let is_name = format!("is_{}", snake_name);
211 let is_name = syn::Ident::new(&is_name, Span::call_site());
212
213 let ref_name = format!("{}_ref", snake_name);
214 let ref_name = syn::Ident::new(&ref_name, Span::call_site());
215
216 let mut_name = format!("{}_mut", snake_name);
217 let mut_name = syn::Ident::new(&mut_name, Span::call_site());
218
219 let unwrap_name = format!("unwrap_{}", snake_name);
220 let unwrap_name = syn::Ident::new(&unwrap_name, Span::call_site());
221
222 let unwrap_mut_name = format!("unwrap_{}_mut", snake_name);
223 let unwrap_mut_name = syn::Ident::new(&unwrap_mut_name, Span::call_site());
224
225 let ref_name_doc = format!(
226 "
227 If this instruction is a `{}`, get a shared reference to it.
228
229 Returns `None` otherwise.
230 ",
231 name
232 );
233
234 let mut_name_doc = format!(
235 "
236 If this instruction is a `{}`, get an exclusive reference to it.
237
238 Returns `None` otherwise.
239 ",
240 name
241 );
242
243 let is_name_doc = format!("Is this instruction a `{}`?", name);
244
245 let unwrap_name_doc = format!(
246 "
247 Get a shared reference to the underlying `{}`.
248
249 Panics if this instruction is not a `{}`.
250 ",
251 name, name
252 );
253
254 let unwrap_mut_name_doc = format!(
255 "
256 Get an exclusive reference to the underlying `{}`.
257
258 Panics if this instruction is not a `{}`.
259 ",
260 name, name
261 );
262
263 quote! {
264 #[doc=#ref_name_doc]
265 #[inline]
266 fn #ref_name(&self) -> Option<&#name> {
267 if let Instr::#name(ref x) = *self {
268 Some(x)
269 } else {
270 None
271 }
272 }
273
274 #[doc=#mut_name_doc]
275 #[inline]
276 pub fn #mut_name(&mut self) -> Option<&mut #name> {
277 if let Instr::#name(ref mut x) = *self {
278 Some(x)
279 } else {
280 None
281 }
282 }
283
284 #[doc=#is_name_doc]
285 #[inline]
286 pub fn #is_name(&self) -> bool {
287 self.#ref_name().is_some()
288 }
289
290 #[doc=#unwrap_name_doc]
291 #[inline]
292 pub fn #unwrap_name(&self) -> &#name {
293 self.#ref_name().unwrap()
294 }
295
296 #[doc=#unwrap_mut_name_doc]
297 #[inline]
298 pub fn #unwrap_mut_name(&mut self) -> &mut #name {
299 self.#mut_name().unwrap()
300 }
301 }
302 })
303 .collect();
304
305 let variants: Vec<_> = variants
306 .iter()
307 .map(|v| {
308 let name = &v.syn.ident;
309 let attrs = &v.syn.attrs;
310 quote! {
311 #( #attrs )*
312 #name(#name)
313 }
314 })
315 .collect();
316
317 quote! {
318 #( #types )*
319
320 #( #attrs )*
321 pub enum Instr {
322 #(#variants),*
323 }
324
325 impl Instr {
326 #( #methods )*
327 }
328 }
329}
330
331fn visit_fields(
332 variant: &WalrusVariant,
333 allow_skip: bool,
334) -> impl Iterator<Item = (syn::Ident, proc_macro2::TokenStream, bool)> + '_ {
335 return variant
336 .syn
337 .fields
338 .iter()
339 .zip(&variant.fields)
340 .enumerate()
341 .filter(move |(_, (_, info))| !allow_skip || !info.skip_visit)
342 .map(move |(i, (field, _info))| {
343 let field_name = match &field.ident {
344 Some(name) => quote! { #name },
345 None => quote! { #i },
346 };
347 let (ty_name, list) = extract_name_and_if_list(&field.ty);
348 let mut method_name = "visit_".to_string();
349 method_name.push_str(&ty_name.to_string().to_snake_case());
350 let method_name = syn::Ident::new(&method_name, Span::call_site());
351 (method_name, field_name, list)
352 });
353
354 fn extract_name_and_if_list(ty: &syn::Type) -> (&syn::Ident, bool) {
355 let path = match ty {
356 syn::Type::Path(p) => &p.path,
357 _ => panic!("field types must be paths"),
358 };
359 let segment = path.segments.last().unwrap();
360 let args = match &segment.arguments {
361 syn::PathArguments::None => return (&segment.ident, false),
362 syn::PathArguments::AngleBracketed(a) => &a.args,
363 _ => panic!("invalid path in #[walrus_instr]"),
364 };
365 let mut ty = match args.first().unwrap() {
366 syn::GenericArgument::Type(ty) => ty,
367 _ => panic!("invalid path in #[walrus_instr]"),
368 };
369 if let syn::Type::Slice(t) = ty {
370 ty = &t.elem;
371 }
372 match ty {
373 syn::Type::Path(p) => {
374 let segment = p.path.segments.last().unwrap();
375 (&segment.ident, true)
376 }
377 _ => panic!("invalid path in #[walrus_instr]"),
378 }
379 }
380}
381
382fn create_visit(variants: &[WalrusVariant]) -> impl quote::ToTokens {
383 let mut visit_impls = Vec::new();
384 let mut visitor_trait_methods = Vec::new();
385 let mut visitor_mut_trait_methods = Vec::new();
386 let mut visit_impl = Vec::new();
387 let mut visit_mut_impl = Vec::new();
388
389 for variant in variants {
390 let name = &variant.syn.ident;
391
392 let mut method_name = "visit_".to_string();
393 method_name.push_str(&name.to_string().to_snake_case());
394 let method_name = syn::Ident::new(&method_name, Span::call_site());
395 let method_name_mut = syn::Ident::new(&format!("{}_mut", method_name), Span::call_site());
396
397 let recurse_fields = visit_fields(variant, true).map(|(method_name, field_name, list)| {
398 if list {
399 quote! {
400 for item in self.#field_name.iter() {
401 visitor.#method_name(item);
402 }
403 }
404 } else {
405 quote! {
406 visitor.#method_name(&self.#field_name);
407 }
408 }
409 });
410 let recurse_fields_mut =
411 visit_fields(variant, true).map(|(method_name, field_name, list)| {
412 let name = format!("{}_mut", method_name);
413 let method_name = syn::Ident::new(&name, Span::call_site());
414 if list {
415 quote! {
416 for item in self.#field_name.iter_mut() {
417 visitor.#method_name(item);
418 }
419 }
420 } else {
421 quote! {
422 visitor.#method_name(&mut self.#field_name);
423 }
424 }
425 });
426
427 visit_impls.push(quote! {
428 impl<'instr> Visit<'instr> for #name {
429 #[inline]
430 fn visit<V: Visitor<'instr>>(&self, visitor: &mut V) {
431 #(#recurse_fields);*
432 }
433 }
434 impl VisitMut for #name {
435 #[inline]
436 fn visit_mut<V: VisitorMut>(&mut self, visitor: &mut V) {
437 #(#recurse_fields_mut);*
438 }
439 }
440 });
441
442 let doc = format!("Visit `{}`.", name);
443 visitor_trait_methods.push(quote! {
444 #[doc=#doc]
445 #[inline]
446 fn #method_name(&mut self, instr: &#name) {
447 }
449 });
450 visitor_mut_trait_methods.push(quote! {
451 #[doc=#doc]
452 #[inline]
453 fn #method_name_mut(&mut self, instr: &mut #name) {
454 instr.visit_mut(self);
455 }
456 });
457
458 let mut method_name = "visit_".to_string();
459 method_name.push_str(&name.to_string().to_snake_case());
460 let method_name = syn::Ident::new(&method_name, Span::call_site());
461 visit_impl.push(quote! {
462 Instr::#name(e) => {
463 visitor.#method_name(e);
464 e.visit(visitor);
465 }
466 });
467 visit_mut_impl.push(quote! {
468 Instr::#name(e) => {
469 visitor.#method_name_mut(e);
470 e.visit_mut(visitor);
471 }
472 });
473 }
474
475 quote! {
476 pub trait Visitor<'instr>: Sized {
494 #[inline]
500 fn start_instr_seq(&mut self, instr_seq: &'instr InstrSeq) {
501 }
503
504 #[inline]
507 fn end_instr_seq(&mut self, instr_seq: &'instr InstrSeq) {
508 }
510
511 #[inline]
513 fn visit_instr(&mut self, instr: &'instr Instr, instr_loc: &'instr InstrLocId) {
514 }
516
517 #[inline]
519 fn visit_instr_seq_id(&mut self, instr_seq_id: &InstrSeqId) {
520 }
522
523 #[inline]
525 fn visit_local_id(&mut self, local: &crate::LocalId) {
526 }
528
529 #[inline]
531 fn visit_memory_id(&mut self, memory: &crate::MemoryId) {
532 }
534
535 #[inline]
537 fn visit_table_id(&mut self, table: &crate::TableId) {
538 }
540
541 #[inline]
543 fn visit_global_id(&mut self, global: &crate::GlobalId) {
544 }
546
547 #[inline]
549 fn visit_function_id(&mut self, function: &crate::FunctionId) {
550 }
552
553 #[inline]
555 fn visit_data_id(&mut self, function: &crate::DataId) {
556 }
558
559 #[inline]
561 fn visit_type_id(&mut self, ty: &crate::TypeId) {
562 }
564
565 #[inline]
567 fn visit_element_id(&mut self, elem: &crate::ElementId) {
568 }
570
571 #[inline]
573 fn visit_tag_id(&mut self, tag: &crate::TagId) {
574 }
576
577 #[inline]
579 fn visit_value(&mut self, value: &crate::ir::Value) {
580 }
582
583 #( #visitor_trait_methods )*
584 }
585
586 pub trait VisitorMut: Sized {
590 #[inline]
596 fn start_instr_seq_mut(&mut self, instr_seq: &mut InstrSeq) {
597 }
599
600 #[inline]
603 fn end_instr_seq_mut(&mut self, instr_seq: &mut InstrSeq) {
604 }
606
607 #[inline]
609 fn visit_instr_mut(&mut self, instr: &mut Instr, instr_loc: &mut InstrLocId) {
610 }
612
613 #[inline]
615 fn visit_instr_seq_id_mut(&mut self, instr_seq_id: &mut InstrSeqId) {
616 }
618
619 #[inline]
621 fn visit_local_id_mut(&mut self, local: &mut crate::LocalId) {
622 }
624
625 #[inline]
627 fn visit_memory_id_mut(&mut self, memory: &mut crate::MemoryId) {
628 }
630
631 #[inline]
633 fn visit_table_id_mut(&mut self, table: &mut crate::TableId) {
634 }
636
637 #[inline]
639 fn visit_global_id_mut(&mut self, global: &mut crate::GlobalId) {
640 }
642
643 #[inline]
645 fn visit_function_id_mut(&mut self, function: &mut crate::FunctionId) {
646 }
648
649 #[inline]
651 fn visit_data_id_mut(&mut self, function: &mut crate::DataId) {
652 }
654
655 #[inline]
657 fn visit_type_id_mut(&mut self, ty: &mut crate::TypeId) {
658 }
660
661 #[inline]
663 fn visit_element_id_mut(&mut self, elem: &mut crate::ElementId) {
664 }
666
667 #[inline]
669 fn visit_tag_id_mut(&mut self, tag: &mut crate::TagId) {
670 }
672
673 #[inline]
675 fn visit_value_mut(&mut self, value: &mut crate::ir::Value) {
676 }
678
679 #( #visitor_mut_trait_methods )*
680 }
681
682 impl<'instr> Visit<'instr> for Instr {
683 #[inline]
684 fn visit<V>(&self, visitor: &mut V) where V: Visitor<'instr> {
685 match self {
686 #( #visit_impl )*
687 }
688 }
689 }
690
691 impl VisitMut for Instr {
692 #[inline]
693 fn visit_mut<V>(&mut self, visitor: &mut V) where V: VisitorMut {
694 match self {
695 #( #visit_mut_impl )*
696 }
697 }
698 }
699
700 #( #visit_impls )*
701 }
702}
703
704fn create_builder(variants: &[WalrusVariant]) -> impl quote::ToTokens {
705 let mut builder_methods = Vec::new();
706 for variant in variants {
707 if variant.opts.skip_builder {
708 continue;
709 }
710
711 let name = &variant.syn.ident;
712
713 let mut method_name = name.to_string().to_snake_case();
714
715 let mut method_name_at = method_name.clone();
716 method_name_at.push_str("_at");
717 let method_name_at = syn::Ident::new(&method_name_at, Span::call_site());
718
719 if method_name == "return" || method_name == "const" {
720 method_name.push('_');
721 } else if method_name == "block" {
722 continue;
723 }
724 let method_name = syn::Ident::new(&method_name, Span::call_site());
725
726 let mut args = Vec::new();
727 let mut arg_names = Vec::new();
728
729 for field in variant.syn.fields.iter() {
730 let name = field.ident.as_ref().expect("can't have unnamed fields");
731 arg_names.push(name);
732 let ty = &field.ty;
733 args.push(quote! { #name: #ty });
734 }
735
736 let doc = format!(
737 "Push a new `{}` instruction onto this builder's block.",
738 name
739 );
740 let at_doc = format!(
741 "Splice a new `{}` instruction into this builder's block at the given index.\n\n\
742 # Panics\n\n\
743 Panics if `position > self.instrs.len()`.",
744 name
745 );
746
747 let arg_names = &arg_names;
748 let args = &args;
749
750 builder_methods.push(quote! {
751 #[inline]
752 #[doc=#doc]
753 pub fn #method_name(&mut self, #(#args),*) -> &mut Self {
754 self.instr(#name { #(#arg_names),* })
755 }
756
757 #[inline]
758 #[doc=#at_doc]
759 pub fn #method_name_at(&mut self, position: usize, #(#args),*) -> &mut Self {
760 self.instr_at(position, #name { #(#arg_names),* })
761 }
762 });
763 }
764 quote! {
765 #[allow(missing_docs)]
766 impl crate::InstrSeqBuilder<'_> {
767 #(#builder_methods)*
768 }
769 }
770}