derive_async_local/
lib.rs

1use proc_macro2::Span;
2use quote::quote;
3use syn::{
4  Data, DeriveInput, GenericArgument, PathArguments, Type, TypePath, parse::Error,
5  parse_macro_input,
6};
7
8mod entry;
9
10fn is_context(type_path: &TypePath) -> bool {
11  let segments: Vec<_> = type_path
12    .path
13    .segments
14    .iter()
15    .map(|segment| segment.ident.to_string())
16    .collect();
17
18  matches!(
19    *segments
20      .iter()
21      .map(String::as_ref)
22      .collect::<Vec<&str>>()
23      .as_slice(),
24    ["async_local", "Context"] | ["Context"]
25  )
26}
27
28/// Derive [AsRef](https://doc.rust-lang.org/std/convert/trait.AsRef.html)<[`Context<T>`](https://docs.rs/async-local/latest/async_local/struct.Context.html)> and [`AsContext`](https://docs.rs/async-local/latest/async_local/trait.AsContext.html) for a struct
29#[proc_macro_derive(AsContext)]
30pub fn derive_as_context(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31  let input = parse_macro_input!(input as DeriveInput);
32  let ident = &input.ident;
33  let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
34
35  if let Some(err) = input
36    .generics
37    .lifetimes()
38    .map(|lifetime| Error::new_spanned(lifetime, "cannot derive AsContext with lifetimes"))
39    .reduce(|mut err, other| {
40      err.combine(other);
41      err
42    })
43  {
44    return err.into_compile_error().into();
45  }
46
47  let data_struct = if let Data::Struct(data_struct) = &input.data {
48    data_struct
49  } else {
50    return Error::new(Span::call_site(), "can only derive AsContext on structs")
51      .into_compile_error()
52      .into();
53  };
54
55  let path_fields: Vec<_> = data_struct
56    .fields
57    .iter()
58    .filter_map(|field| {
59      if let Type::Path(type_path) = &field.ty {
60        Some((field, type_path))
61      } else {
62        None
63      }
64    })
65    .collect();
66
67  let wrapped_context_error = path_fields
68    .iter()
69    .filter(|(_, type_path)| {
70      if let Some(segment) = type_path.path.segments.last() {
71        if let PathArguments::AngleBracketed(inner) = &segment.arguments {
72          if let Some(GenericArgument::Type(Type::Path(type_path))) = inner.args.first() {
73            return is_context(type_path);
74          }
75        }
76      }
77      false
78    })
79    .map(|(_, type_path)| Error::new_spanned(type_path, "Context cannot be wrapped in a pointer type nor cell type and must not be invalidated nor repurposed until dropped"))
80    .reduce(|mut err, other| {
81      err.combine(other);
82      err
83    });
84
85  if let Some(err) = wrapped_context_error {
86    return err.into_compile_error().into();
87  }
88
89  let context_paths: Vec<_> = path_fields
90    .iter()
91    .filter(|(_, type_path)| is_context(type_path))
92    .collect();
93
94  if context_paths.len().eq(&0) {
95    return Error::new(Span::call_site(), "struct must use Context exactly once")
96      .into_compile_error()
97      .into();
98  }
99
100  if context_paths.len().gt(&1) {
101    return context_paths
102      .into_iter()
103      .map(|(_, type_path)| Error::new_spanned(type_path, "Context cannot be used more than once"))
104      .reduce(|mut err, other| {
105        err.combine(other);
106        err
107      })
108      .unwrap()
109      .into_compile_error()
110      .into();
111  }
112
113  let (field, type_path) = context_paths.into_iter().next().unwrap();
114
115  let context_ident = &field.ident;
116
117  let ref_type = type_path.path.segments.last().and_then(|segment| {
118    if let PathArguments::AngleBracketed(ref_type) = &segment.arguments {
119      Some(&ref_type.args)
120    } else {
121      None
122    }
123  });
124
125  let expanded = quote!(
126    impl #impl_generics AsRef<#type_path> for #ident #ty_generics #where_clause {
127      fn as_ref(&self) -> &#type_path {
128        &self.#context_ident
129      }
130    }
131
132    unsafe impl #impl_generics async_local::AsContext for #ident #ty_generics #where_clause {
133      type Target = #ref_type;
134    }
135  );
136
137  expanded.into()
138}
139
140/// Configures main to be executed by the selected Tokio runtime
141///
142/// # Borrowing the runtime
143///
144/// To borrow the runtime directly, add as a function argument
145///
146/// ```
147/// #[async_local::main(flavor = "multi_thread", worker_threads = 10)]
148/// fn main(runtime: &tokio::runtime::Runtime) {}
149/// ```
150///
151/// # Non-worker async function
152///
153/// Note that the async function marked with this macro does not run as a
154/// worker. The expectation is that other tasks are spawned by the function here.
155/// Awaiting on other futures from the function provided here will not
156/// perform as fast as those spawned as workers.
157///
158/// # Multi-threaded runtime
159///
160/// To use the multi-threaded runtime, the macro can be configured using
161/// ```
162/// #[async_local::main(flavor = "multi_thread", worker_threads = 10)]
163/// # async fn main() {}
164/// ```
165///
166/// The `worker_threads` option configures the number of worker threads, and
167/// defaults to the number of cpus on the system. This is the default flavor.
168///
169/// Note: The multi-threaded runtime requires the `rt-multi-thread` feature
170/// flag.
171///
172/// # Current thread runtime
173///
174/// To use the single-threaded runtime known as the `current_thread` runtime,
175/// the macro can be configured using
176/// ```
177/// #[async_local::main(flavor = "current_thread")]
178/// # async fn main() {}
179/// ```
180/// ## Usage
181///
182/// ### Using the multi-thread runtime
183/// ```rust
184/// #[async_local::main]
185/// async fn main() {
186///   println!("Hello world");
187/// }
188/// ```
189///
190/// ### Using current thread runtime
191///
192/// The basic scheduler is single-threaded.
193/// ```rust
194/// #[async_local::main(flavor = "current_thread")]
195/// async fn main() {
196///   println!("Hello world");
197/// }
198/// ```
199///
200/// ### Set number of worker threads
201/// ```rust
202/// #[async_local::main(worker_threads = 2)]
203/// async fn main() {
204///   println!("Hello world");
205/// }
206/// ```
207///
208/// ### Configure the runtime to start with time paused
209/// ```rust
210/// #[async_local::main(flavor = "current_thread", start_paused = true)]
211/// async fn main() {
212///   println!("Hello world");
213/// }
214/// ```
215///
216/// Note that `start_paused` requires the `test-util` feature to be enabled on `tokio`.
217#[proc_macro_attribute]
218pub fn main(
219  args: proc_macro::TokenStream,
220  item: proc_macro::TokenStream,
221) -> proc_macro::TokenStream {
222  entry::main(args.into(), item.into(), cfg!(feature = "rt-multi-thread")).into()
223}
224
225/// Marks async function to be executed by runtime, suitable to test environment.
226///
227/// # Borrowing the runtime
228///
229/// To borrow the runtime directly, add as a function argument
230///
231/// ```
232/// #[async_local::test(flavor = "multi_thread", worker_threads = 10)]
233/// fn test(runtime: &tokio::runtime::Runtime) {
234///   runtime.block_on(async {
235///     assert!(true);
236///   });
237/// }
238/// ```
239///
240/// # Multi-threaded runtime
241///
242/// To use the multi-threaded runtime, the macro can be configured using
243/// ```no_run
244/// #[async_local::test(flavor = "multi_thread", worker_threads = 1)]
245/// async fn my_test() {
246///   assert!(true);
247/// }
248/// ```
249///
250/// The `worker_threads` option configures the number of worker threads, and
251/// defaults to the number of cpus on the system.
252///
253/// Note: The multi-threaded runtime requires the `rt-multi-thread` feature
254/// flag.
255///
256/// # Current thread runtime
257///
258/// The default test runtime is single-threaded. Each test gets a
259/// separate current-thread runtime.
260/// ```no_run
261/// #[async_local::test]
262/// async fn my_test() {
263///   assert!(true);
264/// }
265/// ```
266///
267/// ## Usage
268///
269/// ### Using the multi-thread runtime
270/// ```no_run
271/// #[async_local::test(flavor = "multi_thread")]
272/// async fn my_test() {
273///   assert!(true);
274/// }
275/// ```
276///
277/// ### Using current thread runtime
278/// ```no_run
279/// #[async_local::test]
280/// async fn my_test() {
281///   assert!(true);
282/// }
283/// ```
284///
285/// ### Set number of worker threads
286/// ```no_run
287/// #[async_local::test(flavor = "multi_thread", worker_threads = 2)]
288/// async fn my_test() {
289///   assert!(true);
290/// }
291/// ```
292///
293/// ### Configure the runtime to start with time paused
294/// ```no_run
295/// #[async_local::test(start_paused = true)]
296/// async fn my_test() {
297///   assert!(true);
298/// }
299/// ```
300///
301/// Note that `start_paused` requires the `test-util` feature to be enabled on `tokio``.
302#[proc_macro_attribute]
303pub fn test(
304  args: proc_macro::TokenStream,
305  item: proc_macro::TokenStream,
306) -> proc_macro::TokenStream {
307  entry::test(args.into(), item.into(), cfg!(feature = "rt-multi-thread")).into()
308}