get_size_derive/
lib.rs

1#![doc = include_str!("./lib.md")]
2
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn;
7use attribute_derive::Attribute;
8
9
10
11#[derive(Attribute, Default, Debug)]
12#[attribute(ident = get_size)]
13struct StructFieldAttribute {
14    #[attribute(conflicts = [size_fn, ignore])]
15    size: Option<usize>,
16    #[attribute(conflicts = [size, ignore])]
17    size_fn: Option<syn::Ident>,
18    #[attribute(conflicts = [size, size_fn])]
19    ignore: bool,
20}
21
22
23
24fn extract_ignored_generics_list(list: &Vec<syn::Attribute>) -> Vec<syn::PathSegment> {
25    let mut collection = Vec::new();
26
27    for attr in list.iter() {
28        let mut list = extract_ignored_generics(attr);
29
30        collection.append(&mut list);
31    }
32
33    collection
34}
35
36fn extract_ignored_generics(attr: &syn::Attribute) -> Vec<syn::PathSegment> {
37    let mut collection = Vec::new();
38
39    // Skip all attributes which do not belong to us.
40    if !attr.meta.path().is_ident("get_size") {
41        return collection;
42    }
43
44    // Make sure it is a list.
45    let list = attr.meta.require_list().unwrap();
46
47    // Parse the nested meta.
48    // #[get_size(ignore(A, B))]
49    list.parse_nested_meta(|meta| {
50        // We only parse the ignore attributes.
51        if !meta.path.is_ident("ignore") {
52            return Ok(()); // Just skip.
53        }
54
55        meta.parse_nested_meta(|meta| {
56            for segment in meta.path.segments {
57                collection.push(segment);
58            }
59
60            Ok(())
61        })?;
62
63        Ok(())
64    }).unwrap();
65
66    collection
67}
68
69// Add a bound `T: GetSize` to every type parameter T, unless we ignore it.
70fn add_trait_bounds(
71    mut generics: syn::Generics,
72    ignored: &Vec<syn::PathSegment>,
73) -> syn::Generics {
74    for param in &mut generics.params {
75        if let syn::GenericParam::Type(type_param) = param {
76            let mut found = false;
77            for ignored in ignored.iter() {
78                if ignored.ident==type_param.ident {
79                    found = true;
80                    break;
81                }
82            }
83
84            if found {
85                continue;
86            }
87
88            type_param.bounds.push(syn::parse_quote!(GetSize));
89        }
90    }
91    generics
92}
93
94
95
96#[proc_macro_derive(GetSize, attributes(get_size))]
97pub fn derive_get_size(input: TokenStream) -> TokenStream {
98    // Construct a representation of Rust code as a syntax tree
99    // that we can manipulate
100    let ast: syn::DeriveInput = syn::parse(input).unwrap();
101
102     // The name of the sruct.
103    let name = &ast.ident;
104
105    // Extract all generics we shall ignore.
106    let ignored = extract_ignored_generics_list(&ast.attrs);
107
108    // Add a bound `T: GetSize` to every type parameter T.
109    let generics = add_trait_bounds(ast.generics, &ignored);
110
111    // Extract the generics of the struct/enum.
112    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
113
114    // Traverse the parsed data to generate the individual parts of the function.
115    match ast.data {
116        syn::Data::Enum(data_enum) => {
117            if data_enum.variants.is_empty() {
118                // Empty enums are easy to implement.
119                let gen = quote! {
120                    impl GetSize for #name {}
121                };
122                return gen.into()
123            }
124
125            let mut cmds = Vec::with_capacity(data_enum.variants.len());
126
127            for variant in data_enum.variants.iter() {
128                let ident = &variant.ident;
129
130                match &variant.fields {
131                    syn::Fields::Unnamed(unnamed_fields) => {
132                        let num_fields = unnamed_fields.unnamed.len();
133
134                        let mut field_idents = Vec::with_capacity(num_fields);
135                        for i in 0..num_fields {
136                            let field_ident = String::from("v")+&i.to_string();
137                            let field_ident = syn::parse_str::<syn::Ident>(&field_ident).unwrap();
138
139                            field_idents.push(field_ident);
140                        }
141
142                        let mut field_cmds = Vec::with_capacity(num_fields);
143
144                        for (i, _field) in unnamed_fields.unnamed.iter().enumerate() {
145                            let field_ident = String::from("v")+&i.to_string();
146                            let field_ident = syn::parse_str::<syn::Ident>(&field_ident).unwrap();
147
148                            field_cmds.push(quote! {
149                                total += GetSize::get_heap_size(#field_ident);
150                            })
151                        }
152
153                        cmds.push(quote! {
154                            Self::#ident(#(#field_idents,)*) => {
155                                let mut total = 0;
156
157                                #(#field_cmds)*;
158
159                                total
160                            }
161                        });
162                    }
163                    syn::Fields::Named(named_fields) => {
164                        let num_fields = named_fields.named.len();
165
166                        let mut field_idents = Vec::with_capacity(num_fields);
167
168                        let mut field_cmds = Vec::with_capacity(num_fields);
169
170                        for field in named_fields.named.iter() {
171                            let field_ident = field.ident.as_ref().unwrap();
172
173                            field_idents.push(field_ident);
174
175                            field_cmds.push(quote! {
176                                total += GetSize::get_heap_size(#field_ident);
177                            })
178                        }
179
180                        cmds.push(quote! {
181                            Self::#ident{#(#field_idents,)*} => {
182                                let mut total = 0;
183
184                                #(#field_cmds)*;
185
186                                total
187                            }
188                        });
189                    }
190                    syn::Fields::Unit => {
191                        cmds.push(quote! {
192                            Self::#ident => 0,
193                        });
194                    }
195                }
196            }
197
198            // Build the trait implementation
199            let gen = quote! {
200                impl #impl_generics GetSize for #name #ty_generics #where_clause {
201                    fn get_heap_size(&self) -> usize {
202                        match self {
203                            #(#cmds)*
204                        }
205                    }
206                }
207            };
208            return gen.into();
209        }
210        syn::Data::Union(_data_union) => panic!("Deriving GetSize for unions is currently not supported."),
211        syn::Data::Struct(data_struct) => {
212            if data_struct.fields.is_empty() {
213                // Empty structs are easy to implement.
214                let gen = quote! {
215                    impl GetSize for #name {}
216                };
217                return gen.into();
218            }
219
220            let mut cmds = Vec::with_capacity(data_struct.fields.len());
221
222            let mut unidentified_fields_count = 0; // For newtypes
223
224            for field in data_struct.fields.iter() {
225
226                // Parse all relevant attributes.
227                let attr = StructFieldAttribute::from_attributes(&field.attrs).unwrap();
228
229                // NOTE There will be no attributes if this is a tuple struct.
230                if let Some(size) = attr.size {
231                    cmds.push(quote! {
232                        total += #size;
233                    });
234
235                    continue;
236                } else if let Some(size_fn) = attr.size_fn {
237                    let ident = field.ident.as_ref().unwrap();
238
239                    cmds.push(quote! {
240                        total += #size_fn(&self.#ident);
241                    });
242
243                    continue;
244                } else if attr.ignore {
245                    continue;
246                }
247
248                if let Some(ident) = field.ident.as_ref() {
249                    cmds.push(quote! {
250                        total += GetSize::get_heap_size(&self.#ident);
251                    });
252                } else {
253                    let current_index = syn::Index::from(unidentified_fields_count);
254                    cmds.push(quote! {
255                        total += GetSize::get_heap_size(&self.#current_index);
256                    });
257
258                    unidentified_fields_count += 1;
259                }
260            }
261
262            // Build the trait implementation
263            let gen = quote! {
264                impl #impl_generics GetSize for #name #ty_generics #where_clause {
265                    fn get_heap_size(&self) -> usize {
266                        let mut total = 0;
267
268                        #(#cmds)*;
269
270                        total
271                    }
272                }
273            };
274            return gen.into();
275        },
276    }
277}