memscope_derive/
lib.rs

1//! Procedural macros for memscope-rs memory tracking
2//!
3//! This crate provides the `#[derive(Trackable)]` macro for automatically
4//! implementing the `Trackable` trait for user-defined types.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
9
10/// Derive macro for automatically implementing the `Trackable` trait.
11///
12/// This macro generates implementations for:
13/// - `get_heap_ptr()`: Returns the struct's address for structs with heap allocations
14/// - `get_type_name()`: Returns the type name as a string literal
15/// - `get_size_estimate()`: Calculates total size including internal allocations
16/// - `get_internal_allocations()`: Lists all internal heap allocations
17///
18/// # Examples
19///
20/// ```rust
21/// use memscope_rs::Trackable;
22/// use memscope_derive::Trackable;
23///
24/// #[derive(Trackable)]
25/// struct UserData {
26///     name: String,
27///     scores: Vec<i32>,
28///     metadata: Box<HashMap<String, String>>,
29/// }
30/// ```
31///
32/// The macro handles:
33/// - Structs with named fields
34/// - Tuple structs
35/// - Unit structs
36/// - Enums with data
37/// - Nested types that implement `Trackable`
38#[proc_macro_derive(Trackable)]
39pub fn derive_trackable(input: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(input as DeriveInput);
41    let name = &input.ident;
42    let generics = &input.generics;
43    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
44
45    let expanded = match &input.data {
46        Data::Struct(data_struct) => {
47            let heap_ptr_impl = generate_heap_ptr_impl(&data_struct.fields);
48            let size_estimate_impl = generate_size_estimate_impl(&data_struct.fields);
49            let internal_allocations_impl = generate_internal_allocations_impl(&data_struct.fields);
50
51            quote! {
52                impl #impl_generics memscope_rs::Trackable for #name #ty_generics #where_clause {
53                    fn get_heap_ptr(&self) -> Option<usize> {
54                        #heap_ptr_impl
55                    }
56
57                    fn get_type_name(&self) -> &'static str {
58                        stringify!(#name)
59                    }
60
61                    fn get_size_estimate(&self) -> usize {
62                        #size_estimate_impl
63                    }
64
65                    fn get_internal_allocations(&self, var_name: &str) -> Vec<(usize, String)> {
66                        #internal_allocations_impl
67                    }
68                }
69            }
70        }
71        Data::Enum(data_enum) => {
72            let size_estimate_impl = generate_enum_size_estimate_impl(&data_enum.variants);
73            let internal_allocations_impl =
74                generate_enum_internal_allocations_impl(&data_enum.variants);
75
76            quote! {
77                impl #impl_generics memscope_rs::Trackable for #name #ty_generics #where_clause {
78                    fn get_heap_ptr(&self) -> Option<usize> {
79                        // For enums, use the enum instance address
80                        Some(self as *const _ as usize)
81                    }
82
83                    fn get_type_name(&self) -> &'static str {
84                        stringify!(#name)
85                    }
86
87                    fn get_size_estimate(&self) -> usize {
88                        #size_estimate_impl
89                    }
90
91                    fn get_internal_allocations(&self, var_name: &str) -> Vec<(usize, String)> {
92                        #internal_allocations_impl
93                    }
94                }
95            }
96        }
97        Data::Union(_) => {
98            // Unions are not supported for safety reasons
99            return syn::Error::new_spanned(
100                &input,
101                "Trackable cannot be derived for unions due to safety concerns",
102            )
103            .to_compile_error()
104            .into();
105        }
106    };
107
108    TokenStream::from(expanded)
109}
110
111/// Generate the `get_heap_ptr` implementation for structs
112fn generate_heap_ptr_impl(fields: &Fields) -> proc_macro2::TokenStream {
113    match fields {
114        Fields::Named(_) | Fields::Unnamed(_) => {
115            // Check if any field has heap allocations
116            let has_heap_fields = has_potential_heap_allocations(fields);
117
118            if has_heap_fields {
119                quote! {
120                    // Use the struct's address as the primary identifier
121                    Some(self as *const _ as usize)
122                }
123            } else {
124                quote! {
125                    // No heap allocations detected
126                    None
127                }
128            }
129        }
130        Fields::Unit => {
131            quote! {
132                // Unit structs have no heap allocations
133                None
134            }
135        }
136    }
137}
138
139/// Generate the `get_size_estimate` implementation
140fn generate_size_estimate_impl(fields: &Fields) -> proc_macro2::TokenStream {
141    match fields {
142        Fields::Named(fields_named) => {
143            let field_sizes = fields_named.named.iter().map(|field| {
144                let field_name = &field.ident;
145                quote! {
146                    total_size += memscope_rs::Trackable::get_size_estimate(&self.#field_name);
147                }
148            });
149
150            quote! {
151                let mut total_size = std::mem::size_of::<Self>();
152                #(#field_sizes)*
153                total_size
154            }
155        }
156        Fields::Unnamed(fields_unnamed) => {
157            let field_sizes = fields_unnamed.unnamed.iter().enumerate().map(|(i, _)| {
158                let index = syn::Index::from(i);
159                quote! {
160                    total_size += memscope_rs::Trackable::get_size_estimate(&self.#index);
161                }
162            });
163
164            quote! {
165                let mut total_size = std::mem::size_of::<Self>();
166                #(#field_sizes)*
167                total_size
168            }
169        }
170        Fields::Unit => {
171            quote! {
172                std::mem::size_of::<Self>()
173            }
174        }
175    }
176}
177
178/// Generate the `get_internal_allocations` implementation
179fn generate_internal_allocations_impl(fields: &Fields) -> proc_macro2::TokenStream {
180    match fields {
181        Fields::Named(fields_named) => {
182            let field_allocations = fields_named.named.iter().map(|field| {
183                let field_name = &field.ident;
184                let field_name_str = field_name.as_ref().unwrap().to_string();
185                quote! {
186                    if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(&self.#field_name) {
187                        allocations.push((ptr, format!("{}::{}", var_name, #field_name_str)));
188                    }
189                }
190            });
191
192            quote! {
193                let mut allocations = Vec::new();
194                #(#field_allocations)*
195                allocations
196            }
197        }
198        Fields::Unnamed(fields_unnamed) => {
199            let field_allocations = fields_unnamed.unnamed.iter().enumerate().map(|(i, _)| {
200                let index = syn::Index::from(i);
201                let index_str = i.to_string();
202                quote! {
203                    if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(&self.#index) {
204                        allocations.push((ptr, format!("{}::{}", var_name, #index_str)));
205                    }
206                }
207            });
208
209            quote! {
210                let mut allocations = Vec::new();
211                #(#field_allocations)*
212                allocations
213            }
214        }
215        Fields::Unit => {
216            quote! {
217                Vec::new()
218            }
219        }
220    }
221}
222
223/// Generate size estimate for enums
224fn generate_enum_size_estimate_impl(
225    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
226) -> proc_macro2::TokenStream {
227    let variant_arms = variants.iter().map(|variant| {
228        let variant_name = &variant.ident;
229        match &variant.fields {
230            Fields::Named(fields) => {
231                let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
232                let field_sizes = fields.named.iter().map(|field| {
233                    let field_name = &field.ident;
234                    quote! {
235                        total_size += memscope_rs::Trackable::get_size_estimate(#field_name);
236                    }
237                });
238
239                quote! {
240                    Self::#variant_name { #(#field_names),* } => {
241                        let mut total_size = std::mem::size_of::<Self>();
242                        #(#field_sizes)*
243                        total_size
244                    }
245                }
246            }
247            Fields::Unnamed(fields) => {
248                let field_patterns: Vec<_> = (0..fields.unnamed.len())
249                    .map(|i| {
250                        syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site())
251                    })
252                    .collect();
253                let field_sizes = field_patterns.iter().map(|field_name| {
254                    quote! {
255                        total_size += memscope_rs::Trackable::get_size_estimate(#field_name);
256                    }
257                });
258
259                quote! {
260                    Self::#variant_name(#(#field_patterns),*) => {
261                        let mut total_size = std::mem::size_of::<Self>();
262                        #(#field_sizes)*
263                        total_size
264                    }
265                }
266            }
267            Fields::Unit => {
268                quote! {
269                    Self::#variant_name => std::mem::size_of::<Self>()
270                }
271            }
272        }
273    });
274
275    quote! {
276        match self {
277            #(#variant_arms),*
278        }
279    }
280}
281
282/// Generate internal allocations for enums
283fn generate_enum_internal_allocations_impl(
284    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
285) -> proc_macro2::TokenStream {
286    let variant_arms = variants.iter().map(|variant| {
287        let variant_name = &variant.ident;
288        let variant_name_str = variant_name.to_string();
289        match &variant.fields {
290            Fields::Named(fields) => {
291                let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
292                let field_allocations = fields.named.iter().map(|field| {
293                    let field_name = &field.ident;
294                    let field_name_str = field_name.as_ref().unwrap().to_string();
295                    quote! {
296                        if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(#field_name) {
297                            allocations.push((ptr, format!("{}::{}::{}", var_name, #variant_name_str, #field_name_str)));
298                        }
299                    }
300                });
301                quote! {
302                    Self::#variant_name { #(#field_names),* } => {
303                        let mut allocations = Vec::new();
304                        #(#field_allocations)*
305                        allocations
306                    }
307                }
308            }
309            Fields::Unnamed(fields) => {
310                let field_patterns: Vec<_> = (0..fields.unnamed.len())
311                    .map(|i| syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site()))
312                    .collect();
313                let field_allocations = field_patterns.iter().enumerate().map(|(i, field_name)| {
314                    quote! {
315                        if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(#field_name) {
316                            allocations.push((ptr, format!("{}::{}::{}", var_name, #variant_name_str, #i)));
317                        }
318                    }
319                });
320                quote! {
321                    Self::#variant_name(#(#field_patterns),*) => {
322                        let mut allocations = Vec::new();
323                        #(#field_allocations)*
324                        allocations
325                    }
326                }
327            }
328            Fields::Unit => {
329                quote! {
330                    Self::#variant_name => Vec::new()
331                }
332            }
333        }
334    });
335
336    quote! {
337        match self {
338            #(#variant_arms),*
339        }
340    }
341}
342
343/// Check if fields potentially contain heap allocations
344fn has_potential_heap_allocations(fields: &Fields) -> bool {
345    match fields {
346        Fields::Named(fields_named) => fields_named
347            .named
348            .iter()
349            .any(|field| is_potentially_heap_allocated(&field.ty)),
350        Fields::Unnamed(fields_unnamed) => fields_unnamed
351            .unnamed
352            .iter()
353            .any(|field| is_potentially_heap_allocated(&field.ty)),
354        Fields::Unit => false,
355    }
356}
357
358/// Check if a type is potentially heap-allocated
359fn is_potentially_heap_allocated(ty: &Type) -> bool {
360    match ty {
361        Type::Path(type_path) => {
362            if let Some(segment) = type_path.path.segments.last() {
363                let type_name = segment.ident.to_string();
364                matches!(
365                    type_name.as_str(),
366                    "String"
367                        | "Vec"
368                        | "HashMap"
369                        | "BTreeMap"
370                        | "HashSet"
371                        | "BTreeSet"
372                        | "VecDeque"
373                        | "LinkedList"
374                        | "BinaryHeap"
375                        | "Box"
376                        | "Rc"
377                        | "Arc"
378                )
379            } else {
380                false
381            }
382        }
383        _ => false,
384    }
385}