polling_async_trait/lib.rs
1/*!
2`polling-async-trait` is a library that creates async methods associated with
3polling methods on your traits. It is similar to [`async-trait`], but where
4`async-trait` works on `async` methods, `polling-async-trait` works on `poll_`
5methods.
6
7# Usage
8
9The entry point to this library is the [`async_poll_trait`][macro@async_poll_trait]
10attribute. When applied to a trait, it scans the trait for each method tagged
11with `async_method`. It treats each of these methods as an async polling
12method, and for each one, it adds an equivalent async method to the trait.
13
14```
15# use std::task::{Context, Poll};
16# use std::pin::Pin;
17use polling_async_trait::async_poll_trait;
18use std::io;
19
20#[async_poll_trait]
21trait ExampleTrait {
22 // This will create an async method called `basic` on this trait
23 #[async_method]
24 fn poll_basic(&mut self, cx: &mut Context<'_>) -> Poll<i32>;
25
26 // polling methods can also accept &self or Pin<&mut Self>
27 #[async_method]
28 fn poll_ref_method(&self, cx: &mut Context<'_>) -> Poll<i32>;
29
30 #[async_method]
31 fn poll_pin_method(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<i32>;
32
33 // If `owned` is given, the generated async method will take `self` by move.
34 // This means that the returned future will take ownership of this instance.
35 // Owning futures can still be used with any of `&self`, `&mut self`, or
36 // `Pin<&mut Self>`
37 #[async_method(owned)]
38 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
39
40 #[async_method(owned)]
41 fn poll_close_ref(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
42
43 #[async_method(owned)]
44 fn poll_close_pinned(self: Pin<&mut Self>, cx: &mut Context<'_>)
45 -> Poll<io::Result<()>>;
46
47 // you can use method_name and future_name to control the names of the
48 // generated async method and associated future. This will generate an
49 // async method called do_work, and an associated `Future` called `DoWork`
50 #[async_method(method_name = "do_work", future_name = "DoWork")]
51 fn poll_work(&mut self, cx: &mut Context<'_>) -> Poll<()>;
52}
53
54#[derive(Default)]
55struct ExampleStruct {
56 closed: bool,
57}
58
59impl ExampleTrait for ExampleStruct {
60 fn poll_basic(&mut self, cx: &mut Context<'_>) -> Poll<i32> {
61 Poll::Ready(10)
62 }
63
64 fn poll_ref_method(&self, cx: &mut Context<'_>) -> Poll<i32> {
65 Poll::Ready(20)
66 }
67
68 fn poll_pin_method(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<i32> {
69 Poll::Ready(30)
70 }
71
72 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
73 if !self.closed {
74 println!("closing...");
75 self.closed = true;
76 cx.waker().wake_by_ref();
77 Poll::Pending
78 } else {
79 println!("closed!");
80 Poll::Ready(Ok(()))
81 }
82 }
83
84 fn poll_close_ref(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85 if !self.closed {
86 println!("Error, couldn't close...");
87 Poll::Ready(Err(io::ErrorKind::Other.into()))
88 } else {
89 println!("closed!");
90 Poll::Ready(Ok(()))
91 }
92 }
93
94 fn poll_close_pinned(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
95 let this = self.get_mut();
96 if !this.closed {
97 println!("closing...");
98 this.closed = true;
99 cx.waker().wake_by_ref();
100 Poll::Pending
101 } else {
102 println!("closed!");
103 Poll::Ready(Ok(()))
104 }
105 }
106
107 fn poll_work(&mut self, cx: &mut Context<'_>) -> Poll<()> {
108 Poll::Ready(())
109 }
110}
111
112#[tokio::main]
113async fn main() -> io::Result<()> {
114 let mut data1 = ExampleStruct::default();
115
116 assert_eq!(data1.basic().await, 10);
117 assert_eq!(data1.ref_method().await, 20);
118 data1.do_work().await;
119 data1.close().await?;
120
121 let data2 = ExampleStruct::default();
122 assert!(data2.close_ref().await.is_err());
123
124 let mut data3 = Box::pin(ExampleStruct::default());
125 assert_eq!(data3.as_mut().pin_method().await, 30);
126
127 let data4 = ExampleStruct::default();
128
129 // Soundness: we can can await this method directly because it takes
130 // ownership of `data4`.
131 data4.close_pinned().await?;
132
133 Ok(())
134}
135```
136
137The generated future types will share visibility with the trait (that is, they
138will be `pub` if the trait is `pub`, `pub(crate)` if the trait is `pub(crate)`,
139etc).
140
141# Tradeoffs with [`async-trait`]
142
143Consider carefully which library is best for your use case; polling methods are
144often much more difficult to write (because they require manual state management
145& dealing with `Pin`). If your control flow is complex, it's probably
146preferable to use an `async fn` and [`async-trait`]. The advantage of
147`polling-async-trait` is that the async methods it creates are 0-overhead,
148because the returned futures call the poll methods directly. This means there's
149no need to use a type-erased `Box<dyn Future ... >`.
150
151[`async-trait`]: https://docs.rs/async-trait
152*/
153
154extern crate proc_macro;
155use inflector::Inflector;
156use proc_macro::TokenStream as RawTokenStream;
157use proc_macro2::{Ident, Span};
158use quote::{format_ident, quote, ToTokens};
159use syn::{
160 parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Attribute, Lifetime, Meta,
161 MetaList, MetaNameValue, NestedMeta, PatType, Path, ReturnType, Signature, TraitItem,
162 TraitItemMethod, Type, TypePath,
163};
164
165#[derive(Debug, Copy, Clone, PartialEq, Eq)]
166enum AsyncMethodType {
167 Ref,
168 Owned,
169}
170
171#[derive(Debug, Clone)]
172struct MethodMeta {
173 ty: AsyncMethodType,
174 future_name: Option<String>,
175 async_method_name: Option<String>,
176}
177
178#[derive(Debug, Copy, Clone)]
179enum PollMethodReceiverType {
180 Ref,
181 MutRef,
182 Pinned,
183}
184
185/// Given a return type matching `task::Poll<Type>`, extract `Type` (or return
186/// an error)
187fn extract_output_type(ret: &ReturnType) -> Result<&Type, RawTokenStream> {
188 match *ret {
189 syn::ReturnType::Type(_, ref ty) => match **ty {
190 syn::Type::Path(ref path) => {
191 let tail_segment = path.path.segments.last().unwrap();
192
193 if tail_segment.ident.to_string() != "Poll" {
194 return Err(syn::Error::new(
195 ret.span(),
196 "polling method must return a Poll value",
197 )
198 .to_compile_error()
199 .into());
200 }
201
202 let args = &tail_segment.arguments;
203
204 match *args {
205 syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments {
206 args: ref generics,
207 ..
208 }) if generics.len() != 1 => Err(syn::Error::new(
209 args.span(),
210 "Poll return type should have exactly 1 generic parameter",
211 )
212 .to_compile_error()
213 .into()),
214
215 syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments {
216 args: ref generics,
217 ..
218 }) => match *generics.first().unwrap() {
219 syn::GenericArgument::Type(ref ty) => Ok(ty),
220 _ => Err(syn::Error::new(
221 args.span(),
222 "Error parsing generics of Poll type",
223 )
224 .to_compile_error()
225 .into()),
226 },
227
228 _ => Err(syn::Error::new(
229 ret.span(),
230 "Poll return type must include the <Output> type",
231 )
232 .to_compile_error()
233 .into()),
234 }
235 }
236 _ => Err(
237 syn::Error::new(ret.span(), "polling method must return a Poll value")
238 .to_compile_error()
239 .into(),
240 ),
241 },
242 _ => Err(
243 syn::Error::new(ret.span(), "polling method must return a Poll value")
244 .to_compile_error()
245 .into(),
246 ),
247 }
248}
249
250/// Given a function signature, determine the receiver type. Accepts &self,
251/// &mut self, and self: Pin<&mut Self>.
252fn extract_poll_self_type(sig: &Signature) -> Option<PollMethodReceiverType> {
253 match *sig.inputs.first()? {
254 syn::FnArg::Receiver(ref recv) => {
255 if recv.reference.is_none() {
256 None
257 } else if recv.mutability.is_some() {
258 Some(PollMethodReceiverType::MutRef)
259 } else {
260 Some(PollMethodReceiverType::Ref)
261 }
262 }
263 syn::FnArg::Typed(PatType {
264 ref pat, ref ty, ..
265 }) => {
266 // Check that pattern is `self`
267 let pat_ident = match **pat {
268 syn::Pat::Ident(ref pat_ident) => pat_ident,
269 _ => return None,
270 };
271
272 if pat_ident.by_ref.is_some() || pat_ident.subpat.is_some() {
273 return None;
274 }
275
276 if pat_ident.ident != "self" {
277 return None;
278 }
279
280 // Check that the type is Pin<&mut Self>
281 let ty = match **ty {
282 Type::Path(TypePath {
283 qself: None,
284 path: Path { ref segments, .. },
285 }) => segments.last()?,
286 _ => return None,
287 };
288
289 if ty.ident != "Pin" {
290 return None;
291 }
292
293 let generics = match ty.arguments {
294 syn::PathArguments::AngleBracketed(ref generics) => &generics.args,
295 _ => return None,
296 };
297
298 if generics.len() != 1 {
299 return None;
300 }
301
302 let ty = match generics.first()? {
303 syn::GenericArgument::Type(Type::Reference(ty)) => ty,
304 _ => return None,
305 };
306
307 if ty.mutability.is_none() {
308 return None;
309 }
310
311 let self_ident = match *ty.elem {
312 Type::Path(TypePath {
313 qself: None,
314 ref path,
315 }) => path.get_ident()?,
316 _ => return None,
317 };
318
319 if self_ident != "Self" {
320 return None;
321 }
322
323 Some(PollMethodReceiverType::Pinned)
324 }
325 }
326}
327
328/// Given a list of attributes on a method, if it has an async_method, parse
329/// and remove it
330fn extract_meta<'a>(attrs: &'a mut Vec<Attribute>) -> Option<Result<MethodMeta, RawTokenStream>> {
331 for (index, attr) in attrs.iter_mut().enumerate() {
332 let meta = match attr.parse_meta() {
333 Ok(meta) => meta,
334 Err(..) => continue,
335 };
336
337 let (path, nested) = match meta {
338 syn::Meta::Path(path) => (path, None),
339 syn::Meta::List(MetaList { path, nested, .. }) => (path, Some(nested)),
340 _ => continue,
341 };
342
343 match path.get_ident() {
344 Some(ident) if ident == "async_method" => {}
345 _ => continue,
346 }
347
348 // At this point, we know we have an async_method. Anything wrong past this
349 // point should result in an error.
350
351 attrs.remove(index);
352
353 let mut result = MethodMeta {
354 ty: AsyncMethodType::Ref,
355 async_method_name: None,
356 future_name: None,
357 };
358
359 if let Some(meta_args) = nested {
360 for arg in meta_args.iter() {
361 match arg {
362 NestedMeta::Meta(Meta::NameValue(MetaNameValue {
363 path,
364 lit: syn::Lit::Str(name),
365 ..
366 })) => {
367 let ident = match path.get_ident() {
368 Some(ident) => ident,
369 None => {
370 return Some(Err(syn::Error::new(
371 path.span(),
372 "Unrecognized meta argument",
373 )
374 .to_compile_error()
375 .into()))
376 }
377 };
378
379 if ident == "method_name" {
380 result.async_method_name = Some(name.value())
381 } else if ident == "future_name" {
382 result.future_name = Some(name.value())
383 } else {
384 return Some(Err(syn::Error::new(
385 path.span(),
386 "Unrecognized meta argument",
387 )
388 .to_compile_error()
389 .into()));
390 }
391 }
392 NestedMeta::Meta(Meta::Path(path)) => {
393 let ident = match path.get_ident() {
394 Some(ident) => ident,
395 None => {
396 return Some(Err(syn::Error::new(
397 path.span(),
398 "Unrecognized meta argument",
399 )
400 .to_compile_error()
401 .into()))
402 }
403 };
404
405 if ident == "owned" {
406 result.ty = AsyncMethodType::Owned;
407 } else {
408 return Some(Err(syn::Error::new(
409 path.span(),
410 "Unrecognized meta argument",
411 )
412 .to_compile_error()
413 .into()));
414 }
415 }
416 _ => {
417 return Some(Err(syn::Error::new(
418 arg.span(),
419 "Unrecognized meta argument",
420 )
421 .to_compile_error()
422 .into()))
423 }
424 }
425 }
426 }
427
428 return Some(Ok(result));
429 }
430
431 None
432}
433
434#[proc_macro_attribute]
435pub fn async_poll_trait(_attr: RawTokenStream, item: RawTokenStream) -> RawTokenStream {
436 let mut parsed = parse_macro_input!(item as syn::ItemTrait);
437
438 let trait_ident = &parsed.ident;
439 let trait_name = trait_ident.to_string();
440 let vis = &parsed.vis;
441
442 let mut new_methods = Vec::new();
443 let mut new_structs = Vec::new();
444
445 for item in &mut parsed.items {
446 // Is this a method?
447 let method = match item {
448 TraitItem::Method(method) => method,
449 _ => continue,
450 };
451
452 // Check if this method should be async'd
453 let meta = match extract_meta(&mut method.attrs) {
454 None => continue,
455 Some(Err(err)) => return err,
456 Some(Ok(meta)) => meta,
457 };
458
459 // We have a meta, so we know this method has been designated to
460 // by processed by this library. Anything that fails at this point
461 // is an error.
462
463 // Get the return type our future will use
464 let output_type = match extract_output_type(&method.sig.output) {
465 Ok(ty) => ty,
466 Err(err) => return err,
467 };
468
469 // Check what kind of receiver this method uses (&self, &mut self, self: Pin<&mut Self>)
470 let receiver_type =
471 match extract_poll_self_type(&method.sig) {
472 Some(receiver_type) => receiver_type,
473 None => return syn::Error::new(
474 method.sig.span(),
475 "poll function must be a method that takes &self, &mut self, or Pin<&mut Self>",
476 )
477 .to_compile_error()
478 .into(),
479 };
480
481 let poll_method_ident = &method.sig.ident;
482 let poll_method_name = poll_method_ident.to_string();
483
484 // poll_base_name => base_name
485 let base_name = poll_method_name.strip_prefix("poll_");
486
487 let async_method_name = match meta.async_method_name.as_deref().or(base_name) {
488 Some(name) => name,
489 None => {
490 return syn::Error::new(
491 poll_method_ident.span(),
492 "poll method must start with poll_",
493 )
494 .to_compile_error()
495 .into()
496 }
497 };
498 let async_method_ident = Ident::new(
499 async_method_name,
500 Span::call_site().resolved_at(poll_method_ident.span()),
501 );
502
503 let future_name = match meta
504 .future_name
505 .or_else(|| base_name.map(|name| format!("{}{}", trait_name, name.to_class_case())))
506 {
507 Some(name) => name,
508 None => {
509 return syn::Error::new(
510 poll_method_ident.span(),
511 "poll method must start with poll_",
512 )
513 .to_compile_error()
514 .into()
515 }
516 };
517
518 let future_ident = Ident::new(
519 future_name.as_str(),
520 Span::call_site().resolved_at(trait_ident.span()),
521 );
522
523 // That's everything we need; now it's just a matter of constructing
524 // the new methods and new future structs and inserting them in the
525 // right places.
526
527 // These will come in handy later. They allow us to stitch together
528 // several quotes!() and make sure the identifier hygiene lines up.
529 let self_ident = format_ident!("self");
530 let cx_ident = format_ident!("cx");
531 let inner_ident = format_ident!("inner");
532 let generic_ident = format_ident!("T");
533 let generic_lt = Lifetime::new("'a", Span::call_site());
534
535 let (async_def, future_def) = match meta.ty {
536 AsyncMethodType::Owned => {
537 let async_method_definition = quote! {
538 fn #async_method_ident(self) -> #future_ident<Self>
539 where Self: Sized
540 {
541 #future_ident { #inner_ident: self }
542 }
543 };
544
545 // Safety of this definition:
546 // - if receiver type is ref or mut ref, we can ignore the
547 // pin entirely (project to unpin)
548 // - if receiver type is pin, we know that self is pinned, so
549 // it's safe to project to an inner pin
550 // We could do the same thing with pin_project, and avoid
551 // unsafe, but we'd rather avoid the dependency for something
552 // so simple
553
554 let future_poll_definition = match receiver_type {
555 PollMethodReceiverType::MutRef => quote! {
556 unsafe { #self_ident.get_unchecked_mut() }.#inner_ident.#poll_method_ident(#cx_ident)
557 },
558 PollMethodReceiverType::Ref => quote! {
559 #self_ident.into_ref().get_ref().#inner_ident.#poll_method_ident(#cx_ident)
560 },
561 PollMethodReceiverType::Pinned => quote! {
562 unsafe { Pin::new_unchecked(&mut #self_ident.get_unchecked_mut().#inner_ident) }.#poll_method_ident(#cx_ident)
563 },
564 };
565
566 let future_definition = quote! {
567 #[derive(Debug)]
568 #vis struct #future_ident<T: #trait_ident> {
569 #inner_ident: T,
570 }
571
572 impl<T: #trait_ident> ::core::future::Future for #future_ident<T> {
573 type Output = #output_type;
574
575 fn poll(
576 #self_ident: ::core::pin::Pin<&mut Self>,
577 #cx_ident: &mut ::core::task::Context<'_>,
578 ) -> ::core::task::Poll<Self::Output>
579 {
580 #future_poll_definition
581 }
582 }
583 };
584
585 (async_method_definition, future_definition)
586 }
587 AsyncMethodType::Ref => {
588 let async_method_receiver = match receiver_type {
589 PollMethodReceiverType::Ref => quote! { &#self_ident },
590 PollMethodReceiverType::MutRef => quote! { &mut #self_ident },
591 PollMethodReceiverType::Pinned => {
592 quote! { #self_ident: ::core::pin::Pin<&mut Self> }
593 }
594 };
595
596 let async_method_definition = quote! {
597 fn #async_method_ident(#async_method_receiver) -> #future_ident<Self> {
598 #future_ident { #inner_ident: #self_ident }
599 }
600 };
601
602 let future_inner_type = match receiver_type {
603 PollMethodReceiverType::Ref => quote! {& #generic_lt #generic_ident },
604 PollMethodReceiverType::MutRef => quote! { & #generic_lt mut #generic_ident },
605 PollMethodReceiverType::Pinned => {
606 quote! { Pin<& #generic_lt mut #generic_ident> }
607 }
608 };
609
610 let future_poll_definition = match receiver_type {
611 PollMethodReceiverType::Ref | PollMethodReceiverType::MutRef => quote! {
612 #self_ident.get_mut().#inner_ident.#poll_method_ident(#cx_ident)
613 },
614 PollMethodReceiverType::Pinned => quote! {
615 #self_ident.get_mut().#inner_ident.as_mut().#poll_method_ident(#cx_ident)
616 },
617 };
618
619 let future_definition = quote! {
620 #[derive(Debug)]
621 #vis struct #future_ident<#generic_lt, #generic_ident: #trait_ident + ?Sized> {
622 #inner_ident: #future_inner_type,
623 }
624
625 impl<'a, T: #trait_ident + ?Sized> ::core::future::Future for #future_ident<'a, T> {
626 type Output = #output_type;
627
628 fn poll(
629 #self_ident: ::core::pin::Pin<&mut Self>,
630 #cx_ident: &mut ::core::task::Context<'_>,
631 ) -> ::core::task::Poll<Self::Output>
632 {
633 #future_poll_definition
634 }
635 }
636 };
637
638 (async_method_definition, future_definition)
639 }
640 };
641
642 let async_def = async_def.into();
643 let async_def = parse_macro_input!(async_def as TraitItemMethod);
644
645 new_methods.push(async_def);
646 new_structs.push(future_def);
647 }
648
649 // Add the new methods to the trait
650 parsed
651 .items
652 .extend(new_methods.into_iter().map(TraitItem::Method));
653
654 let mut output = parsed.into_token_stream();
655
656 // Add the new future definitions to the output
657 output.extend(new_structs);
658
659 output.into()
660}