Skip to main content

mem_dbg_derive/
lib.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Tommaso Fontana
3 * SPDX-FileCopyrightText: 2023 Inria
4 * SPDX-FileCopyrightText: 2023 Sebastiano Vigna
5 *
6 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
7 */
8
9//! Derive procedural macros for the [`mem_dbg`](https://crates.io/crates/mem_dbg) crate.
10
11use proc_macro::TokenStream;
12use quote::{ToTokens, quote};
13use syn::{
14    Data, DeriveInput, parse_macro_input, parse_quote, parse_quote_spanned, spanned::Spanned,
15};
16
17/// Generate a `mem_dbg::MemSize` implementation for custom types.
18///
19/// Presently we do not support unions.
20///
21/// The attribute `mem_size_flat` can be used on flat types (typically [`Copy`]
22/// + `'static`) that do not contain non-`'static` references to make
23/// `MemSize::mem_size` faster on arrays, vectors, slices, and supported
24/// containers.
25///
26/// When all fields implement `FlatType<Flat=True>` but neither `#[mem_size_flat]`
27/// nor `#[mem_size_rec]` is present, a compile-time error is emitted. Use
28/// `#[mem_size_rec]` to explicitly silence this check when the type is
29/// intentionally not `#[mem_size_flat]`.
30///
31/// See `mem_dbg::FlatType` for more details.
32#[proc_macro_derive(MemSize, attributes(mem_size_flat, mem_size_rec))]
33pub fn mem_dbg_mem_size(input: TokenStream) -> TokenStream {
34    let mut input = parse_macro_input!(input as DeriveInput);
35
36    let input_ident = input.ident;
37    input.generics.make_where_clause();
38    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
39    let mut where_clause = where_clause.unwrap().clone(); // We just created it
40
41    let is_flat = input
42        .attrs
43        .iter()
44        .any(|x| x.meta.path().is_ident("mem_size_flat"));
45
46    let is_rec = input
47        .attrs
48        .iter()
49        .any(|x| x.meta.path().is_ident("mem_size_rec"));
50
51    if is_flat && is_rec {
52        return syn::Error::new_spanned(
53            &input_ident,
54            "cannot use both #[mem_size_flat] and #[mem_size_rec] on the same type",
55        )
56        .to_compile_error()
57        .into();
58    }
59
60    let flat_type: syn::Expr = if is_flat {
61        parse_quote!(::mem_dbg::True)
62    } else {
63        parse_quote!(::mem_dbg::False)
64    };
65
66    match input.data {
67        Data::Struct(s) => {
68            let mut fields_ident = vec![];
69            let mut fields_ty = vec![];
70
71            for (field_idx, field) in s.fields.iter().enumerate() {
72                fields_ident.push(
73                    field
74                        .ident
75                        .to_owned()
76                        .map(|t| t.to_token_stream())
77                        .unwrap_or(syn::Index::from(field_idx).to_token_stream()),
78                );
79                fields_ty.push(field.ty.to_token_stream());
80                let field_ty = &field.ty;
81                // Add MemSize and FlatType bounds to all fields
82                where_clause
83                    .predicates
84                    .push(parse_quote_spanned!(field.span()=> #field_ty: ::mem_dbg::MemSize + ::mem_dbg::FlatType));
85            }
86
87            let const_assert = if !is_flat && !is_rec {
88                let msg = format!(
89                    "Structure {} could be #[mem_size_flat], but it has not been declared as such; use either the #[mem_size_flat] or the #[mem_size_rec] attribute to silence this error",
90                    input_ident
91                );
92                quote! {
93                    const { assert!(
94                        !(true #(&& <<#fields_ty as ::mem_dbg::FlatType>::Flat
95                                      as ::mem_dbg::Boolean>::VALUE)*),
96                        #msg
97                    ); }
98                }
99            } else {
100                quote! {}
101            };
102
103            quote! {
104                #[automatically_derived]
105                impl #impl_generics ::mem_dbg::FlatType for #input_ident #ty_generics #where_clause
106                {
107                    type Flat = #flat_type;
108                }
109
110                #[automatically_derived]
111                impl #impl_generics ::mem_dbg::MemSize for #input_ident #ty_generics #where_clause {
112                    fn mem_size_rec(&self, _memsize_flags: ::mem_dbg::SizeFlags, _memsize_refs: &mut ::mem_dbg::HashMap<usize, usize>) -> usize {
113                        #const_assert
114                        let mut bytes = ::core::mem::size_of::<Self>();
115                        #(bytes += <#fields_ty as ::mem_dbg::MemSize>::mem_size_rec(&self.#fields_ident, _memsize_flags, _memsize_refs) - ::core::mem::size_of::<#fields_ty>();)*
116                        bytes
117                    }
118                }
119            }
120        }
121
122        Data::Enum(e) => {
123            let mut variants = Vec::new();
124            let mut variants_size = Vec::new();
125            let mut all_field_types = Vec::new();
126
127            for variant in e.variants {
128                let mut res = variant.ident.to_owned().to_token_stream();
129                let mut var_args_size = quote! {::core::mem::size_of::<Self>()};
130                match &variant.fields {
131                    syn::Fields::Unit => {}
132                    syn::Fields::Named(fields) => {
133                        let mut args = proc_macro2::TokenStream::new();
134                        for field in &fields.named {
135                            let field_ty = &field.ty;
136                            where_clause
137                                .predicates
138                                .push(parse_quote_spanned!(field.span() => #field_ty: ::mem_dbg::MemSize + ::mem_dbg::FlatType));
139                            if !is_flat && !is_rec {
140                                all_field_types.push(field.ty.to_token_stream());
141                            }
142                            let field_ident = field.ident.as_ref().unwrap();
143                            // Use a prefixed binding to avoid shadowing
144                            // generated locals.
145                            let binding_ident = syn::Ident::new(
146                                &format!("_memsize_{}", field_ident),
147                                field_ident.span(),
148                            );
149                            let field_ty = field.ty.to_token_stream();
150                            var_args_size.extend([quote! {
151                                + <#field_ty as ::mem_dbg::MemSize>::mem_size_rec(#binding_ident, _memsize_flags, _memsize_refs) - ::core::mem::size_of::<#field_ty>()
152                            }]);
153                            args.extend([quote! { #field_ident: #binding_ident, }]);
154                        }
155                        // Extend res with the args surrounded by curly braces
156                        res.extend(quote! {
157                            { #args }
158                        });
159                    }
160                    syn::Fields::Unnamed(fields) => {
161                        let mut args = proc_macro2::TokenStream::new();
162
163                        for (field_idx, field) in fields.unnamed.iter().enumerate() {
164                            let ident = syn::Ident::new(
165                                &format!("v{}", field_idx),
166                                proc_macro2::Span::call_site(),
167                            )
168                            .to_token_stream();
169                            let field_ty = field.ty.to_token_stream();
170                            var_args_size.extend([quote! {
171                                + <#field_ty as ::mem_dbg::MemSize>::mem_size_rec(#ident, _memsize_flags, _memsize_refs) - ::core::mem::size_of::<#field_ty>()
172                            }]);
173                            args.extend([ident]);
174                            args.extend([quote! {,}]);
175
176                            where_clause
177                                .predicates
178                                .push(parse_quote_spanned!(field.span()=> #field_ty: ::mem_dbg::MemSize + ::mem_dbg::FlatType));
179                            if !is_flat && !is_rec {
180                                all_field_types.push(field.ty.to_token_stream());
181                            }
182                        }
183                        // extend res with the args surrounded by curly braces
184                        res.extend(quote! {
185                            ( #args )
186                        });
187                    }
188                }
189                variants.push(res);
190                variants_size.push(var_args_size);
191            }
192
193            let const_assert = if !is_flat && !is_rec {
194                let msg = format!(
195                    "Enum {} could be #[mem_size_flat], but it has not been declared as such; use either the #[mem_size_flat] or the #[mem_size_rec] attribute to silence this error",
196                    input_ident
197                );
198                quote! {
199                    const { assert!(
200                        !(true #(&& <<#all_field_types as ::mem_dbg::FlatType>::Flat
201                                      as ::mem_dbg::Boolean>::VALUE)*),
202                        #msg
203                    ); }
204                }
205            } else {
206                quote! {}
207            };
208
209            quote! {
210                #[automatically_derived]
211                impl #impl_generics ::mem_dbg::FlatType for #input_ident #ty_generics #where_clause
212                {
213                    type Flat = #flat_type;
214                }
215
216                #[automatically_derived]
217                impl #impl_generics ::mem_dbg::MemSize for #input_ident #ty_generics #where_clause {
218                    fn mem_size_rec(&self, _memsize_flags: ::mem_dbg::SizeFlags, _memsize_refs: &mut ::mem_dbg::HashMap<usize, usize>) -> usize {
219                        #const_assert
220                        match self {
221                            #(
222                               #input_ident::#variants => #variants_size,
223                            )*
224                        }
225                    }
226                }
227            }
228        }
229
230        Data::Union(u) => {
231            return syn::Error::new_spanned(u.union_token, "MemSize for unions is not supported; see the Unions section in the README for a manual implementation pattern")
232                .to_compile_error()
233                .into();
234        }
235    }.into()
236}
237
238/// Generate a `mem_dbg::MemDbg` implementation for custom types.
239///
240/// Presently we do not support unions.
241#[proc_macro_derive(MemDbg)]
242pub fn mem_dbg_mem_dbg(input: TokenStream) -> TokenStream {
243    let mut input = parse_macro_input!(input as DeriveInput);
244
245    let input_ident = input.ident;
246    input.generics.make_where_clause();
247    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
248    let mut where_clause = where_clause.unwrap().clone(); // We just created it
249
250    match input.data {
251        Data::Struct(s) => {
252            let mut id_offset_pushes = vec![];
253            let mut match_code = vec![];
254
255            for (field_idx, field) in s.fields.iter().enumerate() {
256                // Use the field name for named structures, and the index
257                // for tuple structures
258                let field_ident = field
259                    .ident
260                    .to_owned()
261                    .map(|t| t.to_token_stream())
262                    .unwrap_or_else(|| syn::Index::from(field_idx).to_token_stream());
263
264                let field_ident_str = field
265                    .ident
266                    .to_owned()
267                    .map(|t| t.to_string().to_token_stream())
268                    .unwrap_or_else(|| field_idx.to_string().to_token_stream());
269
270                let field_ty = &field.ty;
271                where_clause
272                    .predicates
273                    .push(parse_quote_spanned!(field.span() => #field_ty: ::mem_dbg::MemDbgImpl + ::mem_dbg::FlatType));
274
275                // We push the field index and its offset
276                id_offset_pushes.push(quote!{
277                    id_sizes.push((#field_idx, ::core::mem::offset_of!(#input_ident #ty_generics, #field_ident)));
278                });
279                // This is the arm of the match statement that invokes
280                // _mem_dbg_depth_on on the field.
281                match_code.push(quote!{
282                    #field_idx => <#field_ty as ::mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(&self.#field_ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, Some(#field_ident_str), i == n - 1, padded_size, _memdbg_flags, _memdbg_refs)?,
283                });
284            }
285
286            quote! {
287                #[automatically_derived]
288                impl #impl_generics ::mem_dbg::MemDbgImpl for #input_ident #ty_generics #where_clause {
289                    fn _mem_dbg_rec_on(
290                        &self,
291                        _memdbg_writer: &mut impl ::core::fmt::Write,
292                        _memdbg_total_size: usize,
293                        _memdbg_max_depth: usize,
294                        _memdbg_prefix: &mut String,
295                        _memdbg_is_last: bool,
296                        _memdbg_flags: ::mem_dbg::DbgFlags,
297                        _memdbg_refs: &mut ::mem_dbg::HashSet<usize>,
298                    ) -> ::core::fmt::Result {
299                        let mut id_sizes: Vec<(usize, usize)> = vec![];
300                        #(#id_offset_pushes)*
301                        let n = id_sizes.len();
302                        id_sizes.push((n, ::core::mem::size_of::<Self>()));
303                        // Sort by offset
304                        id_sizes.sort_by_key(|x| x.1);
305                        // Compute padded sizes
306                        for i in 0..n {
307                            id_sizes[i].1 = id_sizes[i + 1].1 - id_sizes[i].1;
308                        };
309                        // Put the candle back unless the user requested otherwise
310                        if ! _memdbg_flags.contains(::mem_dbg::DbgFlags::RUST_LAYOUT) {
311                            id_sizes.sort_by_key(|x| x.0);
312                        }
313
314                        for (i, (field_idx, padded_size)) in id_sizes.into_iter().enumerate().take(n) {
315                            match field_idx {
316                                #(#match_code)*
317                                _ => unreachable!(),
318                            }
319                        }
320                        Ok(())
321                    }
322                }
323            }
324        }
325
326        Data::Enum(e) => {
327            let mut variants = Vec::new();
328            let mut variants_code = Vec::new();
329
330            for variant in &e.variants {
331                let mut res = variant.ident.to_owned().to_token_stream();
332                // Depending on the presence of the feature offset_of_enum, this
333                // will contains field indices and offset_of or field indices
334                // and size_of; in the latter case, we will assume size_of to be
335                // the padded size, resulting in no padding.
336                let mut id_offset_pushes = vec![];
337                let mut match_code = vec![];
338                let mut arrow = '╰';
339                match &variant.fields {
340                    syn::Fields::Unit => {},
341                    syn::Fields::Named(fields) => {
342                        let mut args = proc_macro2::TokenStream::new();
343                        if !fields.named.is_empty() {
344                            arrow = '├';
345                        }
346                        for (field_idx, field) in fields.named.iter().enumerate() {
347                            let field_ty = &field.ty;
348                            let field_ident = field.ident.as_ref().unwrap();
349                            let field_ident_str = format!("{}", field_ident);
350                            // Use a prefixed binding to avoid shadowing
351                            // generated locals such as `n`, `i`, etc.
352                            let binding_ident = syn::Ident::new(
353                                &format!("_memdbg_{}", field_ident),
354                                field_ident.span(),
355                            );
356
357                            #[cfg(feature = "offset_of_enum")]
358                            id_offset_pushes.push({
359                                let variant_ident = &variant.ident;
360                                quote!{
361                                    // We push the offset of the field, which will
362                                    // be used to compute the padded size.
363                                    id_sizes.push((#field_idx, ::core::mem::offset_of!(#input_ident #ty_generics, #variant_ident . #field_ident)));
364                                }
365                            });
366                            #[cfg(not(feature = "offset_of_enum"))]
367                            id_offset_pushes.push(quote!{
368                                // We push the size of the field, which will be
369                                // used as a surrogate of the padded size.
370                                id_sizes.push((#field_idx, ::core::mem::size_of_val(#binding_ident)));
371                            });
372
373                            // This is the arm of the match statement that
374                            // invokes _mem_dbg_depth_on on the field.
375                            match_code.push(quote! {
376                                #field_idx => <#field_ty as ::mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(#binding_ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, Some(#field_ident_str), i == n - 1, padded_size, _memdbg_flags, _memdbg_refs)?,
377                            });
378                            args.extend([quote! { #field_ident: #binding_ident, }]);
379
380                            let field_ty = &field.ty;
381                            where_clause
382                                .predicates
383                                .push(parse_quote_spanned!(field.span()=> #field_ty: ::mem_dbg::MemDbgImpl + ::mem_dbg::FlatType));
384                        }
385                        // Extend res with the args surrounded by curly braces
386                        res.extend(quote! {
387                            { #args }
388                        });
389                    }
390                    syn::Fields::Unnamed(fields) => {
391                        let mut args = proc_macro2::TokenStream::new();
392                        if !fields.unnamed.is_empty() {
393                            arrow = '├';
394                        }
395                        for (field_idx, field) in fields.unnamed.iter().enumerate() {
396                            let field_ident = syn::Ident::new(
397                                &format!("v{}", field_idx),
398                                proc_macro2::Span::call_site(),
399                            )
400                            .to_token_stream();
401                            let field_ty = &field.ty;
402                            let field_ident_str = format!("{}", field_idx);
403                            let _field_tuple_idx = syn::Index::from(field_idx);
404
405                            #[cfg(feature = "offset_of_enum")]
406                            id_offset_pushes.push({
407                                let variant_ident = &variant.ident;
408                                quote!{
409                                    // We push the offset of the field, which will
410                                    // be used to compute the padded size.
411                                    id_sizes.push((#field_idx, ::core::mem::offset_of!(#input_ident #ty_generics, #variant_ident . #_field_tuple_idx)));
412                                }
413                            });
414
415                            #[cfg(not(feature = "offset_of_enum"))]
416                            id_offset_pushes.push(quote!{
417                                // We push the size of the field, which will be
418                                // used as a surrogate of the padded size.
419                                id_sizes.push((#field_idx, ::core::mem::size_of_val(#field_ident)));
420                            });
421
422                            // This is the arm of the match statement that
423                            // invokes _mem_dbg_depth_on on the field.
424                            match_code.push(quote! {
425                                #field_idx => <#field_ty as ::mem_dbg::MemDbgImpl>::_mem_dbg_depth_on(#field_ident, _memdbg_writer, _memdbg_total_size, _memdbg_max_depth, _memdbg_prefix, Some(#field_ident_str), i == n - 1, padded_size, _memdbg_flags, _memdbg_refs)?,
426                            });
427
428                            args.extend([field_ident]);
429                            args.extend([quote! {,}]);
430
431                            let field_ty = &field.ty;
432                            where_clause
433                                .predicates
434                                .push(parse_quote_spanned!(field.span()=> #field_ty: ::mem_dbg::MemDbgImpl + ::mem_dbg::FlatType));
435                        }
436                        // extend res with the args surrounded by curly braces
437                        res.extend(quote! {
438                            ( #args )
439                        });
440                    }
441                }
442                variants.push(res);
443                let variant_name = format!("Variant: {}\n", variant.ident);
444
445                // There's some code duplication here, but we need to keep the
446                // #[cfg] attributes outside of the quote! macro.
447                // IMPORTANT: We must push exactly ONE item to variants_code per
448                // variant to match the length of the variants Vec.
449
450                #[cfg(feature = "offset_of_enum")]
451                variants_code.push(quote!{{
452                    _memdbg_writer.write_char(#arrow)?;
453                    _memdbg_writer.write_char('╴')?;
454                    _memdbg_writer.write_str(#variant_name)?;
455
456                    let mut id_sizes: Vec<(usize, usize)> = vec![];
457                    #(#id_offset_pushes)*
458                    let n = id_sizes.len();
459
460                    // We use the offset_of information to build the real
461                    // space occupied by a field.
462                    id_sizes.push((n, ::core::mem::size_of::<Self>()));
463                    // Sort by offset
464                    id_sizes.sort_by_key(|x| x.1);
465                    // Compute padded sizes
466                    for i in 0..n {
467                        id_sizes[i].1 = id_sizes[i + 1].1 - id_sizes[i].1;
468                    };
469                    // Put the candle back unless the user requested otherwise
470                    if ! _memdbg_flags.contains(::mem_dbg::DbgFlags::RUST_LAYOUT) {
471                        id_sizes.sort_by_key(|x| x.0);
472                    }
473
474                    for (i, (field_idx, padded_size)) in id_sizes.into_iter().enumerate().take(n) {
475                        match field_idx {
476                            #(#match_code)*
477                            _ => unreachable!(),
478                        }
479                    }
480                }});
481
482                #[cfg(not(feature = "offset_of_enum"))]
483                variants_code.push(quote!{{
484                    _memdbg_writer.write_char(#arrow)?;
485                    _memdbg_writer.write_char('╴')?;
486                    _memdbg_writer.write_str(#variant_name)?;
487
488                    let mut id_sizes: Vec<(usize, usize)> = vec![];
489                    #(#id_offset_pushes)*
490                    let n = id_sizes.len();
491
492                    // Lacking offset_of for enums, id_sizes contains the
493                    // size_of of each field which we use as a surrogate of
494                    // the padded size.
495                    assert!(!_memdbg_flags.contains(::mem_dbg::DbgFlags::RUST_LAYOUT), "DbgFlags::RUST_LAYOUT for enums requires the offset_of_enum feature");
496
497                    for (i, (field_idx, padded_size)) in id_sizes.into_iter().enumerate().take(n) {
498                        match field_idx {
499                            #(#match_code)*
500                            _ => unreachable!(),
501                        }
502                    }
503                }});
504            }
505
506            quote! {
507                #[automatically_derived]
508                impl #impl_generics ::mem_dbg::MemDbgImpl  for #input_ident #ty_generics #where_clause {
509                    fn _mem_dbg_rec_on(
510                        &self,
511                        _memdbg_writer: &mut impl ::core::fmt::Write,
512                        _memdbg_total_size: usize,
513                        _memdbg_max_depth: usize,
514                        _memdbg_prefix: &mut String,
515                        _memdbg_is_last: bool,
516                        _memdbg_flags: ::mem_dbg::DbgFlags,
517                        _memdbg_refs: &mut ::mem_dbg::HashSet<usize>,
518                    ) -> ::core::fmt::Result {
519                        let mut _memdbg_digits_number = ::mem_dbg::n_of_digits(_memdbg_total_size);
520                        if _memdbg_flags.contains(::mem_dbg::DbgFlags::SEPARATOR) {
521                            _memdbg_digits_number += _memdbg_digits_number / 3;
522                        }
523                        if _memdbg_flags.contains(::mem_dbg::DbgFlags::HUMANIZE) {
524                            _memdbg_digits_number = 6;
525                        }
526
527                        if _memdbg_flags.contains(::mem_dbg::DbgFlags::PERCENTAGE) {
528                            _memdbg_digits_number += 8;
529                        }
530
531                        for _ in 0.._memdbg_digits_number + 3 {
532                            _memdbg_writer.write_char(' ')?;
533                        }
534                        if !_memdbg_prefix.is_empty() {
535                            // Find the byte index of the 3rd character (skip first 2 chars)
536                            // to handle multi-byte UTF-8 characters like "│"
537                            let start_byte = _memdbg_prefix
538                                .char_indices()
539                                .nth(2)
540                                .map(|(idx, _)| idx)
541                                .unwrap_or(_memdbg_prefix.len());
542                            _memdbg_writer.write_str(&_memdbg_prefix[start_byte..])?;
543                        }
544                        match self {
545                            #(
546                               #input_ident::#variants => #variants_code,
547                            )*
548                        }
549                        Ok(())
550                   }
551                }
552            }
553        }
554
555        Data::Union(u) => {
556            return syn::Error::new_spanned(u.union_token, "MemDbg for unions is not supported; see the Unions section in the README for a manual implementation pattern")
557                .to_compile_error()
558                .into();
559        }
560    }.into()
561}