macerator_macros/
lib.rs

1use darling::FromMeta;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::parse_quote;
5use syn::Expr;
6use syn::{spanned::Spanned, FnArg, GenericParam, ItemFn, Pat};
7
8#[derive(FromMeta, Default)]
9#[darling(default)]
10struct WithSimdOpts {
11    #[darling(default)]
12    arch: Option<Expr>,
13}
14
15#[proc_macro_attribute]
16pub fn with_simd(
17    attr: proc_macro::TokenStream,
18    item: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20    match with_simd_impl(attr.into(), item.into()) {
21        Ok(out) => out.into(),
22        Err(e) => e.into_compile_error().into(),
23    }
24}
25
26fn with_simd_impl(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
27    let opts = match attr.is_empty() {
28        true => WithSimdOpts::default(),
29        false => {
30            let meta = syn::parse2::<syn::Meta>(attr)?;
31            WithSimdOpts::from_meta(&meta)?
32        }
33    };
34
35    let arch = opts.arch.unwrap_or(parse_quote!(macerator::Arch::new()));
36    let func = syn::parse2::<syn::ItemFn>(item)?;
37
38    let ItemFn {
39        attrs,
40        vis,
41        sig,
42        block,
43    } = func.clone();
44
45    let name = &sig.ident;
46
47    let lifetimes = sig.generics.lifetimes();
48    let type_params = sig.generics.type_params();
49    let const_params = sig.generics.const_params();
50
51    let mut outer_fn_sig = sig.clone();
52    outer_fn_sig.generics.params = lifetimes
53        .map(|l| GenericParam::Lifetime(l.clone()))
54        .chain(type_params.skip(1).map(|t| GenericParam::Type(t.clone())))
55        .chain(const_params.map(|c| GenericParam::Const(c.clone())))
56        .collect();
57    let mut inner_fn_sig = sig.clone();
58    inner_fn_sig.ident = format_ident!("{}_impl", name);
59    let struct_name = format_ident!("{}_struct", name);
60
61    let fields = sig
62        .inputs
63        .iter()
64        .map(|arg| match arg {
65            FnArg::Receiver(_) => Err(syn::Error::new(arg.span(), "Can't use macro on methods")),
66            FnArg::Typed(pat_type) => {
67                let ident = match &*pat_type.pat {
68                    Pat::Ident(pat_ident) => &pat_ident.ident,
69                    _ => todo!(),
70                };
71                let ty = &*pat_type.ty;
72                Ok((ident, ty))
73            }
74        })
75        .collect::<Result<Vec<_>, _>>()?;
76
77    let output_ty = match sig.output.clone() {
78        syn::ReturnType::Default => quote! { () },
79        syn::ReturnType::Type(_, ty) => quote! { #ty },
80    };
81
82    let inner_name = &inner_fn_sig.ident;
83    let (impl_generics, type_generics, where_clause) = outer_fn_sig.generics.split_for_impl();
84    let field_decl = fields.iter().map(|(ident, ty)| quote![#ident: #ty]);
85    let field_names = fields.iter().map(|it| it.0).collect::<Vec<_>>();
86
87    let simd_generic_name = sig.generics.type_params().next().unwrap().ident.clone();
88    let (_, inner_generics, _) = inner_fn_sig.generics.split_for_impl();
89    let turbofish = inner_generics.as_turbofish();
90    let struct_turbofish = type_generics.as_turbofish();
91
92    Ok(quote! {
93        #(#attrs)*
94        #vis #outer_fn_sig {
95            #[allow(non_camel_case_types)]
96            struct #struct_name #impl_generics #where_clause {
97                #(#field_decl,)*
98            };
99
100            impl #impl_generics macerator::WithSimd for #struct_name #type_generics #where_clause {
101                type Output = #output_ty;
102
103                #[inline(always)]
104                fn with_simd<#simd_generic_name: macerator::Simd>(self) -> <Self as macerator::WithSimd>::Output {
105                    let Self {
106                        #(#field_names,)*
107                    } = self;
108                    #[allow(unused_unsafe)]
109                    unsafe {
110                        #inner_name #turbofish(#(#field_names,)*)
111                    }
112                }
113            }
114
115            (#arch).dispatch( #struct_name #struct_turbofish { #(#field_names,)* } )
116        }
117
118        #(#attrs)*
119        #inner_fn_sig #block
120    })
121}