Skip to main content

mdbx_derive_macros/
lib.rs

1#[cfg(feature = "mdbx")]
2use heck::ToSnakeCase;
3use itertools::Itertools;
4use proc_macro::TokenStream;
5use quote::{quote, quote_spanned};
6use syn::{
7    Data, DeriveInput, Fields, Index,
8    parse_macro_input,
9    spanned::Spanned,
10};
11#[cfg(feature = "mdbx")]
12use syn::{
13    Ident, Token, Type,
14    parse::{Parse, ParseStream},
15    punctuated::Punctuated,
16};
17
18#[proc_macro_derive(KeyObject)]
19pub fn derive(input: TokenStream) -> TokenStream {
20    let input = parse_macro_input!(input as DeriveInput);
21    let decode = decode_impl(&input);
22    // Encode implementation
23    let ident = input.ident;
24    let ts = match &input.data {
25        Data::Struct(st) => match &st.fields {
26            Fields::Named(fields) => {
27                let recur = fields.named.iter().map(|t| {
28                    let name = &t.ident;
29                    quote_spanned! {t.span()=>
30                        self.#name.key_encode()?.into_iter()
31                    }
32                });
33                quote! {
34                    [#(#recur),*].into_iter().flatten().collect()
35                }
36            }
37            Fields::Unnamed(fields) => {
38                let recur = fields.unnamed.iter().enumerate().map(|(idx, t)| {
39                    let index = Index::from(idx);
40                    quote_spanned! {t.span()=>
41                        self.#index.key_encode()?.into_iter()
42                    }
43                });
44                quote! {
45                    [#(#recur),*].into_iter().flatten().collect()
46                }
47            }
48            _ => quote! {
49                compile_error!("Not supported")
50            },
51        },
52        _ => quote! {
53            compile_error!("Not supported struct")
54        },
55    };
56    #[cfg(feature = "mdbx")]
57    let table_object_impl = quote! {
58        impl mdbx_derive::mdbx::TableObject for #ident {
59            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
60                <Self as mdbx_derive::KeyObjectDecode>::key_decode(data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
61            }
62        }
63    };
64    #[cfg(not(feature = "mdbx"))]
65    let table_object_impl = quote! {};
66
67    let output = quote! {
68        impl mdbx_derive::KeyObjectEncode for #ident {
69            fn key_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
70                Ok(#ts)
71            }
72        }
73
74        #table_object_impl
75
76        #decode
77    };
78    output.into()
79}
80
81fn decode_impl(input: &DeriveInput) -> proc_macro2::TokenStream {
82    let ident = &input.ident;
83    let body = match &input.data {
84        Data::Struct(st) => {
85            let mut named = false;
86            let fs = match &st.fields {
87                Fields::Named(fields) => {
88                    named = true;
89                    Some(fields.named.iter())
90                }
91                Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
92                _ => None,
93            };
94
95            if let Some(fs) = fs {
96                let ranges = fs
97                    .clone()
98                    .scan(quote! {0}, |acc, x| {
99                        let ty = &x.ty;
100                        let ret = Some(quote_spanned! {x.span()=>
101                            (#acc)..(#acc + <#ty>::KEYSIZE)
102                        });
103
104                        *acc = quote! { #acc + <#ty>::KEYSIZE };
105                        ret
106                    })
107                    .collect_vec();
108                let recur = fs.clone().map(|t| {
109                    let ty = &t.ty;
110                    quote_spanned! {t.span()=>
111                        <#ty>::KEYSIZE
112                    }
113                });
114                let tyts = quote! {
115                    0 #(+ #recur)*
116                };
117
118                if named {
119                    let names = fs.clone().map(|t| {
120                        let name = &t.ident;
121                        quote_spanned! {t.span()=>
122                            #name
123                        }
124                    });
125                    let recur = fs.clone().zip(ranges).map(|(t, idx)| {
126                        let name = &t.ident;
127                        let ty = &t.ty;
128                        quote_spanned! {t.span()=>
129                            let #name = <#ty>::key_decode(bs[#idx].try_into().unwrap())?;
130                        }
131                    });
132                    quote! {
133                        let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
134                        #(#recur)*
135                        Ok(Self {
136                            #(#names),*
137                        })
138                    }
139                } else {
140                    let recur = fs.zip(ranges).map(|(t, idx)| {
141                        let ty = &t.ty;
142                        quote_spanned! {t.span()=>
143                            <#ty>::key_decode(bs[#idx].try_into().unwrap())?
144                        }
145                    });
146
147                    quote! {
148                        let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
149                        Ok(Self(#(#recur),*))
150                    }
151                }
152            } else {
153                quote! {
154                    compile_error("Not supported field")
155                }
156            }
157        }
158        _ => quote! {
159            compile_error!("Not supported struct")
160        },
161    };
162
163    let key_sz = match &input.data {
164        Data::Struct(st) => {
165            let ks = st.fields.iter().map(|f| {
166                let ty = &f.ty;
167                quote_spanned! {f.span()=>
168                    <#ty>::KEYSIZE
169                }
170            });
171
172            quote! {
173                0 #(+ #ks)*
174            }
175        }
176        _ => quote! { 0 },
177    };
178
179    let output = quote! {
180        impl mdbx_derive::KeyObjectDecode for #ident {
181            const KEYSIZE: usize = #key_sz ;
182            fn key_decode(val: &[u8]) -> Result<Self, mdbx_derive::Error> {
183                #body
184            }
185        }
186    };
187    output
188}
189
190#[proc_macro_derive(BcsObject)]
191pub fn derive_bcs_object(input: TokenStream) -> TokenStream {
192    let input = parse_macro_input!(input as DeriveInput);
193    let ident = input.ident;
194
195    #[cfg(feature = "mdbx")]
196    let table_object_impl = quote! {
197        impl mdbx_derive::mdbx::TableObject for #ident {
198            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
199                mdbx_derive::bcs::from_bytes(&data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
200            }
201        }
202    };
203    #[cfg(not(feature = "mdbx"))]
204    let table_object_impl = quote! {};
205
206    let output = quote! {
207        impl mdbx_derive::TableObjectDecode for #ident {
208            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
209                Ok(mdbx_derive::bcs::from_bytes(&data_val)?)
210            }
211        }
212
213        #table_object_impl
214
215        impl mdbx_derive::TableObjectEncode for #ident {
216            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
217                Ok(mdbx_derive::bcs::to_bytes(&self)?)
218            }
219        }
220    };
221    output.into()
222}
223
224#[proc_macro_derive(ZstdBcsObject)]
225pub fn derive_zstd_bcs_object(input: TokenStream) -> TokenStream {
226    let input = parse_macro_input!(input as DeriveInput);
227    let ident = input.ident;
228
229    #[cfg(feature = "mdbx")]
230    let table_object_impl = quote! {
231        impl mdbx_derive::mdbx::TableObject for #ident {
232            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
233                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
234                    mdbx_derive::mdbx::Error::Corrupted
235                })?;
236                mdbx_derive::bcs::from_bytes(&decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
237            }
238        }
239    };
240    #[cfg(not(feature = "mdbx"))]
241    let table_object_impl = quote! {};
242
243    let output = quote! {
244        impl mdbx_derive::TableObjectDecode for #ident {
245            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
246                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
247                    mdbx_derive::Error::Zstd(e)
248                })?;
249                Ok(mdbx_derive::bcs::from_bytes(&decompressed)?)
250            }
251        }
252
253        #table_object_impl
254
255        impl mdbx_derive::TableObjectEncode for #ident {
256            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
257                let bs = mdbx_derive::bcs::to_bytes(&self)?;
258                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
259                    mdbx_derive::Error::Zstd(e)
260                })?;
261                Ok(compressed)
262            }
263        }
264    };
265    output.into()
266}
267
268#[proc_macro_derive(KeyAsTableObject)]
269pub fn derive_key_table_object(input: TokenStream) -> TokenStream {
270    let input = parse_macro_input!(input as DeriveInput);
271    let ident = input.ident;
272    let output = quote! {
273        impl mdbx_derive::TableObjectDecode for #ident {
274            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
275                <#ident as mdbx_derive::KeyObjectDecode>::key_decode(data_val)
276            }
277        }
278
279        impl mdbx_derive::TableObjectEncode for #ident {
280            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
281                <#ident as mdbx_derive::KeyObjectEncode>::key_encode(self)
282            }
283        }
284    };
285    output.into()
286}
287
288#[proc_macro_derive(ZstdPostcardObject)]
289pub fn derive_zstd_postcard(input: TokenStream) -> TokenStream {
290    let input = parse_macro_input!(input as DeriveInput);
291    let ident = input.ident;
292
293    #[cfg(feature = "mdbx")]
294    let table_object_impl = quote! {
295        impl mdbx_derive::mdbx::TableObject for #ident {
296            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
297                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
298                    mdbx_derive::mdbx::Error::Corrupted
299                })?;
300                mdbx_derive::postcard::from_bytes(&decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
301            }
302        }
303    };
304    #[cfg(not(feature = "mdbx"))]
305    let table_object_impl = quote! {};
306
307    let output = quote! {
308        impl mdbx_derive::TableObjectDecode for #ident {
309            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
310                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
311                    mdbx_derive::Error::Zstd(e)
312                })?;
313                Ok(mdbx_derive::postcard::from_bytes(&decompressed)?)
314            }
315        }
316
317        #table_object_impl
318
319        impl mdbx_derive::TableObjectEncode for #ident {
320            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
321                let bs = mdbx_derive::postcard::to_allocvec(&self)?;
322                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
323                    mdbx_derive::Error::Zstd(e)
324                })?;
325                Ok(compressed)
326            }
327        }
328    };
329    output.into()
330}
331
332#[cfg(feature = "json")]
333#[proc_macro_derive(ZstdJSONObject)]
334pub fn derive_zstd_json(input: TokenStream) -> TokenStream {
335    let input = parse_macro_input!(input as DeriveInput);
336    let ident = input.ident;
337
338    #[cfg(feature = "mdbx")]
339    let table_object_impl = quote! {
340        impl mdbx_derive::mdbx::TableObject for #ident {
341            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
342                let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
343                    mdbx_derive::mdbx::Error::Corrupted
344                })?;
345                mdbx_derive::json::from_slice(&mut decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
346            }
347        }
348    };
349    #[cfg(not(feature = "mdbx"))]
350    let table_object_impl = quote! {};
351
352    let output = quote! {
353        impl mdbx_derive::TableObjectDecode for #ident {
354            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
355                let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
356                    mdbx_derive::Error::Zstd(e)
357                })?;
358                Ok(mdbx_derive::json::from_slice(&mut decompressed)?)
359            }
360        }
361
362        #table_object_impl
363
364        impl mdbx_derive::TableObjectEncode for #ident {
365            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
366                let bs = mdbx_derive::json::to_vec(&self)?;
367                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
368                    mdbx_derive::Error::Zstd(e)
369                })?;
370                Ok(compressed)
371            }
372        }
373    };
374    output.into()
375}
376
377#[cfg(feature = "mdbx")]
378struct MacroInput {
379    struct_name: Ident,
380    error_type: Type,
381    tables: Punctuated<Type, Token![,]>,
382}
383
384#[cfg(feature = "mdbx")]
385impl Parse for MacroInput {
386    fn parse(input: ParseStream) -> syn::Result<Self> {
387        let struct_name: Ident = input.parse()?;
388        input.parse::<Token![,]>()?;
389        let error_type: Type = input.parse()?;
390        input.parse::<Token![,]>()?;
391        let tables = input.parse_terminated(Type::parse, Token![,])?;
392        Ok(MacroInput {
393            struct_name,
394            error_type,
395            tables,
396        })
397    }
398}
399
400#[cfg(feature = "mdbx")]
401#[proc_macro]
402pub fn generate_dbi_struct(input: TokenStream) -> TokenStream {
403    let MacroInput {
404        struct_name,
405        error_type,
406        tables,
407    } = syn::parse_macro_input!(input as MacroInput);
408
409    let field_names: Vec<_> = tables
410        .iter()
411        .map(|table_type| {
412            let type_path = if let Type::Path(tp) = table_type {
413                tp
414            } else {
415                panic!("Expected a type path")
416            };
417            let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
418            let field_name_str = type_ident_str.to_snake_case();
419            Ident::new(&field_name_str, proc_macro2::Span::call_site())
420        })
421        .collect();
422
423    let field_statemens: Vec<_> = tables
424        .iter()
425        .map(|table_type| {
426            let type_path = if let Type::Path(tp) = table_type {
427                tp
428            } else {
429                panic!("Expected a type path")
430            };
431            let ty = type_path.path.segments.last().unwrap().ident.clone();
432            let field_name_str = ty.to_string().to_snake_case();
433            let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
434
435            quote! {
436                let flags = if <#ty as mdbx_derive::MDBXTable>::DUPSORT {
437                    mdbx_derive::mdbx::DatabaseFlags::DUP_SORT
438                } else {
439                    mdbx_derive::mdbx::DatabaseFlags::default()
440                };
441                let #ident = <#ty as mdbx_derive::MDBXTable>::create_table_tx(&tx, flags).await?;
442
443            }
444        })
445        .collect();
446
447    let ro_field_statemens: Vec<_> = tables
448        .iter()
449        .map(|table_type| {
450            let type_path = if let Type::Path(tp) = table_type {
451                tp
452            } else {
453                panic!("Expected a type path")
454            };
455            let ty = type_path.path.segments.last().unwrap().ident.clone();
456            let field_name_str = ty.to_string().to_snake_case();
457            let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
458
459            quote! {
460                let #ident = <#ty as mdbx_derive::MDBXTable>::open_table_tx(&tx).await?;
461
462            }
463        })
464        .collect();
465
466    let fields = tables
467        .iter()
468        .zip(field_names.iter())
469        .map(|(table_type, field_name)| {
470            let type_path = if let Type::Path(tp) = table_type {
471                tp
472            } else {
473                panic!()
474            };
475            let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
476            let doc_string = format!("DBI handle for the `{}` table.", type_ident_str);
477
478            quote! {
479                #[doc = #doc_string]
480                pub #field_name: u32,
481            }
482        });
483
484    let original_type_names: Vec<_> = tables
485        .iter()
486        .map(|table_type| {
487            let type_path = if let Type::Path(tp) = table_type {
488                tp
489            } else {
490                panic!("Expected a type path")
491            };
492            let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
493            Ident::new(&type_ident_str, proc_macro2::Span::call_site())
494        })
495        .collect(); // The crucial change is here!
496
497    let rw_tables: Vec<_> = tables
498        .iter()
499        .map(|table_type| {
500            let type_path = if let Type::Path(tp) = table_type {
501                tp
502            } else {
503                panic!("Expected a type path")
504            };
505            let ty = type_path.path.segments.last().unwrap().ident.clone();
506            let field_name_str = ty.to_string().to_snake_case();
507            let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
508            let wfname_tx = Ident::new(format!("write_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
509            let rfname_tx = Ident::new(format!("read_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
510            let dfname_tx = Ident::new(format!("del_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
511            let cursor_fname = Ident::new(format!("{}_cursor", &field_name_str).as_str(), proc_macro2::Span::call_site());
512            quote! {
513                pub async fn #wfname_tx
514                (
515                    &self,
516                    tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
517                    key: &<#ty as mdbx_derive::MDBXTable>::Key,
518                    value: &<#ty as mdbx_derive::MDBXTable>::Value,
519                    flags: mdbx_derive::mdbx::WriteFlags
520                ) -> Result<(), mdbx_derive::Error> {
521                    tx.put(
522                        self.#ident,
523                        &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
524                        &<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(value)?,
525                        flags
526                    ).await?;
527                    Ok(())
528                }
529
530                pub async fn #rfname_tx <K: mdbx_derive::mdbx::TransactionKind>
531                (
532                    &self,
533                    tx: &mdbx_derive::mdbx::TransactionAny<K>,
534                    key: &<#ty as mdbx_derive::MDBXTable>::Key
535                ) -> Result<Option< <#ty as mdbx_derive::MDBXTable>::Value >, mdbx_derive::Error> {
536                    let v = tx.get::<Vec<u8>>(
537                        self.#ident,
538                        &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
539                    ).await?;
540                    if let Some(v) = v {
541                        Ok(Some(<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectDecode>::table_decode(&v)?))
542                    } else {
543                        Ok(None)
544                    }
545                }
546
547                pub async fn #dfname_tx
548                (
549                    &self,
550                    tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
551                    key: &<#ty as mdbx_derive::MDBXTable>::Key,
552                    value: Option<&<#ty as mdbx_derive::MDBXTable>::Value>
553                ) -> Result<bool, mdbx_derive::Error> {
554                    let v = value.map(|v| <<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(v))
555                            .transpose()?;
556                    Ok(tx.del(
557                        self.#ident,
558                        &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
559                        v.as_ref().map(|t| t.as_slice())
560                    ).await?)
561                }
562
563                pub async fn #cursor_fname <K: mdbx_derive::mdbx::TransactionKind>
564                (
565                    &self,
566                    tx: &mdbx_derive::mdbx::TransactionAny<K>
567                ) -> Result<mdbx_derive::mdbx::CursorAny<K>, mdbx_derive::Error> {
568                    Ok(tx.cursor_with_dbi(self.#ident).await?)
569                }
570            }
571        })
572        .collect();
573
574    let output = quote! {
575        #[derive(Debug, Clone, Copy)]
576        pub struct #struct_name {
577            #( #fields )*
578        }
579
580        impl #struct_name {
581            pub async fn new(
582                env: &mdbx_derive::mdbx::EnvironmentAny,
583            ) -> Result<Self, #error_type> {
584                let tx = env.begin_rw_txn().await?;
585
586                #(
587                    #field_statemens
588                )*
589
590                tx.commit().await?;
591
592                Ok(Self {
593                    #( #field_names, )*
594                })
595            }
596
597            pub async fn new_ro<K: mdbx_derive::mdbx::TransactionKind>(
598                tx: &mdbx_derive::mdbx::TransactionAny<K>
599            ) -> Result<Self, #error_type> {
600
601                #(
602                    #ro_field_statemens
603                )*
604
605                Ok(Self {
606                    #( #field_names, )*
607                })
608            }
609
610            #(
611                #rw_tables
612            )*
613        }
614
615        impl mdbx_derive::HasMDBXTables for #struct_name {
616            type Error = #error_type;
617            type Tables = mdbx_derive::tuple_list_type!(#( #original_type_names),*);
618        }
619    };
620
621    output.into()
622}