#![doc = include_str!("./lib.md")]
use attribute_derive::{Attribute, FromAttr};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
#[derive(FromAttr, Default, Debug)]
#[attribute(ident = get_size)]
struct StructFieldAttribute {
#[attribute(conflicts = [size_fn, ignore])]
size: Option<usize>,
#[attribute(conflicts = [size, ignore])]
size_fn: Option<syn::Ident>,
#[attribute(conflicts = [size, size_fn])]
ignore: bool,
}
fn extract_ignored_generics_list(list: &Vec<syn::Attribute>) -> Vec<syn::PathSegment> {
let mut collection = Vec::new();
for attr in list {
let mut list = extract_ignored_generics(attr);
collection.append(&mut list);
}
collection
}
fn extract_ignored_generics(attr: &syn::Attribute) -> Vec<syn::PathSegment> {
let mut collection = Vec::new();
if !attr.meta.path().is_ident("get_size") {
return collection;
}
let Ok(list) = attr.meta.require_list() else {
return collection;
};
let _ = list.parse_nested_meta(|meta| {
if !meta.path.is_ident("ignore") {
return Ok(()); }
if meta.input.is_empty() {
return Ok(());
}
meta.parse_nested_meta(|meta| {
for segment in meta.path.segments {
collection.push(segment);
}
Ok(())
})?;
Ok(())
});
collection
}
fn collect_all_ignored_generics(ast: &syn::DeriveInput) -> Vec<syn::PathSegment> {
let mut ignored = extract_ignored_generics_list(&ast.attrs);
match &ast.data {
syn::Data::Struct(data_struct) => {
for field in &data_struct.fields {
ignored.extend(extract_ignored_generics_list(&field.attrs));
}
}
syn::Data::Enum(data_enum) => {
for variant in &data_enum.variants {
ignored.extend(extract_ignored_generics_list(&variant.attrs));
for field in &variant.fields {
ignored.extend(extract_ignored_generics_list(&field.attrs));
}
}
}
syn::Data::Union(_) => {}
}
ignored
}
fn add_trait_bounds(mut generics: syn::Generics, ignored: &Vec<syn::PathSegment>) -> syn::Generics {
for param in &mut generics.params {
if let syn::GenericParam::Type(type_param) = param {
let mut found = false;
for ignored in ignored {
if ignored.ident == type_param.ident {
found = true;
break;
}
}
if found {
continue;
}
type_param
.bounds
.push(syn::parse_quote!(::get_size2::GetSize));
}
}
generics
}
#[proc_macro_derive(GetSize, attributes(get_size))]
pub fn derive_get_size(input: TokenStream) -> TokenStream {
match derive_get_size_impl(input) {
Ok(tokens) => tokens,
Err(err) => err.to_compile_error().into(),
}
}
#[expect(clippy::too_many_lines, reason = "Needs refactoring")]
fn derive_get_size_impl(input: TokenStream) -> syn::Result<TokenStream> {
let ast: syn::DeriveInput = syn::parse(input)?;
let name = &ast.ident;
let ignored = collect_all_ignored_generics(&ast);
let generics = add_trait_bounds(ast.generics, &ignored);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
match ast.data {
syn::Data::Enum(data_enum) => {
if data_enum.variants.is_empty() {
let generated = quote! {
impl ::get_size2::GetSize for #name {}
};
return Ok(generated.into());
}
let mut cmds = Vec::with_capacity(data_enum.variants.len());
for variant in data_enum.variants {
let ident = &variant.ident;
match &variant.fields {
syn::Fields::Unnamed(unnamed_fields) => {
let num_fields = unnamed_fields.unnamed.len();
let mut field_idents = Vec::with_capacity(num_fields);
for i in 0..num_fields {
field_idents.push(format_ident!("v{i}"));
}
let mut field_cmds = Vec::with_capacity(num_fields);
for (i, _field) in unnamed_fields.unnamed.iter().enumerate() {
let field_ident = format_ident!("v{i}");
field_cmds.push(quote! {
let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
total += total_add;
});
}
cmds.push(quote! {
Self::#ident(#(#field_idents,)*) => {
let mut total = 0;
#(#field_cmds)*;
(total, tracker)
}
});
}
syn::Fields::Named(named_fields) => {
let mut field_idents = Vec::new();
let mut field_cmds = Vec::new();
let mut skipped_field = false;
for field in &named_fields.named {
let field_ident = field.ident.as_ref().ok_or_else(|| {
syn::Error::new_spanned(field, "Expected named field")
})?;
let attr = StructFieldAttribute::from_attributes(&field.attrs)
.map_err(|err| syn::Error::new_spanned(field, err.to_string()))?;
if attr.ignore {
skipped_field = true;
continue;
}
field_idents.push(field_ident);
field_cmds.push(quote! {
let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(#field_ident, tracker);
total += total_add;
});
}
let pattern = if skipped_field {
quote! { Self::#ident { #(#field_idents,)* .. } }
} else {
quote! { Self::#ident { #(#field_idents,)* } }
};
cmds.push(quote! {
#pattern => {
let mut total = 0;
#(#field_cmds)*
(total, tracker)
}
});
}
syn::Fields::Unit => {
cmds.push(quote! {
Self::#ident => (0, tracker),
});
}
}
}
let generated = quote! {
impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
fn get_heap_size(&self) -> usize {
let tracker = get_size2::StandardTracker::default();
let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
total
}
fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
&self,
tracker: TRACKER,
) -> (usize, TRACKER) {
match self {
#(#cmds)*
}
}
}
};
Ok(generated.into())
}
syn::Data::Union(_data_union) => Err(syn::Error::new_spanned(
name,
"Deriving GetSize for unions is currently not supported.",
)),
syn::Data::Struct(data_struct) => {
if data_struct.fields.is_empty() {
let generated = quote! {
impl ::get_size2::GetSize for #name {}
};
return Ok(generated.into());
}
let mut cmds = Vec::with_capacity(data_struct.fields.len());
let mut unidentified_fields_count = 0;
for field in &data_struct.fields {
let attr = StructFieldAttribute::from_attributes(&field.attrs)
.map_err(|err| syn::Error::new_spanned(field, err.to_string()))?;
if let Some(size) = attr.size {
cmds.push(quote! {
total += #size;
});
continue;
} else if let Some(size_fn) = attr.size_fn {
let ident = field.ident.as_ref().ok_or_else(|| {
syn::Error::new_spanned(
field,
"get_size(size_fn = ...) is only supported on named fields",
)
})?;
cmds.push(quote! {
total += #size_fn(&self.#ident);
});
continue;
} else if attr.ignore {
continue;
}
if let Some(ident) = field.ident.as_ref() {
cmds.push(quote! {
let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#ident, tracker);
total += total_add;
});
} else {
let current_index = syn::Index::from(unidentified_fields_count);
cmds.push(quote! {
let (total_add, tracker) = ::get_size2::GetSize::get_heap_size_with_tracker(&self.#current_index, tracker);
total += total_add;
});
unidentified_fields_count += 1;
}
}
let generated = quote! {
impl #impl_generics ::get_size2::GetSize for #name #ty_generics #where_clause {
fn get_heap_size(&self) -> usize {
let tracker = get_size2::StandardTracker::default();
let (total, _) = ::get_size2::GetSize::get_heap_size_with_tracker(self, tracker);
total
}
fn get_heap_size_with_tracker<TRACKER: ::get_size2::GetSizeTracker>(
&self,
tracker: TRACKER,
) -> (usize, TRACKER) {
let mut total = 0;
#(#cmds)*;
(total, tracker)
}
}
};
Ok(generated.into())
}
}
}