heapsz_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::result;
4
5use proc_macro2::{Span, TokenStream};
6use quote::{quote, quote_spanned, ToTokens};
7use syn::{
8    punctuated::Punctuated, spanned::Spanned, Attribute, Data, DataStruct, DeriveInput, Expr,
9    ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Ident, Index, Lit, LitStr, Meta,
10    MetaNameValue, Token, Variant,
11};
12
13// #[heap_size]
14const HEAP_IDENT: &str = "heap_size";
15// #[heap_size(with = "...")] Field attributes
16const HEAP_ATTR_WITH_IDENT: &str = "with";
17// #[heap_size(skip)] Field attributes
18const HEAP_ATTR_SKIP_IDENT: &str = "skip";
19
20#[proc_macro_derive(HeapSize, attributes(heap_size))]
21pub fn heap(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22    let input: DeriveInput = match syn::parse(input) {
23        Ok(v) => v,
24        Err(e) => return e.into_compile_error().into(),
25    };
26
27    let tokens = match input.data {
28        Data::Struct(..) => render_struct(input),
29        Data::Enum(..) => render_enum(input),
30        Data::Union(..) => Err(syn::Error::new_spanned(
31            input,
32            "`Heap` can not be derived for a union",
33        )),
34    };
35    tokens.unwrap_or_else(syn::Error::into_compile_error).into()
36}
37
38type Result<T> = result::Result<T, syn::Error>;
39macro_rules! bail {
40    ($token:expr, $($arg:tt)+) => {{
41        return Err(syn::Error::new_spanned($token, format!($($arg)*)))
42    }};
43}
44
45enum HeapAttr {
46    // #[heap_size] on a struct or enum.
47    Container(Meta),
48    // #[heap_size] on a field.
49    Field,
50    // #[heap_size(with = "")] on a field.
51    FieldWith(Meta, LitStr),
52    // #[heap_size(skip)] on a field.
53    FieldSkip(Meta),
54}
55
56impl HeapAttr {
57    fn new<T: ToTokens>(
58        raw_attrs: &[Attribute],
59        is_field: bool,
60        is_variant: bool,
61        origin: T,
62    ) -> Result<Option<Self>> {
63        let mut attrs = vec![];
64        for attr in raw_attrs {
65            match &attr.meta {
66                Meta::List(meta_list) => {
67                    if meta_list.path.is_ident(HEAP_IDENT) {
68                        let heap_attrs = meta_list
69                            .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
70                        if heap_attrs.len() > 1 {
71                            bail!(meta_list, "too many heap_size attributes");
72                        }
73                        attrs.extend(heap_attrs);
74                    }
75                }
76                Meta::Path(path) => {
77                    if path.is_ident(HEAP_IDENT) {
78                        attrs.push(attr.meta.clone());
79                    }
80                }
81                Meta::NameValue(_) => (),
82            }
83        }
84        let meta = if attrs.is_empty() {
85            return Ok(None);
86        } else if attrs.len() == 1 {
87            attrs.pop().unwrap()
88        } else {
89            bail!(origin, "too many heap_size attributes")
90        };
91
92        match meta {
93            Meta::Path(ref name) => {
94                if name.is_ident(HEAP_IDENT) {
95                    if is_field {
96                        Ok(Some(HeapAttr::Field))
97                    } else {
98                        Ok(Some(HeapAttr::Container(meta)))
99                    }
100                } else if name.is_ident(HEAP_ATTR_SKIP_IDENT) {
101                    if is_field || is_variant {
102                        Ok(Some(HeapAttr::FieldSkip(meta)))
103                    } else {
104                        bail!(meta, "`#[heap_size(skip)]` is a field attribute")
105                    }
106                } else if name.is_ident(HEAP_ATTR_WITH_IDENT) {
107                    bail!(
108                        meta,
109                        "heap_size attribute `with` must be followed by \
110                        a module path, `with = \"some::mod\"`"
111                    )
112                } else {
113                    let name = name.to_token_stream().to_string().replace(' ', "");
114                    bail!(meta, "unknown heap_size attribute `{}`", name)
115                }
116            }
117            Meta::NameValue(MetaNameValue {
118                ref path,
119                value:
120                    Expr::Lit(ExprLit {
121                        lit: Lit::Str(ref mod_path),
122                        ..
123                    }),
124                ..
125            }) => {
126                if path.is_ident(HEAP_ATTR_WITH_IDENT) {
127                    Ok(Some(HeapAttr::FieldWith(meta.clone(), mod_path.clone())))
128                } else {
129                    let name = path.to_token_stream().to_string().replace(' ', "");
130                    bail!(meta, "unknown heap_size attribute `{}`", name)
131                }
132            }
133            meta => {
134                let full = meta.to_token_stream().to_string();
135                bail!(meta, "unknown heap attribute `{}`", full)
136            }
137        }
138    }
139}
140
141enum MethodReceiver {
142    FieldIdent,
143    Replace(Ident),
144    PrefixRef(Ident),
145}
146
147struct HeapField {
148    attr: HeapAttr,
149    ident: TokenStream,
150    field: Field,
151}
152
153impl HeapField {
154    fn new(
155        index: usize,
156        field: Field,
157        container_attr: Option<&HeapAttr>,
158        variant_attr: Option<&HeapAttr>,
159    ) -> Result<Option<Self>> {
160        let require_container_attr = |meta| {
161            if let Some(HeapAttr::Container(_)) = container_attr {
162                Ok(None)
163            } else {
164                bail!(
165                    meta,
166                    "`#[heap_size(skip)]` is only allow with a container \
167                    attribute `#[heap_size]`."
168                );
169            }
170        };
171        let attr = match HeapAttr::new(&field.attrs, true, false, &field)? {
172            None => {
173                if let Some(HeapAttr::FieldSkip(meta)) = variant_attr {
174                    return require_container_attr(meta);
175                } else if let Some(HeapAttr::Container(_)) = container_attr {
176                    HeapAttr::Field
177                } else {
178                    return Ok(None);
179                }
180            }
181            Some(HeapAttr::FieldSkip(meta)) => return require_container_attr(&meta),
182            Some(attr) => attr,
183        };
184
185        let ident = field.ident.clone().map_or_else(
186            || {
187                let index = Index {
188                    index: u32::try_from(index).unwrap(),
189                    span: Span::call_site(),
190                };
191                quote!(#index)
192            },
193            |x| quote!(#x),
194        );
195
196        Ok(Some(HeapField { attr, ident, field }))
197    }
198
199    fn method_heap_size(&self, self_: &MethodReceiver) -> Result<TokenStream> {
200        let field_ident = &self.ident;
201        let ident = match self_ {
202            MethodReceiver::FieldIdent => {
203                quote_spanned!(self.field.span()=> #field_ident)
204            }
205            MethodReceiver::Replace(ident) => quote_spanned!(self.field.span()=> #ident),
206            MethodReceiver::PrefixRef(ident) => {
207                quote_spanned!(self.field.span()=> &#ident.#field_ident)
208            }
209        };
210        match self.attr {
211            HeapAttr::Field => Ok(quote_spanned! {self.field.span()=>
212                ::heapsz::HeapSize::heap_size(#ident)
213            }),
214            HeapAttr::FieldWith(ref meta, ref mod_path) => {
215                let path = syn::parse_str::<syn::Path>(&mod_path.value())?;
216                Ok(quote_spanned! {meta.span()=>
217                    #path::heap_size(#ident)
218                })
219            }
220            HeapAttr::FieldSkip(_) => {
221                bail!(
222                    self.field.clone(),
223                    "internal error `#[heap_size(skip)]` field generates `fn heap_size()`",
224                );
225            }
226            HeapAttr::Container(ref meta) => {
227                bail!(
228                    self.field.clone(),
229                    "internal error unexpected container attribute is found on field: {}",
230                    meta.to_token_stream().to_string()
231                );
232            }
233        }
234    }
235}
236
237fn render_struct(input: DeriveInput) -> Result<proc_macro2::TokenStream> {
238    let container_attrs = HeapAttr::new(&input.attrs, false, false, &input)?;
239
240    let ident = input.ident.clone();
241    let Data::Struct(data) = input.data else {
242        bail!(input, "{} should be a struct", ident);
243    };
244    let fields = match data {
245        DataStruct {
246            fields:
247                Fields::Named(FieldsNamed { named: fields, .. })
248                | Fields::Unnamed(FieldsUnnamed {
249                    unnamed: fields, ..
250                }),
251            ..
252        } => fields.into_iter().collect(),
253        DataStruct {
254            fields: Fields::Unit,
255            ..
256        } => Vec::new(),
257    };
258
259    let mut heap_sizes = vec![];
260    let self_ = MethodReceiver::PrefixRef(Ident::new("self", Span::call_site()));
261    for (i, field) in fields.into_iter().enumerate() {
262        if let Some(f) = HeapField::new(i, field.clone(), container_attrs.as_ref(), None)? {
263            heap_sizes.push(f.method_heap_size(&self_)?);
264        }
265    }
266
267    let generics = &input.generics;
268    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
269    Ok(quote! {
270        impl #impl_generics ::heapsz::HeapSize for #ident #ty_generics #where_clause {
271            fn heap_size(&self) -> usize {
272                0 #(+ #heap_sizes)*
273            }
274        }
275    })
276}
277
278fn render_enum(input: DeriveInput) -> Result<TokenStream> {
279    let container_attrs = HeapAttr::new(&input.attrs, false, false, &input)?;
280
281    let ident = input.ident.clone();
282    let Data::Enum(data) = input.data else {
283        bail!(input, "{} should be an enum", ident);
284    };
285    let mut rendered_vars = vec![];
286    for var in data.variants {
287        rendered_vars.push(render_enum_variant(var, container_attrs.as_ref())?);
288    }
289    let matches = if rendered_vars.is_empty() {
290        quote!(0)
291    } else {
292        quote! {
293            #[allow(unused_variables)]
294            match self {
295                #(#rendered_vars)*
296            }
297        }
298    };
299
300    let generics = &input.generics;
301    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
302    Ok(quote! {
303        impl #impl_generics ::heapsz::HeapSize for #ident #ty_generics #where_clause {
304            fn heap_size(&self) -> usize {
305                #matches
306            }
307        }
308    })
309}
310
311fn render_enum_variant(var: Variant, container_attr: Option<&HeapAttr>) -> Result<TokenStream> {
312    let var_attrs = HeapAttr::new(&var.attrs, false, true, &var)?;
313    let var_span = var.span();
314    let var_ident = var.ident;
315    let (match_arm, self_receivers, fields) = match var.fields {
316        Fields::Named(FieldsNamed { named: fields, .. }) => {
317            let idents = fields.iter().map(|f| f.ident.clone().unwrap());
318            let match_arm = quote_spanned! {var_span=>
319                Self::#var_ident { #(#idents,)* }
320            };
321            let self_receivers = fields
322                .iter()
323                .map(|_| MethodReceiver::FieldIdent)
324                .collect::<Vec<_>>();
325            (
326                match_arm,
327                self_receivers,
328                fields.into_iter().collect::<Vec<_>>(),
329            )
330        }
331        Fields::Unnamed(FieldsUnnamed {
332            unnamed: fields, ..
333        }) => {
334            let field_idents = fields
335                .iter()
336                .enumerate()
337                .map(|(i, f)| Ident::new(&format!("f_{i}"), f.span()))
338                .collect::<Vec<_>>();
339            let self_receivers = field_idents
340                .iter()
341                .map(|ident| MethodReceiver::Replace(ident.clone()))
342                .collect::<Vec<_>>();
343            let match_arm = quote_spanned! {var_span=>
344                Self::#var_ident(#(#field_idents,)*)
345            };
346            (
347                match_arm,
348                self_receivers,
349                fields.into_iter().collect::<Vec<_>>(),
350            )
351        }
352        Fields::Unit => {
353            let match_arm = quote_spanned! {var_span=>
354                Self::#var_ident
355            };
356            (match_arm, vec![], vec![])
357        }
358    };
359
360    let mut heap_sizes = vec![];
361    for (i, field) in fields.into_iter().enumerate() {
362        if let Some(f) = HeapField::new(i, field.clone(), container_attr, var_attrs.as_ref())? {
363            heap_sizes.push(f.method_heap_size(&self_receivers[i])?);
364        }
365    }
366
367    Ok(quote_spanned! {var_span=>
368        #match_arm => { 0 #(+ #heap_sizes)* }
369    })
370}