mdbx_derive_macros/
lib.rs1use itertools::Itertools;
2use proc_macro::TokenStream;
3use quote::{quote, quote_spanned};
4use syn::{Data, DeriveInput, Fields, Index, parse_macro_input, spanned::Spanned};
5
6#[proc_macro_derive(KeyObject)]
7pub fn derive(input: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9 let decode = decode_impl(&input);
10 let ident = input.ident;
12 let ts = match &input.data {
13 Data::Struct(st) => match &st.fields {
14 Fields::Named(fields) => {
15 let recur = fields.named.iter().map(|t| {
16 let name = &t.ident;
17 quote_spanned! {t.span()=>
18 self.#name.key_encode()?.into_iter()
19 }
20 });
21 quote! {
22 [#(#recur),*].into_iter().flatten().collect()
23 }
24 }
25 Fields::Unnamed(fields) => {
26 let recur = fields.unnamed.iter().enumerate().map(|(idx, t)| {
27 let index = Index::from(idx);
28 quote_spanned! {t.span()=>
29 self.#index.key_encode()?.into_iter()
30 }
31 });
32 quote! {
33 [#(#recur),*].into_iter().flatten().collect()
34 }
35 }
36 _ => quote! {
37 compile_error!("Not supported")
38 },
39 },
40 _ => quote! {
41 compile_error!("Not supported struct")
42 },
43 };
44 let output = quote! {
45 impl mdbx_derive::KeyObjectEncode for #ident {
46 fn key_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
47 Ok(#ts)
48 }
49 }
50
51 #decode
52 };
53 output.into()
54}
55
56fn decode_impl(input: &DeriveInput) -> proc_macro2::TokenStream {
57 let ident = &input.ident;
58 let body = match &input.data {
59 Data::Struct(st) => {
60 let mut named = false;
61 let fs = match &st.fields {
62 Fields::Named(fields) => {
63 named = true;
64 Some(fields.named.iter())
65 }
66 Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
67 _ => None,
68 };
69
70 if let Some(fs) = fs {
71 let ranges = fs
72 .clone()
73 .scan(quote! {0}, |acc, x| {
74 let ty = &x.ty;
75 let ret = Some(quote_spanned! {x.span()=>
76 (#acc)..(#acc + #ty::KEYSIZE)
77 });
78
79 *acc = quote! { #acc + #ty::KEYSIZE };
80 ret
81 })
82 .collect_vec();
83 let recur = fs.clone().map(|t| {
84 let ty = &t.ty;
85 quote_spanned! {t.span()=>
86 <#ty>::KEYSIZE
87 }
88 });
89 let tyts = quote! {
90 0 #(+ #recur)*
91 };
92
93 if named {
94 let names = fs.clone().map(|t| {
95 let name = &t.ident;
96 quote_spanned! {t.span()=>
97 #name
98 }
99 });
100 let recur = fs.clone().zip(ranges).map(|(t, idx)| {
101 let name = &t.ident;
102 let ty = &t.ty;
103 quote_spanned! {t.span()=>
104 let #name = #ty::key_decode(bs[#idx].try_into().unwrap())?;
105 }
106 });
107 quote! {
108 let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::Corrupted)?;
109 #(#recur)*
110 Ok(Self {
111 #(#names),*
112 })
113 }
114 } else {
115 let recur = fs.zip(ranges).map(|(t, idx)| {
116 let ty = &t.ty;
117 quote_spanned! {t.span()=>
118 #ty::key_decode(bs[#idx].try_into().unwrap())?
119 }
120 });
121
122 quote! {
123 let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::Corrupted)?;
124 Ok(Self(#(#recur),*))
125 }
126 }
127 } else {
128 quote! {
129 compile_error("Not supported field")
130 }
131 }
132 }
133 _ => quote! {
134 compile_error!("Not supported struct")
135 },
136 };
137
138 let key_sz = match &input.data {
139 Data::Struct(st) => {
140 let ks = st.fields.iter().map(|f| {
141 let ty = &f.ty;
142 quote_spanned! {f.span()=>
143 <#ty>::KEYSIZE
144 }
145 });
146
147 quote! {
148 0 #(+ #ks)*
149 }
150 }
151 _ => quote! { 0 },
152 };
153
154 let output = quote! {
155 impl mdbx_derive::KeyObjectDecode for #ident {
156 const KEYSIZE: usize = #key_sz ;
157 fn key_decode(val: &[u8]) -> Result<Self, mdbx_derive::Error> {
158 #body
159 }
160 }
161 };
162 output
163}
164
165#[proc_macro_derive(ZstdBincodeObject)]
166pub fn derive_zstd_bindcode(input: TokenStream) -> TokenStream {
167 let input = parse_macro_input!(input as DeriveInput);
168 let ident = input.ident;
169 let output = quote! {
170 impl mdbx_derive::TableObjectDecode for #ident {
171 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
172 let config = mdbx_derive::bincode::config::standard();
173 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
174 mdbx_derive::Error::Zstd(e)
175 })?;
176 Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config)?.0)
177 }
178 }
179
180 impl mdbx_derive::mdbx::TableObject for #ident {
181 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
182 let config = mdbx_derive::bincode::config::standard();
183 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
184 mdbx_derive::mdbx::Error::Corrupted
185 })?;
186 Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)?.0)
187 }
188 }
189
190 impl mdbx_derive::TableObjectEncode for #ident {
191 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
192 let config = mdbx_derive::bincode::config::standard();
193 let bs = mdbx_derive::bincode::encode_to_vec(&self, config)?;
194 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
195 mdbx_derive::Error::Zstd(e)
196 })?;
197 Ok(compressed)
198 }
199 }
200 };
201 output.into()
202}
203
204#[cfg(feature = "json")]
205#[proc_macro_derive(ZstdJSONObject)]
206pub fn derive_zstd_json(input: TokenStream) -> TokenStream {
207 let input = parse_macro_input!(input as DeriveInput);
208 let ident = input.ident;
209 let output = quote! {
210 impl mdbx_derive::TableObjectDecode for #ident {
211 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
212 let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
213 mdbx_derive::Error::Zstd(e)
214 })?;
215 Ok(mdbx_derive::json::from_slice(&mut decompressed)?)
216 }
217 }
218
219 impl mdbx_derive::mdbx::TableObject for #ident {
220 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
221 let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
222 mdbx_derive::mdbx::Error::Corrupted
223 })?;
224 Ok(mdbx_derive::json::from_slice(&mut decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)?)
225 }
226 }
227
228 impl mdbx_derive::TableObjectEncode for #ident {
229 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
230 let bs = mdbx_derive::json::to_vec(&self)?;
231 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
232 mdbx_derive::Error::Zstd(e)
233 })?;
234 Ok(compressed)
235 }
236 }
237 };
238 output.into()
239}