mdbx_derive_macros/
lib.rs

1use itertools::Itertools;
2use proc_macro::TokenStream;
3use quote::{quote, quote_spanned};
4use syn::{Data, DeriveInput, Fields, Index, parse_macro_input, spanned::Spanned};
5
6#[proc_macro_derive(KeyObject)]
7pub fn derive(input: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    let decode = decode_impl(&input);
10    // Encode implementation
11    let ident = input.ident;
12    let ts = match &input.data {
13        Data::Struct(st) => match &st.fields {
14            Fields::Named(fields) => {
15                let recur = fields.named.iter().map(|t| {
16                    let name = &t.ident;
17                    quote_spanned! {t.span()=>
18                        self.#name.key_encode()?.into_iter()
19                    }
20                });
21                quote! {
22                    [#(#recur),*].into_iter().flatten().collect()
23                }
24            }
25            Fields::Unnamed(fields) => {
26                let recur = fields.unnamed.iter().enumerate().map(|(idx, t)| {
27                    let index = Index::from(idx);
28                    quote_spanned! {t.span()=>
29                        self.#index.key_encode()?.into_iter()
30                    }
31                });
32                quote! {
33                    [#(#recur),*].into_iter().flatten().collect()
34                }
35            }
36            _ => quote! {
37                compile_error!("Not supported")
38            },
39        },
40        _ => quote! {
41            compile_error!("Not supported struct")
42        },
43    };
44    let output = quote! {
45        impl mdbx_derive::KeyObjectEncode for #ident {
46            fn key_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
47                Ok(#ts)
48            }
49        }
50
51        #decode
52    };
53    output.into()
54}
55
56fn decode_impl(input: &DeriveInput) -> proc_macro2::TokenStream {
57    let ident = &input.ident;
58    let body = match &input.data {
59        Data::Struct(st) => {
60            let mut named = false;
61            let fs = match &st.fields {
62                Fields::Named(fields) => {
63                    named = true;
64                    Some(fields.named.iter())
65                }
66                Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
67                _ => None,
68            };
69
70            if let Some(fs) = fs {
71                let ranges = fs
72                    .clone()
73                    .scan(quote! {0}, |acc, x| {
74                        let ty = &x.ty;
75                        let ret = Some(quote_spanned! {x.span()=>
76                            (#acc)..(#acc + #ty::KEYSIZE)
77                        });
78
79                        *acc = quote! { #acc + #ty::KEYSIZE };
80                        ret
81                    })
82                    .collect_vec();
83                let recur = fs.clone().map(|t| {
84                    let ty = &t.ty;
85                    quote_spanned! {t.span()=>
86                        <#ty>::KEYSIZE
87                    }
88                });
89                let tyts = quote! {
90                    0 #(+ #recur)*
91                };
92
93                if named {
94                    let names = fs.clone().map(|t| {
95                        let name = &t.ident;
96                        quote_spanned! {t.span()=>
97                            #name
98                        }
99                    });
100                    let recur = fs.clone().zip(ranges).map(|(t, idx)| {
101                        let name = &t.ident;
102                        let ty = &t.ty;
103                        quote_spanned! {t.span()=>
104                            let #name = #ty::key_decode(bs[#idx].try_into().unwrap())?;
105                        }
106                    });
107                    quote! {
108                        let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::Corrupted)?;
109                        #(#recur)*
110                        Ok(Self {
111                            #(#names),*
112                        })
113                    }
114                } else {
115                    let recur = fs.zip(ranges).map(|(t, idx)| {
116                        let ty = &t.ty;
117                        quote_spanned! {t.span()=>
118                            #ty::key_decode(bs[#idx].try_into().unwrap())?
119                        }
120                    });
121
122                    quote! {
123                        let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::Corrupted)?;
124                        Ok(Self(#(#recur),*))
125                    }
126                }
127            } else {
128                quote! {
129                    compile_error("Not supported field")
130                }
131            }
132        }
133        _ => quote! {
134            compile_error!("Not supported struct")
135        },
136    };
137
138    let key_sz = match &input.data {
139        Data::Struct(st) => {
140            let ks = st.fields.iter().map(|f| {
141                let ty = &f.ty;
142                quote_spanned! {f.span()=>
143                    <#ty>::KEYSIZE
144                }
145            });
146
147            quote! {
148                0 #(+ #ks)*
149            }
150        }
151        _ => quote! { 0 },
152    };
153
154    let output = quote! {
155        impl mdbx_derive::KeyObjectDecode for #ident {
156            const KEYSIZE: usize = #key_sz ;
157            fn key_decode(val: &[u8]) -> Result<Self, mdbx_derive::Error> {
158                #body
159            }
160        }
161    };
162    output
163}
164
165#[proc_macro_derive(ZstdBincodeObject)]
166pub fn derive_zstd_bindcode(input: TokenStream) -> TokenStream {
167    let input = parse_macro_input!(input as DeriveInput);
168    let ident = input.ident;
169    let output = quote! {
170        impl mdbx_derive::TableObjectDecode for #ident {
171            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
172                let config = mdbx_derive::bincode::config::standard();
173                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
174                    mdbx_derive::Error::Zstd(e)
175                })?;
176                Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config)?.0)
177            }
178        }
179
180        impl mdbx_derive::mdbx::TableObject for #ident {
181            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
182                let config = mdbx_derive::bincode::config::standard();
183                let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
184                    mdbx_derive::mdbx::Error::Corrupted
185                })?;
186                Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)?.0)
187            }
188        }
189
190        impl mdbx_derive::TableObjectEncode for #ident {
191            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
192                let config = mdbx_derive::bincode::config::standard();
193                let bs = mdbx_derive::bincode::encode_to_vec(&self, config)?;
194                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
195                    mdbx_derive::Error::Zstd(e)
196                })?;
197                Ok(compressed)
198            }
199        }
200    };
201    output.into()
202}
203
204#[cfg(feature = "json")]
205#[proc_macro_derive(ZstdJSONObject)]
206pub fn derive_zstd_json(input: TokenStream) -> TokenStream {
207    let input = parse_macro_input!(input as DeriveInput);
208    let ident = input.ident;
209    let output = quote! {
210        impl mdbx_derive::TableObjectDecode for #ident {
211            fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
212                let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
213                    mdbx_derive::Error::Zstd(e)
214                })?;
215                Ok(mdbx_derive::json::from_slice(&mut decompressed)?)
216            }
217        }
218
219        impl mdbx_derive::mdbx::TableObject for #ident {
220            fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
221                let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
222                    mdbx_derive::mdbx::Error::Corrupted
223                })?;
224                Ok(mdbx_derive::json::from_slice(&mut decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)?)
225            }
226        }
227
228        impl mdbx_derive::TableObjectEncode for #ident {
229            fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
230                let bs = mdbx_derive::json::to_vec(&self)?;
231                let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
232                    mdbx_derive::Error::Zstd(e)
233                })?;
234                Ok(compressed)
235            }
236        }
237    };
238    output.into()
239}