crossmist_derive/
lib.rs

1#[macro_use]
2extern crate quote;
3
4use proc_macro::TokenStream;
5use quote::ToTokens;
6use syn::parse_macro_input;
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{DeriveInput, Meta, MetaList};
10
11#[proc_macro_attribute]
12pub fn func(meta: TokenStream, input: TokenStream) -> TokenStream {
13    let mut tokio_argument = None;
14    let mut smol_argument = None;
15
16    let args = parse_macro_input!(meta with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
17    for arg in args {
18        if arg.path().is_ident("tokio") {
19            tokio_argument = Some(arg);
20        } else if arg.path().is_ident("smol") {
21            smol_argument = Some(arg);
22        } else {
23            return quote_spanned! { arg.span() => compile_error!("Unknown attribute argument"); }
24                .into();
25        }
26    }
27
28    let mut input = parse_macro_input!(input as syn::ItemFn);
29
30    let return_type = match input.sig.output {
31        syn::ReturnType::Default => quote! { () },
32        syn::ReturnType::Type(_, ref ty) => quote! { #ty },
33    };
34
35    let generic_params = &input.sig.generics;
36    let generics = {
37        let params: Vec<_> = input
38            .sig
39            .generics
40            .params
41            .iter()
42            .map(|param| match param {
43                syn::GenericParam::Type(ref ty) => ty.ident.to_token_stream(),
44                syn::GenericParam::Lifetime(ref lt) => lt.lifetime.to_token_stream(),
45                syn::GenericParam::Const(ref con) => con.ident.to_token_stream(),
46            })
47            .collect();
48        quote! { <#(#params,)*> }
49    };
50    let generic_phantom: Vec<_> = input
51        .sig
52        .generics
53        .params
54        .iter()
55        .enumerate()
56        .map(|(i, param)| {
57            let field = format_ident!("f{}", i);
58            match param {
59                syn::GenericParam::Type(ref ty) => {
60                    let ident = &ty.ident;
61                    quote! { #field: std::marker::PhantomData<fn(#ident) -> #ident> }
62                }
63                syn::GenericParam::Lifetime(ref lt) => {
64                    let lt = &lt.lifetime;
65                    quote! { #field: std::marker::PhantomData<& #lt ()> }
66                }
67                syn::GenericParam::Const(ref _con) => {
68                    unimplemented!()
69                }
70            }
71        })
72        .collect();
73    let generic_phantom_build: Vec<_> = (0..input.sig.generics.params.len())
74        .map(|i| {
75            let field = format_ident!("f{}", i);
76            quote! { #field: std::marker::PhantomData }
77        })
78        .collect();
79
80    // Pray all &input are distinct
81    let link_name = format!(
82        "crossmist_{}_{:?}",
83        input.sig.ident, &input as *const syn::ItemFn,
84    );
85
86    let type_ident = format_ident!("T_{}", link_name);
87    let entry_ident = format_ident!("E_{}", link_name);
88
89    let ident = input.sig.ident;
90    input.sig.ident = format_ident!("invoke");
91
92    let vis = input.vis;
93    input.vis = syn::Visibility::Public(syn::VisPublic {
94        pub_token: <syn::Token![pub] as std::default::Default>::default(),
95    });
96
97    let args = &input.sig.inputs;
98
99    let mut fn_args = Vec::new();
100    let mut fn_types = Vec::new();
101    let mut extracted_args = Vec::new();
102    let mut arg_names = Vec::new();
103    let mut args_from_tuple = Vec::new();
104    let mut binding = Vec::new();
105    let mut has_references = false;
106    for (i, arg) in args.iter().enumerate() {
107        let i = syn::Index::from(i);
108        if let syn::FnArg::Typed(pattype) = arg {
109            if let syn::Pat::Ident(ref patident) = *pattype.pat {
110                let ident = &patident.ident;
111                let colon_token = &pattype.colon_token;
112                let ty = &pattype.ty;
113                fn_args.push(quote! { #ident #colon_token #ty });
114                fn_types.push(quote! { #ty });
115                extracted_args.push(quote! { crossmist_args.#ident });
116                arg_names.push(quote! { #ident });
117                args_from_tuple.push(quote! { args.#i });
118                binding.push(quote! { .bind_value(#ident) });
119                has_references = has_references
120                    || matches!(**ty, syn::Type::Reference(_))
121                    || matches!(
122                        **ty,
123                        syn::Type::Group(syn::TypeGroup { ref elem, .. })
124                            if matches!(**elem, syn::Type::Reference(_)),
125                    );
126            } else {
127                unreachable!();
128            }
129        } else {
130            unreachable!();
131        }
132    }
133
134    let bound = if args.is_empty() {
135        quote! { #ident }
136    } else {
137        let head_ty = &fn_types[0];
138        let tail_ty = &fn_types[1..];
139        let head_arg = &arg_names[0];
140        let tail_binding = &binding[1..];
141        quote! {
142            BindValue::<#head_ty, (#(#tail_ty,)*)>::bind_value(::std::boxed::Box::new(#ident), #head_arg) #(#tail_binding)*
143        }
144    };
145
146    let return_type_wrapped;
147    let pin;
148    if tokio_argument.is_some() || smol_argument.is_some() {
149        return_type_wrapped = quote! { ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = #return_type>>> };
150        pin = quote! { ::std::boxed::Box::pin };
151    } else {
152        return_type_wrapped = return_type.clone();
153        pin = quote! {};
154    }
155
156    let body;
157    if let Some(arg) = tokio_argument {
158        let async_attribute = match arg {
159            Meta::Path(_) => quote! { #[tokio::main] },
160            Meta::List(MetaList { nested, .. }) => quote! { #[tokio::main(#nested)] },
161            Meta::NameValue(..) => {
162                return quote_spanned! { arg.span() => compile_error!("Invalid syntax for 'tokio' argument"); }.into();
163            }
164        };
165        body = quote! {
166            #async_attribute
167            async fn body #generic_params (entry: #entry_ident #generics) -> #return_type {
168                entry.func.deserialize().expect("Failed to deserialize entry").call_object_box(()).await
169            }
170        };
171    } else if let Some(arg) = smol_argument {
172        match arg {
173            Meta::Path(_) => {}
174            _ => {
175                return quote_spanned! { arg.span() => compile_error!("Invalid syntax for 'smol' argument"); }.into();
176            }
177        }
178        body = quote! {
179            fn body #generic_params (entry: #entry_ident #generics) -> #return_type {
180                ::crossmist::imp::async_io::block_on(entry.func.deserialize().expect("Failed to deserialize entry").call_object_box(()))
181            }
182        };
183    } else {
184        body = quote! {
185            fn body #generic_params (entry: #entry_ident #generics) -> #return_type {
186                entry.func.deserialize().expect("Failed to deserialize entry").call_object_box(())
187            }
188        };
189    }
190
191    let impl_code = if has_references {
192        quote! {}
193    } else {
194        quote! {
195            pub fn spawn #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<::crossmist::Child<#return_type>> {
196                use ::crossmist::BindValue;
197                unsafe { ::crossmist::blocking::spawn(::std::boxed::Box::new(::crossmist::CallWrapper(#entry_ident:: #generics ::new(::std::boxed::Box::new(#bound))))) }
198            }
199            pub fn run #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<#return_type> {
200                self.spawn(#(#arg_names,)*)?.join()
201            }
202
203            ::crossmist::if_tokio! {
204                pub async fn spawn_tokio #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<::crossmist::tokio::Child<#return_type>> {
205                    use ::crossmist::BindValue;
206                    unsafe { ::crossmist::tokio::spawn(::std::boxed::Box::new(::crossmist::CallWrapper(#entry_ident:: #generics ::new(::std::boxed::Box::new(#bound))))).await }
207                }
208                pub async fn run_tokio #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<#return_type> {
209                    self.spawn_tokio(#(#arg_names,)*).await?.join().await
210                }
211            }
212
213            ::crossmist::if_smol! {
214                pub async fn spawn_smol #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<::crossmist::smol::Child<#return_type>> {
215                    use ::crossmist::BindValue;
216                    unsafe { ::crossmist::smol::spawn(::std::boxed::Box::new(::crossmist::CallWrapper(#entry_ident:: #generics ::new(::std::boxed::Box::new(#bound))))).await }
217                }
218                pub async fn run_smol #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<#return_type> {
219                    self.spawn_smol(#(#arg_names,)*).await?.join().await
220                }
221            }
222        }
223    };
224
225    let expanded = quote! {
226        #[derive(::crossmist::Object)]
227        struct #entry_ident #generic_params {
228            func: ::crossmist::Delayed<::std::boxed::Box<dyn ::crossmist::FnOnceObject<(), Output = #return_type_wrapped>>>,
229            #(#generic_phantom,)*
230        }
231
232        impl #generic_params #entry_ident #generics {
233            fn new(func: ::std::boxed::Box<dyn ::crossmist::FnOnceObject<(), Output = #return_type_wrapped>>) -> Self {
234                Self {
235                    func: ::crossmist::Delayed::new(func),
236                    #(#generic_phantom_build,)*
237                }
238            }
239        }
240
241        impl #generic_params ::crossmist::InternalFnOnce<(::crossmist::handles::RawHandle,)> for #entry_ident #generics {
242            type Output = i32;
243            #[allow(unreachable_code, clippy::diverging_sub_expression)] // If func returns !
244            fn call_object_once(self, args: (::crossmist::handles::RawHandle,)) -> Self::Output {
245                #body
246                let return_value = body(self);
247                // Avoid explicitly sending a () result
248                if ::crossmist::imp::if_void::<#return_type>().is_none() {
249                    use ::crossmist::handles::FromRawHandle;
250                    // If this function is async, there shouldn't be any task running at this
251                    // moment, so it is fine (and more efficient) to use a sync sender
252                    let output_tx_handle = args.0;
253                    let mut output_tx = unsafe {
254                        ::crossmist::Sender::<#return_type>::from_raw_handle(output_tx_handle)
255                    };
256                    output_tx.send(&return_value)
257                        .expect("Failed to send subprocess output");
258                }
259                0
260            }
261        }
262
263        impl #generic_params ::crossmist::InternalFnOnce<(#(#fn_types,)*)> for #type_ident {
264            type Output = #return_type_wrapped;
265            fn call_object_once(self, args: (#(#fn_types,)*)) -> Self::Output {
266                #pin(#type_ident::invoke(#(#args_from_tuple,)*))
267            }
268        }
269        impl #generic_params ::crossmist::InternalFnMut<(#(#fn_types,)*)> for #type_ident {
270            fn call_object_mut(&mut self, args: (#(#fn_types,)*)) -> Self::Output {
271                #pin(#type_ident::invoke(#(#args_from_tuple,)*))
272            }
273        }
274        impl #generic_params ::crossmist::InternalFn<(#(#fn_types,)*)> for #type_ident {
275            fn call_object(&self, args: (#(#fn_types,)*)) -> Self::Output {
276                #pin(#type_ident::invoke(#(#args_from_tuple,)*))
277            }
278        }
279
280        #[allow(non_camel_case_types)]
281        #[derive(::crossmist::Object)]
282        #vis struct #type_ident;
283
284        impl #type_ident {
285            #[link_name = #link_name]
286            #input
287
288            #impl_code
289        }
290
291        #[allow(non_upper_case_globals)]
292        #vis const #ident: ::crossmist::CallWrapper<#type_ident> = ::crossmist::CallWrapper(#type_ident);
293    };
294
295    TokenStream::from(expanded)
296}
297
298#[proc_macro_attribute]
299pub fn main(_meta: TokenStream, input: TokenStream) -> TokenStream {
300    let mut input = parse_macro_input!(input as syn::ItemFn);
301
302    input.sig.ident = syn::Ident::new("crossmist_old_main", input.sig.ident.span());
303
304    let expanded = quote! {
305        #input
306
307        fn main() {
308            ::crossmist::init();
309            ::std::process::exit(::crossmist::imp::Report::report(crossmist_old_main()));
310        }
311    };
312
313    TokenStream::from(expanded)
314}
315
316#[proc_macro_derive(Object)]
317pub fn derive_object(input: TokenStream) -> TokenStream {
318    let input = parse_macro_input!(input as DeriveInput);
319
320    let ident = &input.ident;
321
322    let generics = {
323        let params: Vec<_> = input
324            .generics
325            .params
326            .iter()
327            .map(|param| match param {
328                syn::GenericParam::Type(ref ty) => ty.ident.to_token_stream(),
329                syn::GenericParam::Lifetime(ref lt) => lt.lifetime.to_token_stream(),
330                syn::GenericParam::Const(ref con) => con.ident.to_token_stream(),
331            })
332            .collect();
333        quote! { <#(#params,)*> }
334    };
335
336    let generic_params = &input.generics.params;
337    let generics_impl = quote! { <#generic_params> };
338
339    let generics_where = input.generics.where_clause;
340
341    let expanded = match input.data {
342        syn::Data::Struct(struct_) => {
343            let field_types: Vec<_> = struct_.fields.iter().map(|field| &field.ty).collect();
344
345            let serialize_fields = match struct_.fields {
346                syn::Fields::Named(ref fields) => fields
347                    .named
348                    .iter()
349                    .map(|field| {
350                        let ident = &field.ident;
351                        quote! {
352                            s.serialize(&self.#ident);
353                        }
354                    })
355                    .collect(),
356                syn::Fields::Unnamed(ref fields) => fields
357                    .unnamed
358                    .iter()
359                    .enumerate()
360                    .map(|(i, _)| {
361                        let i = syn::Index::from(i);
362                        quote! {
363                            s.serialize(&self.#i);
364                        }
365                    })
366                    .collect(),
367                syn::Fields::Unit => Vec::new(),
368            };
369
370            let deserialize_fields = match struct_.fields {
371                syn::Fields::Named(ref fields) => {
372                    let deserialize_fields = fields.named.iter().map(|field| {
373                        let ident = &field.ident;
374                        quote! {
375                            #ident: unsafe { d.deserialize() }?,
376                        }
377                    });
378                    quote! { Ok(Self { #(#deserialize_fields)* }) }
379                }
380                syn::Fields::Unnamed(ref fields) => {
381                    let deserialize_fields = fields.unnamed.iter().map(|_| {
382                        quote! {
383                            unsafe { d.deserialize() }?,
384                        }
385                    });
386                    quote! { Ok(Self (#(#deserialize_fields)*)) }
387                }
388                syn::Fields::Unit => {
389                    quote! { Ok(Self) }
390                }
391            };
392
393            let generics_where_pod: Vec<_> = match generics_where {
394                Some(ref w) => w.predicates.iter().collect(),
395                None => Vec::new(),
396            };
397            let generics_where_pod = quote! {
398                where
399                    #(#generics_where_pod,)*
400                    #(for<'serde> ::crossmist::imp::Identity<'serde, #field_types>: ::crossmist::imp::PlainOldData,)*
401            };
402
403            quote! {
404                unsafe impl #generics_impl ::crossmist:: NonTrivialObject for #ident #generics #generics_where {
405                    fn serialize_self_non_trivial(&self, s: &mut ::crossmist::Serializer) {
406                        #(#serialize_fields)*
407                    }
408                    unsafe fn deserialize_self_non_trivial(d: &mut ::crossmist::Deserializer) -> ::std::io::Result<Self> {
409                        #deserialize_fields
410                    }
411                }
412                impl #generics_impl ::crossmist::imp::PlainOldData for #ident #generics #generics_where_pod {}
413            }
414        }
415        syn::Data::Enum(enum_) => {
416            let field_types: Vec<_> = enum_
417                .variants
418                .iter()
419                .flat_map(|variant| variant.fields.iter().map(|field| &field.ty))
420                .collect();
421
422            let serialize_variants = enum_.variants.iter().enumerate().map(|(i, variant)| {
423                let ident = &variant.ident;
424                match &variant.fields {
425                    syn::Fields::Named(fields) => {
426                        let (refs, sers): (Vec<_>, Vec<_>) = fields
427                            .named
428                            .iter()
429                            .map(|field| {
430                                let ident = &field.ident;
431                                (quote! { ref #ident }, quote! { s.serialize(#ident); })
432                            })
433                            .unzip();
434                        quote! {
435                            Self::#ident{ #(#refs,)* } => {
436                                s.serialize(&(#i as usize));
437                                #(#sers)*
438                            }
439                        }
440                    }
441                    syn::Fields::Unnamed(fields) => {
442                        let (refs, sers): (Vec<_>, Vec<_>) = (0..fields.unnamed.len())
443                            .map(|i| {
444                                let ident = format_ident!("a{}", i);
445                                (quote! { ref #ident }, quote! { s.serialize(#ident); })
446                            })
447                            .unzip();
448                        quote! {
449                            Self::#ident(#(#refs,)*) => {
450                                s.serialize(&(#i as usize));
451                                #(#sers)*
452                            }
453                        }
454                    }
455                    syn::Fields::Unit => {
456                        quote! {
457                            Self::#ident => {
458                                s.serialize(&(#i as usize));
459                            }
460                        }
461                    }
462                }
463            });
464
465            let deserialize_variants = enum_.variants.iter().enumerate().map(|(i, variant)| {
466                let ident = &variant.ident;
467
468                match &variant.fields {
469                    syn::Fields::Named(fields) => {
470                        let des: Vec<_> = fields
471                            .named
472                            .iter()
473                            .map(|field| {
474                                let ident = &field.ident;
475                                quote! { #ident: unsafe { d.deserialize() }? }
476                            })
477                            .collect();
478                        quote! { #i => Ok(Self::#ident{ #(#des,)* }) }
479                    }
480                    syn::Fields::Unnamed(fields) => {
481                        let des: Vec<_> = (0..fields.unnamed.len())
482                            .map(|_| quote! { unsafe { d.deserialize() }? })
483                            .collect();
484                        quote! { #i => Ok(Self::#ident(#(#des,)*)) }
485                    }
486                    syn::Fields::Unit => {
487                        quote! { #i => Ok(Self::#ident) }
488                    }
489                }
490            });
491
492            let generics_where_pod: Vec<_> = match generics_where {
493                Some(ref w) => w.predicates.iter().collect(),
494                None => Vec::new(),
495            };
496            let generics_where_pod = quote! {
497                where
498                    #(#generics_where_pod,)*
499                    #(for<'serde> ::crossmist::imp::Identity<'serde, #field_types>: ::crossmist::imp::PlainOldData,)*
500            };
501
502            quote! {
503                unsafe impl #generics_impl ::crossmist::NonTrivialObject for #ident #generics #generics_where {
504                    fn serialize_self_non_trivial(&self, s: &mut ::crossmist::Serializer) {
505                        match self {
506                            #(#serialize_variants,)*
507                        }
508                    }
509                    unsafe fn deserialize_self_non_trivial(d: &mut ::crossmist::Deserializer) -> ::std::io::Result<Self> {
510                        match d.deserialize::<usize>()? {
511                            #(#deserialize_variants,)*
512                            _ => panic!("Unexpected enum variant"),
513                        }
514                    }
515                }
516                impl #generics_impl ::crossmist::imp::PlainOldData for #ident #generics #generics_where_pod {}
517            }
518        }
519        syn::Data::Union(_) => unimplemented!(),
520    };
521
522    TokenStream::from(expanded)
523}