derive_async_local/lib.rs
1use proc_macro2::Span;
2use quote::quote;
3use syn::{
4 Data, DeriveInput, GenericArgument, PathArguments, Type, TypePath, parse::Error,
5 parse_macro_input,
6};
7
8mod entry;
9
10fn is_context(type_path: &TypePath) -> bool {
11 let segments: Vec<_> = type_path
12 .path
13 .segments
14 .iter()
15 .map(|segment| segment.ident.to_string())
16 .collect();
17
18 matches!(
19 *segments
20 .iter()
21 .map(String::as_ref)
22 .collect::<Vec<&str>>()
23 .as_slice(),
24 ["async_local", "Context"] | ["Context"]
25 )
26}
27
28/// Derive [AsRef](https://doc.rust-lang.org/std/convert/trait.AsRef.html)<[`Context<T>`](https://docs.rs/async-local/latest/async_local/struct.Context.html)> and [`AsContext`](https://docs.rs/async-local/latest/async_local/trait.AsContext.html) for a struct
29#[proc_macro_derive(AsContext)]
30pub fn derive_as_context(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let ident = &input.ident;
33 let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
34
35 if let Some(err) = input
36 .generics
37 .lifetimes()
38 .map(|lifetime| Error::new_spanned(lifetime, "cannot derive AsContext with lifetimes"))
39 .reduce(|mut err, other| {
40 err.combine(other);
41 err
42 })
43 {
44 return err.into_compile_error().into();
45 }
46
47 let data_struct = if let Data::Struct(data_struct) = &input.data {
48 data_struct
49 } else {
50 return Error::new(Span::call_site(), "can only derive AsContext on structs")
51 .into_compile_error()
52 .into();
53 };
54
55 let path_fields: Vec<_> = data_struct
56 .fields
57 .iter()
58 .filter_map(|field| {
59 if let Type::Path(type_path) = &field.ty {
60 Some((field, type_path))
61 } else {
62 None
63 }
64 })
65 .collect();
66
67 let wrapped_context_error = path_fields
68 .iter()
69 .filter(|(_, type_path)| {
70 if let Some(segment) = type_path.path.segments.last() {
71 if let PathArguments::AngleBracketed(inner) = &segment.arguments {
72 if let Some(GenericArgument::Type(Type::Path(type_path))) = inner.args.first() {
73 return is_context(type_path);
74 }
75 }
76 }
77 false
78 })
79 .map(|(_, type_path)| Error::new_spanned(type_path, "Context cannot be wrapped in a pointer type nor cell type and must not be invalidated nor repurposed until dropped"))
80 .reduce(|mut err, other| {
81 err.combine(other);
82 err
83 });
84
85 if let Some(err) = wrapped_context_error {
86 return err.into_compile_error().into();
87 }
88
89 let context_paths: Vec<_> = path_fields
90 .iter()
91 .filter(|(_, type_path)| is_context(type_path))
92 .collect();
93
94 if context_paths.len().eq(&0) {
95 return Error::new(Span::call_site(), "struct must use Context exactly once")
96 .into_compile_error()
97 .into();
98 }
99
100 if context_paths.len().gt(&1) {
101 return context_paths
102 .into_iter()
103 .map(|(_, type_path)| Error::new_spanned(type_path, "Context cannot be used more than once"))
104 .reduce(|mut err, other| {
105 err.combine(other);
106 err
107 })
108 .unwrap()
109 .into_compile_error()
110 .into();
111 }
112
113 let (field, type_path) = context_paths.into_iter().next().unwrap();
114
115 let context_ident = &field.ident;
116
117 let ref_type = type_path.path.segments.last().and_then(|segment| {
118 if let PathArguments::AngleBracketed(ref_type) = &segment.arguments {
119 Some(&ref_type.args)
120 } else {
121 None
122 }
123 });
124
125 let expanded = quote!(
126 impl #impl_generics AsRef<#type_path> for #ident #ty_generics #where_clause {
127 fn as_ref(&self) -> &#type_path {
128 &self.#context_ident
129 }
130 }
131
132 unsafe impl #impl_generics async_local::AsContext for #ident #ty_generics #where_clause {
133 type Target = #ref_type;
134 }
135 );
136
137 expanded.into()
138}
139
140/// Configures main to be executed by the selected Tokio runtime
141///
142/// # Borrowing the runtime
143///
144/// To borrow the runtime directly, add as a function argument
145///
146/// ```
147/// #[async_local::main(flavor = "multi_thread", worker_threads = 10)]
148/// fn main(runtime: &tokio::runtime::Runtime) {}
149/// ```
150///
151/// # Non-worker async function
152///
153/// Note that the async function marked with this macro does not run as a
154/// worker. The expectation is that other tasks are spawned by the function here.
155/// Awaiting on other futures from the function provided here will not
156/// perform as fast as those spawned as workers.
157///
158/// # Multi-threaded runtime
159///
160/// To use the multi-threaded runtime, the macro can be configured using
161/// ```
162/// #[async_local::main(flavor = "multi_thread", worker_threads = 10)]
163/// # async fn main() {}
164/// ```
165///
166/// The `worker_threads` option configures the number of worker threads, and
167/// defaults to the number of cpus on the system. This is the default flavor.
168///
169/// Note: The multi-threaded runtime requires the `rt-multi-thread` feature
170/// flag.
171///
172/// # Current thread runtime
173///
174/// To use the single-threaded runtime known as the `current_thread` runtime,
175/// the macro can be configured using
176/// ```
177/// #[async_local::main(flavor = "current_thread")]
178/// # async fn main() {}
179/// ```
180/// ## Usage
181///
182/// ### Using the multi-thread runtime
183/// ```rust
184/// #[async_local::main]
185/// async fn main() {
186/// println!("Hello world");
187/// }
188/// ```
189///
190/// ### Using current thread runtime
191///
192/// The basic scheduler is single-threaded.
193/// ```rust
194/// #[async_local::main(flavor = "current_thread")]
195/// async fn main() {
196/// println!("Hello world");
197/// }
198/// ```
199///
200/// ### Set number of worker threads
201/// ```rust
202/// #[async_local::main(worker_threads = 2)]
203/// async fn main() {
204/// println!("Hello world");
205/// }
206/// ```
207///
208/// ### Configure the runtime to start with time paused
209/// ```rust
210/// #[async_local::main(flavor = "current_thread", start_paused = true)]
211/// async fn main() {
212/// println!("Hello world");
213/// }
214/// ```
215///
216/// Note that `start_paused` requires the `test-util` feature to be enabled on `tokio`.
217#[proc_macro_attribute]
218pub fn main(
219 args: proc_macro::TokenStream,
220 item: proc_macro::TokenStream,
221) -> proc_macro::TokenStream {
222 entry::main(args.into(), item.into(), cfg!(feature = "rt-multi-thread")).into()
223}
224
225/// Marks async function to be executed by runtime, suitable to test environment.
226///
227/// # Borrowing the runtime
228///
229/// To borrow the runtime directly, add as a function argument
230///
231/// ```
232/// #[async_local::test(flavor = "multi_thread", worker_threads = 10)]
233/// fn test(runtime: &tokio::runtime::Runtime) {
234/// runtime.block_on(async {
235/// assert!(true);
236/// });
237/// }
238/// ```
239///
240/// # Multi-threaded runtime
241///
242/// To use the multi-threaded runtime, the macro can be configured using
243/// ```no_run
244/// #[async_local::test(flavor = "multi_thread", worker_threads = 1)]
245/// async fn my_test() {
246/// assert!(true);
247/// }
248/// ```
249///
250/// The `worker_threads` option configures the number of worker threads, and
251/// defaults to the number of cpus on the system.
252///
253/// Note: The multi-threaded runtime requires the `rt-multi-thread` feature
254/// flag.
255///
256/// # Current thread runtime
257///
258/// The default test runtime is single-threaded. Each test gets a
259/// separate current-thread runtime.
260/// ```no_run
261/// #[async_local::test]
262/// async fn my_test() {
263/// assert!(true);
264/// }
265/// ```
266///
267/// ## Usage
268///
269/// ### Using the multi-thread runtime
270/// ```no_run
271/// #[async_local::test(flavor = "multi_thread")]
272/// async fn my_test() {
273/// assert!(true);
274/// }
275/// ```
276///
277/// ### Using current thread runtime
278/// ```no_run
279/// #[async_local::test]
280/// async fn my_test() {
281/// assert!(true);
282/// }
283/// ```
284///
285/// ### Set number of worker threads
286/// ```no_run
287/// #[async_local::test(flavor = "multi_thread", worker_threads = 2)]
288/// async fn my_test() {
289/// assert!(true);
290/// }
291/// ```
292///
293/// ### Configure the runtime to start with time paused
294/// ```no_run
295/// #[async_local::test(start_paused = true)]
296/// async fn my_test() {
297/// assert!(true);
298/// }
299/// ```
300///
301/// Note that `start_paused` requires the `test-util` feature to be enabled on `tokio``.
302#[proc_macro_attribute]
303pub fn test(
304 args: proc_macro::TokenStream,
305 item: proc_macro::TokenStream,
306) -> proc_macro::TokenStream {
307 entry::test(args.into(), item.into(), cfg!(feature = "rt-multi-thread")).into()
308}