1#![doc = include_str!("./lib.md")]
2
3use attribute_derive::{Attribute, FromAttr};
4use proc_macro::TokenStream;
5use quote::quote;
6
7#[derive(FromAttr, Default, Debug)]
8#[attribute(ident = get_size)]
9struct StructFieldAttribute {
10 #[attribute(conflicts = [size_fn, ignore])]
11 size: Option<usize>,
12 #[attribute(conflicts = [size, ignore])]
13 size_fn: Option<syn::Ident>,
14 #[attribute(conflicts = [size, size_fn])]
15 ignore: bool,
16}
17
18fn extract_ignored_generics_list(list: &Vec<syn::Attribute>) -> Vec<syn::PathSegment> {
19 let mut collection = Vec::new();
20
21 for attr in list {
22 let mut list = extract_ignored_generics(attr);
23
24 collection.append(&mut list);
25 }
26
27 collection
28}
29
30fn extract_ignored_generics(attr: &syn::Attribute) -> Vec<syn::PathSegment> {
31 let mut collection = Vec::new();
32
33 if !attr.meta.path().is_ident("get_size") {
35 return collection;
36 }
37
38 let Ok(list) = attr.meta.require_list() else {
40 return collection;
41 };
42
43 let _ = list.parse_nested_meta(|meta| {
45 if !meta.path.is_ident("ignore") {
47 return Ok(()); }
49
50 if meta.input.is_empty() {
52 return Ok(());
54 }
55
56 meta.parse_nested_meta(|meta| {
58 for segment in meta.path.segments {
59 collection.push(segment);
60 }
61 Ok(())
62 })?;
63
64 Ok(())
65 });
66
67 collection
68}
69
70fn collect_all_ignored_generics(ast: &syn::DeriveInput) -> Vec<syn::PathSegment> {
71 let mut ignored = extract_ignored_generics_list(&ast.attrs);
72
73 match &ast.data {
74 syn::Data::Struct(data_struct) => {
75 for field in &data_struct.fields {
76 ignored.extend(extract_ignored_generics_list(&field.attrs));
77 }
78 }
79 syn::Data::Enum(data_enum) => {
80 for variant in &data_enum.variants {
81 ignored.extend(extract_ignored_generics_list(&variant.attrs));
82 for field in &variant.fields {
83 ignored.extend(extract_ignored_generics_list(&field.attrs));
84 }
85 }
86 }
87 syn::Data::Union(_) => {}
88 }
89
90 ignored
91}
92
93fn add_trait_bounds(mut generics: syn::Generics, ignored: &Vec<syn::PathSegment>) -> syn::Generics {
95 for param in &mut generics.params {
96 if let syn::GenericParam::Type(type_param) = param {
97 let mut found = false;
98 for ignored in ignored {
99 if ignored.ident == type_param.ident {
100 found = true;
101 break;
102 }
103 }
104
105 if found {
106 continue;
107 }
108
109 type_param
110 .bounds
111 .push(syn::parse_quote!(::get_size2::GetSize));
112 }
113 }
114 generics
115}
116
117#[expect(
118 clippy::too_many_lines,
119 clippy::missing_panics_doc,
120 reason = "Needs refactoring"
121)]
122#[proc_macro_derive(GetSize, attributes(get_size))]
123pub fn derive_get_size(input: TokenStream) -> TokenStream {
124 let ast: syn::DeriveInput = syn::parse(input).expect("Could not parse tokens");
127
128 let name = &ast.ident;
130
131 let ignored = collect_all_ignored_generics(&ast);
134
135 let generics = add_trait_bounds(ast.generics, &ignored);
137
138 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
140
141 match ast.data {
143 syn::Data::Enum(data_enum) => {
144 if data_enum.variants.is_empty() {
145 let generated = quote! {
147 impl ::get_size2::GetSize for #name {}
148 };
149 return generated.into();
150 }
151
152 let mut cmds = Vec::with_capacity(data_enum.variants.len());
153
154 for variant in data_enum.variants {
155 let ident = &variant.ident;
156
157 match &variant.fields {
158 syn::Fields::Unnamed(unnamed_fields) => {
159 let num_fields = unnamed_fields.unnamed.len();
160
161 let mut field_idents = Vec::with_capacity(num_fields);
162 for i in 0..num_fields {
163 let field_ident = String::from("v") + &i.to_string();
164 let field_ident = syn::parse_str::<syn::Ident>(&field_ident)
165 .expect("Could not parse string to ident.");
166
167 field_idents.push(field_ident);
168 }
169
170 let mut field_cmds = Vec::with_capacity(num_fields);
171
172 for (i, _field) in unnamed_fields.unnamed.iter().enumerate() {
173 let field_ident = String::from("v") + &i.to_string();
174 let field_ident = syn::parse_str::<syn::Ident>(&field_ident)
175 .expect("Could not parse string to ident.");
176
177 field_cmds.push(quote! {
178 let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
179 total += total_add;
180 });
181 }
182
183 cmds.push(quote! {
184 Self::#ident(#(#field_idents,)*) => {
185 let mut total = 0;
186
187 #(#field_cmds)*;
188
189 (total, tracker)
190 }
191 });
192 }
193 syn::Fields::Named(named_fields) => {
194 let mut field_idents = Vec::new();
195 let mut field_cmds = Vec::new();
196 let mut skipped_field = false;
197
198 for field in &named_fields.named {
199 let field_ident =
200 field.ident.as_ref().expect("Could not get field ident.");
201
202 let attr = StructFieldAttribute::from_attributes(&field.attrs)
203 .expect("Could not parse field attributes.");
204
205 if attr.ignore {
206 skipped_field = true;
207 continue;
208 }
209
210 field_idents.push(field_ident);
211
212 field_cmds.push(quote! {
213 let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
214 total += total_add;
215 });
216 }
217
218 let pattern = if skipped_field {
219 quote! { Self::#ident { #(#field_idents,)* .. } }
220 } else {
221 quote! { Self::#ident { #(#field_idents,)* } }
222 };
223
224 cmds.push(quote! {
225 #pattern => {
226 let mut total = 0;
227 #(#field_cmds)*
228 (total, tracker)
229 }
230 });
231 }
232
233 syn::Fields::Unit => {
234 cmds.push(quote! {
235 Self::#ident => (0, tracker),
236 });
237 }
238 }
239 }
240
241 let generated = quote! {
243 impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
244 fn get_heap_size(&self) -> usize {
245 let tracker = get_size2::StandardTracker::default();
246
247 let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
248
249 total
250 }
251
252 fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
253 &self,
254 tracker: TRACKER,
255 ) -> (usize, TRACKER) {
256 match self {
257 #(#cmds)*
258 }
259 }
260 }
261 };
262 generated.into()
263 }
264 syn::Data::Union(_data_union) => {
265 panic!("Deriving GetSize for unions is currently not supported.")
266 }
267 syn::Data::Struct(data_struct) => {
268 if data_struct.fields.is_empty() {
269 let generated = quote! {
271 impl ::get_size2::GetSize for #name {}
272 };
273 return generated.into();
274 }
275
276 let mut cmds = Vec::with_capacity(data_struct.fields.len());
277
278 let mut unidentified_fields_count = 0; for field in &data_struct.fields {
281 let attr = StructFieldAttribute::from_attributes(&field.attrs)
283 .expect("Could not parse attributes.");
284
285 if let Some(size) = attr.size {
287 cmds.push(quote! {
288 total += #size;
289 });
290
291 continue;
292 } else if let Some(size_fn) = attr.size_fn {
293 let ident = field.ident.as_ref().expect("Could not get field ident.");
294
295 cmds.push(quote! {
296 total += #size_fn(&self.#ident);
297 });
298
299 continue;
300 } else if attr.ignore {
301 continue;
302 }
303
304 if let Some(ident) = field.ident.as_ref() {
305 cmds.push(quote! {
306 let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#ident, tracker);
307 total += total_add;
308 });
309 } else {
310 let current_index = syn::Index::from(unidentified_fields_count);
311 cmds.push(quote! {
312 let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#current_index, tracker);
313 total += total_add;
314 });
315
316 unidentified_fields_count += 1;
317 }
318 }
319
320 let generated = quote! {
322 impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
323 fn get_heap_size(&self) -> usize {
324 let tracker = get_size2::StandardTracker::default();
325
326 let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
327
328 total
329 }
330
331 fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
332 &self,
333 tracker: TRACKER,
334 ) -> (usize, TRACKER) {
335 let mut total = 0;
336
337 #(#cmds)*;
338
339 (total, tracker)
340 }
341 }
342 };
343 generated.into()
344 }
345 }
346}