Skip to main content

get_size_derive2/
lib.rs

1#![doc = include_str!("./lib.md")]
2
3use attribute_derive::{Attribute, FromAttr};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6
7#[derive(FromAttr, Default, Debug)]
8#[attribute(ident = get_size)]
9struct StructFieldAttribute {
10    #[attribute(conflicts = [size_fn, ignore])]
11    size: Option<usize>,
12    #[attribute(conflicts = [size, ignore])]
13    size_fn: Option<syn::Ident>,
14    #[attribute(conflicts = [size, size_fn])]
15    ignore: bool,
16}
17
18fn extract_ignored_generics_list(list: &Vec<syn::Attribute>) -> Vec<syn::PathSegment> {
19    let mut collection = Vec::new();
20
21    for attr in list {
22        let mut list = extract_ignored_generics(attr);
23
24        collection.append(&mut list);
25    }
26
27    collection
28}
29
30fn extract_ignored_generics(attr: &syn::Attribute) -> Vec<syn::PathSegment> {
31    let mut collection = Vec::new();
32
33    // Skip all attributes which do not belong to us.
34    if !attr.meta.path().is_ident("get_size") {
35        return collection;
36    }
37
38    // Make sure it is a list: #[get_size(...)]
39    let Ok(list) = attr.meta.require_list() else {
40        return collection;
41    };
42
43    // Parse the nested meta: #[get_size(ignore(...))] or #[get_size(ignore)]
44    let _ = list.parse_nested_meta(|meta| {
45        // Only handle `ignore`
46        if !meta.path.is_ident("ignore") {
47            return Ok(()); // Skip unrelated
48        }
49
50        // Handle the flag case: #[get_size(ignore)]
51        if meta.input.is_empty() {
52            // Do nothing – valid empty ignore
53            return Ok(());
54        }
55
56        // Handle the list case: #[get_size(ignore(A, B))]
57        meta.parse_nested_meta(|meta| {
58            for segment in meta.path.segments {
59                collection.push(segment);
60            }
61            Ok(())
62        })?;
63
64        Ok(())
65    });
66
67    collection
68}
69
70fn collect_all_ignored_generics(ast: &syn::DeriveInput) -> Vec<syn::PathSegment> {
71    let mut ignored = extract_ignored_generics_list(&ast.attrs);
72
73    match &ast.data {
74        syn::Data::Struct(data_struct) => {
75            for field in &data_struct.fields {
76                ignored.extend(extract_ignored_generics_list(&field.attrs));
77            }
78        }
79        syn::Data::Enum(data_enum) => {
80            for variant in &data_enum.variants {
81                ignored.extend(extract_ignored_generics_list(&variant.attrs));
82                for field in &variant.fields {
83                    ignored.extend(extract_ignored_generics_list(&field.attrs));
84                }
85            }
86        }
87        syn::Data::Union(_) => {}
88    }
89
90    ignored
91}
92
93// Add a bound `T: GetSize` to every type parameter T, unless we ignore it.
94fn add_trait_bounds(mut generics: syn::Generics, ignored: &Vec<syn::PathSegment>) -> syn::Generics {
95    for param in &mut generics.params {
96        if let syn::GenericParam::Type(type_param) = param {
97            let mut found = false;
98            for ignored in ignored {
99                if ignored.ident == type_param.ident {
100                    found = true;
101                    break;
102                }
103            }
104
105            if found {
106                continue;
107            }
108
109            type_param
110                .bounds
111                .push(syn::parse_quote!(::get_size2::GetSize));
112        }
113    }
114    generics
115}
116
117#[proc_macro_derive(GetSize, attributes(get_size))]
118pub fn derive_get_size(input: TokenStream) -> TokenStream {
119    match derive_get_size_impl(input) {
120        Ok(tokens) => tokens,
121        Err(err) => err.to_compile_error().into(),
122    }
123}
124
125#[expect(clippy::too_many_lines, reason = "Needs refactoring")]
126fn derive_get_size_impl(input: TokenStream) -> syn::Result<TokenStream> {
127    // Construct a representation of Rust code as a syntax tree that we can manipulate.
128    let ast: syn::DeriveInput = syn::parse(input)?;
129
130    // The name of the struct.
131    let name = &ast.ident;
132
133    // Extract all generics we shall ignore.
134    // let ignored = extract_ignored_generics_list(&ast.attrs);
135    let ignored = collect_all_ignored_generics(&ast);
136
137    // Add a bound `T: GetSize` to every type parameter T.
138    let generics = add_trait_bounds(ast.generics, &ignored);
139
140    // Extract the generics of the struct/enum.
141    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
142
143    // Traverse the parsed data to generate the individual parts of the function.
144    match ast.data {
145        syn::Data::Enum(data_enum) => {
146            if data_enum.variants.is_empty() {
147                // Empty enums are easy to implement.
148                let generated = quote! {
149                    impl ::get_size2::GetSize for #name {}
150                };
151                return Ok(generated.into());
152            }
153
154            let mut cmds = Vec::with_capacity(data_enum.variants.len());
155
156            for variant in data_enum.variants {
157                let ident = &variant.ident;
158
159                match &variant.fields {
160                    syn::Fields::Unnamed(unnamed_fields) => {
161                        let num_fields = unnamed_fields.unnamed.len();
162
163                        let mut field_idents = Vec::with_capacity(num_fields);
164                        for i in 0..num_fields {
165                            field_idents.push(format_ident!("v{i}"));
166                        }
167
168                        let mut field_cmds = Vec::with_capacity(num_fields);
169
170                        for (i, _field) in unnamed_fields.unnamed.iter().enumerate() {
171                            let field_ident = format_ident!("v{i}");
172
173                            field_cmds.push(quote! {
174                                    let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
175                                    total += total_add;
176                                });
177                        }
178
179                        cmds.push(quote! {
180                            Self::#ident(#(#field_idents,)*) => {
181                                let mut total = 0;
182
183                                #(#field_cmds)*;
184
185                                (total, tracker)
186                            }
187                        });
188                    }
189                    syn::Fields::Named(named_fields) => {
190                        let mut field_idents = Vec::new();
191                        let mut field_cmds = Vec::new();
192                        let mut skipped_field = false;
193
194                        for field in &named_fields.named {
195                            let field_ident = field.ident.as_ref().ok_or_else(|| {
196                                syn::Error::new_spanned(field, "Expected named field")
197                            })?;
198
199                            let attr = StructFieldAttribute::from_attributes(&field.attrs)
200                                .map_err(|err| syn::Error::new_spanned(field, err.to_string()))?;
201
202                            if attr.ignore {
203                                skipped_field = true;
204                                continue;
205                            }
206
207                            field_idents.push(field_ident);
208
209                            field_cmds.push(quote! {
210                                let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
211                                total += total_add;
212                            });
213                        }
214
215                        let pattern = if skipped_field {
216                            quote! { Self::#ident { #(#field_idents,)* .. } }
217                        } else {
218                            quote! { Self::#ident { #(#field_idents,)* } }
219                        };
220
221                        cmds.push(quote! {
222                            #pattern => {
223                                let mut total = 0;
224                                #(#field_cmds)*
225                                (total, tracker)
226                            }
227                        });
228                    }
229
230                    syn::Fields::Unit => {
231                        cmds.push(quote! {
232                            Self::#ident => (0, tracker),
233                        });
234                    }
235                }
236            }
237
238            // Build the trait implementation
239            let generated = quote! {
240                impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
241                    fn get_heap_size(&self) -> usize {
242                        let tracker = get_size2::StandardTracker::default();
243
244                        let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
245
246                        total
247                    }
248
249                    fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
250                        &self,
251                        tracker: TRACKER,
252                    ) -> (usize, TRACKER) {
253                        match self {
254                            #(#cmds)*
255                        }
256                    }
257                }
258            };
259            Ok(generated.into())
260        }
261        syn::Data::Union(_data_union) => Err(syn::Error::new_spanned(
262            name,
263            "Deriving GetSize for unions is currently not supported.",
264        )),
265        syn::Data::Struct(data_struct) => {
266            if data_struct.fields.is_empty() {
267                // Empty structs are easy to implement.
268                let generated = quote! {
269                    impl ::get_size2::GetSize for #name {}
270                };
271                return Ok(generated.into());
272            }
273
274            let mut cmds = Vec::with_capacity(data_struct.fields.len());
275
276            let mut unidentified_fields_count = 0; // For newtypes
277
278            for field in &data_struct.fields {
279                // Parse all relevant attributes.
280                let attr = StructFieldAttribute::from_attributes(&field.attrs)
281                    .map_err(|err| syn::Error::new_spanned(field, err.to_string()))?;
282
283                // NOTE There will be no attributes if this is a tuple struct.
284                if let Some(size) = attr.size {
285                    cmds.push(quote! {
286                        total += #size;
287                    });
288
289                    continue;
290                } else if let Some(size_fn) = attr.size_fn {
291                    let ident = field.ident.as_ref().ok_or_else(|| {
292                        syn::Error::new_spanned(
293                            field,
294                            "get_size(size_fn = ...) is only supported on named fields",
295                        )
296                    })?;
297
298                    cmds.push(quote! {
299                        total += #size_fn(&self.#ident);
300                    });
301
302                    continue;
303                } else if attr.ignore {
304                    continue;
305                }
306
307                if let Some(ident) = field.ident.as_ref() {
308                    cmds.push(quote! {
309                        let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#ident, tracker);
310                        total += total_add;
311                    });
312                } else {
313                    let current_index = syn::Index::from(unidentified_fields_count);
314                    cmds.push(quote! {
315                        let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#current_index, tracker);
316                        total += total_add;
317                    });
318
319                    unidentified_fields_count += 1;
320                }
321            }
322
323            // Build the trait implementation
324            let generated = quote! {
325                impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
326                    fn get_heap_size(&self) -> usize {
327                        let tracker = get_size2::StandardTracker::default();
328
329                        let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
330
331                        total
332                    }
333
334                    fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
335                        &self,
336                        tracker: TRACKER,
337                    ) -> (usize, TRACKER) {
338                        let mut total = 0;
339
340                        #(#cmds)*;
341
342                        (total, tracker)
343                    }
344                }
345            };
346            Ok(generated.into())
347        }
348    }
349}