datasize_derive/
lib.rs

1//! Proc-macro `DataSize` derive for use with the `datasize` crate.
2
3use proc_macro::TokenStream;
4use quote::quote;
5use std::collections::HashSet;
6use syn::{
7    parse, parse_macro_input, AngleBracketedGenericArguments, AttrStyle, Attribute, Binding,
8    DataEnum, DataStruct, DeriveInput, Generics, Ident, Index, ParenthesizedGenericArguments, Path,
9    PathArguments, ReturnType, TraitBound, Type, TypeArray, TypeBareFn, TypeGroup, TypeImplTrait,
10    TypeParam, TypeParamBound, TypeParen, TypePath, TypePtr, TypeReference, TypeSlice,
11    TypeTraitObject, TypeTuple, WhereClause,
12};
13
14/// Automatically derive the `DataSize` trait for a type.
15///
16/// Supports a single option, `#[data_size(skip)]`. If set on a field, it will be ignored entirely
17/// when deriving the implementation.
18#[proc_macro_derive(DataSize, attributes(data_size))]
19pub fn derive_data_size(input: TokenStream) -> TokenStream {
20    let input = parse_macro_input!(input as DeriveInput);
21    let input = remove_default_generic_values(input);
22
23    match input.data {
24        syn::Data::Struct(ds) => derive_for_struct(input.ident, input.generics, ds),
25        syn::Data::Enum(de) => derive_for_enum(input.ident, input.generics, de),
26        syn::Data::Union(_) => panic!("unions not supported"),
27    }
28}
29
30fn remove_default_generic_values(mut input: DeriveInput) -> DeriveInput {
31    for param in input.generics.params.iter_mut() {
32        if let syn::GenericParam::Type(ty) = param {
33            ty.eq_token = None;
34            ty.default = None;
35        }
36    }
37
38    input
39}
40
41/// Returns whether any of the `generics` show up in `ty`.
42///
43/// Used to determine whether or not a specific type needs to be listed in where clauses because it
44/// contains generic types. Note that the function is not entirely accurate and may produce false
45/// postives in some cases.
46fn contains_generic(generics: &Generics, ty: &Type) -> bool {
47    match ty {
48        Type::Array(TypeArray { elem, .. }) => contains_generic(generics, elem),
49        Type::BareFn(TypeBareFn { inputs, output, .. }) => {
50            for arg in inputs {
51                if contains_generic(generics, &arg.ty) {
52                    return true;
53                }
54            }
55
56            match output {
57                ReturnType::Default => false,
58                ReturnType::Type(_, ty) => contains_generic(generics, ty),
59            }
60        }
61        Type::Group(TypeGroup { elem, .. }) => contains_generic(generics, elem),
62        Type::ImplTrait(TypeImplTrait { bounds, .. }) => bounds
63            .iter()
64            .any(|b| param_bound_contains_generic(generics, b)),
65        Type::Infer(_) => false,
66        Type::Macro(_) => true,
67        Type::Never(_) => false,
68        Type::Paren(TypeParen { elem, .. }) => contains_generic(generics, elem),
69        Type::Path(TypePath { path, .. }) => path_contains_generic(generics, path),
70        Type::Ptr(TypePtr { elem, .. }) => contains_generic(generics, elem),
71        Type::Reference(TypeReference { elem, .. }) => contains_generic(generics, elem),
72        Type::Slice(TypeSlice { elem, .. }) => contains_generic(generics, elem),
73        Type::TraitObject(TypeTraitObject { bounds, .. }) => bounds
74            .iter()
75            .any(|b| param_bound_contains_generic(generics, b)),
76        Type::Tuple(TypeTuple { elems, .. }) => {
77            elems.iter().any(|ty| contains_generic(generics, ty))
78        }
79        // Nothing we can do here, err on the side of too many `where` clauses.
80        Type::Verbatim(_) => true,
81        // TODO: This may be problematic, double check if we did not miss anything.
82        _ => true,
83    }
84}
85
86/// Returns whether any of the `generics` show up in a given path.
87///
88/// May yield false positives.
89fn path_contains_generic(generics: &Generics, path: &Path) -> bool {
90    let mut candidates = HashSet::new();
91
92    for segment in &path.segments {
93        candidates.insert(segment.ident.clone());
94
95        match &segment.arguments {
96            PathArguments::None => {}
97            PathArguments::AngleBracketed(AngleBracketedGenericArguments { ref args, .. }) => {
98                for arg in args {
99                    match arg {
100                        syn::GenericArgument::Lifetime(_) => {
101                            // Ignore lifetime args.
102                        }
103                        syn::GenericArgument::Type(ty) => {
104                            // Simply recurse and check directly.
105                            if contains_generic(generics, ty) {
106                                return true;
107                            }
108                        }
109                        syn::GenericArgument::Binding(Binding {
110                            // We can ignore the ident here, as it is local.
111                            ty,
112                            ..
113                        }) => {
114                            // Again, exit early with `true` if we find a match.
115                            if contains_generic(generics, ty) {
116                                return true;
117                            }
118                        }
119                        syn::GenericArgument::Constraint(_) => {
120                            // Additional constraints are fine?
121                        }
122                        syn::GenericArgument::Const(_) => {
123                            // Constants do not require `DataSize` impls.
124                        }
125                    }
126                }
127            }
128            syn::PathArguments::Parenthesized(ParenthesizedGenericArguments {
129                inputs,
130                output,
131                ..
132            }) => {
133                if inputs.iter().any(|ty| contains_generic(generics, ty)) {
134                    return true;
135                }
136
137                match output {
138                    ReturnType::Default => {}
139                    ReturnType::Type(_, ref ty) => {
140                        if contains_generic(generics, ty) {
141                            return true;
142                        }
143                    }
144                }
145            }
146        }
147    }
148
149    let generic_idents: HashSet<_> = generics
150        .params
151        .iter()
152        .filter_map(|gen| match gen {
153            syn::GenericParam::Type(TypeParam { ident, .. }) => Some(ident.clone()),
154            syn::GenericParam::Lifetime(_) => None,
155            syn::GenericParam::Const(_) => None,
156        })
157        .collect();
158
159    // If we find at least one generic in all of the types, we return `true` here.
160    candidates.intersection(&generic_idents).next().is_some()
161}
162
163/// Returns whether any of the `generics` show up in a type parameter binding.
164///
165/// May return false positives.
166fn param_bound_contains_generic(generics: &Generics, bound: &TypeParamBound) -> bool {
167    match bound {
168        syn::TypeParamBound::Trait(TraitBound { path, .. }) => {
169            path_contains_generic(generics, path)
170        }
171        syn::TypeParamBound::Lifetime(_) => false,
172    }
173}
174
175#[derive(Debug)]
176/// A single attribute on top of a datasize'd field.
177enum DataAttribute {
178    /// The `data_size(skip)` attribute.
179    Skip,
180    /// The `data_size(with = "...")` attribute.
181    With(syn::Path),
182}
183
184impl parse::Parse for DataAttribute {
185    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
186        let ident = input.parse::<Ident>().expect("IDENT??").to_string();
187
188        match ident.as_str() {
189            "skip" => Ok(DataAttribute::Skip),
190            "with" => {
191                let punct: proc_macro2::Punct = input.parse().expect("PUNCT??");
192                if punct.as_char() != '=' {
193                    return Err(syn::parse::Error::new(
194                        input.span(),
195                        "expected `=` after `with`",
196                    ));
197                }
198
199                let path: syn::Path = input.parse()?;
200                Ok(DataAttribute::With(path))
201            }
202            kw => panic!("unsupported attribute keyword: {}", kw),
203        }
204    }
205}
206
207/// A set of attributes on top of a field in a struct.
208#[derive(Debug)]
209struct DataSizeAttributes {
210    /// Whether or not to skip the field entirely (`data_size(skip)`).
211    pub skip: bool,
212    /// A function to call instead of deriving the data size.
213    pub with: Option<syn::Path>,
214}
215
216impl DataSizeAttributes {
217    /// Parses a set of attributes from untyped [`Attribute`]s.
218    fn parse(attrs: &Vec<Attribute>) -> Self {
219        let mut skip = None;
220        let mut with = None;
221
222        for attr in attrs {
223            if attr.style != AttrStyle::Outer {
224                // We ignore outer attributes.
225                continue;
226            }
227
228            // Ensure it is a `data_size` attribute.
229            if attr.path.segments.len() != 1 || attr.path.segments[0].ident != "data_size" {
230                continue;
231            }
232
233            let parsed: DataAttribute = attr
234                .parse_args()
235                .expect("could not parse datasize attribute");
236
237            match parsed {
238                DataAttribute::Skip => {
239                    if skip.is_some() {
240                        panic!("duplicated `skip` attribute");
241                    } else {
242                        skip = Some(true);
243                    }
244                }
245                DataAttribute::With(fragment) => {
246                    if with.is_some() {
247                        panic!("duplicated `with` attribute");
248                    } else {
249                        with = Some(fragment)
250                    }
251                }
252            }
253        }
254
255        DataSizeAttributes {
256            skip: skip.unwrap_or(false),
257            with,
258        }
259    }
260}
261
262/// Derives `DataSize` for a `struct`
263fn derive_for_struct(name: Ident, generics: Generics, ds: DataStruct) -> TokenStream {
264    let fields = ds.fields;
265
266    let mut where_clauses = proc_macro2::TokenStream::new();
267    let mut is_dynamic = proc_macro2::TokenStream::new();
268    let mut static_heap_size = proc_macro2::TokenStream::new();
269    let mut dynamic_size = proc_macro2::TokenStream::new();
270    let mut detail_calls = proc_macro2::TokenStream::new();
271
272    let mut has_manual_field = false;
273
274    for (idx, field) in fields.iter().enumerate() {
275        let field_attrs = DataSizeAttributes::parse(&field.attrs);
276        if field_attrs.skip {
277            continue;
278        }
279
280        if field_attrs.with.is_some() {
281            has_manual_field = true;
282        }
283
284        let ty = &field.ty;
285        // We need a where clause for every non-skipped, non-with field. We try our best to filter
286        // out bounds here that are not needed (e.g. `u8: DataSize`), as they can be problematic
287        // when mixing `pub(super)` and `pub` visiblity restrictions.
288        if field_attrs.with.is_none() && contains_generic(&generics, ty) {
289            if where_clauses.is_empty() {
290                where_clauses.extend(quote!(where));
291            }
292
293            where_clauses.extend(quote!(
294                #ty : datasize::DataSize,
295            ));
296        }
297
298        if !is_dynamic.is_empty() {
299            is_dynamic.extend(quote!(|));
300        }
301
302        if !static_heap_size.is_empty() {
303            static_heap_size.extend(quote!(+));
304        }
305
306        if !dynamic_size.is_empty() {
307            dynamic_size.extend(quote!(+));
308        }
309
310        is_dynamic.extend(quote!(<#ty as datasize::DataSize>));
311        is_dynamic.extend(quote!(::IS_DYNAMIC));
312
313        if field_attrs.with.is_none() {
314            static_heap_size.extend(quote!(<#ty as datasize::DataSize>));
315            static_heap_size.extend(quote!(::STATIC_HEAP_SIZE));
316        } else {
317            static_heap_size.extend(quote!(0));
318        };
319
320        let handle = if let Some(ref ident) = &field.ident {
321            quote!(#ident)
322        } else {
323            let idx = Index::from(idx);
324            quote!(#idx)
325        };
326
327        let name = if let Some(ref ident) = &field.ident {
328            ident.to_string()
329        } else {
330            "idx".to_string()
331        };
332
333        match field_attrs.with {
334            Some(manual) => {
335                dynamic_size.extend(quote!(
336                    #manual(&self.#handle)
337                ));
338
339                detail_calls.extend(quote!(
340                    members.insert(#name, datasize::MemUsageNode::Size(#manual(&self.#handle)));
341                ));
342            }
343            None => {
344                dynamic_size.extend(quote!(
345                    datasize::data_size::<#ty>(&self.#handle)
346                ));
347
348                detail_calls.extend(quote!(
349                    members.insert(#name, self.#handle.estimate_detailed_heap_size());
350                ));
351            }
352        }
353    }
354
355    // Handle structs with no fields.
356    if is_dynamic.is_empty() {
357        is_dynamic.extend(quote!(false));
358    }
359    if static_heap_size.is_empty() {
360        static_heap_size.extend(quote!(0));
361    }
362    if dynamic_size.is_empty() {
363        dynamic_size.extend(quote!(0));
364    }
365
366    // Ensure that any `where` clause on the struct itself is preserved, otherwise the impl is
367    // invalid.
368    if let Some(WhereClause { ref predicates, .. }) = generics.where_clause {
369        where_clauses.extend(quote!(#predicates));
370    }
371
372    let detailed_impl = if cfg!(feature = "detailed") {
373        quote!(
374            fn estimate_detailed_heap_size(&self) -> datasize::MemUsageNode {
375                let mut members = ::std::collections::HashMap::new();
376                #detail_calls
377                datasize::MemUsageNode::Detailed(members)
378            }
379        )
380    } else {
381        quote!()
382    };
383
384    // If we found at least one manual field, ensure we recalculate heap size always.
385    if has_manual_field {
386        is_dynamic = proc_macro2::TokenStream::new();
387        is_dynamic.extend(quote!(true));
388    }
389
390    TokenStream::from(quote! {
391        impl #generics datasize::DataSize for #name #generics #where_clauses {
392            const IS_DYNAMIC: bool = #is_dynamic;
393            const STATIC_HEAP_SIZE: usize = #static_heap_size;
394
395            fn estimate_heap_size(&self) -> usize {
396                #dynamic_size
397            }
398
399            #detailed_impl
400        }
401    })
402}
403
404/// Derives `DataSize` for an `enum`
405fn derive_for_enum(name: Ident, generics: Generics, de: DataEnum) -> TokenStream {
406    let mut match_arms = proc_macro2::TokenStream::new();
407    let mut where_types = proc_macro2::TokenStream::new();
408
409    let mut skipped = false;
410    for variant in de.variants.into_iter() {
411        let ds_attrs = DataSizeAttributes::parse(&variant.attrs);
412
413        if ds_attrs.skip {
414            skipped = true;
415            continue;
416        }
417
418        let variant_ident = variant.ident;
419
420        let mut field_match = proc_macro2::TokenStream::new();
421        let mut field_calc = proc_macro2::TokenStream::new();
422
423        match variant.fields {
424            syn::Fields::Named(fields) => {
425                let mut left = proc_macro2::TokenStream::new();
426
427                for field in fields.named.into_iter() {
428                    let ident = field.ident.expect("named fields must have idents");
429                    let ds_attrs = DataSizeAttributes::parse(&field.attrs);
430
431                    if ds_attrs.skip {
432                        left.extend(quote!(#ident:_));
433                    } else {
434                        left.extend(quote!(#ident ,));
435
436                        let ty = field.ty;
437                        if contains_generic(&generics, &ty) {
438                            where_types.extend(quote!(#ty : datasize::DataSize,));
439                        }
440                    }
441
442                    if !ds_attrs.skip {
443                        if !field_calc.is_empty() {
444                            field_calc.extend(quote!(+));
445                        }
446                        field_calc.extend(quote!(DataSize::estimate_heap_size(#ident)));
447                    }
448                }
449
450                field_match.extend(quote! {
451                    {#left}
452                });
453            }
454            syn::Fields::Unnamed(fields) => {
455                let mut left = proc_macro2::TokenStream::new();
456
457                for (idx, field) in fields.unnamed.into_iter().enumerate() {
458                    let field_ds_attrs = DataSizeAttributes::parse(&field.attrs);
459
460                    let ident = Ident::new(
461                        &format!("{}f{}", if field_ds_attrs.skip { "_" } else { "" }, idx),
462                        proc_macro2::Span::call_site(),
463                    );
464
465                    left.extend(quote!(#ident ,));
466
467                    if !field_ds_attrs.skip {
468                        if !field_calc.is_empty() {
469                            field_calc.extend(quote!(+));
470                        }
471                        field_calc.extend(quote!(DataSize::estimate_heap_size(#ident)));
472
473                        let ty = field.ty;
474                        where_types.extend(quote!(#ty : datasize::DataSize,));
475                    }
476                }
477
478                field_match.extend(quote! {
479                    (#left)
480                });
481            }
482            syn::Fields::Unit => {
483                field_calc.extend(quote!(0));
484            }
485        }
486
487        if field_calc.is_empty() {
488            field_calc.extend(quote!(0));
489        }
490
491        match_arms.extend(quote!(
492            #name::#variant_ident #field_match => { #field_calc }
493        ));
494    }
495
496    // If we skipped any variant, add a fallback.
497    if skipped {
498        match_arms.extend(quote! {
499            _ => 0,
500        })
501    }
502
503    let mut where_clause = proc_macro2::TokenStream::new();
504    if !where_types.is_empty() {
505        where_clause.extend(quote!(where #where_types));
506    }
507
508    // TODO: Accurately determine `IS_DYNAMIC` and `STATIC_HEAP_SIZE`.
509    //
510    //       It is possible to accurately pre-calculate these, but it takes a bit of extra
511    //       effort. `IS_DYNAMIC` depends on none of the variants (and their fields) being
512    //       being dynamic, as well as all having the same `STATIC_HEAP_SIZE`.
513    //
514    //       `STATIC_HEAP_SIZE` in turn is the minimum of `STATIC_HEAP_SIZE` of every
515    //       variant (which are the sum of their fields). `min` can be determined by the
516    //       `datasize::min` function, which is a `const fn` variant of `min`.
517    let mut is_dynamic = true;
518    let static_heap_size = 0usize;
519
520    // Handle enums with no fields.
521    if match_arms.is_empty() {
522        match_arms.extend(quote!(_ => 0));
523        is_dynamic = false;
524    }
525
526    // Ensure that any `where` clause on the struct enum is preserved.
527    if let Some(WhereClause { ref predicates, .. }) = generics.where_clause {
528        where_clause.extend(quote!(#predicates));
529    }
530
531    TokenStream::from(quote! {
532        impl #generics DataSize for #name #generics #where_clause {
533
534            const IS_DYNAMIC: bool = #is_dynamic;
535            const STATIC_HEAP_SIZE: usize = #static_heap_size;
536
537            #[inline]
538            fn estimate_heap_size(&self) -> usize {
539                match self {
540                    #match_arms
541                }
542            }
543        }
544    })
545}