loupe_derive/lib.rs
1//! Companion of the [`loupe`](../loupe-derive/index.html) crate.
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote, quote_spanned};
5use syn::{
6 parse, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Generics, Ident, Index,
7};
8
9/// Procedural macro to implement the `loupe::MemoryUsage` trait
10/// automatically for structs and enums.
11///
12/// All struct fields and enum variants must implement `MemoryUsage`
13/// trait. If it's not possible, the `#[loupe(skip)]` attribute can be
14/// used on a field or a variant to instruct the derive procedural
15/// macro to skip that item.
16///
17/// # Example
18///
19/// ```rust,ignore
20/// #[derive(MemoryUsage)]
21/// struct Point {
22/// x: i32,
23/// y: i32,
24/// }
25///
26/// struct Mystery { ptr: *const i32 }
27///
28/// #[derive(MemoryUsage)]
29/// struct S {
30/// points: Vec<Point>,
31///
32/// #[loupe(skip)]
33/// other: Mystery,
34/// }
35/// ```
36#[proc_macro_derive(MemoryUsage, attributes(loupe))]
37pub fn derive_memory_usage(input: TokenStream) -> TokenStream {
38 let derive_input: DeriveInput = parse(input).unwrap();
39
40 match derive_input.data {
41 Data::Struct(ref struct_data) => {
42 derive_memory_usage_for_struct(&derive_input.ident, struct_data, &derive_input.generics)
43 }
44
45 Data::Enum(ref enum_data) => {
46 derive_memory_usage_for_enum(&derive_input.ident, enum_data, &derive_input.generics)
47 }
48
49 Data::Union(_) => panic!("unions are not yet implemented"),
50 /*
51 // TODO: unions.
52 // We have no way of knowing which union member is active, so we should
53 // refuse to derive an impl except for unions where all members are
54 // primitive types or arrays of them.
55 Data::Union(ref union_data) => {
56 derive_memory_usage_union(union_data)
57 },
58 */
59 }
60}
61
62// TODO: use Iterator::fold_first once it's stable. https://github.com/rust-lang/rust/pull/79805
63fn join_fold<I, F, B>(mut iter: I, function: F, empty: B) -> B
64where
65 I: Iterator<Item = B>,
66 F: FnMut(B, I::Item) -> B,
67{
68 if let Some(first) = iter.next() {
69 iter.fold(first, function)
70 } else {
71 empty
72 }
73}
74
75fn derive_memory_usage_for_struct(
76 struct_name: &Ident,
77 data: &DataStruct,
78 generics: &Generics,
79) -> TokenStream {
80 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81
82 let sum = join_fold(
83 // Check all fields of the `struct`.
84 match &data.fields {
85 // Field has the form:
86 //
87 // F { x, y }
88 Fields::Named(ref fields) => fields
89 .named
90 .iter()
91 .filter_map(|field| {
92 if must_skip(&field.attrs) {
93 return None;
94 }
95
96 let ident = field.ident.as_ref().unwrap();
97 let span = ident.span();
98
99 Some(quote_spanned!(
100 span => loupe::MemoryUsage::size_of_val(&self.#ident, visited) - std::mem::size_of_val(&self.#ident)
101 ))
102 })
103 .collect(),
104
105 // Field has the form:
106 //
107 // F
108 Fields::Unit => vec![],
109
110 // Field has the form:
111 //
112 // F(x, y)
113 Fields::Unnamed(ref fields) => fields
114 .unnamed
115 .iter()
116 .enumerate()
117 .filter_map(|(nth, field)| {
118 if must_skip(&field.attrs) {
119 return None;
120 }
121
122 let ident = Index::from(nth);
123
124 Some(quote! { loupe::MemoryUsage::size_of_val(&self.#ident, visited) - std::mem::size_of_val(&self.#ident) })
125 })
126 .collect(),
127 }
128 .into_iter(),
129 |x, y| quote! { #x + #y },
130 quote! { 0 },
131 );
132
133 // Implement the `MemoryUsage` trait for `struct_name`.
134 (quote! {
135 #[allow(dead_code)]
136 impl #impl_generics loupe::MemoryUsage for #struct_name #ty_generics
137 #where_clause
138 {
139 fn size_of_val(&self, visited: &mut loupe::MemoryUsageTracker) -> usize {
140 std::mem::size_of_val(self) + #sum
141 }
142 }
143 })
144 .into()
145}
146
147fn derive_memory_usage_for_enum(
148 enum_name: &Ident,
149 data: &DataEnum,
150 generics: &Generics,
151) -> TokenStream {
152 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
153
154 let match_arms = join_fold(
155 data.variants
156 .iter()
157 .map(|variant| {
158 let ident = &variant.ident;
159 let span = ident.span();
160
161 // Check all the variants of the `enum`.
162 //
163 // We want to generate something like this:
164 //
165 // Self::Variant ... => { ... }
166 // ^^^^^^^ ^^^ ^^^
167 // | | |
168 // | | given by the `sum` variable
169 // | given by the `pattern` variable
170 // given by the `ident` variable
171 //
172 // Let's compute the `pattern` and `sum` parts.
173 let (pattern, mut sum) = match variant.fields {
174 // Variant has the form:
175 //
176 // V { x, y }
177 //
178 // We want to generate:
179 //
180 // Self::V { x, y } => { /* memory usage of x + y */ }
181 Fields::Named(ref fields) => {
182 // Collect the identifiers.
183 let identifiers = fields.named.iter().map(|field| {
184 let ident = field.ident.as_ref().unwrap();
185 let span = ident.span();
186
187 quote_spanned!(span => #ident)
188 });
189
190 // Generate the `pattern` part.
191 let pattern = {
192 let pattern = join_fold(
193 identifiers.clone(),
194 |x, y| quote! { #x , #y },
195 quote! {}
196 );
197
198 quote! { { #pattern } }
199 };
200
201 // Generate the `sum` part.
202 let sum = {
203 let sum = join_fold(
204 identifiers.map(|ident| quote! {
205 loupe::MemoryUsage::size_of_val(#ident, visited) - std::mem::size_of_val(#ident)
206 }),
207 |x, y| quote! { #x + #y },
208 quote! { 0 },
209 );
210
211 quote! { #sum }
212 };
213
214 (pattern, sum)
215 }
216
217 // Variant has the form:
218 //
219 // V
220 //
221 // We want to generate:
222 //
223 // Self::V => { 0 }
224 Fields::Unit => {
225 let pattern = quote! {};
226 let sum = quote! { 0 };
227
228 (pattern, sum)
229 },
230
231 // Variant has the form:
232 //
233 // V(x, y)
234 //
235 // We want to generate:
236 //
237 // Self::V(x, y) => { /* memory usage of x + y */ }
238 Fields::Unnamed(ref fields) => {
239 // Collect the identifiers. They are unnamed,
240 // so let's use the `xi` convention where `i`
241 // is the identifier index.
242 let identifiers = fields
243 .unnamed
244 .iter()
245 .enumerate()
246 .map(|(nth, _field)| {
247 let ident = format_ident!("x{}", Index::from(nth));
248
249 quote! { #ident }
250 });
251
252 // Generate the `pattern` part.
253 let pattern = {
254 let pattern = join_fold(
255 identifiers.clone(),
256 |x, y| quote! { #x , #y },
257 quote! {}
258 );
259
260 quote! { ( #pattern ) }
261 };
262
263 // Generate the `sum` part.
264 let sum = {
265 let sum = join_fold(
266 identifiers.map(|ident| quote! {
267 loupe::MemoryUsage::size_of_val(#ident, visited) - std::mem::size_of_val(#ident)
268 }),
269 |x, y| quote! { #x + #y },
270 quote! { 0 },
271 );
272
273 quote! { #sum }
274 };
275
276 (pattern, sum)
277 }
278 };
279
280 if must_skip(&variant.attrs) {
281 sum = quote! { 0 };
282 }
283
284 // At this step, `pattern` and `sum` are well
285 // defined. Let's generate the full arm for the
286 // `match` statement.
287 quote_spanned! { span => Self::#ident#pattern => #sum }
288 }
289 ),
290 |x, y| quote! { #x , #y },
291 quote! {},
292 );
293
294 // Implement the `MemoryUsage` trait for `enum_name`.
295 (quote! {
296 #[allow(dead_code)]
297 impl #impl_generics loupe::MemoryUsage for #enum_name #ty_generics
298 #where_clause
299 {
300 fn size_of_val(&self, visited: &mut loupe::MemoryUsageTracker) -> usize {
301 std::mem::size_of_val(self) + match self {
302 #match_arms
303 }
304 }
305 }
306 })
307 .into()
308}
309
310fn must_skip(attrs: &[Attribute]) -> bool {
311 attrs.iter().any(|attr| {
312 attr.path.is_ident("loupe") && matches!(attr.parse_args::<Ident>(), Ok(a) if a == "skip")
313 })
314}