get_size_derive2/
lib.rs

1#![doc = include_str!("./lib.md")]
2
3use attribute_derive::{Attribute, FromAttr};
4use proc_macro::TokenStream;
5use quote::quote;
6
7#[derive(FromAttr, Default, Debug)]
8#[attribute(ident = get_size)]
9struct StructFieldAttribute {
10    #[attribute(conflicts = [size_fn, ignore])]
11    size: Option<usize>,
12    #[attribute(conflicts = [size, ignore])]
13    size_fn: Option<syn::Ident>,
14    #[attribute(conflicts = [size, size_fn])]
15    ignore: bool,
16}
17
18fn extract_ignored_generics_list(list: &Vec<syn::Attribute>) -> Vec<syn::PathSegment> {
19    let mut collection = Vec::new();
20
21    for attr in list {
22        let mut list = extract_ignored_generics(attr);
23
24        collection.append(&mut list);
25    }
26
27    collection
28}
29
30fn extract_ignored_generics(attr: &syn::Attribute) -> Vec<syn::PathSegment> {
31    let mut collection = Vec::new();
32
33    // Skip all attributes which do not belong to us.
34    if !attr.meta.path().is_ident("get_size") {
35        return collection;
36    }
37
38    // Make sure it is a list: #[get_size(...)]
39    let Ok(list) = attr.meta.require_list() else {
40        return collection;
41    };
42
43    // Parse the nested meta: #[get_size(ignore(...))] or #[get_size(ignore)]
44    let _ = list.parse_nested_meta(|meta| {
45        // Only handle `ignore`
46        if !meta.path.is_ident("ignore") {
47            return Ok(()); // Skip unrelated
48        }
49
50        // Handle the flag case: #[get_size(ignore)]
51        if meta.input.is_empty() {
52            // Do nothing – valid empty ignore
53            return Ok(());
54        }
55
56        // Handle the list case: #[get_size(ignore(A, B))]
57        meta.parse_nested_meta(|meta| {
58            for segment in meta.path.segments {
59                collection.push(segment);
60            }
61            Ok(())
62        })?;
63
64        Ok(())
65    });
66
67    collection
68}
69
70fn collect_all_ignored_generics(ast: &syn::DeriveInput) -> Vec<syn::PathSegment> {
71    let mut ignored = extract_ignored_generics_list(&ast.attrs);
72
73    match &ast.data {
74        syn::Data::Struct(data_struct) => {
75            for field in &data_struct.fields {
76                ignored.extend(extract_ignored_generics_list(&field.attrs));
77            }
78        }
79        syn::Data::Enum(data_enum) => {
80            for variant in &data_enum.variants {
81                ignored.extend(extract_ignored_generics_list(&variant.attrs));
82                for field in &variant.fields {
83                    ignored.extend(extract_ignored_generics_list(&field.attrs));
84                }
85            }
86        }
87        syn::Data::Union(_) => {}
88    }
89
90    ignored
91}
92
93// Add a bound `T: GetSize` to every type parameter T, unless we ignore it.
94fn add_trait_bounds(mut generics: syn::Generics, ignored: &Vec<syn::PathSegment>) -> syn::Generics {
95    for param in &mut generics.params {
96        if let syn::GenericParam::Type(type_param) = param {
97            let mut found = false;
98            for ignored in ignored {
99                if ignored.ident == type_param.ident {
100                    found = true;
101                    break;
102                }
103            }
104
105            if found {
106                continue;
107            }
108
109            type_param
110                .bounds
111                .push(syn::parse_quote!(::get_size2::GetSize));
112        }
113    }
114    generics
115}
116
117#[expect(
118    clippy::too_many_lines,
119    clippy::missing_panics_doc,
120    reason = "Needs refactoring"
121)]
122#[proc_macro_derive(GetSize, attributes(get_size))]
123pub fn derive_get_size(input: TokenStream) -> TokenStream {
124    // Construct a representation of Rust code as a syntax tree
125    // that we can manipulate
126    let ast: syn::DeriveInput = syn::parse(input).expect("Could not parse tokens");
127
128    // The name of the struct.
129    let name = &ast.ident;
130
131    // Extract all generics we shall ignore.
132    // let ignored = extract_ignored_generics_list(&ast.attrs);
133    let ignored = collect_all_ignored_generics(&ast);
134
135    // Add a bound `T: GetSize` to every type parameter T.
136    let generics = add_trait_bounds(ast.generics, &ignored);
137
138    // Extract the generics of the struct/enum.
139    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
140
141    // Traverse the parsed data to generate the individual parts of the function.
142    match ast.data {
143        syn::Data::Enum(data_enum) => {
144            if data_enum.variants.is_empty() {
145                // Empty enums are easy to implement.
146                let generated = quote! {
147                    impl ::get_size2::GetSize for #name {}
148                };
149                return generated.into();
150            }
151
152            let mut cmds = Vec::with_capacity(data_enum.variants.len());
153
154            for variant in data_enum.variants {
155                let ident = &variant.ident;
156
157                match &variant.fields {
158                    syn::Fields::Unnamed(unnamed_fields) => {
159                        let num_fields = unnamed_fields.unnamed.len();
160
161                        let mut field_idents = Vec::with_capacity(num_fields);
162                        for i in 0..num_fields {
163                            let field_ident = String::from("v") + &i.to_string();
164                            let field_ident = syn::parse_str::<syn::Ident>(&field_ident)
165                                .expect("Could not parse string to ident.");
166
167                            field_idents.push(field_ident);
168                        }
169
170                        let mut field_cmds = Vec::with_capacity(num_fields);
171
172                        for (i, _field) in unnamed_fields.unnamed.iter().enumerate() {
173                            let field_ident = String::from("v") + &i.to_string();
174                            let field_ident = syn::parse_str::<syn::Ident>(&field_ident)
175                                .expect("Could not parse string to ident.");
176
177                            field_cmds.push(quote! {
178                                let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
179                                total += total_add;
180                            });
181                        }
182
183                        cmds.push(quote! {
184                            Self::#ident(#(#field_idents,)*) => {
185                                let mut total = 0;
186
187                                #(#field_cmds)*;
188
189                                (total, tracker)
190                            }
191                        });
192                    }
193                    syn::Fields::Named(named_fields) => {
194                        let mut field_idents = Vec::new();
195                        let mut field_cmds = Vec::new();
196                        let mut skipped_field = false;
197
198                        for field in &named_fields.named {
199                            let field_ident =
200                                field.ident.as_ref().expect("Could not get field ident.");
201
202                            let attr = StructFieldAttribute::from_attributes(&field.attrs)
203                                .expect("Could not parse field attributes.");
204
205                            if attr.ignore {
206                                skipped_field = true;
207                                continue;
208                            }
209
210                            field_idents.push(field_ident);
211
212                            field_cmds.push(quote! {
213                                let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
214                                total += total_add;
215                            });
216                        }
217
218                        let pattern = if skipped_field {
219                            quote! { Self::#ident { #(#field_idents,)* .. } }
220                        } else {
221                            quote! { Self::#ident { #(#field_idents,)* } }
222                        };
223
224                        cmds.push(quote! {
225                            #pattern => {
226                                let mut total = 0;
227                                #(#field_cmds)*
228                                (total, tracker)
229                            }
230                        });
231                    }
232
233                    syn::Fields::Unit => {
234                        cmds.push(quote! {
235                            Self::#ident => (0, tracker),
236                        });
237                    }
238                }
239            }
240
241            // Build the trait implementation
242            let generated = quote! {
243                impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
244                    fn get_heap_size(&self) -> usize {
245                        let tracker = get_size2::StandardTracker::default();
246
247                        let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
248
249                        total
250                    }
251
252                    fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
253                        &self,
254                        tracker: TRACKER,
255                    ) -> (usize, TRACKER) {
256                        match self {
257                            #(#cmds)*
258                        }
259                    }
260                }
261            };
262            generated.into()
263        }
264        syn::Data::Union(_data_union) => {
265            panic!("Deriving GetSize for unions is currently not supported.")
266        }
267        syn::Data::Struct(data_struct) => {
268            if data_struct.fields.is_empty() {
269                // Empty structs are easy to implement.
270                let generated = quote! {
271                    impl ::get_size2::GetSize for #name {}
272                };
273                return generated.into();
274            }
275
276            let mut cmds = Vec::with_capacity(data_struct.fields.len());
277
278            let mut unidentified_fields_count = 0; // For newtypes
279
280            for field in &data_struct.fields {
281                // Parse all relevant attributes.
282                let attr = StructFieldAttribute::from_attributes(&field.attrs)
283                    .expect("Could not parse attributes.");
284
285                // NOTE There will be no attributes if this is a tuple struct.
286                if let Some(size) = attr.size {
287                    cmds.push(quote! {
288                        total += #size;
289                    });
290
291                    continue;
292                } else if let Some(size_fn) = attr.size_fn {
293                    let ident = field.ident.as_ref().expect("Could not get field ident.");
294
295                    cmds.push(quote! {
296                        total += #size_fn(&self.#ident);
297                    });
298
299                    continue;
300                } else if attr.ignore {
301                    continue;
302                }
303
304                if let Some(ident) = field.ident.as_ref() {
305                    cmds.push(quote! {
306                        let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#ident, tracker);
307                        total += total_add;
308                    });
309                } else {
310                    let current_index = syn::Index::from(unidentified_fields_count);
311                    cmds.push(quote! {
312                        let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#current_index, tracker);
313                        total += total_add;
314                    });
315
316                    unidentified_fields_count += 1;
317                }
318            }
319
320            // Build the trait implementation
321            let generated = quote! {
322                impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
323                    fn get_heap_size(&self) -> usize {
324                        let tracker = get_size2::StandardTracker::default();
325
326                        let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
327
328                        total
329                    }
330
331                    fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
332                        &self,
333                        tracker: TRACKER,
334                    ) -> (usize, TRACKER) {
335                        let mut total = 0;
336
337                        #(#cmds)*;
338
339                        (total, tracker)
340                    }
341                }
342            };
343            generated.into()
344        }
345    }
346}