1#[macro_use]
2extern crate quote;
3
4use proc_macro::TokenStream;
5use quote::ToTokens;
6use syn::parse_macro_input;
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{DeriveInput, Meta, MetaList};
10
11#[proc_macro_attribute]
12pub fn func(meta: TokenStream, input: TokenStream) -> TokenStream {
13 let mut tokio_argument = None;
14 let mut smol_argument = None;
15
16 let args = parse_macro_input!(meta with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
17 for arg in args {
18 if arg.path().is_ident("tokio") {
19 tokio_argument = Some(arg);
20 } else if arg.path().is_ident("smol") {
21 smol_argument = Some(arg);
22 } else {
23 return quote_spanned! { arg.span() => compile_error!("Unknown attribute argument"); }
24 .into();
25 }
26 }
27
28 let mut input = parse_macro_input!(input as syn::ItemFn);
29
30 let return_type = match input.sig.output {
31 syn::ReturnType::Default => quote! { () },
32 syn::ReturnType::Type(_, ref ty) => quote! { #ty },
33 };
34
35 let generic_params = &input.sig.generics;
36 let generics = {
37 let params: Vec<_> = input
38 .sig
39 .generics
40 .params
41 .iter()
42 .map(|param| match param {
43 syn::GenericParam::Type(ref ty) => ty.ident.to_token_stream(),
44 syn::GenericParam::Lifetime(ref lt) => lt.lifetime.to_token_stream(),
45 syn::GenericParam::Const(ref con) => con.ident.to_token_stream(),
46 })
47 .collect();
48 quote! { <#(#params,)*> }
49 };
50 let generic_phantom: Vec<_> = input
51 .sig
52 .generics
53 .params
54 .iter()
55 .enumerate()
56 .map(|(i, param)| {
57 let field = format_ident!("f{}", i);
58 match param {
59 syn::GenericParam::Type(ref ty) => {
60 let ident = &ty.ident;
61 quote! { #field: std::marker::PhantomData<fn(#ident) -> #ident> }
62 }
63 syn::GenericParam::Lifetime(ref lt) => {
64 let lt = <.lifetime;
65 quote! { #field: std::marker::PhantomData<& #lt ()> }
66 }
67 syn::GenericParam::Const(ref _con) => {
68 unimplemented!()
69 }
70 }
71 })
72 .collect();
73 let generic_phantom_build: Vec<_> = (0..input.sig.generics.params.len())
74 .map(|i| {
75 let field = format_ident!("f{}", i);
76 quote! { #field: std::marker::PhantomData }
77 })
78 .collect();
79
80 let link_name = format!(
82 "crossmist_{}_{:?}",
83 input.sig.ident, &input as *const syn::ItemFn,
84 );
85
86 let type_ident = format_ident!("T_{}", link_name);
87 let entry_ident = format_ident!("E_{}", link_name);
88
89 let ident = input.sig.ident;
90 input.sig.ident = format_ident!("invoke");
91
92 let vis = input.vis;
93 input.vis = syn::Visibility::Public(syn::VisPublic {
94 pub_token: <syn::Token![pub] as std::default::Default>::default(),
95 });
96
97 let args = &input.sig.inputs;
98
99 let mut fn_args = Vec::new();
100 let mut fn_types = Vec::new();
101 let mut extracted_args = Vec::new();
102 let mut arg_names = Vec::new();
103 let mut args_from_tuple = Vec::new();
104 let mut binding = Vec::new();
105 let mut has_references = false;
106 for (i, arg) in args.iter().enumerate() {
107 let i = syn::Index::from(i);
108 if let syn::FnArg::Typed(pattype) = arg {
109 if let syn::Pat::Ident(ref patident) = *pattype.pat {
110 let ident = &patident.ident;
111 let colon_token = &pattype.colon_token;
112 let ty = &pattype.ty;
113 fn_args.push(quote! { #ident #colon_token #ty });
114 fn_types.push(quote! { #ty });
115 extracted_args.push(quote! { crossmist_args.#ident });
116 arg_names.push(quote! { #ident });
117 args_from_tuple.push(quote! { args.#i });
118 binding.push(quote! { .bind_value(#ident) });
119 has_references = has_references
120 || matches!(**ty, syn::Type::Reference(_))
121 || matches!(
122 **ty,
123 syn::Type::Group(syn::TypeGroup { ref elem, .. })
124 if matches!(**elem, syn::Type::Reference(_)),
125 );
126 } else {
127 unreachable!();
128 }
129 } else {
130 unreachable!();
131 }
132 }
133
134 let bound = if args.is_empty() {
135 quote! { #ident }
136 } else {
137 let head_ty = &fn_types[0];
138 let tail_ty = &fn_types[1..];
139 let head_arg = &arg_names[0];
140 let tail_binding = &binding[1..];
141 quote! {
142 BindValue::<#head_ty, (#(#tail_ty,)*)>::bind_value(::std::boxed::Box::new(#ident), #head_arg) #(#tail_binding)*
143 }
144 };
145
146 let return_type_wrapped;
147 let pin;
148 if tokio_argument.is_some() || smol_argument.is_some() {
149 return_type_wrapped = quote! { ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = #return_type>>> };
150 pin = quote! { ::std::boxed::Box::pin };
151 } else {
152 return_type_wrapped = return_type.clone();
153 pin = quote! {};
154 }
155
156 let body;
157 if let Some(arg) = tokio_argument {
158 let async_attribute = match arg {
159 Meta::Path(_) => quote! { #[tokio::main] },
160 Meta::List(MetaList { nested, .. }) => quote! { #[tokio::main(#nested)] },
161 Meta::NameValue(..) => {
162 return quote_spanned! { arg.span() => compile_error!("Invalid syntax for 'tokio' argument"); }.into();
163 }
164 };
165 body = quote! {
166 #async_attribute
167 async fn body #generic_params (entry: #entry_ident #generics) -> #return_type {
168 entry.func.deserialize().expect("Failed to deserialize entry").call_object_box(()).await
169 }
170 };
171 } else if let Some(arg) = smol_argument {
172 match arg {
173 Meta::Path(_) => {}
174 _ => {
175 return quote_spanned! { arg.span() => compile_error!("Invalid syntax for 'smol' argument"); }.into();
176 }
177 }
178 body = quote! {
179 fn body #generic_params (entry: #entry_ident #generics) -> #return_type {
180 ::crossmist::imp::async_io::block_on(entry.func.deserialize().expect("Failed to deserialize entry").call_object_box(()))
181 }
182 };
183 } else {
184 body = quote! {
185 fn body #generic_params (entry: #entry_ident #generics) -> #return_type {
186 entry.func.deserialize().expect("Failed to deserialize entry").call_object_box(())
187 }
188 };
189 }
190
191 let impl_code = if has_references {
192 quote! {}
193 } else {
194 quote! {
195 pub fn spawn #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<::crossmist::Child<#return_type>> {
196 use ::crossmist::BindValue;
197 unsafe { ::crossmist::blocking::spawn(::std::boxed::Box::new(::crossmist::CallWrapper(#entry_ident:: #generics ::new(::std::boxed::Box::new(#bound))))) }
198 }
199 pub fn run #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<#return_type> {
200 self.spawn(#(#arg_names,)*)?.join()
201 }
202
203 ::crossmist::if_tokio! {
204 pub async fn spawn_tokio #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<::crossmist::tokio::Child<#return_type>> {
205 use ::crossmist::BindValue;
206 unsafe { ::crossmist::tokio::spawn(::std::boxed::Box::new(::crossmist::CallWrapper(#entry_ident:: #generics ::new(::std::boxed::Box::new(#bound))))).await }
207 }
208 pub async fn run_tokio #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<#return_type> {
209 self.spawn_tokio(#(#arg_names,)*).await?.join().await
210 }
211 }
212
213 ::crossmist::if_smol! {
214 pub async fn spawn_smol #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<::crossmist::smol::Child<#return_type>> {
215 use ::crossmist::BindValue;
216 unsafe { ::crossmist::smol::spawn(::std::boxed::Box::new(::crossmist::CallWrapper(#entry_ident:: #generics ::new(::std::boxed::Box::new(#bound))))).await }
217 }
218 pub async fn run_smol #generic_params(&self, #(#fn_args,)*) -> ::std::io::Result<#return_type> {
219 self.spawn_smol(#(#arg_names,)*).await?.join().await
220 }
221 }
222 }
223 };
224
225 let expanded = quote! {
226 #[derive(::crossmist::Object)]
227 struct #entry_ident #generic_params {
228 func: ::crossmist::Delayed<::std::boxed::Box<dyn ::crossmist::FnOnceObject<(), Output = #return_type_wrapped>>>,
229 #(#generic_phantom,)*
230 }
231
232 impl #generic_params #entry_ident #generics {
233 fn new(func: ::std::boxed::Box<dyn ::crossmist::FnOnceObject<(), Output = #return_type_wrapped>>) -> Self {
234 Self {
235 func: ::crossmist::Delayed::new(func),
236 #(#generic_phantom_build,)*
237 }
238 }
239 }
240
241 impl #generic_params ::crossmist::InternalFnOnce<(::crossmist::handles::RawHandle,)> for #entry_ident #generics {
242 type Output = i32;
243 #[allow(unreachable_code, clippy::diverging_sub_expression)] fn call_object_once(self, args: (::crossmist::handles::RawHandle,)) -> Self::Output {
245 #body
246 let return_value = body(self);
247 if ::crossmist::imp::if_void::<#return_type>().is_none() {
249 use ::crossmist::handles::FromRawHandle;
250 let output_tx_handle = args.0;
253 let mut output_tx = unsafe {
254 ::crossmist::Sender::<#return_type>::from_raw_handle(output_tx_handle)
255 };
256 output_tx.send(&return_value)
257 .expect("Failed to send subprocess output");
258 }
259 0
260 }
261 }
262
263 impl #generic_params ::crossmist::InternalFnOnce<(#(#fn_types,)*)> for #type_ident {
264 type Output = #return_type_wrapped;
265 fn call_object_once(self, args: (#(#fn_types,)*)) -> Self::Output {
266 #pin(#type_ident::invoke(#(#args_from_tuple,)*))
267 }
268 }
269 impl #generic_params ::crossmist::InternalFnMut<(#(#fn_types,)*)> for #type_ident {
270 fn call_object_mut(&mut self, args: (#(#fn_types,)*)) -> Self::Output {
271 #pin(#type_ident::invoke(#(#args_from_tuple,)*))
272 }
273 }
274 impl #generic_params ::crossmist::InternalFn<(#(#fn_types,)*)> for #type_ident {
275 fn call_object(&self, args: (#(#fn_types,)*)) -> Self::Output {
276 #pin(#type_ident::invoke(#(#args_from_tuple,)*))
277 }
278 }
279
280 #[allow(non_camel_case_types)]
281 #[derive(::crossmist::Object)]
282 #vis struct #type_ident;
283
284 impl #type_ident {
285 #[link_name = #link_name]
286 #input
287
288 #impl_code
289 }
290
291 #[allow(non_upper_case_globals)]
292 #vis const #ident: ::crossmist::CallWrapper<#type_ident> = ::crossmist::CallWrapper(#type_ident);
293 };
294
295 TokenStream::from(expanded)
296}
297
298#[proc_macro_attribute]
299pub fn main(_meta: TokenStream, input: TokenStream) -> TokenStream {
300 let mut input = parse_macro_input!(input as syn::ItemFn);
301
302 input.sig.ident = syn::Ident::new("crossmist_old_main", input.sig.ident.span());
303
304 let expanded = quote! {
305 #input
306
307 fn main() {
308 ::crossmist::init();
309 ::std::process::exit(::crossmist::imp::Report::report(crossmist_old_main()));
310 }
311 };
312
313 TokenStream::from(expanded)
314}
315
316#[proc_macro_derive(Object)]
317pub fn derive_object(input: TokenStream) -> TokenStream {
318 let input = parse_macro_input!(input as DeriveInput);
319
320 let ident = &input.ident;
321
322 let generics = {
323 let params: Vec<_> = input
324 .generics
325 .params
326 .iter()
327 .map(|param| match param {
328 syn::GenericParam::Type(ref ty) => ty.ident.to_token_stream(),
329 syn::GenericParam::Lifetime(ref lt) => lt.lifetime.to_token_stream(),
330 syn::GenericParam::Const(ref con) => con.ident.to_token_stream(),
331 })
332 .collect();
333 quote! { <#(#params,)*> }
334 };
335
336 let generic_params = &input.generics.params;
337 let generics_impl = quote! { <#generic_params> };
338
339 let generics_where = input.generics.where_clause;
340
341 let expanded = match input.data {
342 syn::Data::Struct(struct_) => {
343 let field_types: Vec<_> = struct_.fields.iter().map(|field| &field.ty).collect();
344
345 let serialize_fields = match struct_.fields {
346 syn::Fields::Named(ref fields) => fields
347 .named
348 .iter()
349 .map(|field| {
350 let ident = &field.ident;
351 quote! {
352 s.serialize(&self.#ident);
353 }
354 })
355 .collect(),
356 syn::Fields::Unnamed(ref fields) => fields
357 .unnamed
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let i = syn::Index::from(i);
362 quote! {
363 s.serialize(&self.#i);
364 }
365 })
366 .collect(),
367 syn::Fields::Unit => Vec::new(),
368 };
369
370 let deserialize_fields = match struct_.fields {
371 syn::Fields::Named(ref fields) => {
372 let deserialize_fields = fields.named.iter().map(|field| {
373 let ident = &field.ident;
374 quote! {
375 #ident: unsafe { d.deserialize() }?,
376 }
377 });
378 quote! { Ok(Self { #(#deserialize_fields)* }) }
379 }
380 syn::Fields::Unnamed(ref fields) => {
381 let deserialize_fields = fields.unnamed.iter().map(|_| {
382 quote! {
383 unsafe { d.deserialize() }?,
384 }
385 });
386 quote! { Ok(Self (#(#deserialize_fields)*)) }
387 }
388 syn::Fields::Unit => {
389 quote! { Ok(Self) }
390 }
391 };
392
393 let generics_where_pod: Vec<_> = match generics_where {
394 Some(ref w) => w.predicates.iter().collect(),
395 None => Vec::new(),
396 };
397 let generics_where_pod = quote! {
398 where
399 #(#generics_where_pod,)*
400 #(for<'serde> ::crossmist::imp::Identity<'serde, #field_types>: ::crossmist::imp::PlainOldData,)*
401 };
402
403 quote! {
404 unsafe impl #generics_impl ::crossmist:: NonTrivialObject for #ident #generics #generics_where {
405 fn serialize_self_non_trivial(&self, s: &mut ::crossmist::Serializer) {
406 #(#serialize_fields)*
407 }
408 unsafe fn deserialize_self_non_trivial(d: &mut ::crossmist::Deserializer) -> ::std::io::Result<Self> {
409 #deserialize_fields
410 }
411 }
412 impl #generics_impl ::crossmist::imp::PlainOldData for #ident #generics #generics_where_pod {}
413 }
414 }
415 syn::Data::Enum(enum_) => {
416 let field_types: Vec<_> = enum_
417 .variants
418 .iter()
419 .flat_map(|variant| variant.fields.iter().map(|field| &field.ty))
420 .collect();
421
422 let serialize_variants = enum_.variants.iter().enumerate().map(|(i, variant)| {
423 let ident = &variant.ident;
424 match &variant.fields {
425 syn::Fields::Named(fields) => {
426 let (refs, sers): (Vec<_>, Vec<_>) = fields
427 .named
428 .iter()
429 .map(|field| {
430 let ident = &field.ident;
431 (quote! { ref #ident }, quote! { s.serialize(#ident); })
432 })
433 .unzip();
434 quote! {
435 Self::#ident{ #(#refs,)* } => {
436 s.serialize(&(#i as usize));
437 #(#sers)*
438 }
439 }
440 }
441 syn::Fields::Unnamed(fields) => {
442 let (refs, sers): (Vec<_>, Vec<_>) = (0..fields.unnamed.len())
443 .map(|i| {
444 let ident = format_ident!("a{}", i);
445 (quote! { ref #ident }, quote! { s.serialize(#ident); })
446 })
447 .unzip();
448 quote! {
449 Self::#ident(#(#refs,)*) => {
450 s.serialize(&(#i as usize));
451 #(#sers)*
452 }
453 }
454 }
455 syn::Fields::Unit => {
456 quote! {
457 Self::#ident => {
458 s.serialize(&(#i as usize));
459 }
460 }
461 }
462 }
463 });
464
465 let deserialize_variants = enum_.variants.iter().enumerate().map(|(i, variant)| {
466 let ident = &variant.ident;
467
468 match &variant.fields {
469 syn::Fields::Named(fields) => {
470 let des: Vec<_> = fields
471 .named
472 .iter()
473 .map(|field| {
474 let ident = &field.ident;
475 quote! { #ident: unsafe { d.deserialize() }? }
476 })
477 .collect();
478 quote! { #i => Ok(Self::#ident{ #(#des,)* }) }
479 }
480 syn::Fields::Unnamed(fields) => {
481 let des: Vec<_> = (0..fields.unnamed.len())
482 .map(|_| quote! { unsafe { d.deserialize() }? })
483 .collect();
484 quote! { #i => Ok(Self::#ident(#(#des,)*)) }
485 }
486 syn::Fields::Unit => {
487 quote! { #i => Ok(Self::#ident) }
488 }
489 }
490 });
491
492 let generics_where_pod: Vec<_> = match generics_where {
493 Some(ref w) => w.predicates.iter().collect(),
494 None => Vec::new(),
495 };
496 let generics_where_pod = quote! {
497 where
498 #(#generics_where_pod,)*
499 #(for<'serde> ::crossmist::imp::Identity<'serde, #field_types>: ::crossmist::imp::PlainOldData,)*
500 };
501
502 quote! {
503 unsafe impl #generics_impl ::crossmist::NonTrivialObject for #ident #generics #generics_where {
504 fn serialize_self_non_trivial(&self, s: &mut ::crossmist::Serializer) {
505 match self {
506 #(#serialize_variants,)*
507 }
508 }
509 unsafe fn deserialize_self_non_trivial(d: &mut ::crossmist::Deserializer) -> ::std::io::Result<Self> {
510 match d.deserialize::<usize>()? {
511 #(#deserialize_variants,)*
512 _ => panic!("Unexpected enum variant"),
513 }
514 }
515 }
516 impl #generics_impl ::crossmist::imp::PlainOldData for #ident #generics #generics_where_pod {}
517 }
518 }
519 syn::Data::Union(_) => unimplemented!(),
520 };
521
522 TokenStream::from(expanded)
523}