mem_dbg_derive/
lib.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Tommaso Fontana
3 * SPDX-FileCopyrightText: 2023 Inria
4 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
5 *
6 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
7 */
8
9//! Derive procedural macros for the [`mem_dbg`](https://crates.io/crates/mem_dbg) crate.
10
11use proc_macro::TokenStream;
12use quote::{quote, ToTokens};
13use syn::{
14    parse_macro_input, parse_quote, parse_quote_spanned, spanned::Spanned, Data, DeriveInput,
15};
16
17/**
18
19Generate a `mem_dbg::MemSize` implementation for custom types.
20
21Presently we do not support unions.
22
23The attribute `copy_type` can be used on [`Copy`] types that do not contain non-`'static` references
24to make `MemSize::mem_size` faster on arrays, vectors and slices. Note that specifying
25`copy_type` will add the bound that the type is `Copy + 'static`.
26
27See `mem_dbg::CopyType` for more details.
28
29*/
30#[proc_macro_derive(MemSize, attributes(copy_type))]
31pub fn mem_dbg_mem_size(input: TokenStream) -> TokenStream {
32    let mut input = parse_macro_input!(input as DeriveInput);
33
34    let input_ident = input.ident;
35    input.generics.make_where_clause();
36    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
37    let mut where_clause = where_clause.unwrap().clone(); // We just created it
38
39    let is_copy_type = input
40        .attrs
41        .iter()
42        .any(|x| x.meta.path().is_ident("copy_type"));
43
44    // If copy_type, add the Copy + 'static bound
45    let copy_type: syn::Expr = if is_copy_type {
46        where_clause
47            .predicates
48            .push(parse_quote_spanned!(input_ident.span()=> Self: Copy + 'static));
49        parse_quote!(mem_dbg::True)
50    } else {
51        parse_quote!(mem_dbg::False)
52    };
53
54    match input.data {
55        Data::Struct(s) => {
56            let mut fields_ident = vec![];
57            let mut fields_ty = vec![];
58
59            for (field_idx, field) in s.fields.iter().enumerate() {
60                fields_ident.push(
61                    field
62                        .ident
63                        .to_owned()
64                        .map(|t| t.to_token_stream())
65                        .unwrap_or(syn::Index::from(field_idx).to_token_stream()),
66                );
67                fields_ty.push(field.ty.to_token_stream());
68                let field_ty = &field.ty;
69                // Add MemSize bound to all fields
70                where_clause
71                    .predicates
72                    .push(parse_quote_spanned!(field.span()=> #field_ty: mem_dbg::MemSize));
73            }
74            quote! {
75                #[automatically_derived]
76                impl #impl_generics mem_dbg::CopyType for #input_ident #ty_generics #where_clause
77                {
78                    type Copy = #copy_type;
79                }
80
81                #[automatically_derived]
82                impl #impl_generics mem_dbg::MemSize for #input_ident #ty_generics #where_clause {
83                    fn mem_size(&self, _memsize_flags: mem_dbg::SizeFlags) -> usize {
84                        let mut bytes = core::mem::size_of::<Self>();
85                        #(bytes += <#fields_ty as mem_dbg::MemSize>::mem_size(&self.#fields_ident, _memsize_flags) - core::mem::size_of::<#fields_ty>();)*
86                        bytes
87                    }
88                }
89            }
90        }
91
92        Data::Enum(e) => {
93            let mut variants = Vec::new();
94            let mut variants_size = Vec::new();
95
96            for variant in e.variants {
97                let mut res = variant.ident.to_owned().to_token_stream();
98                let mut var_args_size = quote! {core::mem::size_of::<Self>()};
99                match &variant.fields {
100                    syn::Fields::Unit => {}
101                    syn::Fields::Named(fields) => {
102                        let mut args = proc_macro2::TokenStream::new();
103                        for field in &fields.named {
104                            let field_ty = &field.ty;
105                            where_clause
106                                .predicates
107                                .push(parse_quote_spanned!(field.span() => #field_ty: mem_dbg::MemSize));
108                                let field_ident = &field.ident;
109                                let field_ty = field.ty.to_token_stream();
110                                var_args_size.extend([quote! {
111                                    + <#field_ty as mem_dbg::MemSize>::mem_size(#field_ident, _memsize_flags) - core::mem::size_of::<#field_ty>()
112                                }]);
113                                args.extend([field_ident.to_token_stream()]);
114                                args.extend([quote! {,}]);
115                            }
116                        // extend res with the args sourrounded by curly braces
117                        res.extend(quote! {
118                            { #args }
119                        });
120                    }
121                    syn::Fields::Unnamed(fields) => {
122                        let mut args = proc_macro2::TokenStream::new();
123
124                        for (field_idx, field) in fields.unnamed.iter().enumerate() {
125                            let ident = syn::Ident::new(
126                                &format!("v{}", field_idx),
127                                proc_macro2::Span::call_site(),
128                            )
129                            .to_token_stream();
130                            let field_ty = field.ty.to_token_stream();
131                            var_args_size.extend([quote! {
132                                + <#field_ty as mem_dbg::MemSize>::mem_size(#ident, _memsize_flags) - core::mem::size_of::<#field_ty>()
133                            }]);
134                            args.extend([ident]);
135                            args.extend([quote! {,}]);
136
137                            where_clause
138                                .predicates
139                                .push(parse_quote_spanned!(field.span()=> #field_ty: mem_dbg::MemSize));
140                        }
141                        // extend res with the args sourrounded by curly braces
142                        res.extend(quote! {
143                            ( #args )
144                        });
145                    }
146                }
147                variants.push(res);
148                variants_size.push(var_args_size);
149            }
150
151            quote! {
152                #[automatically_derived]
153                impl #impl_generics mem_dbg::CopyType for #input_ident #ty_generics #where_clause
154                {
155                    type Copy = #copy_type;
156                }
157
158                #[automatically_derived]
159                impl #impl_generics mem_dbg::MemSize for #input_ident #ty_generics #where_clause {
160                    fn mem_size(&self, _memsize_flags: mem_dbg::SizeFlags) -> usize {
161                        match self {
162                            #(
163                               #input_ident::#variants => #variants_size,
164                            )*
165                        }
166                    }
167                }
168            }
169        }
170
171        Data::Union(u) => {
172            // We only support single-field unions completely or call to mem_size on unions
173            // without the FOLLOW_REFS flag, as we cannot know programmatically
174            // which field is initialized.
175
176            let fields = u.fields.named.iter().collect::<Vec<_>>();
177
178            match fields.len() {
179                0 => unreachable!("Empty unions are not supported by the Rust programming language."),
180                1 => {
181                    let field = fields[0];
182                    let field_ty = &field.ty;
183                    let ident = field.ident.as_ref().unwrap();
184                    where_clause
185                        .predicates
186                        .push(parse_quote_spanned!(field.span() => #field_ty: mem_dbg::MemSize));
187                    quote! {
188                        #[automatically_derived]
189                        impl #impl_generics mem_dbg::CopyType for #input_ident #ty_generics #where_clause
190                        {
191                            type Copy = #copy_type;
192                        }
193
194                        #[automatically_derived]
195                        impl #impl_generics mem_dbg::MemSize for #input_ident #ty_generics #where_clause {
196                            fn mem_size(&self, _memsize_flags: mem_dbg::SizeFlags) -> usize {
197                                unsafe{<#field_ty as mem_dbg::MemSize>::mem_size(&self.#ident, _memsize_flags)}
198                            }
199                        }
200                    }
201                }
202                number_of_fields => unimplemented!(
203                    "mem_dbg::MemSize for unions with more than one field ({}) is not supported.", number_of_fields
204                )
205            }
206        }
207    }.into()
208}
209
210/**
211
212Generate a `mem_dbg::MemDbg` implementation for custom types.
213
214Presently we do not support unions.
215
216*/
217#[proc_macro_derive(MemDbg)]
218pub fn mem_dbg_mem_dbg(input: TokenStream) -> TokenStream {
219    let mut input = parse_macro_input!(input as DeriveInput);
220
221    let input_ident = input.ident;
222    input.generics.make_where_clause();
223    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
224    let mut where_clause = where_clause.unwrap().clone(); // We just created it
225
226    match input.data {
227        Data::Struct(s) => {
228            let mut id_offset_pushes = vec![];
229            let mut match_code = vec![];
230
231            for (field_idx, field) in s.fields.iter().enumerate() {
232                // Use the field name for named structures, and the index
233                // for tuple structures
234                let field_ident = field
235                    .ident
236                    .to_owned()
237                    .map(|t| t.to_token_stream())
238                    .unwrap_or_else(|| syn::Index::from(field_idx).to_token_stream());
239
240                let field_ident_str = field
241                    .ident
242                    .to_owned()
243                    .map(|t| t.to_string().to_token_stream())
244                    .unwrap_or_else(|| field_idx.to_string().to_token_stream());
245
246                let field_ty = &field.ty;
247                where_clause
248                    .predicates
249                    .push(parse_quote_spanned!(field.span() => #field_ty: mem_dbg::MemDbgImpl));
250
251                // We push the field index and its offset
252                id_offset_pushes.push(quote!{
253                    id_sizes.push((#field_idx, core::mem::offset_of!(#input_ident #ty_generics, #field_ident)));
254                });
255                // This is the arm of the match statement that invokes
256                // _mem_dbg_depth_on on the field.
257                match_code.push(quote!{
258                    #field_idx => <#field_ty as mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(&self.#field_ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, Some(#field_ident_str), i == n - 1, padded_size, _memdbg_flags)?,
259                });
260            }
261
262            quote! {
263                #[automatically_derived]
264                impl #impl_generics mem_dbg::MemDbgImpl for #input_ident #ty_generics #where_clause {
265                    #[inline(always)]
266                    fn _mem_dbg_rec_on(
267                        &self,
268                        _memdbg_writer: &mut impl core::fmt::Write,
269                        _memdbg_total_size: usize,
270                        _memdbg_max_depth: usize,
271                        _memdbg_prefix: &mut String,
272                        _memdbg_is_last: bool,
273                        _memdbg_flags: mem_dbg::DbgFlags,
274                    ) -> core::fmt::Result {
275                        let mut id_sizes: Vec<(usize, usize)> = vec![];
276                        #(#id_offset_pushes)*
277                        let n = id_sizes.len();
278                        id_sizes.push((n, core::mem::size_of::<Self>()));
279                        // Sort by offset
280                        id_sizes.sort_by_key(|x| x.1);
281                        // Compute padded sizes
282                        for i in 0..n {
283                            id_sizes[i].1 = id_sizes[i + 1].1 - id_sizes[i].1;
284                        };
285                        // Put the candle back unless the user requested otherwise
286                        if ! _memdbg_flags.contains(mem_dbg::DbgFlags::RUST_LAYOUT) {
287                            id_sizes.sort_by_key(|x| x.0);
288                        }
289
290                        for (i, (field_idx, padded_size)) in id_sizes.into_iter().enumerate().take(n) {
291                            match field_idx {
292                                #(#match_code)*
293                                _ => unreachable!(),
294                            }
295                        }
296                        Ok(())
297                    }
298                }
299            }
300        }
301
302        Data::Enum(e) => {
303            let mut variants = Vec::new();
304            let mut variants_code = Vec::new();
305
306            for variant in &e.variants {
307                let variant_ident = &variant.ident;
308                let mut res = variant.ident.to_owned().to_token_stream();
309                // Depending on the presence of the feature offset_of_enum, this
310                // will contains field indices and offset_of or field indices
311                // and size_of; in the latter case, we will assume size_of to be
312                // the padded size, resulting in no padding.
313                let mut id_offset_pushes = vec![];
314                let mut match_code = vec![];
315                let mut arrow = '╰';
316                match &variant.fields {
317                    syn::Fields::Unit => {},
318                    syn::Fields::Named(fields) => {
319                        let mut args = proc_macro2::TokenStream::new();
320                        if !fields.named.is_empty() {
321                            arrow = '├';
322                        }
323                        for (field_idx, field) in fields.named.iter().enumerate() {
324                            let field_ty = &field.ty;
325                            let field_ident = field.ident.as_ref().unwrap();
326                            let field_ident_str = format!("{}", field_ident);
327
328                            #[cfg(feature = "offset_of_enum")]
329                            id_offset_pushes.push(quote!{
330                                // We push the offset of the field, which will
331                                // be used to compute the padded size.
332                                id_sizes.push((#field_idx, core::mem::offset_of!(#input_ident #ty_generics, #variant_ident . #field_ident)));
333                            });
334                            #[cfg(not(feature = "offset_of_enum"))]
335                            id_offset_pushes.push(quote!{
336                                id_sizes.push((#field_idx, core::mem::offset_of!(#input_ident #ty_generics, #variant_ident . #field_ident)));
337                                // We push the size of the field, which will be
338                                // used as a surrogate of the padded size.
339                                id_sizes.push((#field_idx, std::mem::size_of_val(#field_ident)));
340                            });
341
342                            // This is the arm of the match statement that
343                            // invokes _mem_dbg_depth_on on the field.
344                            match_code.push(quote! {
345                                #field_idx => <#field_ty as mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(#field_ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, Some(#field_ident_str), i == n - 1, padded_size, _memdbg_flags)?,
346                            });
347                            args.extend([field_ident.to_token_stream()]);
348                            args.extend([quote! {,}]);
349
350                            let field_ty = &field.ty;
351                            where_clause
352                                .predicates
353                                .push(parse_quote_spanned!(field.span()=> #field_ty: mem_dbg::MemDbgImpl));
354                        }
355                        // extend res with the args sourrounded by curly braces
356                        res.extend(quote! {
357                            // TODO: sanitize somehow the names or it'll be
358                            // havoc.
359                            { #args }
360                        });
361                    }
362                    syn::Fields::Unnamed(fields) => {
363                        let mut args = proc_macro2::TokenStream::new();
364                        if !fields.unnamed.is_empty() {
365                            arrow = '├';
366                        }
367                        for (field_idx, field) in fields.unnamed.iter().enumerate() {
368                            let field_ident = syn::Ident::new(
369                                &format!("v{}", field_idx),
370                                proc_macro2::Span::call_site(),
371                            )
372                            .to_token_stream();
373                            let field_ty = &field.ty;
374                            let field_ident_str = format!("{}", field_idx);
375                            let _field_tuple_idx = syn::Index::from(field_idx);
376
377                            #[cfg(feature = "offset_of_enum")]
378                            id_offset_pushes.push(quote!{
379                                // We push the offset of the field, which will
380                                // be used to compute the padded size.
381                                id_sizes.push((#field_idx, core::mem::offset_of!(#input_ident #ty_generics, #variant_ident . #_field_tuple_idx)));
382                            });
383
384                            #[cfg(not(feature = "offset_of_enum"))]
385                            id_offset_pushes.push(quote!{
386                                // We push the size of the field, which will be
387                                // used as a surrogate of the padded size.
388                                id_sizes.push((#field_idx, std::mem::size_of_val(#field_ident)));
389                            });
390
391                            // This is the arm of the match statement that
392                            // invokes _mem_dbg_depth_on on the field.
393                            match_code.push(quote! {
394                                #field_idx => <#field_ty as mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(#field_ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, Some(#field_ident_str), i == n - 1, padded_size, _memdbg_flags)?,
395                            });
396
397                            args.extend([field_ident]);
398                            args.extend([quote! {,}]);
399
400                            let field_ty = &field.ty;
401                            where_clause
402                                .predicates
403                                .push(parse_quote_spanned!(field.span()=> #field_ty: mem_dbg::MemDbgImpl));
404                        }
405                        // extend res with the args sourrounded by curly braces
406                        res.extend(quote! {
407                            ( #args )
408                        });
409                    }
410                }
411                variants.push(res);
412                let variant_name = format!("Variant: {}\n", variant.ident);
413                variants_code.push(quote!{{
414                    _memdbg_writer.write_char(#arrow)?;
415                    _memdbg_writer.write_char('╴')?;
416                    _memdbg_writer.write_str(#variant_name)?;
417                }});
418
419                // There's some abundant code duplication here, but we need to
420                // keep the #[cfg] attributes outside of the quote! macro.
421
422                #[cfg(feature = "offset_of_enum")]
423                variants_code.push(quote!{{
424                    let mut id_sizes: Vec<(usize, usize)> = vec![];
425                    #(#id_offset_pushes)*
426                    let n = id_sizes.len();
427
428                    // We use the offset_of information to build the real
429                    // space occupied by a field.
430                    id_sizes.push((n, core::mem::size_of::<Self>()));
431                    // Sort by offset
432                    id_sizes.sort_by_key(|x| x.1);
433                    // Compute padded sizes
434                    for i in 0..n {
435                        id_sizes[i].1 = id_sizes[i + 1].1 - id_sizes[i].1;
436                    };
437                    // Put the candle back unless the user requested otherwise
438                    if ! _memdbg_flags.contains(mem_dbg::DbgFlags::RUST_LAYOUT) {
439                        id_sizes.sort_by_key(|x| x.0);
440                    }
441
442                    for (i, (field_idx, padded_size)) in id_sizes.into_iter().enumerate().take(n) {
443                        match field_idx {
444                            #(#match_code)*
445                            _ => unreachable!(),
446                        }
447                    }
448                }});
449
450                #[cfg(not(feature = "offset_of_enum"))]
451                variants_code.push(quote!{{
452                    let mut id_sizes: Vec<(usize, usize)> = vec![];
453                    #(#id_offset_pushes)*
454                    let n = id_sizes.len();
455
456                    // Lacking offset_of for enums, id_sizes contains the
457                    // size_of of each field which we use as a surrogate of
458                    // the padded size.
459                    assert!(!_memdbg_flags.contains(mem_dbg::DbgFlags::RUST_LAYOUT), "DbgFlags::RUST_LAYOUT for enums requires the offset_of_enum feature");
460
461                    for (i, (field_idx, padded_size)) in id_sizes.into_iter().enumerate().take(n) {
462                        match field_idx {
463                            #(#match_code)*
464                            _ => unreachable!(),
465                        }
466                    }
467                }});
468            }
469
470            quote! {
471                #[automatically_derived]
472                impl #impl_generics mem_dbg::MemDbgImpl  for #input_ident #ty_generics #where_clause {
473                    #[inline(always)]
474                    fn _mem_dbg_rec_on(
475                        &self,
476                        _memdbg_writer: &mut impl core::fmt::Write,
477                        _memdbg_total_size: usize,
478                        _memdbg_max_depth: usize,
479                        _memdbg_prefix: &mut String,
480                        _memdbg_is_last: bool,
481                        _memdbg_flags: mem_dbg::DbgFlags,
482                    ) -> core::fmt::Result {
483                        let mut _memdbg_digits_number = mem_dbg::n_of_digits(_memdbg_total_size);
484                        if _memdbg_flags.contains(mem_dbg::DbgFlags::SEPARATOR) {
485                            _memdbg_digits_number += _memdbg_digits_number / 3;
486                        }
487                        if _memdbg_flags.contains(mem_dbg::DbgFlags::HUMANIZE) {
488                            _memdbg_digits_number = 6;
489                        }
490
491                        if _memdbg_flags.contains(mem_dbg::DbgFlags::PERCENTAGE) {
492                            _memdbg_digits_number += 8;
493                        }
494
495                        for _ in 0.._memdbg_digits_number + 3 {
496                            _memdbg_writer.write_char(' ')?;
497                        }
498                        if !_memdbg_prefix.is_empty() {
499                            _memdbg_writer.write_str(&_memdbg_prefix[2..])?;
500                        }
501                        match self {
502                            #(
503                               #input_ident::#variants => #variants_code,
504                            )*
505                        }
506                        Ok(())
507                   }
508                }
509            }
510        }
511
512        Data::Union(u) => {
513            // We only support single-field unions for the MemDbg.
514
515            let fields = u.fields.named.iter().collect::<Vec<_>>();
516
517            match fields.len() {
518                0 => unreachable!("Empty unions are not supported by the Rust programming language."),
519                1 => {
520                    let field = fields[0];
521                    let field_ty = &field.ty;
522                    let ident = field.ident.as_ref().unwrap();
523                    where_clause
524                        .predicates
525                        .push(parse_quote_spanned!(field.span() => #field_ty: mem_dbg::MemDbgImpl));
526                    quote! {
527                        #[automatically_derived]
528                        impl #impl_generics mem_dbg::MemDbgImpl for #input_ident #ty_generics #where_clause {
529                            #[inline(always)]
530                            fn _mem_dbg_rec_on(
531                                &self,
532                                _memdbg_writer: &mut impl core::fmt::Write,
533                                _memdbg_total_size: usize,
534                                _memdbg_max_depth: usize,
535                                _memdbg_prefix: &mut String,
536                                _memdbg_is_last: bool,
537                                _memdbg_flags: mem_dbg::DbgFlags,
538                            ) -> core::fmt::Result {
539                                unsafe{<#field_ty as mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(&self.#ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, None, _memdbg_is_last, core::mem::size_of::<#field_ty>(), _memdbg_flags)}
540                            }
541                        }
542                    }
543                }
544                _ => unimplemented!(
545                    "mem_dbg::MemDbg for unions with more than one field is not supported."
546                )
547            }
548        }
549    }.into()
550}