Skip to main content

aiscript_derived/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{ToTokens, quote, quote_spanned};
3use syn::{
4    parse::{Parse, ParseStream},
5    spanned::Spanned,
6    visit_mut::VisitMut,
7};
8use synstructure::{AddBounds, decl_derive};
9
10fn find_collect_meta(attrs: &[syn::Attribute]) -> syn::Result<Option<&syn::Attribute>> {
11    let mut found = None;
12    for attr in attrs {
13        if attr.path().is_ident("collect") && found.replace(attr).is_some() {
14            return Err(syn::parse::Error::new_spanned(
15                attr.path(),
16                "Cannot specify multiple `#[collect]` attributes! Consider merging them.",
17            ));
18        }
19    }
20
21    Ok(found)
22}
23
24fn usage_error(meta: &syn::meta::ParseNestedMeta, msg: &str) -> syn::parse::Error {
25    meta.error(format_args!("{msg}. `#[collect(...)]` requires one mode (`require_static`, `no_drop`, or `unsafe_drop`) and optionally `bound = \"...\"`."))
26}
27
28// Deriving `Collect` must be done with care, because an implementation of `Drop` is not
29// necessarily safe for `Collect` types. This derive macro has three available modes to ensure
30// that this is safe:
31//   1) Require that the type be 'static with `#[collect(require_static)]`.
32//   2) Prohibit a `Drop` impl on the type with `#[collect(no_drop)]`
33//   3) Allow a custom `Drop` impl that might be unsafe with `#[collect(unsafe_drop)]`. Such
34//      `Drop` impls must *not* access garbage collected pointers during `Drop::drop`.
35#[derive(PartialEq)]
36enum Mode {
37    RequireStatic,
38    NoDrop,
39    UnsafeDrop,
40}
41
42fn collect_derive(mut s: synstructure::Structure) -> TokenStream {
43    let mut mode = None;
44    let mut override_bound = None;
45
46    let result = match find_collect_meta(&s.ast().attrs) {
47        Ok(Some(attr)) => attr.parse_nested_meta(|meta| {
48            if meta.path.is_ident("bound") {
49                if override_bound.is_some() {
50                    return Err(usage_error(&meta, "multiple bounds specified"));
51                }
52
53                let lit: syn::LitStr = meta.value()?.parse()?;
54                override_bound = Some(lit);
55                return Ok(());
56            }
57
58            meta.input.parse::<syn::parse::Nothing>()?;
59
60            if mode.is_some() {
61                return Err(usage_error(&meta, "multiple modes specified"));
62            } else if meta.path.is_ident("require_static") {
63                mode = Some(Mode::RequireStatic);
64            } else if meta.path.is_ident("no_drop") {
65                mode = Some(Mode::NoDrop);
66            } else if meta.path.is_ident("unsafe_drop") {
67                mode = Some(Mode::UnsafeDrop);
68            } else {
69                return Err(usage_error(&meta, "unknown option"));
70            }
71            Ok(())
72        }),
73        Ok(None) => Ok(()),
74        Err(err) => Err(err),
75    };
76
77    if let Err(err) = result {
78        return err.to_compile_error();
79    }
80
81    let Some(mode) = mode else {
82        panic!(
83            "{}",
84            "deriving `Collect` requires a `#[collect(...)]` attribute"
85        );
86    };
87
88    let where_clause = if mode == Mode::RequireStatic {
89        quote!(where Self: 'static)
90    } else {
91        override_bound
92            .as_ref()
93            .map(|x| {
94                x.parse()
95                    .expect("`#[collect]` failed to parse explicit trait bound expression")
96            })
97            .unwrap_or_else(|| quote!())
98    };
99
100    let mut errors = vec![];
101
102    let collect_impl = if mode == Mode::RequireStatic {
103        s.clone().add_bounds(AddBounds::None).gen_impl(quote! {
104            gen unsafe impl ::aiscript_arena::Collect for @Self #where_clause {
105                #[inline]
106                fn needs_trace() -> bool {
107                    false
108                }
109            }
110        })
111    } else {
112        let mut needs_trace_body = TokenStream::new();
113        quote!(false).to_tokens(&mut needs_trace_body);
114
115        let mut static_bindings = vec![];
116
117        // Ignore all bindings that have `#[collect(require_static)]` For each binding with
118        // `#[collect(require_static)]`, we push a bound of the form `FieldType: 'static` to
119        // `static_bindings`, which will be added to the genererated `Collect` impl. The presence of
120        // the bound guarantees that the field cannot hold any `Gc` pointers, so it's safe to ignore
121        // that field in `needs_trace` and `trace`
122        s.filter(|b| match find_collect_meta(&b.ast().attrs) {
123            Ok(Some(attr)) => {
124                let mut static_binding = false;
125                let result = attr.parse_nested_meta(|meta| {
126                    if meta.input.is_empty() && meta.path.is_ident("require_static") {
127                        static_binding = true;
128                        static_bindings.push(b.ast().ty.clone());
129                        Ok(())
130                    } else {
131                        Err(meta.error("Only `#[collect(require_static)]` is supported on a field"))
132                    }
133                });
134                errors.extend(result.err());
135                !static_binding
136            }
137            Ok(None) => true,
138            Err(err) => {
139                errors.push(err);
140                true
141            }
142        });
143
144        for static_binding in static_bindings {
145            s.add_where_predicate(syn::parse_quote! { #static_binding: 'static });
146        }
147
148        // `#[collect(require_static)]` only makes sense on fields, not enum variants. Emit an error
149        // if it is used in the wrong place
150        if let syn::Data::Enum(..) = s.ast().data {
151            for v in s.variants() {
152                for attr in v.ast().attrs {
153                    if attr.path().is_ident("collect") {
154                        errors.push(syn::parse::Error::new_spanned(
155                            attr.path(),
156                            "`#[collect]` is not suppported on enum variants",
157                        ));
158                    }
159                }
160            }
161        }
162
163        // We've already called `s.filter`, so we we won't try to call `needs_trace` for the types
164        // of fields that have `#[collect(require_static)]`
165        for v in s.variants() {
166            for b in v.bindings() {
167                let ty = &b.ast().ty;
168                // Resolving the span at the call site makes rustc emit a 'the error originates a
169                // derive macro note' We only use this span on tokens that need to resolve to items
170                // (e.g. `aiscript_arena::Collect`), so this won't cause any hygiene issues
171                let call_span = b.ast().span().resolved_at(Span::call_site());
172                quote_spanned!(call_span=>
173                    || <#ty as ::aiscript_arena::Collect>::needs_trace()
174                )
175                .to_tokens(&mut needs_trace_body);
176            }
177        }
178        // Likewise, this will skip any fields that have `#[collect(require_static)]`
179        let trace_body = s.each(|bi| {
180            // See the above call to `needs_trace` for an explanation of this
181            let call_span = bi.ast().span().resolved_at(Span::call_site());
182            quote_spanned!(call_span=>
183                {
184                    // Use a temporary variable to ensure that all tokens in the call to
185                    // `aiscript_arena::Collect::trace` have the same hygiene information. If we used
186                    // #bi directly, then we would have a mix of hygiene contexts, which would
187                    // cause rustc to produce sub-optimal error messagse due to its inability to
188                    // merge the spans. This is purely for diagnostic purposes, and has no effect
189                    // on correctness
190                    let bi = #bi;
191                    ::aiscript_arena::Collect::trace(bi, cc)
192                }
193            )
194        });
195
196        let bounds_type = if override_bound.is_some() {
197            AddBounds::None
198        } else {
199            AddBounds::Generics
200        };
201        s.clone().add_bounds(bounds_type).gen_impl(quote! {
202            gen unsafe impl ::aiscript_arena::Collect for @Self #where_clause {
203                #[inline]
204                fn needs_trace() -> bool {
205                    #needs_trace_body
206                }
207
208                #[inline]
209                fn trace(&self, cc: &::aiscript_arena::Collection) {
210                    match *self { #trace_body }
211                }
212            }
213        })
214    };
215
216    let drop_impl = if mode == Mode::NoDrop {
217        let mut s = s;
218        s.add_bounds(AddBounds::None).gen_impl(quote! {
219            gen impl ::aiscript_arena::__MustNotImplDrop for @Self {}
220        })
221    } else {
222        quote!()
223    };
224
225    let errors = errors.into_iter().map(|e| e.to_compile_error());
226    quote! {
227        #collect_impl
228        #drop_impl
229        #(#errors)*
230    }
231}
232
233decl_derive! {
234    [Collect, attributes(collect)] =>
235    /// Derives the `Collect` trait needed to trace a gc type.
236    ///
237    /// To derive `Collect`, an additional attribute is required on the struct/enum called
238    /// `collect`. This has several optional arguments, but the only required argument is the derive
239    /// strategy. This can be one of
240    ///
241    /// - `#[collect(require_static)]` - Adds a `'static` bound, which allows for a no-op trace
242    ///   implementation. This is the ideal choice where possible.
243    /// - `#[collect(no_drop)]` - The typical safe tracing derive strategy which only has to add a
244    ///   requirement that your struct/enum does not have a custom implementation of `Drop`.
245    /// - `#[collect(unsafe_drop)]` - The most versatile tracing derive strategy which allows a
246    ///   custom drop implementation. However, this strategy can lead to unsoundness if care is not
247    ///   taken (see the above explanation of `Drop` interactions).
248    ///
249    /// The `collect` attribute also accepts a number of optional configuration settings:
250    ///
251    /// - `#[collect(bound = "<code>")]` - Replaces the default generated `where` clause with the
252    ///   given code. This can be an empty string to add no `where` clause, or otherwise must start
253    ///   with `"where"`, e.g., `#[collect(bound = "where T: Collect")]`. Note that this option is
254    ///   ignored for `require_static` mode since the only bound it produces is `Self: 'static`.
255    ///   Also note that providing an explicit bound in this way is safe, and only changes the trait
256    ///   bounds used to enable the implementation of `Collect`.
257    ///
258    /// Options may be passed to the `collect` attribute together, e.g., `#[collect(no_drop, bound
259    /// = "")]`.
260    ///
261    /// The `collect` attribute may also be used on any field of an enum or struct, however the
262    /// only allowed usage is to specify the strategy as `require_static` (no other strategies are
263    /// allowed, and no optional settings can be specified). This will add a `'static` bound to the
264    /// type of the field (regardless of an explicit `bound` setting) in exchange for not having
265    /// to trace into the given field (the ideal choice where possible). Note that if the entire
266    /// struct/enum is marked with `require_static` then this is unnecessary.
267    collect_derive
268}
269
270// Not public API; implementation detail of `aiscript_arena::Rootable!`.
271// Replaces all `'_` lifetimes in a type by the specified named lifetime.
272// Syntax: `__unelide_lifetimes!('lt; SomeType)`.
273#[doc(hidden)]
274#[proc_macro]
275pub fn __unelide_lifetimes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
276    struct Input {
277        lt: syn::Lifetime,
278        ty: syn::Type,
279    }
280
281    impl Parse for Input {
282        fn parse(input: ParseStream) -> syn::Result<Self> {
283            let lt: syn::Lifetime = input.parse()?;
284            let _: syn::Token!(;) = input.parse()?;
285            let ty: syn::Type = input.parse()?;
286            Ok(Self { lt, ty })
287        }
288    }
289
290    struct UnelideLifetimes(syn::Lifetime);
291
292    impl VisitMut for UnelideLifetimes {
293        fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
294            if i.ident == "_" {
295                *i = self.0.clone();
296            }
297        }
298    }
299
300    let mut input = syn::parse_macro_input!(input as Input);
301    UnelideLifetimes(input.lt).visit_type_mut(&mut input.ty);
302    input.ty.to_token_stream().into()
303}