1use proc_macro::TokenStream;
4use quote::quote;
5use std::collections::HashSet;
6use syn::{
7 parse, parse_macro_input, AngleBracketedGenericArguments, AttrStyle, Attribute, Binding,
8 DataEnum, DataStruct, DeriveInput, Generics, Ident, Index, ParenthesizedGenericArguments, Path,
9 PathArguments, ReturnType, TraitBound, Type, TypeArray, TypeBareFn, TypeGroup, TypeImplTrait,
10 TypeParam, TypeParamBound, TypeParen, TypePath, TypePtr, TypeReference, TypeSlice,
11 TypeTraitObject, TypeTuple, WhereClause,
12};
13
14#[proc_macro_derive(DataSize, attributes(data_size))]
19pub fn derive_data_size(input: TokenStream) -> TokenStream {
20 let input = parse_macro_input!(input as DeriveInput);
21 let input = remove_default_generic_values(input);
22
23 match input.data {
24 syn::Data::Struct(ds) => derive_for_struct(input.ident, input.generics, ds),
25 syn::Data::Enum(de) => derive_for_enum(input.ident, input.generics, de),
26 syn::Data::Union(_) => panic!("unions not supported"),
27 }
28}
29
30fn remove_default_generic_values(mut input: DeriveInput) -> DeriveInput {
31 for param in input.generics.params.iter_mut() {
32 if let syn::GenericParam::Type(ty) = param {
33 ty.eq_token = None;
34 ty.default = None;
35 }
36 }
37
38 input
39}
40
41fn contains_generic(generics: &Generics, ty: &Type) -> bool {
47 match ty {
48 Type::Array(TypeArray { elem, .. }) => contains_generic(generics, elem),
49 Type::BareFn(TypeBareFn { inputs, output, .. }) => {
50 for arg in inputs {
51 if contains_generic(generics, &arg.ty) {
52 return true;
53 }
54 }
55
56 match output {
57 ReturnType::Default => false,
58 ReturnType::Type(_, ty) => contains_generic(generics, ty),
59 }
60 }
61 Type::Group(TypeGroup { elem, .. }) => contains_generic(generics, elem),
62 Type::ImplTrait(TypeImplTrait { bounds, .. }) => bounds
63 .iter()
64 .any(|b| param_bound_contains_generic(generics, b)),
65 Type::Infer(_) => false,
66 Type::Macro(_) => true,
67 Type::Never(_) => false,
68 Type::Paren(TypeParen { elem, .. }) => contains_generic(generics, elem),
69 Type::Path(TypePath { path, .. }) => path_contains_generic(generics, path),
70 Type::Ptr(TypePtr { elem, .. }) => contains_generic(generics, elem),
71 Type::Reference(TypeReference { elem, .. }) => contains_generic(generics, elem),
72 Type::Slice(TypeSlice { elem, .. }) => contains_generic(generics, elem),
73 Type::TraitObject(TypeTraitObject { bounds, .. }) => bounds
74 .iter()
75 .any(|b| param_bound_contains_generic(generics, b)),
76 Type::Tuple(TypeTuple { elems, .. }) => {
77 elems.iter().any(|ty| contains_generic(generics, ty))
78 }
79 Type::Verbatim(_) => true,
81 _ => true,
83 }
84}
85
86fn path_contains_generic(generics: &Generics, path: &Path) -> bool {
90 let mut candidates = HashSet::new();
91
92 for segment in &path.segments {
93 candidates.insert(segment.ident.clone());
94
95 match &segment.arguments {
96 PathArguments::None => {}
97 PathArguments::AngleBracketed(AngleBracketedGenericArguments { ref args, .. }) => {
98 for arg in args {
99 match arg {
100 syn::GenericArgument::Lifetime(_) => {
101 }
103 syn::GenericArgument::Type(ty) => {
104 if contains_generic(generics, ty) {
106 return true;
107 }
108 }
109 syn::GenericArgument::Binding(Binding {
110 ty,
112 ..
113 }) => {
114 if contains_generic(generics, ty) {
116 return true;
117 }
118 }
119 syn::GenericArgument::Constraint(_) => {
120 }
122 syn::GenericArgument::Const(_) => {
123 }
125 }
126 }
127 }
128 syn::PathArguments::Parenthesized(ParenthesizedGenericArguments {
129 inputs,
130 output,
131 ..
132 }) => {
133 if inputs.iter().any(|ty| contains_generic(generics, ty)) {
134 return true;
135 }
136
137 match output {
138 ReturnType::Default => {}
139 ReturnType::Type(_, ref ty) => {
140 if contains_generic(generics, ty) {
141 return true;
142 }
143 }
144 }
145 }
146 }
147 }
148
149 let generic_idents: HashSet<_> = generics
150 .params
151 .iter()
152 .filter_map(|gen| match gen {
153 syn::GenericParam::Type(TypeParam { ident, .. }) => Some(ident.clone()),
154 syn::GenericParam::Lifetime(_) => None,
155 syn::GenericParam::Const(_) => None,
156 })
157 .collect();
158
159 candidates.intersection(&generic_idents).next().is_some()
161}
162
163fn param_bound_contains_generic(generics: &Generics, bound: &TypeParamBound) -> bool {
167 match bound {
168 syn::TypeParamBound::Trait(TraitBound { path, .. }) => {
169 path_contains_generic(generics, path)
170 }
171 syn::TypeParamBound::Lifetime(_) => false,
172 }
173}
174
175#[derive(Debug)]
176enum DataAttribute {
178 Skip,
180 With(syn::Path),
182}
183
184impl parse::Parse for DataAttribute {
185 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
186 let ident = input.parse::<Ident>().expect("IDENT??").to_string();
187
188 match ident.as_str() {
189 "skip" => Ok(DataAttribute::Skip),
190 "with" => {
191 let punct: proc_macro2::Punct = input.parse().expect("PUNCT??");
192 if punct.as_char() != '=' {
193 return Err(syn::parse::Error::new(
194 input.span(),
195 "expected `=` after `with`",
196 ));
197 }
198
199 let path: syn::Path = input.parse()?;
200 Ok(DataAttribute::With(path))
201 }
202 kw => panic!("unsupported attribute keyword: {}", kw),
203 }
204 }
205}
206
207#[derive(Debug)]
209struct DataSizeAttributes {
210 pub skip: bool,
212 pub with: Option<syn::Path>,
214}
215
216impl DataSizeAttributes {
217 fn parse(attrs: &Vec<Attribute>) -> Self {
219 let mut skip = None;
220 let mut with = None;
221
222 for attr in attrs {
223 if attr.style != AttrStyle::Outer {
224 continue;
226 }
227
228 if attr.path.segments.len() != 1 || attr.path.segments[0].ident != "data_size" {
230 continue;
231 }
232
233 let parsed: DataAttribute = attr
234 .parse_args()
235 .expect("could not parse datasize attribute");
236
237 match parsed {
238 DataAttribute::Skip => {
239 if skip.is_some() {
240 panic!("duplicated `skip` attribute");
241 } else {
242 skip = Some(true);
243 }
244 }
245 DataAttribute::With(fragment) => {
246 if with.is_some() {
247 panic!("duplicated `with` attribute");
248 } else {
249 with = Some(fragment)
250 }
251 }
252 }
253 }
254
255 DataSizeAttributes {
256 skip: skip.unwrap_or(false),
257 with,
258 }
259 }
260}
261
262fn derive_for_struct(name: Ident, generics: Generics, ds: DataStruct) -> TokenStream {
264 let fields = ds.fields;
265
266 let mut where_clauses = proc_macro2::TokenStream::new();
267 let mut is_dynamic = proc_macro2::TokenStream::new();
268 let mut static_heap_size = proc_macro2::TokenStream::new();
269 let mut dynamic_size = proc_macro2::TokenStream::new();
270 let mut detail_calls = proc_macro2::TokenStream::new();
271
272 let mut has_manual_field = false;
273
274 for (idx, field) in fields.iter().enumerate() {
275 let field_attrs = DataSizeAttributes::parse(&field.attrs);
276 if field_attrs.skip {
277 continue;
278 }
279
280 if field_attrs.with.is_some() {
281 has_manual_field = true;
282 }
283
284 let ty = &field.ty;
285 if field_attrs.with.is_none() && contains_generic(&generics, ty) {
289 if where_clauses.is_empty() {
290 where_clauses.extend(quote!(where));
291 }
292
293 where_clauses.extend(quote!(
294 #ty : datasize::DataSize,
295 ));
296 }
297
298 if !is_dynamic.is_empty() {
299 is_dynamic.extend(quote!(|));
300 }
301
302 if !static_heap_size.is_empty() {
303 static_heap_size.extend(quote!(+));
304 }
305
306 if !dynamic_size.is_empty() {
307 dynamic_size.extend(quote!(+));
308 }
309
310 is_dynamic.extend(quote!(<#ty as datasize::DataSize>));
311 is_dynamic.extend(quote!(::IS_DYNAMIC));
312
313 if field_attrs.with.is_none() {
314 static_heap_size.extend(quote!(<#ty as datasize::DataSize>));
315 static_heap_size.extend(quote!(::STATIC_HEAP_SIZE));
316 } else {
317 static_heap_size.extend(quote!(0));
318 };
319
320 let handle = if let Some(ref ident) = &field.ident {
321 quote!(#ident)
322 } else {
323 let idx = Index::from(idx);
324 quote!(#idx)
325 };
326
327 let name = if let Some(ref ident) = &field.ident {
328 ident.to_string()
329 } else {
330 "idx".to_string()
331 };
332
333 match field_attrs.with {
334 Some(manual) => {
335 dynamic_size.extend(quote!(
336 #manual(&self.#handle)
337 ));
338
339 detail_calls.extend(quote!(
340 members.insert(#name, datasize::MemUsageNode::Size(#manual(&self.#handle)));
341 ));
342 }
343 None => {
344 dynamic_size.extend(quote!(
345 datasize::data_size::<#ty>(&self.#handle)
346 ));
347
348 detail_calls.extend(quote!(
349 members.insert(#name, self.#handle.estimate_detailed_heap_size());
350 ));
351 }
352 }
353 }
354
355 if is_dynamic.is_empty() {
357 is_dynamic.extend(quote!(false));
358 }
359 if static_heap_size.is_empty() {
360 static_heap_size.extend(quote!(0));
361 }
362 if dynamic_size.is_empty() {
363 dynamic_size.extend(quote!(0));
364 }
365
366 if let Some(WhereClause { ref predicates, .. }) = generics.where_clause {
369 where_clauses.extend(quote!(#predicates));
370 }
371
372 let detailed_impl = if cfg!(feature = "detailed") {
373 quote!(
374 fn estimate_detailed_heap_size(&self) -> datasize::MemUsageNode {
375 let mut members = ::std::collections::HashMap::new();
376 #detail_calls
377 datasize::MemUsageNode::Detailed(members)
378 }
379 )
380 } else {
381 quote!()
382 };
383
384 if has_manual_field {
386 is_dynamic = proc_macro2::TokenStream::new();
387 is_dynamic.extend(quote!(true));
388 }
389
390 TokenStream::from(quote! {
391 impl #generics datasize::DataSize for #name #generics #where_clauses {
392 const IS_DYNAMIC: bool = #is_dynamic;
393 const STATIC_HEAP_SIZE: usize = #static_heap_size;
394
395 fn estimate_heap_size(&self) -> usize {
396 #dynamic_size
397 }
398
399 #detailed_impl
400 }
401 })
402}
403
404fn derive_for_enum(name: Ident, generics: Generics, de: DataEnum) -> TokenStream {
406 let mut match_arms = proc_macro2::TokenStream::new();
407 let mut where_types = proc_macro2::TokenStream::new();
408
409 let mut skipped = false;
410 for variant in de.variants.into_iter() {
411 let ds_attrs = DataSizeAttributes::parse(&variant.attrs);
412
413 if ds_attrs.skip {
414 skipped = true;
415 continue;
416 }
417
418 let variant_ident = variant.ident;
419
420 let mut field_match = proc_macro2::TokenStream::new();
421 let mut field_calc = proc_macro2::TokenStream::new();
422
423 match variant.fields {
424 syn::Fields::Named(fields) => {
425 let mut left = proc_macro2::TokenStream::new();
426
427 for field in fields.named.into_iter() {
428 let ident = field.ident.expect("named fields must have idents");
429 let ds_attrs = DataSizeAttributes::parse(&field.attrs);
430
431 if ds_attrs.skip {
432 left.extend(quote!(#ident:_));
433 } else {
434 left.extend(quote!(#ident ,));
435
436 let ty = field.ty;
437 if contains_generic(&generics, &ty) {
438 where_types.extend(quote!(#ty : datasize::DataSize,));
439 }
440 }
441
442 if !ds_attrs.skip {
443 if !field_calc.is_empty() {
444 field_calc.extend(quote!(+));
445 }
446 field_calc.extend(quote!(DataSize::estimate_heap_size(#ident)));
447 }
448 }
449
450 field_match.extend(quote! {
451 {#left}
452 });
453 }
454 syn::Fields::Unnamed(fields) => {
455 let mut left = proc_macro2::TokenStream::new();
456
457 for (idx, field) in fields.unnamed.into_iter().enumerate() {
458 let field_ds_attrs = DataSizeAttributes::parse(&field.attrs);
459
460 let ident = Ident::new(
461 &format!("{}f{}", if field_ds_attrs.skip { "_" } else { "" }, idx),
462 proc_macro2::Span::call_site(),
463 );
464
465 left.extend(quote!(#ident ,));
466
467 if !field_ds_attrs.skip {
468 if !field_calc.is_empty() {
469 field_calc.extend(quote!(+));
470 }
471 field_calc.extend(quote!(DataSize::estimate_heap_size(#ident)));
472
473 let ty = field.ty;
474 where_types.extend(quote!(#ty : datasize::DataSize,));
475 }
476 }
477
478 field_match.extend(quote! {
479 (#left)
480 });
481 }
482 syn::Fields::Unit => {
483 field_calc.extend(quote!(0));
484 }
485 }
486
487 if field_calc.is_empty() {
488 field_calc.extend(quote!(0));
489 }
490
491 match_arms.extend(quote!(
492 #name::#variant_ident #field_match => { #field_calc }
493 ));
494 }
495
496 if skipped {
498 match_arms.extend(quote! {
499 _ => 0,
500 })
501 }
502
503 let mut where_clause = proc_macro2::TokenStream::new();
504 if !where_types.is_empty() {
505 where_clause.extend(quote!(where #where_types));
506 }
507
508 let mut is_dynamic = true;
518 let static_heap_size = 0usize;
519
520 if match_arms.is_empty() {
522 match_arms.extend(quote!(_ => 0));
523 is_dynamic = false;
524 }
525
526 if let Some(WhereClause { ref predicates, .. }) = generics.where_clause {
528 where_clause.extend(quote!(#predicates));
529 }
530
531 TokenStream::from(quote! {
532 impl #generics DataSize for #name #generics #where_clause {
533
534 const IS_DYNAMIC: bool = #is_dynamic;
535 const STATIC_HEAP_SIZE: usize = #static_heap_size;
536
537 #[inline]
538 fn estimate_heap_size(&self) -> usize {
539 match self {
540 #match_arms
541 }
542 }
543 }
544 })
545}