Skip to main content

midenc_hir_macros/
lib.rs

1#![deny(warnings)]
2
3extern crate proc_macro;
4
5mod dialect;
6mod operation;
7mod operations;
8mod spanned;
9#[cfg(test)]
10mod tests;
11
12use inflector::cases::kebabcase::to_kebab_case;
13use quote::{ToTokens, format_ident, quote};
14use syn::{Data, DeriveInput, Error, Ident, Token, parse_macro_input, spanned::Spanned};
15
16#[proc_macro_derive(Spanned, attributes(span))]
17pub fn derive_spanned(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
18    // Parse into syntax tree
19    let derive = parse_macro_input!(input as DeriveInput);
20    // Structure name
21    let name = derive.ident;
22    let result = match derive.data {
23        Data::Struct(data) => spanned::derive_spanned_struct(name, data, derive.generics),
24        Data::Enum(data) => spanned::derive_spanned_enum(name, data, derive.generics),
25        Data::Union(_) => {
26            Err(Error::new(name.span(), "deriving Spanned on unions is not currently supported"))
27        }
28    };
29    match result {
30        Ok(ts) => ts,
31        Err(err) => err.into_compile_error().into(),
32    }
33}
34
35#[proc_macro_derive(Dialect, attributes(dialect))]
36pub fn derive_dialect(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
37    // Parse into syntax tree
38    let derive = parse_macro_input!(input as DeriveInput);
39    // Structure name
40    let result = dialect::derive_dialect(&derive);
41    match result {
42        Ok(ts) => proc_macro::TokenStream::from(ts.into_token_stream()),
43        Err(err) => err.write_errors().into(),
44    }
45}
46
47#[proc_macro_derive(DialectRegistration, attributes(dialect))]
48pub fn derive_dialect_registration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
49    // Parse into syntax tree
50    let derive = parse_macro_input!(input as DeriveInput);
51    // Structure name
52    let result = dialect::derive_dialect_registration(&derive);
53    match result {
54        Ok(ts) => proc_macro::TokenStream::from(ts.into_token_stream()),
55        Err(err) => err.write_errors().into(),
56    }
57}
58
59#[proc_macro_derive(DialectAttribute, attributes(attribute))]
60pub fn derive_dialect_attribute(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
61    // Parse into syntax tree
62    let derive = parse_macro_input!(input as DeriveInput);
63    // Structure name
64    let result = dialect::derive_attribute(&derive);
65    match result {
66        Ok(ts) => proc_macro::TokenStream::from(ts.into_token_stream()),
67        Err(err) => err.write_errors().into(),
68    }
69}
70
71#[proc_macro_derive(EffectOpInterface, attributes(effects))]
72pub fn derive_effect_op_interface(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
73    // Parse into syntax tree
74    let derive = parse_macro_input!(input as DeriveInput);
75    // Structure name
76    let result = operations::derive_effect_op_interface(&derive);
77    match result {
78        Ok(ts) => proc_macro::TokenStream::from(ts.into_token_stream()),
79        Err(err) => err.write_errors().into(),
80    }
81}
82
83#[proc_macro_derive(OpPrinter)]
84pub fn derive_op_printer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
85    // Parse into syntax tree
86    let derive = parse_macro_input!(input as DeriveInput);
87    // Structure name
88    let result = operations::derive_op_printer(&derive);
89    match result {
90        Ok(ts) => proc_macro::TokenStream::from(ts.into_token_stream()),
91        Err(err) => err.write_errors().into(),
92    }
93}
94
95#[proc_macro_derive(OpParser)]
96pub fn derive_op_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
97    // Parse into syntax tree
98    let derive = parse_macro_input!(input as DeriveInput);
99    // Structure name
100    let result = operations::derive_op_parser(&derive);
101    match result {
102        Ok(ts) => proc_macro::TokenStream::from(ts.into_token_stream()),
103        Err(err) => err.write_errors().into(),
104    }
105}
106
107#[proc_macro_attribute]
108pub fn operation_trait(
109    attr: proc_macro::TokenStream,
110    item: proc_macro::TokenStream,
111) -> proc_macro::TokenStream {
112    let attr = proc_macro2::TokenStream::from(attr);
113    let input = syn::parse_macro_input!(item as syn::ItemTrait);
114
115    let meta = match darling::ast::NestedMeta::parse_meta_list(attr) {
116        Ok(meta) => meta,
117        Err(err) => return err.into_compile_error().into(),
118    };
119
120    match self::operations::derive_operation_trait(meta, input) {
121        Ok(ts) => proc_macro::TokenStream::from(ts),
122        Err(err) => err.write_errors().into(),
123    }
124}
125
126/// Define an operation.
127///
128/// ## Examples
129///
130/// ```text
131/// #[operation(
132///     dialect = HirDialect,
133///     traits(Terminator),
134///     implements(BranchOpInterface),
135/// )]
136/// pub struct Switch {
137///     #[operand]
138///     selector: UInt32,
139///     #[successors(keyed)]
140///     cases: SwitchArm,
141///     #[successor]
142///     fallback: Successor,
143/// }
144///
145/// pub struct Call {
146///     #[attr]
147///     callee: Symbol,
148///     #[operands]
149///     arguments: Vec<AnyType>,
150///     #[results]
151///     results: Vec<AnyType>,
152/// }
153///
154/// #[operation]
155/// pub struct If {
156///     #[operand]
157///     condition: Bool,
158///     #[region]
159///     then_region: RegionRef,
160///     #[region]
161///     else_region: RegionRef,
162/// }
163/// ```
164#[proc_macro_attribute]
165pub fn operation(
166    attr: proc_macro::TokenStream,
167    item: proc_macro::TokenStream,
168) -> proc_macro::TokenStream {
169    let attr = proc_macro2::TokenStream::from(attr);
170    let mut input = syn::parse_macro_input!(item as syn::ItemStruct);
171    let span = input.span();
172
173    // Reconstruct the input so we can treat this like a derive macro
174    //
175    // We can't _actually_ use derive, because we need to modify the item itself.
176    input.attrs.push(syn::Attribute {
177        pound_token: syn::token::Pound(span),
178        style: syn::AttrStyle::Outer,
179        bracket_token: syn::token::Bracket(span),
180        meta: syn::Meta::List(syn::MetaList {
181            path: syn::parse_str("operation").unwrap(),
182            delimiter: syn::MacroDelimiter::Paren(syn::token::Paren(span)),
183            tokens: attr,
184        }),
185    });
186
187    let input = syn::parse_quote! {
188        #input
189    };
190
191    match operation::derive_operation(input) {
192        Ok(token_stream) => proc_macro::TokenStream::from(token_stream),
193        Err(err) => err.write_errors().into(),
194    }
195}
196
197#[proc_macro_derive(PassInfo)]
198pub fn derive_pass_info(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
199    let derive_input = parse_macro_input!(item as DeriveInput);
200    let derive_span = derive_input.span();
201    let id = derive_input.ident.clone();
202    let generics = derive_input.generics;
203    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
204    let name = derive_input.ident.to_string();
205    let pass_name = to_kebab_case(&name);
206    let pass_name_lit = syn::Lit::Str(syn::LitStr::new(&pass_name, id.span()));
207
208    let doc_ident = format_ident!("doc", span = derive_span);
209    let docs = derive_input
210        .attrs
211        .iter()
212        .filter_map(|attr| match attr.meta {
213            syn::Meta::NameValue(ref nv) => {
214                if nv.path.get_ident()? == &doc_ident {
215                    match nv.value {
216                        syn::Expr::Lit(syn::ExprLit {
217                            lit: syn::Lit::Str(ref lit),
218                            ..
219                        }) => Some(lit.value()),
220                        _ => None,
221                    }
222                } else {
223                    None
224                }
225            }
226            syn::Meta::Path(_) | syn::Meta::List(_) => None,
227        })
228        .collect::<Vec<_>>();
229    let pass_summary = match docs.first() {
230        Some(line) => syn::Lit::Str(syn::LitStr::new(line, derive_span)),
231        None => syn::Lit::Str(syn::LitStr::new("", derive_span)),
232    };
233    let description = docs.into_iter().collect::<String>();
234    let pass_description = syn::Lit::Str(syn::LitStr::new(&description, derive_span));
235
236    let quoted = quote! {
237        impl #impl_generics PassInfo for #id #ty_generics #where_clause {
238            const FLAG: &'static str = #pass_name_lit;
239            const SUMMARY: &'static str = #pass_summary;
240            const DESCRIPTION: &'static str = #pass_description;
241        }
242    };
243
244    proc_macro::TokenStream::from(quoted)
245}
246
247#[proc_macro_derive(AnalysisKey, attributes(analysis_key))]
248pub fn derive_analysis_key(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
249    let derive_input = parse_macro_input!(item as DeriveInput);
250    let derive_span = derive_input.span();
251    let id = derive_input.ident.clone();
252    let generics = derive_input.generics;
253    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
254
255    let found = match &derive_input.data {
256        syn::Data::Struct(data) => match &data.fields {
257            syn::Fields::Named(fields) => {
258                let mut found = None;
259                for field in fields.named.iter() {
260                    if field.attrs.iter().any(is_analysis_key_attr) {
261                        if found.is_some() {
262                            return syn::Error::new(
263                                field.span(),
264                                "duplicate #[analysis_key] field",
265                            )
266                            .into_compile_error()
267                            .into();
268                        }
269                        found = Some((field.ident.as_ref().cloned().unwrap(), field.ty.clone()));
270                    }
271                }
272                found
273            }
274            syn::Fields::Unnamed(fields) => {
275                let mut found = None;
276                for (i, field) in fields.unnamed.iter().enumerate() {
277                    if field.attrs.iter().any(is_analysis_key_attr) {
278                        if found.is_some() {
279                            return syn::Error::new(
280                                field.span(),
281                                "duplicate #[analysis_key] field",
282                            )
283                            .into_compile_error()
284                            .into();
285                        }
286                        found = Some((Ident::new(&i.to_string(), field.span()), field.ty.clone()));
287                    }
288                }
289                found
290            }
291            syn::Fields::Unit => {
292                return syn::Error::new(
293                    derive_span,
294                    "structs with unit fields cannot derive AnalysisKey",
295                )
296                .into_compile_error()
297                .into();
298            }
299        },
300        syn::Data::Enum(_) => {
301            return syn::Error::new(derive_span, "enums cannot derive AnalysisKey")
302                .into_compile_error()
303                .into();
304        }
305        syn::Data::Union(_) => {
306            return syn::Error::new(derive_span, "unions cannot derive AnalysisKey")
307                .into_compile_error()
308                .into();
309        }
310    };
311
312    let (field_id, field_ty) = match found {
313        Some(found) => found,
314        None => {
315            return syn::Error::new(derive_span, "missing #[analysis_key] attribute")
316                .into_compile_error()
317                .into();
318        }
319    };
320
321    let quoted = quote! {
322        impl #impl_generics AnalysisKey for #id #ty_generics #where_clause {
323            type Key = #field_ty;
324
325            fn key(&self) -> Self::Key { self.#field_id }
326        }
327    };
328
329    proc_macro::TokenStream::from(quoted)
330}
331
332#[proc_macro_derive(RewritePassRegistration)]
333pub fn derive_rewrite_pass_registration(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
334    let derive_input = parse_macro_input!(item as DeriveInput);
335    let id = derive_input.ident.clone();
336    let generics = derive_input.generics;
337    let mut params = syn::punctuated::Punctuated::<_, Token![,]>::new();
338    for gp in generics.params.iter() {
339        match gp {
340            syn::GenericParam::Lifetime(lt) => {
341                if !lt.bounds.empty_or_trailing() {
342                    return syn::Error::new(
343                        gp.span(),
344                        "cannot derive RewritePassRegistration on a type with lifetime bounds",
345                    )
346                    .into_compile_error()
347                    .into();
348                }
349                params.push(syn::GenericArgument::Lifetime(syn::Lifetime {
350                    apostrophe: lt.span(),
351                    ident: Ident::new("_", lt.span()),
352                }));
353            }
354            syn::GenericParam::Type(ty) => {
355                if !ty.bounds.empty_or_trailing() {
356                    return syn::Error::new(
357                        gp.span(),
358                        "cannot derive RewritePassRegistration on a generic type with type bounds",
359                    )
360                    .into_compile_error()
361                    .into();
362                }
363                let param_ty: syn::Type = syn::parse_quote_spanned! { ty.span() => () };
364                params.push(syn::GenericArgument::Type(param_ty));
365            }
366            syn::GenericParam::Const(_) => {
367                return syn::Error::new(
368                    gp.span(),
369                    "cannot derive RewritePassRegistration on a generic type with const arguments",
370                )
371                .into_compile_error()
372                .into();
373            }
374        }
375    }
376
377    let quoted = if params.empty_or_trailing() {
378        quote! {
379            inventory::submit!(midenc_hir::pass::RewritePassRegistration::new::<#id>());
380            inventory::submit! {
381                midenc_session::CompileFlag::new(<#id as PassInfo>::FLAG)
382                    .long(<#id as PassInfo>::FLAG)
383                    .help(<#id as PassInfo>::SUMMARY)
384                    .help_heading("Rewrites")
385                    .action(midenc_session::FlagAction::SetTrue)
386                    .hide(true)
387            }
388        }
389    } else {
390        quote! {
391            inventory::submit!(midenc_hir::pass::RewritePassRegistration::new::<#id<#params>>());
392            inventory::submit! {
393                midenc_session::CompileFlag::new(<#id<#params> as PassInfo>::FLAG)
394                    .long(<#id<#params> as PassInfo>::FLAG)
395                    .help(<#id<#params> as PassInfo>::SUMMARY)
396                    .help_heading("Rewrites")
397                    .action(midenc_session::FlagAction::SetTrue)
398                    .hide(true)
399            }
400        }
401    };
402
403    proc_macro::TokenStream::from(quoted)
404}
405
406#[proc_macro_derive(ModuleRewritePassAdapter)]
407pub fn derive_module_rewrite_pass_adapter(
408    item: proc_macro::TokenStream,
409) -> proc_macro::TokenStream {
410    let derive_input = parse_macro_input!(item as DeriveInput);
411    let id = derive_input.ident.clone();
412
413    let quoted = quote! {
414        inventory::submit!(midenc_hir::pass::RewritePassRegistration::new::<midenc_hir::pass::ModuleRewritePassAdapter::<#id>>());
415    };
416
417    proc_macro::TokenStream::from(quoted)
418}
419
420#[proc_macro_derive(ConversionPassRegistration)]
421pub fn derive_conversion_pass_registration(
422    item: proc_macro::TokenStream,
423) -> proc_macro::TokenStream {
424    let derive_input = parse_macro_input!(item as DeriveInput);
425    let id = derive_input.ident.clone();
426    let generics = derive_input.generics;
427    let mut params = syn::punctuated::Punctuated::<_, Token![,]>::new();
428    for gp in generics.params.iter() {
429        match gp {
430            syn::GenericParam::Lifetime(lt) => {
431                if !lt.bounds.empty_or_trailing() {
432                    return syn::Error::new(
433                        gp.span(),
434                        "cannot derive ConversionPassRegistration on a type with lifetime bounds",
435                    )
436                    .into_compile_error()
437                    .into();
438                }
439                params.push(syn::GenericArgument::Lifetime(syn::Lifetime {
440                    apostrophe: lt.span(),
441                    ident: Ident::new("_", lt.span()),
442                }));
443            }
444            syn::GenericParam::Type(ty) => {
445                if !ty.bounds.empty_or_trailing() {
446                    return syn::Error::new(
447                        gp.span(),
448                        "cannot derive ConversionPassRegistration on a generic type with type \
449                         bounds",
450                    )
451                    .into_compile_error()
452                    .into();
453                }
454                let param_ty: syn::Type = syn::parse_quote_spanned! { ty.span() => () };
455                params.push(syn::GenericArgument::Type(param_ty));
456            }
457            syn::GenericParam::Const(_) => {
458                return syn::Error::new(
459                    gp.span(),
460                    "cannot derive ConversionPassRegistration on a generic type with const \
461                     arguments",
462                )
463                .into_compile_error()
464                .into();
465            }
466        }
467    }
468
469    let quoted = if params.empty_or_trailing() {
470        quote! {
471            inventory::submit! {
472                midenc_session::CompileFlag::new(<#id as PassInfo>::FLAG)
473                    .long(<#id as PassInfo>::FLAG)
474                    .help(<#id as PassInfo>::SUMMARY)
475                    .help_heading("Conversions")
476                    .action(midenc_session::FlagAction::SetTrue)
477                    .hide(true)
478            }
479        }
480    } else {
481        quote! {
482            inventory::submit! {
483                midenc_session::CompileFlag::new(<#id<#params> as PassInfo>::FLAG)
484                    .long(<#id<#params> as PassInfo>::FLAG)
485                    .help(<#id<#params> as PassInfo>::SUMMARY)
486                    .help_heading("Conversions")
487                    .action(midenc_session::FlagAction::SetTrue)
488                    .hide(true)
489            }
490        }
491    };
492
493    proc_macro::TokenStream::from(quoted)
494}
495
496fn is_analysis_key_attr(attr: &syn::Attribute) -> bool {
497    if let syn::Meta::Path(ref path) = attr.meta {
498        path.is_ident("analysis_key")
499    } else {
500        false
501    }
502}