1#![doc = include_str!("../README.md")]
2
3use std::result;
4
5use proc_macro2::{Span, TokenStream};
6use quote::{quote, quote_spanned, ToTokens};
7use syn::{
8 punctuated::Punctuated, spanned::Spanned, Attribute, Data, DataStruct, DeriveInput, Expr,
9 ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Ident, Index, Lit, LitStr, Meta,
10 MetaNameValue, Token, Variant,
11};
12
13const HEAP_IDENT: &str = "heap_size";
15const HEAP_ATTR_WITH_IDENT: &str = "with";
17const HEAP_ATTR_SKIP_IDENT: &str = "skip";
19
20#[proc_macro_derive(HeapSize, attributes(heap_size))]
21pub fn heap(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22 let input: DeriveInput = match syn::parse(input) {
23 Ok(v) => v,
24 Err(e) => return e.into_compile_error().into(),
25 };
26
27 let tokens = match input.data {
28 Data::Struct(..) => render_struct(input),
29 Data::Enum(..) => render_enum(input),
30 Data::Union(..) => Err(syn::Error::new_spanned(
31 input,
32 "`Heap` can not be derived for a union",
33 )),
34 };
35 tokens.unwrap_or_else(syn::Error::into_compile_error).into()
36}
37
38type Result<T> = result::Result<T, syn::Error>;
39macro_rules! bail {
40 ($token:expr, $($arg:tt)+) => {{
41 return Err(syn::Error::new_spanned($token, format!($($arg)*)))
42 }};
43}
44
45enum HeapAttr {
46 Container(Meta),
48 Field,
50 FieldWith(Meta, LitStr),
52 FieldSkip(Meta),
54}
55
56impl HeapAttr {
57 fn new<T: ToTokens>(
58 raw_attrs: &[Attribute],
59 is_field: bool,
60 is_variant: bool,
61 origin: T,
62 ) -> Result<Option<Self>> {
63 let mut attrs = vec![];
64 for attr in raw_attrs {
65 match &attr.meta {
66 Meta::List(meta_list) => {
67 if meta_list.path.is_ident(HEAP_IDENT) {
68 let heap_attrs = meta_list
69 .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
70 if heap_attrs.len() > 1 {
71 bail!(meta_list, "too many heap_size attributes");
72 }
73 attrs.extend(heap_attrs);
74 }
75 }
76 Meta::Path(path) => {
77 if path.is_ident(HEAP_IDENT) {
78 attrs.push(attr.meta.clone());
79 }
80 }
81 Meta::NameValue(_) => (),
82 }
83 }
84 let meta = if attrs.is_empty() {
85 return Ok(None);
86 } else if attrs.len() == 1 {
87 attrs.pop().unwrap()
88 } else {
89 bail!(origin, "too many heap_size attributes")
90 };
91
92 match meta {
93 Meta::Path(ref name) => {
94 if name.is_ident(HEAP_IDENT) {
95 if is_field {
96 Ok(Some(HeapAttr::Field))
97 } else {
98 Ok(Some(HeapAttr::Container(meta)))
99 }
100 } else if name.is_ident(HEAP_ATTR_SKIP_IDENT) {
101 if is_field || is_variant {
102 Ok(Some(HeapAttr::FieldSkip(meta)))
103 } else {
104 bail!(meta, "`#[heap_size(skip)]` is a field attribute")
105 }
106 } else if name.is_ident(HEAP_ATTR_WITH_IDENT) {
107 bail!(
108 meta,
109 "heap_size attribute `with` must be followed by \
110 a module path, `with = \"some::mod\"`"
111 )
112 } else {
113 let name = name.to_token_stream().to_string().replace(' ', "");
114 bail!(meta, "unknown heap_size attribute `{}`", name)
115 }
116 }
117 Meta::NameValue(MetaNameValue {
118 ref path,
119 value:
120 Expr::Lit(ExprLit {
121 lit: Lit::Str(ref mod_path),
122 ..
123 }),
124 ..
125 }) => {
126 if path.is_ident(HEAP_ATTR_WITH_IDENT) {
127 Ok(Some(HeapAttr::FieldWith(meta.clone(), mod_path.clone())))
128 } else {
129 let name = path.to_token_stream().to_string().replace(' ', "");
130 bail!(meta, "unknown heap_size attribute `{}`", name)
131 }
132 }
133 meta => {
134 let full = meta.to_token_stream().to_string();
135 bail!(meta, "unknown heap attribute `{}`", full)
136 }
137 }
138 }
139}
140
141enum MethodReceiver {
142 FieldIdent,
143 Replace(Ident),
144 PrefixRef(Ident),
145}
146
147struct HeapField {
148 attr: HeapAttr,
149 ident: TokenStream,
150 field: Field,
151}
152
153impl HeapField {
154 fn new(
155 index: usize,
156 field: Field,
157 container_attr: Option<&HeapAttr>,
158 variant_attr: Option<&HeapAttr>,
159 ) -> Result<Option<Self>> {
160 let require_container_attr = |meta| {
161 if let Some(HeapAttr::Container(_)) = container_attr {
162 Ok(None)
163 } else {
164 bail!(
165 meta,
166 "`#[heap_size(skip)]` is only allow with a container \
167 attribute `#[heap_size]`."
168 );
169 }
170 };
171 let attr = match HeapAttr::new(&field.attrs, true, false, &field)? {
172 None => {
173 if let Some(HeapAttr::FieldSkip(meta)) = variant_attr {
174 return require_container_attr(meta);
175 } else if let Some(HeapAttr::Container(_)) = container_attr {
176 HeapAttr::Field
177 } else {
178 return Ok(None);
179 }
180 }
181 Some(HeapAttr::FieldSkip(meta)) => return require_container_attr(&meta),
182 Some(attr) => attr,
183 };
184
185 let ident = field.ident.clone().map_or_else(
186 || {
187 let index = Index {
188 index: u32::try_from(index).unwrap(),
189 span: Span::call_site(),
190 };
191 quote!(#index)
192 },
193 |x| quote!(#x),
194 );
195
196 Ok(Some(HeapField { attr, ident, field }))
197 }
198
199 fn method_heap_size(&self, self_: &MethodReceiver) -> Result<TokenStream> {
200 let field_ident = &self.ident;
201 let ident = match self_ {
202 MethodReceiver::FieldIdent => {
203 quote_spanned!(self.field.span()=> #field_ident)
204 }
205 MethodReceiver::Replace(ident) => quote_spanned!(self.field.span()=> #ident),
206 MethodReceiver::PrefixRef(ident) => {
207 quote_spanned!(self.field.span()=> &#ident.#field_ident)
208 }
209 };
210 match self.attr {
211 HeapAttr::Field => Ok(quote_spanned! {self.field.span()=>
212 ::heapsz::HeapSize::heap_size(#ident)
213 }),
214 HeapAttr::FieldWith(ref meta, ref mod_path) => {
215 let path = syn::parse_str::<syn::Path>(&mod_path.value())?;
216 Ok(quote_spanned! {meta.span()=>
217 #path::heap_size(#ident)
218 })
219 }
220 HeapAttr::FieldSkip(_) => {
221 bail!(
222 self.field.clone(),
223 "internal error `#[heap_size(skip)]` field generates `fn heap_size()`",
224 );
225 }
226 HeapAttr::Container(ref meta) => {
227 bail!(
228 self.field.clone(),
229 "internal error unexpected container attribute is found on field: {}",
230 meta.to_token_stream().to_string()
231 );
232 }
233 }
234 }
235}
236
237fn render_struct(input: DeriveInput) -> Result<proc_macro2::TokenStream> {
238 let container_attrs = HeapAttr::new(&input.attrs, false, false, &input)?;
239
240 let ident = input.ident.clone();
241 let Data::Struct(data) = input.data else {
242 bail!(input, "{} should be a struct", ident);
243 };
244 let fields = match data {
245 DataStruct {
246 fields:
247 Fields::Named(FieldsNamed { named: fields, .. })
248 | Fields::Unnamed(FieldsUnnamed {
249 unnamed: fields, ..
250 }),
251 ..
252 } => fields.into_iter().collect(),
253 DataStruct {
254 fields: Fields::Unit,
255 ..
256 } => Vec::new(),
257 };
258
259 let mut heap_sizes = vec![];
260 let self_ = MethodReceiver::PrefixRef(Ident::new("self", Span::call_site()));
261 for (i, field) in fields.into_iter().enumerate() {
262 if let Some(f) = HeapField::new(i, field.clone(), container_attrs.as_ref(), None)? {
263 heap_sizes.push(f.method_heap_size(&self_)?);
264 }
265 }
266
267 let generics = &input.generics;
268 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
269 Ok(quote! {
270 impl #impl_generics ::heapsz::HeapSize for #ident #ty_generics #where_clause {
271 fn heap_size(&self) -> usize {
272 0 #(+ #heap_sizes)*
273 }
274 }
275 })
276}
277
278fn render_enum(input: DeriveInput) -> Result<TokenStream> {
279 let container_attrs = HeapAttr::new(&input.attrs, false, false, &input)?;
280
281 let ident = input.ident.clone();
282 let Data::Enum(data) = input.data else {
283 bail!(input, "{} should be an enum", ident);
284 };
285 let mut rendered_vars = vec![];
286 for var in data.variants {
287 rendered_vars.push(render_enum_variant(var, container_attrs.as_ref())?);
288 }
289 let matches = if rendered_vars.is_empty() {
290 quote!(0)
291 } else {
292 quote! {
293 #[allow(unused_variables)]
294 match self {
295 #(#rendered_vars)*
296 }
297 }
298 };
299
300 let generics = &input.generics;
301 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
302 Ok(quote! {
303 impl #impl_generics ::heapsz::HeapSize for #ident #ty_generics #where_clause {
304 fn heap_size(&self) -> usize {
305 #matches
306 }
307 }
308 })
309}
310
311fn render_enum_variant(var: Variant, container_attr: Option<&HeapAttr>) -> Result<TokenStream> {
312 let var_attrs = HeapAttr::new(&var.attrs, false, true, &var)?;
313 let var_span = var.span();
314 let var_ident = var.ident;
315 let (match_arm, self_receivers, fields) = match var.fields {
316 Fields::Named(FieldsNamed { named: fields, .. }) => {
317 let idents = fields.iter().map(|f| f.ident.clone().unwrap());
318 let match_arm = quote_spanned! {var_span=>
319 Self::#var_ident { #(#idents,)* }
320 };
321 let self_receivers = fields
322 .iter()
323 .map(|_| MethodReceiver::FieldIdent)
324 .collect::<Vec<_>>();
325 (
326 match_arm,
327 self_receivers,
328 fields.into_iter().collect::<Vec<_>>(),
329 )
330 }
331 Fields::Unnamed(FieldsUnnamed {
332 unnamed: fields, ..
333 }) => {
334 let field_idents = fields
335 .iter()
336 .enumerate()
337 .map(|(i, f)| Ident::new(&format!("f_{i}"), f.span()))
338 .collect::<Vec<_>>();
339 let self_receivers = field_idents
340 .iter()
341 .map(|ident| MethodReceiver::Replace(ident.clone()))
342 .collect::<Vec<_>>();
343 let match_arm = quote_spanned! {var_span=>
344 Self::#var_ident(#(#field_idents,)*)
345 };
346 (
347 match_arm,
348 self_receivers,
349 fields.into_iter().collect::<Vec<_>>(),
350 )
351 }
352 Fields::Unit => {
353 let match_arm = quote_spanned! {var_span=>
354 Self::#var_ident
355 };
356 (match_arm, vec![], vec![])
357 }
358 };
359
360 let mut heap_sizes = vec![];
361 for (i, field) in fields.into_iter().enumerate() {
362 if let Some(f) = HeapField::new(i, field.clone(), container_attr, var_attrs.as_ref())? {
363 heap_sizes.push(f.method_heap_size(&self_receivers[i])?);
364 }
365 }
366
367 Ok(quote_spanned! {var_span=>
368 #match_arm => { 0 #(+ #heap_sizes)* }
369 })
370}