Skip to main content

redb_derive/
lib.rs

1#![deny(clippy::all, clippy::pedantic, clippy::disallowed_methods)]
2#![allow(
3    clippy::must_use_candidate,
4    clippy::redundant_closure_for_method_calls,
5    clippy::similar_names,
6    clippy::too_many_lines
7)]
8
9use proc_macro::TokenStream;
10use quote::quote;
11use syn::{Data, DeriveInput, Fields, GenericParam, Ident, parse_macro_input};
12
13#[proc_macro_derive(Key)]
14pub fn derive_key(input: TokenStream) -> TokenStream {
15    let input = parse_macro_input!(input as DeriveInput);
16
17    match generate_key_impl(&input) {
18        Ok(tokens) => tokens.into(),
19        Err(err) => err.to_compile_error().into(),
20    }
21}
22
23fn generate_key_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
24    let Data::Struct(_) = &input.data else {
25        return Err(syn::Error::new_spanned(
26            input,
27            "Key can only be derived for structs",
28        ));
29    };
30
31    let name = &input.ident;
32    let generics = &input.generics;
33    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34
35    Ok(quote! {
36        impl #impl_generics redb::Key for #name #ty_generics #where_clause {
37            fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
38                let value1 = #name::from_bytes(data1);
39                let value2 = #name::from_bytes(data2);
40                Ord::cmp(&value1, &value2)
41            }
42        }
43    })
44}
45
46#[proc_macro_derive(Value)]
47pub fn derive_value(input: TokenStream) -> TokenStream {
48    let input = parse_macro_input!(input as DeriveInput);
49
50    match generate_value_impl(&input) {
51        Ok(tokens) => tokens.into(),
52        Err(err) => err.to_compile_error().into(),
53    }
54}
55
56fn generate_value_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
57    let Data::Struct(data_struct) = &input.data else {
58        return Err(syn::Error::new_spanned(
59            input,
60            "Value can only be derived for structs",
61        ));
62    };
63
64    let name = &input.ident;
65    let generics = &input.generics;
66    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
67
68    let self_type = generate_self_type(name, generics)?;
69
70    let type_name_impl = generate_type_name(name, &data_struct.fields);
71    let as_bytes_impl = generate_as_bytes(&data_struct.fields);
72    let from_bytes_impl = generate_from_bytes(name, &data_struct.fields);
73    let fixed_width_impl = generate_fixed_width(&data_struct.fields);
74
75    Ok(quote! {
76        impl #impl_generics redb::Value for #name #ty_generics #where_clause {
77            type SelfType<'a> = #self_type
78            where
79                Self: 'a;
80            type AsBytes<'a> = Vec<u8>
81            where
82                Self: 'a;
83
84            fn fixed_width() -> Option<usize> {
85                #fixed_width_impl
86            }
87
88            fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
89            where
90                Self: 'a,
91            {
92                #from_bytes_impl
93            }
94
95            fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
96            where
97                Self: 'b,
98            {
99                #as_bytes_impl
100            }
101
102            fn type_name() -> redb::TypeName {
103                #type_name_impl
104            }
105        }
106    })
107}
108
109fn generate_self_type(
110    name: &syn::Ident,
111    generics: &syn::Generics,
112) -> syn::Result<proc_macro2::TokenStream> {
113    if generics.params.is_empty() {
114        Ok(quote! { #name })
115    } else {
116        let mut params = vec![];
117        for param in &generics.params {
118            match param {
119                GenericParam::Lifetime(_) => params.push(quote! { 'a }),
120                GenericParam::Type(type_param) => {
121                    return Err(syn::Error::new_spanned(
122                        type_param,
123                        "Value derivation is not implemented for structs with type parameters",
124                    ));
125                }
126                GenericParam::Const(const_param) => {
127                    return Err(syn::Error::new_spanned(
128                        const_param,
129                        "Value derivation is not implemented for structs with const parameters",
130                    ));
131                }
132            }
133        }
134
135        Ok(quote! { #name<#(#params),*> })
136    }
137}
138
139fn generate_type_name(struct_name: &Ident, fields: &Fields) -> proc_macro2::TokenStream {
140    match fields {
141        Fields::Named(fields_named) => {
142            let field_strings: Vec<_> = fields_named
143                .named
144                .iter()
145                .map(|field| {
146                    let field_name = field.ident.as_ref().unwrap();
147                    let field_type = &field.ty;
148                    quote! {
149                        format!("{}: {}", stringify!(#field_name), <#field_type>::type_name().name())
150                    }
151                })
152                .collect();
153
154            if field_strings.is_empty() {
155                quote! {
156                    redb::TypeName::new(&format!("{} {{}}",
157                        stringify!(#struct_name),
158                    ))
159                }
160            } else {
161                quote! {
162                    redb::TypeName::new(&format!("{} {{{}}}",
163                        stringify!(#struct_name),
164                        [#(#field_strings),*].join(", ")
165                    ))
166                }
167            }
168        }
169        Fields::Unnamed(fields_unnamed) => {
170            let field_strings: Vec<_> = fields_unnamed
171                .unnamed
172                .iter()
173                .map(|field| {
174                    let field_type = &field.ty;
175                    quote! {
176                        <#field_type>::type_name().name()
177                    }
178                })
179                .collect();
180
181            if field_strings.is_empty() {
182                quote! {
183                    redb::TypeName::new(&format!("{}()",
184                        stringify!(#struct_name),
185                    ))
186                }
187            } else {
188                quote! {
189                    redb::TypeName::new(&format!("{}({})",
190                        stringify!(#struct_name),
191                        [#(#field_strings),*].join(", ")
192                    ))
193                }
194            }
195        }
196        Fields::Unit => {
197            quote! {
198                redb::TypeName::new(stringify!(#struct_name))
199            }
200        }
201    }
202}
203
204fn get_field_types(fields: &Fields) -> Vec<syn::Type> {
205    match fields {
206        Fields::Named(fields_named) => fields_named
207            .named
208            .iter()
209            .map(|field| &field.ty)
210            .cloned()
211            .collect(),
212        Fields::Unnamed(fields_unnamed) => fields_unnamed
213            .unnamed
214            .iter()
215            .map(|field| &field.ty)
216            .cloned()
217            .collect(),
218        Fields::Unit => vec![],
219    }
220}
221
222fn generate_fixed_width(fields: &Fields) -> proc_macro2::TokenStream {
223    let field_types = get_field_types(fields);
224    quote! {
225        let mut total_width = 0usize;
226        #(
227            total_width += <#field_types>::fixed_width()?;
228        )*
229        Some(total_width)
230    }
231}
232
233fn generate_as_bytes(fields: &Fields) -> proc_macro2::TokenStream {
234    let field_types = get_field_types(fields);
235    let field_accessors = match fields {
236        Fields::Named(fields_named) => fields_named
237            .named
238            .iter()
239            .map(|field| {
240                let name = &field.ident;
241                quote! { #name }
242            })
243            .collect(),
244        Fields::Unnamed(_) => (0..field_types.len())
245            .map(|i| {
246                let index = syn::Index::from(i);
247                quote! { #index }
248            })
249            .collect(),
250        Fields::Unit => Vec::new(),
251    };
252
253    let num_fields = field_types.len();
254
255    if num_fields == 0 {
256        quote! { Vec::new() }
257    } else if num_fields == 1 {
258        let field_accessor = &field_accessors[0];
259        let field_type = &field_types[0];
260        quote! {
261            {
262                let field_bytes = <#field_type>::as_bytes(&value.#field_accessor);
263                field_bytes.as_ref().to_vec()
264            }
265        }
266    } else {
267        let field_types_except_last = &field_types[..num_fields - 1];
268        let field_accessors_except_last = &field_accessors[..num_fields - 1];
269
270        quote! {
271            {
272                let mut result = Vec::new();
273
274                #(
275                    if <#field_types_except_last>::fixed_width().is_none() {
276                        let field_bytes = <#field_types_except_last>::as_bytes(&value.#field_accessors_except_last);
277                        let bytes: &[u8] = field_bytes.as_ref();
278                        let len = bytes.len();
279                        if len < 254 {
280                            result.push(len.try_into().unwrap());
281                        } else if len <= u16::MAX.into() {
282                            let u16_len: u16 = len.try_into().unwrap();
283                            result.push(254u8);
284                            result.extend_from_slice(&u16_len.to_le_bytes());
285                        } else {
286                            let u32_len: u32 = len.try_into().unwrap();
287                            result.push(255u8);
288                            result.extend_from_slice(&u32_len.to_le_bytes());
289                        }
290                    }
291                )*
292
293                #(
294                    {
295                        let field_bytes = <#field_types>::as_bytes(&value.#field_accessors);
296                        result.extend_from_slice(field_bytes.as_ref());
297                    }
298                )*
299
300                result
301            }
302        }
303    }
304}
305
306fn generate_from_bytes(name: &Ident, fields: &Fields) -> proc_macro2::TokenStream {
307    let field_types = get_field_types(fields);
308    let field_vars: Vec<_> = (0..field_types.len())
309        .map(|i| quote::format_ident!("field_{}", i))
310        .collect();
311    let num_fields = field_types.len();
312
313    let body = if num_fields == 0 {
314        quote! {}
315    } else if num_fields == 1 {
316        let field_var = &field_vars[0];
317        let field_type = &field_types[0];
318        quote! {
319            let #field_var = <#field_type>::from_bytes(data);
320        }
321    } else {
322        let field_types_except_last = &field_types[..num_fields - 1];
323        let field_vars_except_last = &field_vars[..num_fields - 1];
324        let last_field_var = field_vars.last();
325        let last_field_type = field_types.last();
326
327        quote! {
328            let mut offset = 0usize;
329            let mut var_lengths = Vec::new();
330
331            #(
332                if <#field_types_except_last>::fixed_width().is_none() {
333                    let (len, bytes_read) = match data[offset] {
334                        0u8..=253u8 => (data[offset] as usize, 1usize),
335                        254u8 => (
336                            u16::from_le_bytes(data[offset + 1..offset + 3].try_into().unwrap()) as usize,
337                            3usize,
338                        ),
339                        255u8 => (
340                            u32::from_le_bytes(data[offset + 1..offset + 5].try_into().unwrap()) as usize,
341                            5usize,
342                        ),
343                    };
344                    var_lengths.push(len);
345                    offset += bytes_read;
346                }
347            )*
348
349            let mut var_index = 0;
350            #(
351                let #field_vars_except_last = if let Some(fixed_width) = <#field_types_except_last>::fixed_width() {
352                    let field_data = &data[offset..offset + fixed_width];
353                    offset += fixed_width;
354                    <#field_types_except_last>::from_bytes(field_data)
355                } else {
356                    let len = var_lengths[var_index];
357                    let field_data = &data[offset..offset + len];
358                    offset += len;
359                    var_index += 1;
360                    <#field_types_except_last>::from_bytes(field_data)
361                };
362            )*
363
364            let #last_field_var = if let Some(fixed_width) = <#last_field_type>::fixed_width() {
365                let field_data = &data[offset..offset + fixed_width];
366                <#last_field_type>::from_bytes(field_data)
367            } else {
368                <#last_field_type>::from_bytes(&data[offset..])
369            };
370        }
371    };
372    match fields {
373        Fields::Named(fields_named) => {
374            let field_names: Vec<_> = fields_named
375                .named
376                .iter()
377                .map(|field| &field.ident)
378                .collect();
379
380            quote! {
381                {
382                    #body
383                    #name {
384                        #(#field_names: #field_vars),*
385                    }
386                }
387            }
388        }
389        Fields::Unnamed(_) => {
390            quote! {
391                {
392                    #body
393                    #name(#(#field_vars),*)
394                }
395            }
396        }
397        Fields::Unit => {
398            quote! { #name }
399        }
400    }
401}