Skip to main content

lance_derive/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{Data, DeriveInput, Fields, parse_macro_input};
7
8/// Derive macro for the `DeepSizeOf` trait.
9///
10/// Generates an implementation that sums the `deep_size_of_children` of all
11/// fields (for structs) or the active variant's fields (for enums).
12#[proc_macro_derive(DeepSizeOf)]
13pub fn derive_deep_size_of(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15    let name = &input.ident;
16    let generics = &input.generics;
17
18    // Add DeepSizeOf bounds to all type parameters
19    let mut bounded_generics = generics.clone();
20    for param in &mut bounded_generics.params {
21        if let syn::GenericParam::Type(ref mut type_param) = *param {
22            type_param
23                .bounds
24                .push(syn::parse_quote!(lance_core::deepsize::DeepSizeOf));
25        }
26    }
27    let (impl_generics, _, where_clause) = bounded_generics.split_for_impl();
28    let (_, ty_generics, _) = generics.split_for_impl();
29
30    let body = match &input.data {
31        Data::Struct(data) => generate_struct_body(&data.fields),
32        Data::Enum(data) => {
33            let arms: Vec<_> = data
34                .variants
35                .iter()
36                .map(|variant| {
37                    let variant_ident = &variant.ident;
38                    match &variant.fields {
39                        Fields::Unit => {
40                            quote! { Self::#variant_ident => 0 }
41                        }
42                        Fields::Unnamed(fields) => {
43                            let bindings: Vec<_> = (0..fields.unnamed.len())
44                                .map(|i| {
45                                    syn::Ident::new(
46                                        &format!("__field_{}", i),
47                                        proc_macro2::Span::call_site(),
48                                    )
49                                })
50                                .collect();
51                            let sum = bindings.iter().map(|b| {
52                                quote! { lance_core::deepsize::DeepSizeOf::deep_size_of_children(#b, __context) }
53                            });
54                            quote! {
55                                Self::#variant_ident(#(#bindings),*) => {
56                                    0 #(+ #sum)*
57                                }
58                            }
59                        }
60                        Fields::Named(fields) => {
61                            let field_names: Vec<_> =
62                                fields.named.iter().map(|f| &f.ident).collect();
63                            let sum = field_names.iter().map(|f| {
64                                quote! { lance_core::deepsize::DeepSizeOf::deep_size_of_children(#f, __context) }
65                            });
66                            quote! {
67                                Self::#variant_ident { #(#field_names),* } => {
68                                    0 #(+ #sum)*
69                                }
70                            }
71                        }
72                    }
73                })
74                .collect();
75            quote! {
76                match self {
77                    #(#arms),*
78                }
79            }
80        }
81        Data::Union(_) => {
82            return syn::Error::new_spanned(&input, "DeepSizeOf cannot be derived for unions")
83                .to_compile_error()
84                .into();
85        }
86    };
87
88    let expanded = quote! {
89        impl #impl_generics lance_core::deepsize::DeepSizeOf for #name #ty_generics #where_clause {
90            fn deep_size_of_children(&self, __context: &mut lance_core::deepsize::Context) -> usize {
91                #body
92            }
93        }
94    };
95
96    TokenStream::from(expanded)
97}
98
99fn generate_struct_body(fields: &Fields) -> proc_macro2::TokenStream {
100    match fields {
101        Fields::Named(fields) => {
102            let field_sizes = fields.named.iter().map(|f| {
103                let name = &f.ident;
104                quote! { lance_core::deepsize::DeepSizeOf::deep_size_of_children(&self.#name, __context) }
105            });
106            quote! { 0 #(+ #field_sizes)* }
107        }
108        Fields::Unnamed(fields) => {
109            let field_sizes = (0..fields.unnamed.len()).map(|i| {
110                let index = syn::Index::from(i);
111                quote! { lance_core::deepsize::DeepSizeOf::deep_size_of_children(&self.#index, __context) }
112            });
113            quote! { 0 #(+ #field_sizes)* }
114        }
115        Fields::Unit => {
116            quote! { 0 }
117        }
118    }
119}