loupe_derive/
lib.rs

1//! Companion of the [`loupe`](../loupe-derive/index.html) crate.
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5use syn::{
6    parse, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Generics, Ident, Index,
7};
8
9/// Procedural macro to implement the `loupe::MemoryUsage` trait
10/// automatically for structs and enums.
11///
12/// All struct fields and enum variants must implement `MemoryUsage`
13/// trait. If it's not possible, the `#[loupe(skip)]` attribute can be
14/// used on a field or a variant to instruct the derive procedural
15/// macro to skip that item.
16///
17/// # Example
18///
19/// ```rust,ignore
20/// #[derive(MemoryUsage)]
21/// struct Point {
22///     x: i32,
23///     y: i32,
24/// }
25///
26/// struct Mystery { ptr: *const i32 }
27///
28/// #[derive(MemoryUsage)]
29/// struct S {
30///     points: Vec<Point>,
31///
32///     #[loupe(skip)]
33///     other: Mystery,
34/// }
35/// ```
36#[proc_macro_derive(MemoryUsage, attributes(loupe))]
37pub fn derive_memory_usage(input: TokenStream) -> TokenStream {
38    let derive_input: DeriveInput = parse(input).unwrap();
39
40    match derive_input.data {
41        Data::Struct(ref struct_data) => {
42            derive_memory_usage_for_struct(&derive_input.ident, struct_data, &derive_input.generics)
43        }
44
45        Data::Enum(ref enum_data) => {
46            derive_memory_usage_for_enum(&derive_input.ident, enum_data, &derive_input.generics)
47        }
48
49        Data::Union(_) => panic!("unions are not yet implemented"),
50        /*
51        // TODO: unions.
52        // We have no way of knowing which union member is active, so we should
53        // refuse to derive an impl except for unions where all members are
54        // primitive types or arrays of them.
55        Data::Union(ref union_data) => {
56            derive_memory_usage_union(union_data)
57        },
58        */
59    }
60}
61
62// TODO: use Iterator::fold_first once it's stable. https://github.com/rust-lang/rust/pull/79805
63fn join_fold<I, F, B>(mut iter: I, function: F, empty: B) -> B
64where
65    I: Iterator<Item = B>,
66    F: FnMut(B, I::Item) -> B,
67{
68    if let Some(first) = iter.next() {
69        iter.fold(first, function)
70    } else {
71        empty
72    }
73}
74
75fn derive_memory_usage_for_struct(
76    struct_name: &Ident,
77    data: &DataStruct,
78    generics: &Generics,
79) -> TokenStream {
80    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81
82    let sum = join_fold(
83        // Check all fields of the `struct`.
84        match &data.fields {
85            // Field has the form:
86            //
87            //     F { x, y }
88            Fields::Named(ref fields) => fields
89                .named
90                .iter()
91                .filter_map(|field| {
92                    if must_skip(&field.attrs) {
93                        return None;
94                    }
95
96                    let ident = field.ident.as_ref().unwrap();
97                    let span = ident.span();
98
99                    Some(quote_spanned!(
100                        span => loupe::MemoryUsage::size_of_val(&self.#ident, visited) - std::mem::size_of_val(&self.#ident)
101                    ))
102                })
103                .collect(),
104
105            // Field has the form:
106            //
107            //     F
108            Fields::Unit => vec![],
109
110            // Field has the form:
111            //
112            //     F(x, y)
113            Fields::Unnamed(ref fields) => fields
114                .unnamed
115                .iter()
116                .enumerate()
117                .filter_map(|(nth, field)| {
118                    if must_skip(&field.attrs) {
119                        return None;
120                    }
121
122                    let ident = Index::from(nth);
123
124                    Some(quote! { loupe::MemoryUsage::size_of_val(&self.#ident, visited) - std::mem::size_of_val(&self.#ident) })
125                })
126                .collect(),
127        }
128        .into_iter(),
129        |x, y| quote! { #x + #y },
130        quote! { 0 },
131    );
132
133    // Implement the `MemoryUsage` trait for `struct_name`.
134    (quote! {
135        #[allow(dead_code)]
136        impl #impl_generics loupe::MemoryUsage for #struct_name #ty_generics
137        #where_clause
138        {
139            fn size_of_val(&self, visited: &mut loupe::MemoryUsageTracker) -> usize {
140                std::mem::size_of_val(self) + #sum
141            }
142        }
143    })
144    .into()
145}
146
147fn derive_memory_usage_for_enum(
148    enum_name: &Ident,
149    data: &DataEnum,
150    generics: &Generics,
151) -> TokenStream {
152    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
153
154    let match_arms = join_fold(
155        data.variants
156            .iter()
157            .map(|variant| {
158                let ident = &variant.ident;
159                let span = ident.span();
160
161                // Check all the variants of the `enum`.
162                //
163                // We want to generate something like this:
164                //
165                //     Self::Variant ... => { ... }
166                //           ^^^^^^^ ^^^      ^^^
167                //           |       |        |
168                //           |       |        given by the `sum` variable
169                //           |       given by the `pattern` variable
170                //           given by the `ident` variable
171                //
172                // Let's compute the `pattern` and `sum` parts.
173                let (pattern, mut sum) = match variant.fields {
174                    // Variant has the form:
175                    //
176                    //     V { x, y }
177                    //
178                    // We want to generate:
179                    //
180                    //     Self::V { x, y } => { /* memory usage of x + y */ }
181                    Fields::Named(ref fields) => {
182                        // Collect the identifiers.
183                        let identifiers = fields.named.iter().map(|field| {
184                            let ident = field.ident.as_ref().unwrap();
185                            let span = ident.span();
186
187                            quote_spanned!(span => #ident)
188                        });
189
190                        // Generate the `pattern` part.
191                        let pattern = {
192                            let pattern = join_fold(
193                                identifiers.clone(),
194                                |x, y| quote! { #x , #y },
195                                quote! {}
196                            );
197
198                            quote! { { #pattern } }
199                        };
200
201                        // Generate the `sum` part.
202                        let sum = {
203                            let sum = join_fold(
204                                identifiers.map(|ident| quote! {
205                                    loupe::MemoryUsage::size_of_val(#ident, visited) - std::mem::size_of_val(#ident)
206                                }),
207                                |x, y| quote! { #x + #y },
208                                quote! { 0 },
209                            );
210
211                            quote! { #sum }
212                        };
213
214                        (pattern, sum)
215                    }
216
217                    // Variant has the form:
218                    //
219                    //     V
220                    //
221                    // We want to generate:
222                    //
223                    //     Self::V => { 0 }
224                    Fields::Unit => {
225                        let pattern = quote! {};
226                        let sum = quote! { 0 };
227
228                        (pattern, sum)
229                    },
230
231                    // Variant has the form:
232                    //
233                    //     V(x, y)
234                    //
235                    // We want to generate:
236                    //
237                    //     Self::V(x, y) => { /* memory usage of x + y */ }
238                    Fields::Unnamed(ref fields) => {
239                        // Collect the identifiers. They are unnamed,
240                        // so let's use the `xi` convention where `i`
241                        // is the identifier index.
242                        let identifiers = fields
243                            .unnamed
244                            .iter()
245                            .enumerate()
246                            .map(|(nth, _field)| {
247                                let ident = format_ident!("x{}", Index::from(nth));
248
249                                quote! { #ident }
250                            });
251
252                        // Generate the `pattern` part.
253                        let pattern = {
254                            let pattern = join_fold(
255                                identifiers.clone(),
256                                |x, y| quote! { #x , #y },
257                                quote! {}
258                            );
259
260                            quote! { ( #pattern ) }
261                        };
262
263                        // Generate the `sum` part.
264                        let sum = {
265                            let sum = join_fold(
266                                identifiers.map(|ident| quote! {
267                                    loupe::MemoryUsage::size_of_val(#ident, visited) - std::mem::size_of_val(#ident)
268                                }),
269                                |x, y| quote! { #x + #y },
270                                quote! { 0 },
271                            );
272
273                            quote! { #sum }
274                        };
275
276                        (pattern, sum)
277                    }
278                };
279
280                if must_skip(&variant.attrs) {
281                    sum = quote! { 0 };
282                }
283
284                // At this step, `pattern` and `sum` are well
285                // defined. Let's generate the full arm for the
286                // `match` statement.
287                quote_spanned! { span => Self::#ident#pattern => #sum }
288            }
289        ),
290        |x, y| quote! { #x , #y },
291        quote! {},
292    );
293
294    // Implement the `MemoryUsage` trait for `enum_name`.
295    (quote! {
296        #[allow(dead_code)]
297        impl #impl_generics loupe::MemoryUsage for #enum_name #ty_generics
298        #where_clause
299        {
300            fn size_of_val(&self, visited: &mut loupe::MemoryUsageTracker) -> usize {
301                std::mem::size_of_val(self) + match self {
302                    #match_arms
303                }
304            }
305        }
306    })
307    .into()
308}
309
310fn must_skip(attrs: &[Attribute]) -> bool {
311    attrs.iter().any(|attr| {
312        attr.path.is_ident("loupe") && matches!(attr.parse_args::<Ident>(), Ok(a) if a == "skip")
313    })
314}