mdbx_derive_macros/
lib.rs

1use heck::ToSnakeCase;
2use itertools::Itertools;
3use proc_macro::TokenStream;
4use quote::{quote, quote_spanned};
5use syn::{
6    Data, DeriveInput, Fields, Ident, Index, Token, Type,
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10    spanned::Spanned,
11};
12
13#[proc_macro_derive(KeyObject)]
14pub fn derive(input: TokenStream) -> TokenStream {
15    let input = parse_macro_input!(input as DeriveInput);
16    let decode = decode_impl(&input);
17    // Encode implementation
18    let ident = input.ident;
19    let ts = match &input.data {
20        Data::Struct(st) => match &st.fields {
21            Fields::Named(fields) => {
22                let recur = fields.named.iter().map(|t| {
23                    let name = &t.ident;
24                    quote_spanned! {t.span()=>
25                        self.#name.key_encode()?.into_iter()
26                    }
27                });
28                quote! {
29                    [#(#recur),*].into_iter().flatten().collect()
30                }
31            }
32            Fields::Unnamed(fields) => {
33                let recur = fields.unnamed.iter().enumerate().map(|(idx, t)| {
34                    let index = Index::from(idx);
35                    quote_spanned! {t.span()=>
36                        self.#index.key_encode()?.into_iter()
37                    }
38                });
39                quote! {
40                    [#(#recur),*].into_iter().flatten().collect()
41                }
42            }
43            _ => quote! {
44                compile_error!("Not supported")
45            },
46        },
47        _ => quote! {
48            compile_error!("Not supported struct")
49        },
50    };
51    let output = quote! {
52        impl mdbx_derive::KeyObjectEncode for #ident {
53            fn key_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
54                Ok(#ts)
55            }
56        }
57
58        impl mdbx_derive::mdbx::TableObject for #ident {
59            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
60                <Self as mdbx_derive::KeyObjectDecode>::key_decode(data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
61            }
62        }
63
64        #decode
65    };
66    output.into()
67}
68
69fn decode_impl(input: &DeriveInput) -> proc_macro2::TokenStream {
70    let ident = &input.ident;
71    let body = match &input.data {
72        Data::Struct(st) => {
73            let mut named = false;
74            let fs = match &st.fields {
75                Fields::Named(fields) => {
76                    named = true;
77                    Some(fields.named.iter())
78                }
79                Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
80                _ => None,
81            };
82
83            if let Some(fs) = fs {
84                let ranges = fs
85                    .clone()
86                    .scan(quote! {0}, |acc, x| {
87                        let ty = &x.ty;
88                        let ret = Some(quote_spanned! {x.span()=>
89                            (#acc)..(#acc + <#ty>::KEYSIZE)
90                        });
91
92                        *acc = quote! { #acc + <#ty>::KEYSIZE };
93                        ret
94                    })
95                    .collect_vec();
96                let recur = fs.clone().map(|t| {
97                    let ty = &t.ty;
98                    quote_spanned! {t.span()=>
99                        <#ty>::KEYSIZE
100                    }
101                });
102                let tyts = quote! {
103                    0 #(+ #recur)*
104                };
105
106                if named {
107                    let names = fs.clone().map(|t| {
108                        let name = &t.ident;
109                        quote_spanned! {t.span()=>
110                            #name
111                        }
112                    });
113                    let recur = fs.clone().zip(ranges).map(|(t, idx)| {
114                        let name = &t.ident;
115                        let ty = &t.ty;
116                        quote_spanned! {t.span()=>
117                            let #name = <#ty>::key_decode(bs[#idx].try_into().unwrap())?;
118                        }
119                    });
120                    quote! {
121                        let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
122                        #(#recur)*
123                        Ok(Self {
124                            #(#names),*
125                        })
126                    }
127                } else {
128                    let recur = fs.zip(ranges).map(|(t, idx)| {
129                        let ty = &t.ty;
130                        quote_spanned! {t.span()=>
131                            <#ty>::key_decode(bs[#idx].try_into().unwrap())?
132                        }
133                    });
134
135                    quote! {
136                        let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
137                        Ok(Self(#(#recur),*))
138                    }
139                }
140            } else {
141                quote! {
142                    compile_error("Not supported field")
143                }
144            }
145        }
146        _ => quote! {
147            compile_error!("Not supported struct")
148        },
149    };
150
151    let key_sz = match &input.data {
152        Data::Struct(st) => {
153            let ks = st.fields.iter().map(|f| {
154                let ty = &f.ty;
155                quote_spanned! {f.span()=>
156                    <#ty>::KEYSIZE
157                }
158            });
159
160            quote! {
161                0 #(+ #ks)*
162            }
163        }
164        _ => quote! { 0 },
165    };
166
167    let output = quote! {
168        impl mdbx_derive::KeyObjectDecode for #ident {
169            const KEYSIZE: usize = #key_sz ;
170            fn key_decode(val: &[u8]) -> Result<Self, mdbx_derive::Error> {
171                #body
172            }
173        }
174    };
175    output
176}
177
178#[proc_macro_derive(BcsObject)]
179pub fn derive_bcs_object(input: TokenStream) -> TokenStream {
180    let input = parse_macro_input!(input as DeriveInput);
181    let ident = input.ident;
182    let output = quote! {
183        impl mdbx_derive::TableObjectDecode for #ident {
184            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
185                Ok(mdbx_derive::bcs::from_bytes(&data_val)?)
186            }
187        }
188
189        impl mdbx_derive::mdbx::TableObject for #ident {
190            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
191                mdbx_derive::bcs::from_bytes(&data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
192            }
193        }
194
195        impl mdbx_derive::TableObjectEncode for #ident {
196            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
197                Ok(mdbx_derive::bcs::to_bytes(&self)?)
198            }
199        }
200    };
201    output.into()
202}
203
204#[proc_macro_derive(ZstdBcsObject)]
205pub fn derive_zstd_bcs_object(input: TokenStream) -> TokenStream {
206    let input = parse_macro_input!(input as DeriveInput);
207    let ident = input.ident;
208    let output = quote! {
209        impl mdbx_derive::TableObjectDecode for #ident {
210            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
211                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
212                    mdbx_derive::Error::Zstd(e)
213                })?;
214                Ok(mdbx_derive::bcs::from_bytes(&decompressed)?)
215            }
216        }
217
218        impl mdbx_derive::mdbx::TableObject for #ident {
219            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
220                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
221                    mdbx_derive::mdbx::Error::Corrupted
222                })?;
223                mdbx_derive::bcs::from_bytes(&decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
224            }
225        }
226
227        impl mdbx_derive::TableObjectEncode for #ident {
228            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
229                let bs = mdbx_derive::bcs::to_bytes(&self)?;
230                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
231                    mdbx_derive::Error::Zstd(e)
232                })?;
233                Ok(compressed)
234            }
235        }
236    };
237    output.into()
238}
239
240#[proc_macro_derive(KeyAsTableObject)]
241pub fn derive_key_table_object(input: TokenStream) -> TokenStream {
242    let input = parse_macro_input!(input as DeriveInput);
243    let ident = input.ident;
244    let output = quote! {
245        impl mdbx_derive::TableObjectDecode for #ident {
246            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
247                <#ident as mdbx_derive::KeyObjectDecode>::key_decode(data_val)
248            }
249        }
250
251        impl mdbx_derive::TableObjectEncode for #ident {
252            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
253                <#ident as mdbx_derive::KeyObjectEncode>::key_encode(self)
254            }
255        }
256    };
257    output.into()
258}
259
260#[proc_macro_derive(ZstdBincodeObject)]
261pub fn derive_zstd_bindcode(input: TokenStream) -> TokenStream {
262    let input = parse_macro_input!(input as DeriveInput);
263    let ident = input.ident;
264    let output = quote! {
265        impl mdbx_derive::TableObjectDecode for #ident {
266            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
267                let config = mdbx_derive::bincode::config::standard();
268                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
269                    mdbx_derive::Error::Zstd(e)
270                })?;
271                Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config)?.0)
272            }
273        }
274
275        impl mdbx_derive::mdbx::TableObject for #ident {
276            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
277                let config = mdbx_derive::bincode::config::standard();
278                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
279                    mdbx_derive::mdbx::Error::Corrupted
280                })?;
281                Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)?.0)
282            }
283        }
284
285        impl mdbx_derive::TableObjectEncode for #ident {
286            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
287                let config = mdbx_derive::bincode::config::standard();
288                let bs = mdbx_derive::bincode::encode_to_vec(&self, config)?;
289                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
290                    mdbx_derive::Error::Zstd(e)
291                })?;
292                Ok(compressed)
293            }
294        }
295    };
296    output.into()
297}
298
299#[cfg(feature = "json")]
300#[proc_macro_derive(ZstdJSONObject)]
301pub fn derive_zstd_json(input: TokenStream) -> TokenStream {
302    let input = parse_macro_input!(input as DeriveInput);
303    let ident = input.ident;
304    let output = quote! {
305        impl mdbx_derive::TableObjectDecode for #ident {
306            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
307                let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
308                    mdbx_derive::Error::Zstd(e)
309                })?;
310                Ok(mdbx_derive::json::from_slice(&mut decompressed)?)
311            }
312        }
313
314        impl mdbx_derive::mdbx::TableObject for #ident {
315            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
316                let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
317                    mdbx_derive::mdbx::Error::Corrupted
318                })?;
319                mdbx_derive::json::from_slice(&mut decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
320            }
321        }
322
323        impl mdbx_derive::TableObjectEncode for #ident {
324            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
325                let bs = mdbx_derive::json::to_vec(&self)?;
326                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
327                    mdbx_derive::Error::Zstd(e)
328                })?;
329                Ok(compressed)
330            }
331        }
332    };
333    output.into()
334}
335
336// Helper struct to parse the macro's input
337struct MacroInput {
338    struct_name: Ident,
339    error_type: Type,
340    tables: Punctuated<Type, Token![,]>,
341}
342
343impl Parse for MacroInput {
344    fn parse(input: ParseStream) -> syn::Result<Self> {
345        let struct_name: Ident = input.parse()?;
346        input.parse::<Token![,]>()?;
347        let error_type: Type = input.parse()?;
348        input.parse::<Token![,]>()?;
349        let tables = input.parse_terminated(Type::parse, Token![,])?;
350        Ok(MacroInput {
351            struct_name,
352            error_type,
353            tables,
354        })
355    }
356}
357
358#[proc_macro]
359pub fn generate_dbi_struct(input: TokenStream) -> TokenStream {
360    let MacroInput {
361        struct_name,
362        error_type,
363        tables,
364    } = syn::parse_macro_input!(input as MacroInput);
365
366    let field_names: Vec<_> = tables
367        .iter()
368        .map(|table_type| {
369            let type_path = if let Type::Path(tp) = table_type {
370                tp
371            } else {
372                panic!("Expected a type path")
373            };
374            let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
375            let field_name_str = type_ident_str.to_snake_case();
376            Ident::new(&field_name_str, proc_macro2::Span::call_site())
377        })
378        .collect();
379
380    let field_statemens: Vec<_> = tables
381        .iter()
382        .map(|table_type| {
383            let type_path = if let Type::Path(tp) = table_type {
384                tp
385            } else {
386                panic!("Expected a type path")
387            };
388            let ty = type_path.path.segments.last().unwrap().ident.clone();
389            let field_name_str = ty.to_string().to_snake_case();
390            let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
391
392            quote! {
393                let flags = if <#ty as mdbx_derive::MDBXTable>::DUPSORT {
394                    mdbx_derive::mdbx::DatabaseFlags::DUP_SORT
395                } else {
396                    mdbx_derive::mdbx::DatabaseFlags::default()
397                };
398                let #ident = <#ty as mdbx_derive::MDBXTable>::create_table_tx(&tx, flags).await?;
399
400            }
401        })
402        .collect();
403
404    let ro_field_statemens: Vec<_> = tables
405        .iter()
406        .map(|table_type| {
407            let type_path = if let Type::Path(tp) = table_type {
408                tp
409            } else {
410                panic!("Expected a type path")
411            };
412            let ty = type_path.path.segments.last().unwrap().ident.clone();
413            let field_name_str = ty.to_string().to_snake_case();
414            let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
415
416            quote! {
417                let #ident = <#ty as mdbx_derive::MDBXTable>::open_table_tx(&tx).await?;
418
419            }
420        })
421        .collect();
422
423    let fields = tables
424        .iter()
425        .zip(field_names.iter())
426        .map(|(table_type, field_name)| {
427            let type_path = if let Type::Path(tp) = table_type {
428                tp
429            } else {
430                panic!()
431            };
432            let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
433            let doc_string = format!("DBI handle for the `{}` table.", type_ident_str);
434
435            quote! {
436                #[doc = #doc_string]
437                pub #field_name: u32,
438            }
439        });
440
441    let original_type_names: Vec<_> = tables
442        .iter()
443        .map(|table_type| {
444            let type_path = if let Type::Path(tp) = table_type {
445                tp
446            } else {
447                panic!("Expected a type path")
448            };
449            let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
450            Ident::new(&type_ident_str, proc_macro2::Span::call_site())
451        })
452        .collect(); // The crucial change is here!
453
454    let rw_tables: Vec<_> = tables
455        .iter()
456        .map(|table_type| {
457            let type_path = if let Type::Path(tp) = table_type {
458                tp
459            } else {
460                panic!("Expected a type path")
461            };
462            let ty = type_path.path.segments.last().unwrap().ident.clone();
463            let field_name_str = ty.to_string().to_snake_case();
464            let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
465            let wfname_tx = Ident::new(format!("write_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
466            let rfname_tx = Ident::new(format!("read_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
467            let dfname_tx = Ident::new(format!("del_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
468            let cursor_fname = Ident::new(format!("{}_cursor", &field_name_str).as_str(), proc_macro2::Span::call_site());
469            quote! {
470                pub async fn #wfname_tx
471                (
472                    &self,
473                    tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
474                    key: &<#ty as mdbx_derive::MDBXTable>::Key,
475                    value: &<#ty as mdbx_derive::MDBXTable>::Value,
476                    flags: mdbx_derive::mdbx::WriteFlags
477                ) -> Result<(), mdbx_derive::Error> {
478                    tx.put(
479                        self.#ident,
480                        &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
481                        &<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(value)?,
482                        flags
483                    ).await?;
484                    Ok(())
485                }
486
487                pub async fn #rfname_tx <K: mdbx_derive::mdbx::TransactionKind>
488                (
489                    &self,
490                    tx: &mdbx_derive::mdbx::TransactionAny<K>,
491                    key: &<#ty as mdbx_derive::MDBXTable>::Key
492                ) -> Result<Option< <#ty as mdbx_derive::MDBXTable>::Value >, mdbx_derive::Error> {
493                    let v = tx.get::<Vec<u8>>(
494                        self.#ident,
495                        &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
496                    ).await?;
497                    if let Some(v) = v {
498                        Ok(Some(<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectDecode>::table_decode(&v)?))
499                    } else {
500                        Ok(None)
501                    }
502                }
503
504                pub async fn #dfname_tx
505                (
506                    &self,
507                    tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
508                    key: &<#ty as mdbx_derive::MDBXTable>::Key,
509                    value: Option<&<#ty as mdbx_derive::MDBXTable>::Value>
510                ) -> Result<bool, mdbx_derive::Error> {
511                    let v = value.map(|v| <<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(v))
512                            .transpose()?;
513                    Ok(tx.del(
514                        self.#ident,
515                        &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
516                        v.as_ref().map(|t| t.as_slice())
517                    ).await?)
518                }
519
520                pub async fn #cursor_fname <K: mdbx_derive::mdbx::TransactionKind>
521                (
522                    &self,
523                    tx: &mdbx_derive::mdbx::TransactionAny<K>
524                ) -> Result<mdbx_derive::mdbx::CursorAny<K>, mdbx_derive::Error> {
525                    Ok(tx.cursor_with_dbi(self.#ident).await?)
526                }
527            }
528        })
529        .collect();
530
531    let output = quote! {
532        #[derive(Debug, Clone, Copy)]
533        pub struct #struct_name {
534            #( #fields )*
535        }
536
537        impl #struct_name {
538            pub async fn new(
539                env: &mdbx_derive::mdbx::EnvironmentAny,
540            ) -> Result<Self, #error_type> {
541                let tx = env.begin_rw_txn().await?;
542
543                #(
544                    #field_statemens
545                )*
546
547                tx.commit().await?;
548
549                Ok(Self {
550                    #( #field_names, )*
551                })
552            }
553
554            pub async fn new_ro<K: mdbx_derive::mdbx::TransactionKind>(
555                tx: &mdbx_derive::mdbx::TransactionAny<K>
556            ) -> Result<Self, #error_type> {
557
558                #(
559                    #ro_field_statemens
560                )*
561
562                Ok(Self {
563                    #( #field_names, )*
564                })
565            }
566
567            #(
568                #rw_tables
569            )*
570        }
571
572        impl mdbx_derive::HasMDBXTables for #struct_name {
573            type Error = #error_type;
574            type Tables = mdbx_derive::tuple_list_type!(#( #original_type_names),*);
575        }
576    };
577
578    output.into()
579}