polling_async_trait/
lib.rs

1/*!
2`polling-async-trait` is a library that creates async methods associated with
3polling methods on your traits. It is similar to [`async-trait`], but where
4`async-trait` works on `async` methods, `polling-async-trait` works on `poll_`
5methods.
6
7# Usage
8
9The entry point to this library is the [`async_poll_trait`][macro@async_poll_trait]
10attribute. When applied to a trait, it scans the trait for each method tagged
11with `async_method`. It treats each of these methods as an async polling
12method, and for each one, it adds an equivalent async method to the trait.
13
14```
15# use std::task::{Context, Poll};
16# use std::pin::Pin;
17use polling_async_trait::async_poll_trait;
18use std::io;
19
20#[async_poll_trait]
21trait ExampleTrait {
22    // This will create an async method called `basic` on this trait
23    #[async_method]
24    fn poll_basic(&mut self, cx: &mut Context<'_>) -> Poll<i32>;
25
26    // polling methods can also accept &self or Pin<&mut Self>
27    #[async_method]
28    fn poll_ref_method(&self, cx: &mut Context<'_>) -> Poll<i32>;
29
30    #[async_method]
31    fn poll_pin_method(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<i32>;
32
33    // If `owned` is given, the generated async method will take `self` by move.
34    // This means that the returned future will take ownership of this instance.
35    // Owning futures can still be used with any of `&self`, `&mut self`, or
36    // `Pin<&mut Self>`
37    #[async_method(owned)]
38    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
39
40    #[async_method(owned)]
41    fn poll_close_ref(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
42
43    #[async_method(owned)]
44    fn poll_close_pinned(self: Pin<&mut Self>, cx: &mut Context<'_>)
45        -> Poll<io::Result<()>>;
46
47    // you can use method_name and future_name to control the names of the
48    // generated async method and associated future. This will generate an
49    // async method called do_work, and an associated `Future` called `DoWork`
50    #[async_method(method_name = "do_work", future_name = "DoWork")]
51    fn poll_work(&mut self, cx: &mut Context<'_>) -> Poll<()>;
52}
53
54#[derive(Default)]
55struct ExampleStruct {
56    closed: bool,
57}
58
59impl ExampleTrait for ExampleStruct {
60    fn poll_basic(&mut self, cx: &mut Context<'_>) -> Poll<i32> {
61        Poll::Ready(10)
62    }
63
64    fn poll_ref_method(&self, cx: &mut Context<'_>) -> Poll<i32> {
65        Poll::Ready(20)
66    }
67
68    fn poll_pin_method(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<i32> {
69        Poll::Ready(30)
70    }
71
72    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
73        if !self.closed {
74            println!("closing...");
75            self.closed = true;
76            cx.waker().wake_by_ref();
77            Poll::Pending
78        } else {
79            println!("closed!");
80            Poll::Ready(Ok(()))
81        }
82    }
83
84    fn poll_close_ref(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85        if !self.closed {
86            println!("Error, couldn't close...");
87            Poll::Ready(Err(io::ErrorKind::Other.into()))
88        } else {
89            println!("closed!");
90            Poll::Ready(Ok(()))
91        }
92    }
93
94    fn poll_close_pinned(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
95        let this = self.get_mut();
96        if !this.closed {
97            println!("closing...");
98            this.closed = true;
99            cx.waker().wake_by_ref();
100            Poll::Pending
101        } else {
102            println!("closed!");
103            Poll::Ready(Ok(()))
104        }
105    }
106
107    fn poll_work(&mut self, cx: &mut Context<'_>) -> Poll<()> {
108        Poll::Ready(())
109    }
110}
111
112#[tokio::main]
113async fn main() -> io::Result<()> {
114    let mut data1 = ExampleStruct::default();
115
116    assert_eq!(data1.basic().await, 10);
117    assert_eq!(data1.ref_method().await, 20);
118    data1.do_work().await;
119    data1.close().await?;
120
121    let data2 = ExampleStruct::default();
122    assert!(data2.close_ref().await.is_err());
123
124    let mut data3 = Box::pin(ExampleStruct::default());
125    assert_eq!(data3.as_mut().pin_method().await, 30);
126
127    let data4 = ExampleStruct::default();
128
129    // Soundness: we can can await this method directly because it takes
130    // ownership of `data4`.
131    data4.close_pinned().await?;
132
133    Ok(())
134}
135```
136
137The generated future types will share visibility with the trait (that is, they
138will be `pub` if the trait is `pub`, `pub(crate)` if the trait is `pub(crate)`,
139etc).
140
141# Tradeoffs with [`async-trait`]
142
143Consider carefully which library is best for your use case; polling methods are
144often much more difficult to write (because they require manual state management
145& dealing with `Pin`). If your control flow is complex, it's probably
146preferable to use an `async fn` and [`async-trait`]. The advantage of
147`polling-async-trait` is that the async methods it creates are 0-overhead,
148because the returned futures call the poll methods directly. This means there's
149no need to use a type-erased `Box<dyn Future ... >`.
150
151[`async-trait`]: https://docs.rs/async-trait
152*/
153
154extern crate proc_macro;
155use inflector::Inflector;
156use proc_macro::TokenStream as RawTokenStream;
157use proc_macro2::{Ident, Span};
158use quote::{format_ident, quote, ToTokens};
159use syn::{
160    parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Attribute, Lifetime, Meta,
161    MetaList, MetaNameValue, NestedMeta, PatType, Path, ReturnType, Signature, TraitItem,
162    TraitItemMethod, Type, TypePath,
163};
164
165#[derive(Debug, Copy, Clone, PartialEq, Eq)]
166enum AsyncMethodType {
167    Ref,
168    Owned,
169}
170
171#[derive(Debug, Clone)]
172struct MethodMeta {
173    ty: AsyncMethodType,
174    future_name: Option<String>,
175    async_method_name: Option<String>,
176}
177
178#[derive(Debug, Copy, Clone)]
179enum PollMethodReceiverType {
180    Ref,
181    MutRef,
182    Pinned,
183}
184
185/// Given a return type matching `task::Poll<Type>`, extract `Type` (or return
186/// an error)
187fn extract_output_type(ret: &ReturnType) -> Result<&Type, RawTokenStream> {
188    match *ret {
189        syn::ReturnType::Type(_, ref ty) => match **ty {
190            syn::Type::Path(ref path) => {
191                let tail_segment = path.path.segments.last().unwrap();
192
193                if tail_segment.ident.to_string() != "Poll" {
194                    return Err(syn::Error::new(
195                        ret.span(),
196                        "polling method must return a Poll value",
197                    )
198                    .to_compile_error()
199                    .into());
200                }
201
202                let args = &tail_segment.arguments;
203
204                match *args {
205                    syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments {
206                        args: ref generics,
207                        ..
208                    }) if generics.len() != 1 => Err(syn::Error::new(
209                        args.span(),
210                        "Poll return type should have exactly 1 generic parameter",
211                    )
212                    .to_compile_error()
213                    .into()),
214
215                    syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments {
216                        args: ref generics,
217                        ..
218                    }) => match *generics.first().unwrap() {
219                        syn::GenericArgument::Type(ref ty) => Ok(ty),
220                        _ => Err(syn::Error::new(
221                            args.span(),
222                            "Error parsing generics of Poll type",
223                        )
224                        .to_compile_error()
225                        .into()),
226                    },
227
228                    _ => Err(syn::Error::new(
229                        ret.span(),
230                        "Poll return type must include the <Output> type",
231                    )
232                    .to_compile_error()
233                    .into()),
234                }
235            }
236            _ => Err(
237                syn::Error::new(ret.span(), "polling method must return a Poll value")
238                    .to_compile_error()
239                    .into(),
240            ),
241        },
242        _ => Err(
243            syn::Error::new(ret.span(), "polling method must return a Poll value")
244                .to_compile_error()
245                .into(),
246        ),
247    }
248}
249
250/// Given a function signature, determine the receiver type. Accepts &self,
251/// &mut self, and self: Pin<&mut Self>.
252fn extract_poll_self_type(sig: &Signature) -> Option<PollMethodReceiverType> {
253    match *sig.inputs.first()? {
254        syn::FnArg::Receiver(ref recv) => {
255            if recv.reference.is_none() {
256                None
257            } else if recv.mutability.is_some() {
258                Some(PollMethodReceiverType::MutRef)
259            } else {
260                Some(PollMethodReceiverType::Ref)
261            }
262        }
263        syn::FnArg::Typed(PatType {
264            ref pat, ref ty, ..
265        }) => {
266            // Check that pattern is `self`
267            let pat_ident = match **pat {
268                syn::Pat::Ident(ref pat_ident) => pat_ident,
269                _ => return None,
270            };
271
272            if pat_ident.by_ref.is_some() || pat_ident.subpat.is_some() {
273                return None;
274            }
275
276            if pat_ident.ident != "self" {
277                return None;
278            }
279
280            // Check that the type is Pin<&mut Self>
281            let ty = match **ty {
282                Type::Path(TypePath {
283                    qself: None,
284                    path: Path { ref segments, .. },
285                }) => segments.last()?,
286                _ => return None,
287            };
288
289            if ty.ident != "Pin" {
290                return None;
291            }
292
293            let generics = match ty.arguments {
294                syn::PathArguments::AngleBracketed(ref generics) => &generics.args,
295                _ => return None,
296            };
297
298            if generics.len() != 1 {
299                return None;
300            }
301
302            let ty = match generics.first()? {
303                syn::GenericArgument::Type(Type::Reference(ty)) => ty,
304                _ => return None,
305            };
306
307            if ty.mutability.is_none() {
308                return None;
309            }
310
311            let self_ident = match *ty.elem {
312                Type::Path(TypePath {
313                    qself: None,
314                    ref path,
315                }) => path.get_ident()?,
316                _ => return None,
317            };
318
319            if self_ident != "Self" {
320                return None;
321            }
322
323            Some(PollMethodReceiverType::Pinned)
324        }
325    }
326}
327
328/// Given a list of attributes on a method, if it has an async_method, parse
329/// and remove it
330fn extract_meta<'a>(attrs: &'a mut Vec<Attribute>) -> Option<Result<MethodMeta, RawTokenStream>> {
331    for (index, attr) in attrs.iter_mut().enumerate() {
332        let meta = match attr.parse_meta() {
333            Ok(meta) => meta,
334            Err(..) => continue,
335        };
336
337        let (path, nested) = match meta {
338            syn::Meta::Path(path) => (path, None),
339            syn::Meta::List(MetaList { path, nested, .. }) => (path, Some(nested)),
340            _ => continue,
341        };
342
343        match path.get_ident() {
344            Some(ident) if ident == "async_method" => {}
345            _ => continue,
346        }
347
348        // At this point, we know we have an async_method. Anything wrong past this
349        // point should result in an error.
350
351        attrs.remove(index);
352
353        let mut result = MethodMeta {
354            ty: AsyncMethodType::Ref,
355            async_method_name: None,
356            future_name: None,
357        };
358
359        if let Some(meta_args) = nested {
360            for arg in meta_args.iter() {
361                match arg {
362                    NestedMeta::Meta(Meta::NameValue(MetaNameValue {
363                        path,
364                        lit: syn::Lit::Str(name),
365                        ..
366                    })) => {
367                        let ident = match path.get_ident() {
368                            Some(ident) => ident,
369                            None => {
370                                return Some(Err(syn::Error::new(
371                                    path.span(),
372                                    "Unrecognized meta argument",
373                                )
374                                .to_compile_error()
375                                .into()))
376                            }
377                        };
378
379                        if ident == "method_name" {
380                            result.async_method_name = Some(name.value())
381                        } else if ident == "future_name" {
382                            result.future_name = Some(name.value())
383                        } else {
384                            return Some(Err(syn::Error::new(
385                                path.span(),
386                                "Unrecognized meta argument",
387                            )
388                            .to_compile_error()
389                            .into()));
390                        }
391                    }
392                    NestedMeta::Meta(Meta::Path(path)) => {
393                        let ident = match path.get_ident() {
394                            Some(ident) => ident,
395                            None => {
396                                return Some(Err(syn::Error::new(
397                                    path.span(),
398                                    "Unrecognized meta argument",
399                                )
400                                .to_compile_error()
401                                .into()))
402                            }
403                        };
404
405                        if ident == "owned" {
406                            result.ty = AsyncMethodType::Owned;
407                        } else {
408                            return Some(Err(syn::Error::new(
409                                path.span(),
410                                "Unrecognized meta argument",
411                            )
412                            .to_compile_error()
413                            .into()));
414                        }
415                    }
416                    _ => {
417                        return Some(Err(syn::Error::new(
418                            arg.span(),
419                            "Unrecognized meta argument",
420                        )
421                        .to_compile_error()
422                        .into()))
423                    }
424                }
425            }
426        }
427
428        return Some(Ok(result));
429    }
430
431    None
432}
433
434#[proc_macro_attribute]
435pub fn async_poll_trait(_attr: RawTokenStream, item: RawTokenStream) -> RawTokenStream {
436    let mut parsed = parse_macro_input!(item as syn::ItemTrait);
437
438    let trait_ident = &parsed.ident;
439    let trait_name = trait_ident.to_string();
440    let vis = &parsed.vis;
441
442    let mut new_methods = Vec::new();
443    let mut new_structs = Vec::new();
444
445    for item in &mut parsed.items {
446        // Is this a method?
447        let method = match item {
448            TraitItem::Method(method) => method,
449            _ => continue,
450        };
451
452        // Check if this method should be async'd
453        let meta = match extract_meta(&mut method.attrs) {
454            None => continue,
455            Some(Err(err)) => return err,
456            Some(Ok(meta)) => meta,
457        };
458
459        // We have a meta, so we know this method has been designated to
460        // by processed by this library. Anything that fails at this point
461        // is an error.
462
463        // Get the return type our future will use
464        let output_type = match extract_output_type(&method.sig.output) {
465            Ok(ty) => ty,
466            Err(err) => return err,
467        };
468
469        // Check what kind of receiver this method uses (&self, &mut self, self: Pin<&mut Self>)
470        let receiver_type =
471            match extract_poll_self_type(&method.sig) {
472                Some(receiver_type) => receiver_type,
473                None => return syn::Error::new(
474                    method.sig.span(),
475                    "poll function must be a method that takes &self, &mut self, or Pin<&mut Self>",
476                )
477                .to_compile_error()
478                .into(),
479            };
480
481        let poll_method_ident = &method.sig.ident;
482        let poll_method_name = poll_method_ident.to_string();
483
484        // poll_base_name => base_name
485        let base_name = poll_method_name.strip_prefix("poll_");
486
487        let async_method_name = match meta.async_method_name.as_deref().or(base_name) {
488            Some(name) => name,
489            None => {
490                return syn::Error::new(
491                    poll_method_ident.span(),
492                    "poll method must start with poll_",
493                )
494                .to_compile_error()
495                .into()
496            }
497        };
498        let async_method_ident = Ident::new(
499            async_method_name,
500            Span::call_site().resolved_at(poll_method_ident.span()),
501        );
502
503        let future_name = match meta
504            .future_name
505            .or_else(|| base_name.map(|name| format!("{}{}", trait_name, name.to_class_case())))
506        {
507            Some(name) => name,
508            None => {
509                return syn::Error::new(
510                    poll_method_ident.span(),
511                    "poll method must start with poll_",
512                )
513                .to_compile_error()
514                .into()
515            }
516        };
517
518        let future_ident = Ident::new(
519            future_name.as_str(),
520            Span::call_site().resolved_at(trait_ident.span()),
521        );
522
523        // That's everything we need; now it's just a matter of constructing
524        // the new methods and new future structs and inserting them in the
525        // right places.
526
527        // These will come in handy later. They allow us to stitch together
528        // several quotes!() and make sure the identifier hygiene lines up.
529        let self_ident = format_ident!("self");
530        let cx_ident = format_ident!("cx");
531        let inner_ident = format_ident!("inner");
532        let generic_ident = format_ident!("T");
533        let generic_lt = Lifetime::new("'a", Span::call_site());
534
535        let (async_def, future_def) = match meta.ty {
536            AsyncMethodType::Owned => {
537                let async_method_definition = quote! {
538                    fn #async_method_ident(self) -> #future_ident<Self>
539                        where Self: Sized
540                    {
541                        #future_ident { #inner_ident: self }
542                    }
543                };
544
545                // Safety of this definition:
546                // - if receiver type is ref or mut ref, we can ignore the
547                //   pin entirely (project to unpin)
548                // - if receiver type is pin, we know that self is pinned, so
549                //   it's safe to project to an inner pin
550                // We could do the same thing with pin_project, and avoid
551                // unsafe, but we'd rather avoid the dependency for something
552                // so simple
553
554                let future_poll_definition = match receiver_type {
555                    PollMethodReceiverType::MutRef => quote! {
556                        unsafe { #self_ident.get_unchecked_mut() }.#inner_ident.#poll_method_ident(#cx_ident)
557                    },
558                    PollMethodReceiverType::Ref => quote! {
559                        #self_ident.into_ref().get_ref().#inner_ident.#poll_method_ident(#cx_ident)
560                    },
561                    PollMethodReceiverType::Pinned => quote! {
562                        unsafe { Pin::new_unchecked(&mut #self_ident.get_unchecked_mut().#inner_ident) }.#poll_method_ident(#cx_ident)
563                    },
564                };
565
566                let future_definition = quote! {
567                    #[derive(Debug)]
568                    #vis struct #future_ident<T: #trait_ident> {
569                        #inner_ident: T,
570                    }
571
572                    impl<T: #trait_ident> ::core::future::Future for #future_ident<T> {
573                        type Output = #output_type;
574
575                        fn poll(
576                            #self_ident: ::core::pin::Pin<&mut Self>,
577                            #cx_ident: &mut ::core::task::Context<'_>,
578                        ) -> ::core::task::Poll<Self::Output>
579                        {
580                            #future_poll_definition
581                        }
582                    }
583                };
584
585                (async_method_definition, future_definition)
586            }
587            AsyncMethodType::Ref => {
588                let async_method_receiver = match receiver_type {
589                    PollMethodReceiverType::Ref => quote! { &#self_ident },
590                    PollMethodReceiverType::MutRef => quote! { &mut #self_ident },
591                    PollMethodReceiverType::Pinned => {
592                        quote! { #self_ident: ::core::pin::Pin<&mut Self> }
593                    }
594                };
595
596                let async_method_definition = quote! {
597                    fn #async_method_ident(#async_method_receiver) -> #future_ident<Self> {
598                        #future_ident { #inner_ident: #self_ident }
599                    }
600                };
601
602                let future_inner_type = match receiver_type {
603                    PollMethodReceiverType::Ref => quote! {& #generic_lt #generic_ident },
604                    PollMethodReceiverType::MutRef => quote! { & #generic_lt mut #generic_ident },
605                    PollMethodReceiverType::Pinned => {
606                        quote! { Pin<& #generic_lt mut #generic_ident> }
607                    }
608                };
609
610                let future_poll_definition = match receiver_type {
611                    PollMethodReceiverType::Ref | PollMethodReceiverType::MutRef => quote! {
612                        #self_ident.get_mut().#inner_ident.#poll_method_ident(#cx_ident)
613                    },
614                    PollMethodReceiverType::Pinned => quote! {
615                        #self_ident.get_mut().#inner_ident.as_mut().#poll_method_ident(#cx_ident)
616                    },
617                };
618
619                let future_definition = quote! {
620                    #[derive(Debug)]
621                    #vis struct #future_ident<#generic_lt, #generic_ident: #trait_ident + ?Sized> {
622                        #inner_ident: #future_inner_type,
623                    }
624
625                    impl<'a, T: #trait_ident + ?Sized> ::core::future::Future for #future_ident<'a, T> {
626                        type Output = #output_type;
627
628                        fn poll(
629                            #self_ident: ::core::pin::Pin<&mut Self>,
630                            #cx_ident: &mut ::core::task::Context<'_>,
631                        ) -> ::core::task::Poll<Self::Output>
632                        {
633                            #future_poll_definition
634                        }
635                    }
636                };
637
638                (async_method_definition, future_definition)
639            }
640        };
641
642        let async_def = async_def.into();
643        let async_def = parse_macro_input!(async_def as TraitItemMethod);
644
645        new_methods.push(async_def);
646        new_structs.push(future_def);
647    }
648
649    // Add the new methods to the trait
650    parsed
651        .items
652        .extend(new_methods.into_iter().map(TraitItem::Method));
653
654    let mut output = parsed.into_token_stream();
655
656    // Add the new future definitions to the output
657    output.extend(new_structs);
658
659    output.into()
660}