rnet_macros/
lib.rs

1//! rnet-macros
2//!
3//! Procedural macros for `rnet`
4#![deny(missing_docs)]
5
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::{format_ident, quote, ToTokens, TokenStreamExt};
9use syn::{
10    parse::{Parse, ParseStream, Parser},
11    visit_mut::{self, VisitMut},
12    AttrStyle, Attribute, AttributeArgs, DataStruct, DeriveInput, Error, FnArg, Generics, Lifetime,
13    Pat, Result, Signature, Token, Type, TypeReference, TypeTuple, Visibility,
14};
15
16#[derive(Clone)]
17struct MaybeItemFn {
18    attrs: Vec<Attribute>,
19    vis: Visibility,
20    sig: Signature,
21    block: TokenStream2,
22}
23
24/// This parses a `TokenStream` into a `MaybeItemFn`
25/// (just like `ItemFn`, but skips parsing the body).
26impl Parse for MaybeItemFn {
27    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
28        let attrs = input.call(syn::Attribute::parse_outer)?;
29        let vis: Visibility = input.parse()?;
30        let sig: Signature = input.parse()?;
31        let block: TokenStream2 = input.parse()?;
32        Ok(Self {
33            attrs,
34            vis,
35            sig,
36            block,
37        })
38    }
39}
40
41impl ToTokens for MaybeItemFn {
42    fn to_tokens(&self, tokens: &mut TokenStream2) {
43        tokens.append_all(
44            self.attrs
45                .iter()
46                .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
47        );
48        self.vis.to_tokens(tokens);
49        self.sig.to_tokens(tokens);
50        self.block.to_tokens(tokens);
51    }
52}
53
54fn parse_attribute_args(input: ParseStream) -> Result<AttributeArgs> {
55    let mut metas = Vec::new();
56
57    loop {
58        if input.is_empty() {
59            break;
60        }
61        let value = input.parse()?;
62        metas.push(value);
63        if input.is_empty() {
64            break;
65        }
66        input.parse::<Token![,]>()?;
67    }
68
69    Ok(metas)
70}
71
72struct LifetimeInjector {
73    lifetime: Lifetime,
74}
75
76impl VisitMut for LifetimeInjector {
77    fn visit_type_reference_mut(&mut self, i: &mut TypeReference) {
78        visit_mut::visit_type_reference_mut(self, i);
79        if i.lifetime.is_none() {
80            i.lifetime = Some(self.lifetime.clone())
81        }
82    }
83    fn visit_lifetime_mut(&mut self, i: &mut Lifetime) {
84        visit_mut::visit_lifetime_mut(self, i);
85        if i.ident == "_" {
86            *i = self.lifetime.clone();
87        }
88    }
89}
90
91fn net_impl(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
92    let root = quote! { ::rnet::hidden };
93    let _args = parse_attribute_args.parse(attr)?;
94    let mut inner_fn = MaybeItemFn::parse.parse(item)?;
95
96    let fn_name_str = inner_fn.sig.ident.to_string();
97    let fn_name = format_ident!("rnet_export_{}", fn_name_str);
98    inner_fn.sig.ident = Ident::new("inner", Span::call_site());
99
100    let args = inner_fn
101        .sig
102        .inputs
103        .iter()
104        .enumerate()
105        .map(|(i, arg)| match arg {
106            FnArg::Receiver(_) => Err(Error::new_spanned(arg, "`self` parameter not supported")),
107            FnArg::Typed(t) => Ok((
108                match &*t.pat {
109                    Pat::Ident(x) => x.ident.to_string(),
110                    _ => format!("arg{}", i),
111                },
112                t.ty.clone(),
113            )),
114        })
115        .collect::<Result<Vec<_>>>()?;
116
117    let (arg_names, arg_types): (Vec<_>, Vec<_>) = args.into_iter().unzip();
118
119    let mut lifetime_injector = LifetimeInjector {
120        lifetime: Lifetime::new("'net", Span::call_site()),
121    };
122    let local_arg_types = arg_types.iter().cloned().map(|mut arg| {
123        lifetime_injector.visit_type_mut(&mut arg);
124        arg
125    });
126
127    let ret_type = match &inner_fn.sig.output {
128        syn::ReturnType::Default => Box::new(Type::Tuple(TypeTuple {
129            paren_token: Default::default(),
130            elems: Default::default(),
131        })),
132        syn::ReturnType::Type(_, t) => t.clone(),
133    };
134
135    let arg_idents: Vec<_> = (0..arg_types.len())
136        .map(|i| format_ident!("arg{}", i))
137        .collect();
138
139    Ok(quote! {
140        #[no_mangle]
141        pub unsafe extern "C" fn #fn_name<'net>(#(
142            #arg_idents: <<#local_arg_types as #root::FromNetArg<'net>>::Owned as #root::Net>::Raw
143        ),*) -> <#ret_type as #root::ToNetReturn>::RawReturn {
144            #[#root::linkme::distributed_slice(#root::EXPORTED_FNS)]
145            static EXPORTED_FN: #root::FnDesc = #root::FnDesc {
146                name: #fn_name_str,
147                args: &[#(
148                    #root::ArgDesc {
149                        name: #arg_names,
150                        ty_: <<#arg_types as #root::FromNetArg>::Owned as #root::FromNet>::FROM_DESC,
151                    }
152                ),*],
153                ret_ty: <#ret_type as #root::ToNetReturn>::RETURN_DESC,
154            };
155
156            #inner_fn
157
158            #root::ToNetReturn::to_raw_return(inner(#(
159                #root::FromNetArg::borrow_or_take(&mut Some(#root::FromNet::from_raw(#arg_idents)))
160            ),*))
161        }
162    })
163}
164
165/// This attribute can be applied to a standalone function to allow it to be called
166/// from .net.
167#[proc_macro_attribute]
168pub fn net(attr: TokenStream, item: TokenStream) -> TokenStream {
169    match net_impl(attr, item.clone()) {
170        Ok(res) => res.into(),
171        Err(e) => {
172            let mut res: TokenStream = e.into_compile_error().into();
173            res.extend(item);
174            res
175        }
176    }
177}
178
179fn derive_net_struct_impl(
180    name: &Ident,
181    generics: &Generics,
182    data: &DataStruct,
183) -> Result<TokenStream2> {
184    let root = quote! { ::rnet::hidden };
185    let name_str = name.to_string();
186    let (field_name_ident, field_type): (Vec<_>, Vec<_>) = data
187        .fields
188        .iter()
189        .enumerate()
190        .map(|(i, field)| {
191            let ident = field
192                .ident
193                .clone()
194                .unwrap_or_else(|| format_ident!("elem{}", i));
195            ((ident.to_string(), ident), &field.ty)
196        })
197        .unzip();
198    let (field_name_str, field_name): (Vec<_>, Vec<_>) = field_name_ident.into_iter().unzip();
199    let raw_name = format_ident!("Raw{}", name);
200    let raw_name_str = format!("_Struct{}", name);
201
202    Ok(quote! {
203        const _: () = {
204            #[#root::linkme::distributed_slice(#root::EXPORTED_STRUCTS)]
205            static EXPORTED_STRUCT: #root::StructDesc = #root::StructDesc {
206                name: #name_str,
207                fields: &[#(
208                    #root::FieldDesc {
209                        name: #field_name_str,
210                        ty_: &#root::TypeDesc {
211                            marshal_in: Some(<#field_type as #root::FromNet>::gen_marshal),
212                            marshal_out: Some(<#field_type as #root::ToNet>::gen_marshal),
213                            ..*<#field_type as #root::Net>::DESC
214                        },
215                    }
216                ),*],
217            };
218
219            #[repr(C)]
220            pub struct #raw_name #generics {
221                #(
222                    pub #field_name: <#field_type as #root::Net>::Raw,
223                )*
224            }
225
226            impl Copy for #raw_name {}
227            impl Clone for #raw_name {
228                fn clone(&self) -> Self {
229                    *self
230                }
231            }
232            impl Default for #raw_name {
233                fn default() -> Self {
234                    Self {
235                        #(
236                            #field_name: Default::default(),
237                        )*
238                    }
239                }
240            }
241
242            unsafe impl #root::Net for #name {
243                type Raw = #raw_name;
244
245                fn gen_type(_ctx: &mut #root::GeneratorContext) -> Box<str> {
246                    #name_str.into()
247                }
248
249                fn gen_raw_type(_ctx: &mut #root::GeneratorContext) -> Box<str> {
250                    #raw_name_str.into()
251                }
252
253                fn is_nullable(_ctx: &mut #root::GeneratorContext) -> bool {
254                    true
255                }
256            }
257
258            unsafe impl #root::FromNet for #name {
259                unsafe fn from_raw(arg: Self::Raw) -> Self {
260                    Self {
261                        #(
262                            #field_name: <#field_type as #root::FromNet>::from_raw(arg.#field_name),
263                        )*
264                    }
265                }
266
267                fn gen_marshal(ctx: &mut #root::GeneratorContext, arg: &str) -> Box<str> {
268                    format!("{}.Encode({})", Self::gen_raw_type(ctx), arg).into()
269                }
270            }
271
272            unsafe impl #root::ToNet for #name {
273                fn into_raw(self) -> Self::Raw {
274                    #raw_name {
275                        #(
276                            #field_name: #root::ToNet::into_raw(self.#field_name),
277                        )*
278                    }
279                }
280
281                fn gen_marshal(_ctx: &mut #root::GeneratorContext, arg: &str) -> Box<str> {
282                    format!("({}).Decode()", arg).into()
283                }
284            }
285        };
286    })
287}
288
289fn derive_net_impl(item: TokenStream) -> Result<TokenStream2> {
290    let derive_input = DeriveInput::parse.parse(item)?;
291
292    match &derive_input.data {
293        syn::Data::Struct(s) => {
294            derive_net_struct_impl(&derive_input.ident, &derive_input.generics, s)
295        }
296        _ => Err(Error::new_spanned(
297            derive_input,
298            "Net derive can only be applied to structs",
299        )),
300    }
301}
302
303/// This derive will implement the `Net`, `ToNet`, and `FromNet`
304/// traits for the given struct, allowing it to be passed to or
305/// returned from .net code.
306#[proc_macro_derive(Net)]
307pub fn derive_net(item: TokenStream) -> TokenStream {
308    match derive_net_impl(item) {
309        Ok(res) => res.into(),
310        Err(e) => e.into_compile_error().into(),
311    }
312}