real_async_trait/
lib.rs

1//!
2//! # `#[real_async_trait]`
3//! [![travis]](https://travis-ci.org/4lDO2/real-async-trait-rs)
4//! [![cratesio]](https://crates.io/crates/real-async-trait)
5//! [![docsrs]](https://docs.rs/real-async-trait/)
6//!
7//! [travis]: https://travis-ci.org/4lDO2/real-async-trait-rs.svg?branch=master
8//! [cratesio]: https://img.shields.io/crates/v/real-async-trait.svg
9//! [docsrs]: https://docs.rs/real-async-trait/badge.svg
10//!
11//! This crate provides a producedural macro that works around the current limitation of not being
12//! able to put `async fn`s in a trait, _without type erasure_, by using experimental
13//! nightly-features, namely [generic associated types
14//! (GATs)](https://github.com/rust-lang/rfcs/blob/master/text/1598-generic_associated_types.md)
15//! and [existential
16//! types](https://github.com/rust-lang/rfcs/blob/master/text/2515-type_alias_impl_trait.md).
17//!
18//! ## Caveats
19//!
20//! While this proc macro will allow you to write non-type-erased allocation-free async fns within
21//! traits, there are a few caveats to this (non-exhaustive):
22//!
23//! * at the moment, all references used in the async fn, must have their lifetimes be explicitly
24//! specified, either from the top-level of the trait, or in the function declaration;
25//! * there can only be a single lifetime in use simultaneously. I have no idea why, but it could
26//! be due to buggy interaction between existential types and generic associated types;
27//! * since GATs are an "incomplete" feature in rust, it may not be sound or just not compile
28//! correctly or at all. __Don't use this in production code!__
29//!
30//! ## Example
31//! ```ignore
32//! #[async_std::main]
33//! # async fn main() {
34//! /// An error code, similar to `errno` in C.
35//! pub type Errno = usize;
36//!
37//! /// A UNIX-like file descriptor.
38//! pub type FileDescriptor = usize;
39//!
40//! /// "No such file or directory"
41//! pub const ENOENT: usize = 1;
42//!
43//! /// "Bad file descriptor"
44//! pub const EBADF: usize = 2;
45//!
46//! /// A filesystem-like primitive, used in the Redox Operating System.
47//! #[real_async_trait]
48//! pub trait RedoxScheme {
49//!     async fn open<'a>(&'a self, path: &'a [u8], flags: usize) -> Result<FileDescriptor, Errno>;
50//!     async fn read<'a>(&'a self, fd: FileDescriptor, buffer: &'a mut [u8]) -> Result<usize, Errno>;
51//!     async fn write<'a>(&'a self, fd: FileDescriptor, buffer: &'a [u8]) -> Result<usize, Errno>;
52//!     async fn close<'a>(&'a self, fd: FileDescriptor) -> Result<(), Errno>;
53//! }
54//!
55//! /// A scheme that does absolutely nothing.
56//! struct MyNothingScheme;
57//!
58//! #[real_async_trait]
59//! impl RedoxScheme for MyNothingScheme {
60//!     async fn open<'a>(&'a self, path: &'a [u8], flags: usize) -> Result<FileDescriptor, Errno> {
61//!         // I can write async code in here!
62//!         Err(ENOENT)
63//!     }
64//!     async fn read<'a>(&'a self, buffer: &'a mut [u8]) -> Result<usize, Errno> {
65//!         Err(EBADF)
66//!     }
67//!     async fn write<'a>(&'a self, path: &'a [u8]) -> Result<usize, Errno> {
68//!         Err(EBADF)
69//!     }
70//!     async fn close<'a>(&'a self, path: &'a [u8]) -> Result<(), Errno> {
71//!         Err(EBADF)
72//!     }
73//! }
74//!
75//! let my_nothing_scheme = MyNothingScheme;
76//!
77//! assert_eq!(my_nothing_scheme.open(b"nothing exists here", 0).await, Err(ENOENT), "why would anything exist here?");
78//! assert_eq!(my_nothing_scheme.read(1337, &mut []).await, Err(EBADF));
79//! assert_eq!(my_nothing_scheme.write(1337, &[]).await, Err(EBADF));
80//! assert_eq!(my_nothing_scheme.close(1337).await, Err(EBADF));
81//!
82//! # }
83//!
84//! ```
85//! ## How it works
86//!
87//! Under the hood, this proc macro will insert generic associated types (GATs) for the the futures
88//! that are the return types of the async fns in the trait definition. The macro will generate the
89//! following for the `RedoxScheme` trait (simplified generated names):
90//!
91//! ```ignore
92//! pub trait RedoxScheme {
93//!     // Downgraded functions, from async fn to fn. Their types have changed into a generic
94//!     // associated type.
95//!     fn open<'a>(&'a self, path: &'a [u8], flags: usize) -> Self::OpenFuture<'a>;
96//!     fn read<'a>(&'a self, fd: usize, buf: &'a mut [u8]) -> Self::ReadFuture<'a>;
97//!     fn write<'a>(&'a self, fd: usize, buf: &'a [u8]) -> Self::WriteFuture<'a>;
98//!     fn close<'a>(&'a self, fd: usize) -> Self::CloseFuture<'a>;
99//!
100//!     // Generic associated types, the return values are moved to here.
101//!     type OpenFuture<'a>: ::core::future::Future<Output = Result<FileDescriptor, Errno>> + 'a;
102//!     type ReadFuture<'a>: ::core::future::Future<Output = Result<usize, Errno>> + 'a;
103//!     type WriteFuture<'a>: ::core::future::Future<Output = Result<usize, Errno>> + 'a;
104//!     type CloseFuture<'a>: ::core::future::Future<Output = Result<(), Errno>> + 'a;
105//! }
106//! ```
107//!
108//! Meanwhile, the impls will get the following generated code (simplified here as well):
109//!
110//! ```ignore
111//!
112//! // Wrap everything in a private module to prevent the existential types from leaking.
113//! mod __private {
114//!     impl RedoxScheme for MyNothingScheme {
115//!         // Async fns are downgraded here as well, and the same thing goes with the return
116//!         // values.
117//!         fn open<'a>(&'a self, path: &'a [u8], flags: usize) -> Self::OpenFuture<'a> {
118//!             // All expressions in async fns are wrapped in async closures. The compiler will
119//!             // automagically figure out the actual types of the existential type aliases, even
120//!             // though they are anonymous.
121//!             async move { Err(ENOENT) }
122//!         }
123//!         fn read<'a>(&'a self, fd: usize, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
124//!             async move { Err(EBADF) }
125//!         }
126//!         fn write<'a>(&'a self, fd: usize, buf: &'a [u8]) -> Self::WriteFuture<'a> {
127//!             async move { Err(EBADF) }
128//!         }
129//!         fn close<'a>(&'a self, fd: usize) -> Self::CloseFuture<'a> {
130//!             async move { Err(EBADF) }
131//!         }
132//!
133//!         // This is the part where the existential types come in. Currently, there is no
134//!         // possible way to use types within type aliases within traits, that aren't publicly
135//!         // accessible. This we need async closures to avoid having to redefine our futures with
136//!         // custom state machines, or use type erased pointers, we'll use existential types.
137//!         type OpenFuture<'a> = OpenFutureExistentialType<'a>;
138//!         type ReadFuture<'a> = ReadFutureExistentialType<'a>;
139//!         type WriteFuture<'a> = WriteFutureExistentialType<'a>;
140//!         type CloseFuture<'a> = CloseFutureExistentialType<'a>;
141//!     }
142//!     // This is where the return values actually are defined. At the moment these type alises
143//!     // with impl trait can only occur outside of the trait itself, unfortunately. There can
144//!     // only be one type that this type alias refers to, which the compiler will keep track of.
145//!     type OpenFutureExistentialType<'a> = impl Future<Output = Result<FileDescriptor, Errno>> +
146//!     'a;
147//!     type ReadFutureExistentialType<'a> = impl Future<Output = Result<usize, Errno>> + 'a;
148//!     type WriteFutureExistentialType<'a> = impl Future<Output = Result<usize, Errno>> + 'a;
149//!     type CloseFutureExistentialType<'a> = impl Future<Output = Result<(), Errno>> + 'a;
150//! }
151//! ```
152//!
153
154extern crate proc_macro;
155
156use std::str::FromStr;
157use std::{iter, mem};
158
159use proc_macro2::{Span, TokenStream};
160use quote::quote;
161use syn::punctuated::Punctuated;
162use syn::token;
163use syn::{
164    AngleBracketedGenericArguments, Binding, Block, Expr, ExprAsync, FnArg, GenericArgument,
165    GenericParam, Generics, Ident, ImplItem, ImplItemType, ItemImpl, ItemTrait, ItemType, Lifetime,
166    LifetimeDef, PatType, Path, PathArguments, PathSegment, ReturnType, Signature, Stmt, Token,
167    TraitBound, TraitBoundModifier, TraitItem, TraitItemType, Type, TypeImplTrait, TypeParamBound,
168    TypePath, TypeReference, TypeTuple, Visibility,
169};
170
171mod tests;
172
173struct LifetimeVisitor;
174
175impl<'ast> syn::visit::Visit<'ast> for LifetimeVisitor {
176    fn visit_type_reference(&mut self, i: &'ast TypeReference) {
177        if i.lifetime.is_none() {
178            panic!("Reference at {:?} lacked an explicit lifetime, which is required by this proc macro", i.and_token.span);
179        }
180    }
181}
182
183fn handle_item_impl(mut item: ItemImpl) -> TokenStream {
184    let mut existential_type_defs = Vec::new();
185    let mut gat_defs = Vec::new();
186
187    for method in item
188        .items
189        .iter_mut()
190        .filter_map(|item| {
191            if let ImplItem::Method(method) = item {
192                Some(method)
193            } else {
194                None
195            }
196        })
197        .filter(|method| method.sig.asyncness.is_some())
198    {
199        method.sig.asyncness = None;
200
201        validate_that_function_always_has_lifetimes(&method.sig);
202
203        let (toplevel_lifetimes, function_lifetimes) =
204            already_defined_lifetimes(&item.generics, &method.sig.generics);
205
206        let existential_type_name = format!(
207            "__real_async_trait_impl_ExistentialTypeFor_{}",
208            method.sig.ident
209        );
210        let existential_type_ident = Ident::new(&existential_type_name, Span::call_site());
211
212        existential_type_defs.push(ItemType {
213            attrs: Vec::new(),
214            eq_token: Token!(=)(Span::call_site()),
215            generics: Generics {
216                gt_token: Some(Token!(>)(Span::call_site())),
217                lt_token: Some(Token!(<)(Span::call_site())),
218                params: toplevel_lifetimes
219                    .iter()
220                    .cloned()
221                    .map(GenericParam::Lifetime)
222                    .collect(),
223                where_clause: None,
224            },
225            ident: existential_type_ident,
226            semi_token: Token!(;)(Span::call_site()),
227            vis: Visibility::Inherited,
228            ty: Box::new(Type::ImplTrait(TypeImplTrait {
229                bounds: iter::once(TypeParamBound::Trait(future_trait_bound(return_type(
230                    method.sig.output.clone(),
231                ))))
232                .chain(
233                    toplevel_lifetimes
234                        .iter()
235                        .cloned()
236                        .map(|lifetime_def| TypeParamBound::Lifetime(lifetime_def.lifetime)),
237                )
238                .collect(),
239                impl_token: Token!(impl)(Span::call_site()),
240            })),
241            type_token: Token!(type)(Span::call_site()),
242        });
243
244        let existential_type_path_for_impl = Path {
245            // self::__real_async_trait_impl_ExistentialTypeFor_FUNCTIONNAME
246            leading_colon: None,
247            segments: vec![
248                PathSegment {
249                    arguments: PathArguments::None,
250                    ident: Ident::new("self", Span::call_site()),
251                },
252                PathSegment {
253                    arguments: PathArguments::AngleBracketed(lifetime_angle_bracketed_bounds(
254                        toplevel_lifetimes
255                            .into_iter()
256                            .map(|lifetime_def| lifetime_def.lifetime),
257                    )),
258                    ident: Ident::new(&existential_type_name, Span::call_site()),
259                },
260            ]
261            .into_iter()
262            .collect(),
263        };
264        let existential_path_type = Type::Path(TypePath {
265            path: existential_type_path_for_impl,
266            qself: None,
267        });
268
269        let gat_ident = gat_ident_for_sig(&method.sig);
270
271        gat_defs.push(ImplItemType {
272            attrs: Vec::new(),
273            defaultness: None,
274            eq_token: Token!(=)(Span::call_site()),
275            generics: Generics {
276                lt_token: Some(Token!(<)(Span::call_site())),
277                gt_token: Some(Token!(>)(Span::call_site())),
278                where_clause: None,
279                params: function_lifetimes
280                    .iter()
281                    .cloned()
282                    .map(GenericParam::Lifetime)
283                    .collect(),
284            },
285            ident: gat_ident.clone(),
286            semi_token: Token!(;)(Span::call_site()),
287            ty: existential_path_type.clone(),
288            type_token: Token!(type)(Span::call_site()),
289            vis: Visibility::Inherited,
290        });
291
292        let gat_self_type = self_gat_type(
293            gat_ident,
294            function_lifetimes
295                .into_iter()
296                .map(|lifetime_def| lifetime_def.lifetime),
297        );
298
299        method.sig.output = ReturnType::Type(
300            Token!(->)(Span::call_site()),
301            Box::new(gat_self_type.into()),
302        );
303
304        let method_stmts = mem::replace(&mut method.block.stmts, Vec::new());
305
306        method.block.stmts = vec![Stmt::Expr(Expr::Async(ExprAsync {
307            async_token: Token!(async)(Span::call_site()),
308            attrs: Vec::new(),
309            block: Block {
310                brace_token: token::Brace {
311                    span: Span::call_site(),
312                },
313                stmts: method_stmts,
314            },
315            capture: Some(Token!(move)(Span::call_site())),
316        }))];
317    }
318
319    item.items.extend(gat_defs.into_iter().map(Into::into));
320
321    quote! {
322
323        mod __real_async_trait_impl {
324            use super::*;
325
326            #item
327
328            #(#existential_type_defs)*
329        }
330    }
331}
332
333fn return_type(retval: ReturnType) -> Type {
334    match retval {
335        ReturnType::Default => Type::Tuple(TypeTuple {
336            elems: Punctuated::new(),
337            paren_token: token::Paren {
338                span: Span::call_site(),
339            },
340        }),
341        ReturnType::Type(_, ty) => *ty,
342    }
343}
344
345fn future_trait_bound(fn_output_ty: Type) -> TraitBound {
346    const FUTURE_TRAIT_PATH_STR: &str = "::core::future::Future";
347    const FUTURE_TRAIT_OUTPUT_IDENT_STR: &str = "Output";
348
349    let mut future_trait_path =
350        syn::parse2::<Path>(TokenStream::from_str(FUTURE_TRAIT_PATH_STR).unwrap())
351            .expect("failed to parse `::core::future::Future` as a syn `Path`");
352
353    let future_angle_bracketed_args = AngleBracketedGenericArguments {
354        colon2_token: None, // FIXME
355        lt_token: Token!(<)(Span::call_site()),
356        gt_token: Token!(>)(Span::call_site()),
357        args: iter::once(GenericArgument::Binding(Binding {
358            ident: Ident::new(FUTURE_TRAIT_OUTPUT_IDENT_STR, Span::call_site()),
359            eq_token: Token!(=)(Span::call_site()),
360            ty: fn_output_ty,
361        }))
362        .collect(),
363    };
364
365    future_trait_path
366        .segments
367        .last_mut()
368        .expect("Expected ::core::future::Future to have `Future` as the last segment")
369        .arguments = PathArguments::AngleBracketed(future_angle_bracketed_args);
370
371    TraitBound {
372        // for TraitBounds, these are HRTBs, which are useless since there are already GATs present
373        lifetimes: None,
374        // This is not ?Sized or something like that
375        modifier: TraitBoundModifier::None,
376        paren_token: None,
377        path: future_trait_path,
378    }
379}
380
381fn validate_that_function_always_has_lifetimes(signature: &Signature) {
382    for input in signature.inputs.iter() {
383        match input {
384            FnArg::Receiver(ref recv) => {
385                if let Some((_ampersand, _lifetime @ None)) = &recv.reference {
386                    panic!("{}self parameter lacked an explicit lifetime, which is required by this proc macro", if recv.mutability.is_some() { "&mut " } else { "&" });
387                }
388            }
389            FnArg::Typed(PatType { ref ty, .. }) => {
390                syn::visit::visit_type(&mut LifetimeVisitor, ty)
391            }
392        }
393    }
394    if let ReturnType::Type(_, ref ty) = signature.output {
395        syn::visit::visit_type(&mut LifetimeVisitor, ty);
396    };
397}
398fn already_defined_lifetimes(
399    toplevel_generics: &Generics,
400    method_generics: &Generics,
401) -> (Vec<LifetimeDef>, Vec<LifetimeDef>) {
402    //Global scope
403    //let mut lifetimes = vec! [LifetimeDef::new(Lifetime::new("'static", Span::call_site()))];
404
405    let mut lifetimes = Vec::new();
406    // Trait definition scope
407    lifetimes.extend(toplevel_generics.lifetimes().cloned());
408    // Function definition scope
409    let function_lifetimes = method_generics.lifetimes().cloned().collect::<Vec<_>>();
410    lifetimes.extend(function_lifetimes.iter().cloned());
411    (lifetimes, function_lifetimes)
412}
413fn lifetime_angle_bracketed_bounds(
414    lifetimes: impl IntoIterator<Item = Lifetime>,
415) -> AngleBracketedGenericArguments {
416    AngleBracketedGenericArguments {
417        colon2_token: None,
418        lt_token: Token!(<)(Span::call_site()),
419        gt_token: Token!(>)(Span::call_site()),
420        args: lifetimes
421            .into_iter()
422            .map(|lifetime_def| GenericArgument::Lifetime(lifetime_def))
423            .collect(),
424    }
425}
426fn gat_ident_for_sig(sig: &Signature) -> Ident {
427    let gat_name = format!("__real_async_trait_impl_TypeFor_{}", sig.ident);
428    Ident::new(&gat_name, Span::call_site())
429}
430fn self_gat_type(
431    gat_ident: Ident,
432    function_lifetimes: impl IntoIterator<Item = Lifetime>,
433) -> TypePath {
434    TypePath {
435        path: Path {
436            // represents the pattern Self::GAT_NAME...
437            leading_colon: None,
438            segments: vec![
439                PathSegment {
440                    ident: Ident::new("Self", Span::call_site()),
441                    arguments: PathArguments::None,
442                },
443                PathSegment {
444                    ident: gat_ident,
445                    arguments: PathArguments::AngleBracketed(lifetime_angle_bracketed_bounds(
446                        function_lifetimes,
447                    )),
448                },
449            ]
450            .into_iter()
451            .collect(),
452        },
453        qself: None,
454    }
455}
456fn handle_item_trait(mut item: ItemTrait) -> TokenStream {
457    let mut new_gat_items = Vec::new();
458
459    // Loop through every single async fn declared in the trait.
460    for method in item
461        .items
462        .iter_mut()
463        .filter_map(|item| {
464            if let TraitItem::Method(func) = item {
465                Some(func)
466            } else {
467                None
468            }
469        })
470        .filter(|method| method.sig.asyncness.is_some())
471    {
472        // For each async fn, remove the async part, replace the return value with a generic
473        // associated type, and add that generic associated type to the trait item.
474
475        // Check that all types have a lifetime that is either specific to the trait item, or
476        // to the current function (or 'static). Any other lifetime will and must produce a
477        // compiler error.
478        let gat_ident = gat_ident_for_sig(&method.sig);
479
480        let method_return_ty = return_type(method.sig.output.clone());
481
482        validate_that_function_always_has_lifetimes(&method.sig);
483
484        method.sig.asyncness = None;
485
486        let (toplevel_lifetimes, function_lifetimes) =
487            already_defined_lifetimes(&item.generics, &method.sig.generics);
488
489        new_gat_items.push(TraitItemType {
490            attrs: Vec::new(),
491            type_token: Token!(type)(Span::call_site()),
492            bounds: iter::once(TypeParamBound::Trait(future_trait_bound(method_return_ty)))
493                .chain(
494                    toplevel_lifetimes
495                        .into_iter()
496                        .map(|lifetime_def| lifetime_def.lifetime)
497                        .map(TypeParamBound::Lifetime),
498                )
499                .collect(),
500            colon_token: Some(Token!(:)(Span::call_site())),
501            default: None,
502            generics: Generics {
503                lt_token: Some(Token!(<)(Span::call_site())),
504                gt_token: Some(Token!(>)(Span::call_site())),
505                where_clause: None,
506                params: function_lifetimes
507                    .iter()
508                    .cloned()
509                    .map(GenericParam::Lifetime)
510                    .collect(),
511            },
512            ident: gat_ident.clone(),
513            semi_token: Token!(;)(Span::call_site()),
514        });
515
516        let self_gat_type = self_gat_type(
517            gat_ident,
518            function_lifetimes
519                .into_iter()
520                .map(|lifetime_def| lifetime_def.lifetime),
521        );
522
523        method.sig.output = ReturnType::Type(
524            Token!(->)(Span::call_site()),
525            Box::new(self_gat_type.into()),
526        );
527    }
528    item.items
529        .extend(new_gat_items.into_iter().map(TraitItem::Type));
530
531    quote! {
532        #item
533    }
534}
535fn real_async_trait2(_args_stream: TokenStream, token_stream: TokenStream) -> TokenStream {
536    // The #[real_async_trait] attribute macro, is applicable to both trait blocks, and to impl
537    // blocks that operate on that trait.
538
539    if let Ok(item_trait) = syn::parse2::<ItemTrait>(token_stream.clone()) {
540        handle_item_trait(item_trait)
541    } else if let Ok(item_impl) = syn::parse2::<ItemImpl>(token_stream) {
542        handle_item_impl(item_impl)
543    } else {
544        panic!("expected either a trait or an impl item")
545    }
546    .into()
547}
548
549/// A proc macro that supports using async fn in traits and trait impls. Refer to the top-level
550/// crate documentation for more information.
551#[proc_macro_attribute]
552pub fn real_async_trait(
553    args_stream: proc_macro::TokenStream,
554    token_stream: proc_macro::TokenStream,
555) -> proc_macro::TokenStream {
556    real_async_trait2(args_stream.into(), token_stream.into()).into()
557}