1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
9
10#[proc_macro_derive(Trackable)]
39pub fn derive_trackable(input: TokenStream) -> TokenStream {
40 let input = parse_macro_input!(input as DeriveInput);
41 let name = &input.ident;
42 let generics = &input.generics;
43 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
44
45 let expanded = match &input.data {
46 Data::Struct(data_struct) => {
47 let heap_ptr_impl = generate_heap_ptr_impl(&data_struct.fields);
48 let size_estimate_impl = generate_size_estimate_impl(&data_struct.fields);
49 let internal_allocations_impl = generate_internal_allocations_impl(&data_struct.fields);
50
51 quote! {
52 impl #impl_generics memscope_rs::Trackable for #name #ty_generics #where_clause {
53 fn get_heap_ptr(&self) -> Option<usize> {
54 #heap_ptr_impl
55 }
56
57 fn get_type_name(&self) -> &'static str {
58 stringify!(#name)
59 }
60
61 fn get_size_estimate(&self) -> usize {
62 #size_estimate_impl
63 }
64
65 fn get_internal_allocations(&self, var_name: &str) -> Vec<(usize, String)> {
66 #internal_allocations_impl
67 }
68 }
69 }
70 }
71 Data::Enum(data_enum) => {
72 let size_estimate_impl = generate_enum_size_estimate_impl(&data_enum.variants);
73 let internal_allocations_impl =
74 generate_enum_internal_allocations_impl(&data_enum.variants);
75
76 quote! {
77 impl #impl_generics memscope_rs::Trackable for #name #ty_generics #where_clause {
78 fn get_heap_ptr(&self) -> Option<usize> {
79 Some(self as *const _ as usize)
81 }
82
83 fn get_type_name(&self) -> &'static str {
84 stringify!(#name)
85 }
86
87 fn get_size_estimate(&self) -> usize {
88 #size_estimate_impl
89 }
90
91 fn get_internal_allocations(&self, var_name: &str) -> Vec<(usize, String)> {
92 #internal_allocations_impl
93 }
94 }
95 }
96 }
97 Data::Union(_) => {
98 return syn::Error::new_spanned(
100 &input,
101 "Trackable cannot be derived for unions due to safety concerns",
102 )
103 .to_compile_error()
104 .into();
105 }
106 };
107
108 TokenStream::from(expanded)
109}
110
111fn generate_heap_ptr_impl(fields: &Fields) -> proc_macro2::TokenStream {
113 match fields {
114 Fields::Named(_) | Fields::Unnamed(_) => {
115 let has_heap_fields = has_potential_heap_allocations(fields);
117
118 if has_heap_fields {
119 quote! {
120 Some(self as *const _ as usize)
122 }
123 } else {
124 quote! {
125 None
127 }
128 }
129 }
130 Fields::Unit => {
131 quote! {
132 None
134 }
135 }
136 }
137}
138
139fn generate_size_estimate_impl(fields: &Fields) -> proc_macro2::TokenStream {
141 match fields {
142 Fields::Named(fields_named) => {
143 let field_sizes = fields_named.named.iter().map(|field| {
144 let field_name = &field.ident;
145 quote! {
146 total_size += memscope_rs::Trackable::get_size_estimate(&self.#field_name);
147 }
148 });
149
150 quote! {
151 let mut total_size = std::mem::size_of::<Self>();
152 #(#field_sizes)*
153 total_size
154 }
155 }
156 Fields::Unnamed(fields_unnamed) => {
157 let field_sizes = fields_unnamed.unnamed.iter().enumerate().map(|(i, _)| {
158 let index = syn::Index::from(i);
159 quote! {
160 total_size += memscope_rs::Trackable::get_size_estimate(&self.#index);
161 }
162 });
163
164 quote! {
165 let mut total_size = std::mem::size_of::<Self>();
166 #(#field_sizes)*
167 total_size
168 }
169 }
170 Fields::Unit => {
171 quote! {
172 std::mem::size_of::<Self>()
173 }
174 }
175 }
176}
177
178fn generate_internal_allocations_impl(fields: &Fields) -> proc_macro2::TokenStream {
180 match fields {
181 Fields::Named(fields_named) => {
182 let field_allocations = fields_named.named.iter().map(|field| {
183 let field_name = &field.ident;
184 let field_name_str = field_name.as_ref().unwrap().to_string();
185 quote! {
186 if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(&self.#field_name) {
187 allocations.push((ptr, format!("{}::{}", var_name, #field_name_str)));
188 }
189 }
190 });
191
192 quote! {
193 let mut allocations = Vec::new();
194 #(#field_allocations)*
195 allocations
196 }
197 }
198 Fields::Unnamed(fields_unnamed) => {
199 let field_allocations = fields_unnamed.unnamed.iter().enumerate().map(|(i, _)| {
200 let index = syn::Index::from(i);
201 let index_str = i.to_string();
202 quote! {
203 if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(&self.#index) {
204 allocations.push((ptr, format!("{}::{}", var_name, #index_str)));
205 }
206 }
207 });
208
209 quote! {
210 let mut allocations = Vec::new();
211 #(#field_allocations)*
212 allocations
213 }
214 }
215 Fields::Unit => {
216 quote! {
217 Vec::new()
218 }
219 }
220 }
221}
222
223fn generate_enum_size_estimate_impl(
225 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
226) -> proc_macro2::TokenStream {
227 let variant_arms = variants.iter().map(|variant| {
228 let variant_name = &variant.ident;
229 match &variant.fields {
230 Fields::Named(fields) => {
231 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
232 let field_sizes = fields.named.iter().map(|field| {
233 let field_name = &field.ident;
234 quote! {
235 total_size += memscope_rs::Trackable::get_size_estimate(#field_name);
236 }
237 });
238
239 quote! {
240 Self::#variant_name { #(#field_names),* } => {
241 let mut total_size = std::mem::size_of::<Self>();
242 #(#field_sizes)*
243 total_size
244 }
245 }
246 }
247 Fields::Unnamed(fields) => {
248 let field_patterns: Vec<_> = (0..fields.unnamed.len())
249 .map(|i| {
250 syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site())
251 })
252 .collect();
253 let field_sizes = field_patterns.iter().map(|field_name| {
254 quote! {
255 total_size += memscope_rs::Trackable::get_size_estimate(#field_name);
256 }
257 });
258
259 quote! {
260 Self::#variant_name(#(#field_patterns),*) => {
261 let mut total_size = std::mem::size_of::<Self>();
262 #(#field_sizes)*
263 total_size
264 }
265 }
266 }
267 Fields::Unit => {
268 quote! {
269 Self::#variant_name => std::mem::size_of::<Self>()
270 }
271 }
272 }
273 });
274
275 quote! {
276 match self {
277 #(#variant_arms),*
278 }
279 }
280}
281
282fn generate_enum_internal_allocations_impl(
284 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
285) -> proc_macro2::TokenStream {
286 let variant_arms = variants.iter().map(|variant| {
287 let variant_name = &variant.ident;
288 let variant_name_str = variant_name.to_string();
289 match &variant.fields {
290 Fields::Named(fields) => {
291 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
292 let field_allocations = fields.named.iter().map(|field| {
293 let field_name = &field.ident;
294 let field_name_str = field_name.as_ref().unwrap().to_string();
295 quote! {
296 if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(#field_name) {
297 allocations.push((ptr, format!("{}::{}::{}", var_name, #variant_name_str, #field_name_str)));
298 }
299 }
300 });
301 quote! {
302 Self::#variant_name { #(#field_names),* } => {
303 let mut allocations = Vec::new();
304 #(#field_allocations)*
305 allocations
306 }
307 }
308 }
309 Fields::Unnamed(fields) => {
310 let field_patterns: Vec<_> = (0..fields.unnamed.len())
311 .map(|i| syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site()))
312 .collect();
313 let field_allocations = field_patterns.iter().enumerate().map(|(i, field_name)| {
314 quote! {
315 if let Some(ptr) = memscope_rs::Trackable::get_heap_ptr(#field_name) {
316 allocations.push((ptr, format!("{}::{}::{}", var_name, #variant_name_str, #i)));
317 }
318 }
319 });
320 quote! {
321 Self::#variant_name(#(#field_patterns),*) => {
322 let mut allocations = Vec::new();
323 #(#field_allocations)*
324 allocations
325 }
326 }
327 }
328 Fields::Unit => {
329 quote! {
330 Self::#variant_name => Vec::new()
331 }
332 }
333 }
334 });
335
336 quote! {
337 match self {
338 #(#variant_arms),*
339 }
340 }
341}
342
343fn has_potential_heap_allocations(fields: &Fields) -> bool {
345 match fields {
346 Fields::Named(fields_named) => fields_named
347 .named
348 .iter()
349 .any(|field| is_potentially_heap_allocated(&field.ty)),
350 Fields::Unnamed(fields_unnamed) => fields_unnamed
351 .unnamed
352 .iter()
353 .any(|field| is_potentially_heap_allocated(&field.ty)),
354 Fields::Unit => false,
355 }
356}
357
358fn is_potentially_heap_allocated(ty: &Type) -> bool {
360 match ty {
361 Type::Path(type_path) => {
362 if let Some(segment) = type_path.path.segments.last() {
363 let type_name = segment.ident.to_string();
364 matches!(
365 type_name.as_str(),
366 "String"
367 | "Vec"
368 | "HashMap"
369 | "BTreeMap"
370 | "HashSet"
371 | "BTreeSet"
372 | "VecDeque"
373 | "LinkedList"
374 | "BinaryHeap"
375 | "Box"
376 | "Rc"
377 | "Arc"
378 )
379 } else {
380 false
381 }
382 }
383 _ => false,
384 }
385}