async_dropper_derive/
lib.rs

1use proc_macro2::TokenStream;
2use syn::{DataEnum, DataStruct, DataUnion, DeriveInput, Fields, FieldsNamed};
3
4#[proc_macro_derive(AsyncDrop)]
5pub fn derive_async_drop(items: proc_macro::TokenStream) -> proc_macro::TokenStream {
6    match syn::parse2::<DeriveInput>(items.into()) {
7        Ok(derive_input) => proc_macro2::TokenStream::from_iter([
8            gen_preamble(&derive_input),
9            gen_impl(&derive_input),
10        ])
11        .into(),
12        Err(e) => e.to_compile_error().into(),
13    }
14}
15
16fn make_shared_default_name(ident: &proc_macro2::Ident) -> proc_macro2::Ident {
17    quote::format_ident!("_shared_default_{}", ident)
18}
19
20/// Default implementation of deriving async drop that does nothing
21/// you're expected to use either the 'tokio' feature or 'async-std'
22fn gen_preamble(di: &DeriveInput) -> proc_macro2::TokenStream {
23    let ident = &di.ident;
24    let shared_default_name = make_shared_default_name(ident);
25
26    // Retrieve the struct data fields from the derive input
27    let mut df_setters: Vec<TokenStream> = Vec::new();
28    match &di.data {
29        syn::Data::Struct(DataStruct { fields, .. }) => {
30            if let Fields::Unit = fields {
31                df_setters.push(
32                    syn::Error::new(ident.span(), "unit sturcts cannot be async dropped")
33                        .to_compile_error(),
34                );
35            }
36            for f in fields.iter() {
37                df_setters.push(f.ident.as_ref().map_or_else(
38                    || {
39                        syn::parse_str(
40                            format!("self.{} = Default::default()", df_setters.len()).as_str(),
41                        )
42                        .unwrap_or_else(|_| {
43                            syn::Error::new(
44                                ident.span(),
45                                "failed to generate default setter for field",
46                            )
47                            .to_compile_error()
48                        })
49                    },
50                    |id| quote::quote! { self.#id = Default::default(); },
51                ));
52            }
53        }
54        syn::Data::Enum(DataEnum { variants, .. }) => {
55            for v in variants.iter() {
56                for vf in v.fields.iter() {
57                    df_setters.push(vf.ident.as_ref().map_or_else(
58                        || {
59                            syn::parse_str(
60                                format!("self.{} = Default::default()", df_setters.len()).as_str(),
61                            )
62                            .unwrap_or_else(|_| {
63                                syn::Error::new(
64                                    ident.span(),
65                                    "failed to generate default setter for field",
66                                )
67                                .to_compile_error()
68                            })
69                        },
70                        |id| quote::quote! { self.#id = Default::default(); },
71                    ))
72                }
73            }
74        }
75        syn::Data::Union(DataUnion {
76            fields: FieldsNamed { named, .. },
77            ..
78        }) => {
79            for f in named.iter() {
80                if let Some(id) = &f.ident {
81                    df_setters.push(quote::quote! { self.#id = Default::default(); });
82                }
83            }
84        }
85    };
86
87    quote::quote!(
88        /// Automatically generated implementation of reset to default for #ident
89        #[automatically_derived]
90        impl ::async_dropper::ResetDefault for #ident {
91            fn reset_to_default(&mut self) {
92                #(
93                    #df_setters;
94                )*
95            }
96        }
97
98        /// Utility function unique to #ident which retrieves a shared mutable single default instance of it
99        /// that single default instance is compared to other instances and indicates whether async drop
100        /// should be called
101        #[allow(non_snake_case)]
102        fn #shared_default_name() -> &'static std::sync::Mutex<#ident> {
103            #[allow(non_upper_case_globals)]
104            static #shared_default_name: std::sync::OnceLock<std::sync::Mutex<#ident>> = std::sync::OnceLock::new();
105            #shared_default_name.get_or_init(|| std::sync::Mutex::new(#ident::default()))
106        }
107
108    )
109}
110
111#[cfg(all(not(feature = "async-std"), not(feature = "tokio")))]
112fn gen_impl(_: &DeriveInput) -> proc_macro::TokenStream {
113    compile_error!(
114        "either 'async-std' or 'tokio' features must be enabled for the async-dropper crate"
115    );
116}
117
118#[cfg(all(feature = "async-std", feature = "tokio"))]
119fn gen_impl(_: &DeriveInput) -> proc_macro::TokenStream {
120    compile_error!(
121        "both 'async-std' and 'tokio' features must not be enabled for the async-dropper crate"
122    )
123}
124
125/// Tokio implementation of AsyncDrop
126#[cfg(all(feature = "tokio", not(feature = "async-std")))]
127fn gen_impl(DeriveInput { ident, .. }: &DeriveInput) -> proc_macro2::TokenStream {
128    let shared_default_name = make_shared_default_name(ident);
129    quote::quote!(
130        #[automatically_derived]
131        #[async_trait]
132        impl Drop for #ident {
133            fn drop(&mut self) {
134                // We consider a self that is completely equivalent to it's default version to be dropped
135                let thing = #shared_default_name();
136                if *thing.lock().unwrap() == *self {
137                    return;
138                }
139
140                // Ensure that the default_version is manually dropped
141                let mut original = std::mem::take(self);
142
143                // Spawn a task to do the drop
144                let task = ::tokio::spawn(async move {
145                    let drop_fail_action = <#ident as ::async_dropper::AsyncDrop>::drop_fail_action(&original);
146                    let task_res = match ::tokio::time::timeout(
147                        <#ident as ::async_dropper::AsyncDrop>::drop_timeout(&original),
148                        <#ident as ::async_dropper::AsyncDrop>::async_drop(&mut original),
149                    ).await {
150                        // Task timed out
151                        Err(_) | Ok(Err(AsyncDropError::Timeout)) => {
152                            match drop_fail_action {
153                                ::async_dropper::DropFailAction::Continue => Ok(()),
154                                ::async_dropper::DropFailAction::Panic => Err("async drop timed out".to_string()),
155                            }
156                        },
157                        // Internal task error
158                        Ok(Err(AsyncDropError::UnexpectedError(e))) => Err(format!("async drop failed: {e}")),
159                        // Task completed successfully
160                        Ok(_) => Ok(()),
161                    };
162                    (original, task_res)
163                });
164
165                // Perform a synchronous wait
166                let (mut original, task_res) = ::tokio::task::block_in_place(|| ::tokio::runtime::Handle::current().block_on(task).unwrap());
167
168
169                // After the async wait, we must reset all fields to the default (so future checks will fail)
170                <#ident as ::async_dropper::AsyncDrop>::reset(&mut original);
171                if *thing.lock().unwrap() != original {
172                    panic!("after calling AsyncDrop::reset(), the object does *not* equal T::default()");
173                }
174
175                if let Err(e) = task_res {
176                    panic!("{e}");
177                }
178            }
179        }
180    )
181}
182
183/// async-std  implementation of AsyncDrop
184#[cfg(all(feature = "async-std", not(feature = "tokio")))]
185fn gen_impl(DeriveInput { ident, .. }: &DeriveInput) -> proc_macro2::TokenStream {
186    let shared_default_name = make_shared_default_name(ident);
187    quote::quote!(
188        #[automatically_derived]
189        #[async_trait]
190        impl Drop for #ident {
191            fn drop(&mut self) {
192                // We consider a self that is completely equivalent to it's default version to be dropped
193                let thing = #shared_default_name();
194                if *thing.lock().unwrap() == *self {
195                    return;
196                }
197
198                // Swap out the existing with a completely default
199                let mut original = std::mem::take(self);
200
201                // Spawn a task to do the drop
202                let task = ::async_std::task::spawn(async move {
203                    let drop_fail_action = <#ident as ::async_dropper::AsyncDrop>::drop_fail_action(&original);
204                    let task_res = match ::async_std::future::timeout(
205                        <#ident as ::async_dropper::AsyncDrop>::drop_timeout(&original),
206                        <#ident as ::async_dropper::AsyncDrop>::async_drop(&mut original),
207                    ).await {
208                        // Task timed out
209                        Err(_) | Ok(Err(AsyncDropError::Timeout)) => {
210                            match drop_fail_action {
211                                ::async_dropper::DropFailAction::Continue => Ok(()),
212                                ::async_dropper::DropFailAction::Panic => Err("async drop timed out".to_string()),
213                            }
214                        },
215                        // Internal task error
216                        Ok(Err(AsyncDropError::UnexpectedError(e))) => Err(format!("async drop failed: {e}")),
217                        // Task completed successfully
218                        Ok(_) => Ok(()),
219                    };
220                    (original, task_res)
221                });
222
223                // Perform synchronous wait
224                let (mut original, task_res) = ::futures::executor::block_on(task);
225
226                // Reset the task to ensure it won't trigger async drop behavior again
227                <#ident as ::async_dropper::AsyncDrop>::reset(&mut original);
228                if *thing.lock().unwrap() != original {
229                    panic!("after calling AsyncDrop::reset(), the object does *not* equal T::default()");
230                }
231
232                if let Err(e) = task_res {
233                    panic!("{e}");
234                }
235            }
236        }
237    )
238    .into()
239}