1use darling::FromMeta;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::parse_quote;
5use syn::Expr;
6use syn::{spanned::Spanned, FnArg, GenericParam, ItemFn, Pat};
7
8#[derive(FromMeta, Default)]
9#[darling(default)]
10struct WithSimdOpts {
11 #[darling(default)]
12 arch: Option<Expr>,
13}
14
15#[proc_macro_attribute]
16pub fn with_simd(
17 attr: proc_macro::TokenStream,
18 item: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20 match with_simd_impl(attr.into(), item.into()) {
21 Ok(out) => out.into(),
22 Err(e) => e.into_compile_error().into(),
23 }
24}
25
26fn with_simd_impl(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
27 let opts = match attr.is_empty() {
28 true => WithSimdOpts::default(),
29 false => {
30 let meta = syn::parse2::<syn::Meta>(attr)?;
31 WithSimdOpts::from_meta(&meta)?
32 }
33 };
34
35 let arch = opts.arch.unwrap_or(parse_quote!(macerator::Arch::new()));
36 let func = syn::parse2::<syn::ItemFn>(item)?;
37
38 let ItemFn {
39 attrs,
40 vis,
41 sig,
42 block,
43 } = func.clone();
44
45 let name = &sig.ident;
46
47 let lifetimes = sig.generics.lifetimes();
48 let type_params = sig.generics.type_params();
49 let const_params = sig.generics.const_params();
50
51 let mut outer_fn_sig = sig.clone();
52 outer_fn_sig.generics.params = lifetimes
53 .map(|l| GenericParam::Lifetime(l.clone()))
54 .chain(type_params.skip(1).map(|t| GenericParam::Type(t.clone())))
55 .chain(const_params.map(|c| GenericParam::Const(c.clone())))
56 .collect();
57 let mut inner_fn_sig = sig.clone();
58 inner_fn_sig.ident = format_ident!("{}_impl", name);
59 let struct_name = format_ident!("{}_struct", name);
60
61 let fields = sig
62 .inputs
63 .iter()
64 .map(|arg| match arg {
65 FnArg::Receiver(_) => Err(syn::Error::new(arg.span(), "Can't use macro on methods")),
66 FnArg::Typed(pat_type) => {
67 let ident = match &*pat_type.pat {
68 Pat::Ident(pat_ident) => &pat_ident.ident,
69 _ => todo!(),
70 };
71 let ty = &*pat_type.ty;
72 Ok((ident, ty))
73 }
74 })
75 .collect::<Result<Vec<_>, _>>()?;
76
77 let output_ty = match sig.output.clone() {
78 syn::ReturnType::Default => quote! { () },
79 syn::ReturnType::Type(_, ty) => quote! { #ty },
80 };
81
82 let inner_name = &inner_fn_sig.ident;
83 let (impl_generics, type_generics, where_clause) = outer_fn_sig.generics.split_for_impl();
84 let field_decl = fields.iter().map(|(ident, ty)| quote![#ident: #ty]);
85 let field_names = fields.iter().map(|it| it.0).collect::<Vec<_>>();
86
87 let simd_generic_name = sig.generics.type_params().next().unwrap().ident.clone();
88 let (_, inner_generics, _) = inner_fn_sig.generics.split_for_impl();
89 let turbofish = inner_generics.as_turbofish();
90 let struct_turbofish = type_generics.as_turbofish();
91
92 Ok(quote! {
93 #(#attrs)*
94 #vis #outer_fn_sig {
95 #[allow(non_camel_case_types)]
96 struct #struct_name #impl_generics #where_clause {
97 #(#field_decl,)*
98 };
99
100 impl #impl_generics macerator::WithSimd for #struct_name #type_generics #where_clause {
101 type Output = #output_ty;
102
103 #[inline(always)]
104 fn with_simd<#simd_generic_name: macerator::Simd>(self) -> <Self as macerator::WithSimd>::Output {
105 let Self {
106 #(#field_names,)*
107 } = self;
108 #[allow(unused_unsafe)]
109 unsafe {
110 #inner_name #turbofish(#(#field_names,)*)
111 }
112 }
113 }
114
115 (#arch).dispatch( #struct_name #struct_turbofish { #(#field_names,)* } )
116 }
117
118 #(#attrs)*
119 #inner_fn_sig #block
120 })
121}