1use heck::ToSnakeCase;
2use itertools::Itertools;
3use proc_macro::TokenStream;
4use quote::{quote, quote_spanned};
5use syn::{
6 Data, DeriveInput, Fields, Ident, Index, Token, Type,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10 spanned::Spanned,
11};
12
13#[proc_macro_derive(KeyObject)]
14pub fn derive(input: TokenStream) -> TokenStream {
15 let input = parse_macro_input!(input as DeriveInput);
16 let decode = decode_impl(&input);
17 let ident = input.ident;
19 let ts = match &input.data {
20 Data::Struct(st) => match &st.fields {
21 Fields::Named(fields) => {
22 let recur = fields.named.iter().map(|t| {
23 let name = &t.ident;
24 quote_spanned! {t.span()=>
25 self.#name.key_encode()?.into_iter()
26 }
27 });
28 quote! {
29 [#(#recur),*].into_iter().flatten().collect()
30 }
31 }
32 Fields::Unnamed(fields) => {
33 let recur = fields.unnamed.iter().enumerate().map(|(idx, t)| {
34 let index = Index::from(idx);
35 quote_spanned! {t.span()=>
36 self.#index.key_encode()?.into_iter()
37 }
38 });
39 quote! {
40 [#(#recur),*].into_iter().flatten().collect()
41 }
42 }
43 _ => quote! {
44 compile_error!("Not supported")
45 },
46 },
47 _ => quote! {
48 compile_error!("Not supported struct")
49 },
50 };
51 let output = quote! {
52 impl mdbx_derive::KeyObjectEncode for #ident {
53 fn key_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
54 Ok(#ts)
55 }
56 }
57
58 impl mdbx_derive::mdbx::TableObject for #ident {
59 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
60 <Self as mdbx_derive::KeyObjectDecode>::key_decode(data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
61 }
62 }
63
64 #decode
65 };
66 output.into()
67}
68
69fn decode_impl(input: &DeriveInput) -> proc_macro2::TokenStream {
70 let ident = &input.ident;
71 let body = match &input.data {
72 Data::Struct(st) => {
73 let mut named = false;
74 let fs = match &st.fields {
75 Fields::Named(fields) => {
76 named = true;
77 Some(fields.named.iter())
78 }
79 Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
80 _ => None,
81 };
82
83 if let Some(fs) = fs {
84 let ranges = fs
85 .clone()
86 .scan(quote! {0}, |acc, x| {
87 let ty = &x.ty;
88 let ret = Some(quote_spanned! {x.span()=>
89 (#acc)..(#acc + <#ty>::KEYSIZE)
90 });
91
92 *acc = quote! { #acc + <#ty>::KEYSIZE };
93 ret
94 })
95 .collect_vec();
96 let recur = fs.clone().map(|t| {
97 let ty = &t.ty;
98 quote_spanned! {t.span()=>
99 <#ty>::KEYSIZE
100 }
101 });
102 let tyts = quote! {
103 0 #(+ #recur)*
104 };
105
106 if named {
107 let names = fs.clone().map(|t| {
108 let name = &t.ident;
109 quote_spanned! {t.span()=>
110 #name
111 }
112 });
113 let recur = fs.clone().zip(ranges).map(|(t, idx)| {
114 let name = &t.ident;
115 let ty = &t.ty;
116 quote_spanned! {t.span()=>
117 let #name = <#ty>::key_decode(bs[#idx].try_into().unwrap())?;
118 }
119 });
120 quote! {
121 let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
122 #(#recur)*
123 Ok(Self {
124 #(#names),*
125 })
126 }
127 } else {
128 let recur = fs.zip(ranges).map(|(t, idx)| {
129 let ty = &t.ty;
130 quote_spanned! {t.span()=>
131 <#ty>::key_decode(bs[#idx].try_into().unwrap())?
132 }
133 });
134
135 quote! {
136 let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
137 Ok(Self(#(#recur),*))
138 }
139 }
140 } else {
141 quote! {
142 compile_error("Not supported field")
143 }
144 }
145 }
146 _ => quote! {
147 compile_error!("Not supported struct")
148 },
149 };
150
151 let key_sz = match &input.data {
152 Data::Struct(st) => {
153 let ks = st.fields.iter().map(|f| {
154 let ty = &f.ty;
155 quote_spanned! {f.span()=>
156 <#ty>::KEYSIZE
157 }
158 });
159
160 quote! {
161 0 #(+ #ks)*
162 }
163 }
164 _ => quote! { 0 },
165 };
166
167 let output = quote! {
168 impl mdbx_derive::KeyObjectDecode for #ident {
169 const KEYSIZE: usize = #key_sz ;
170 fn key_decode(val: &[u8]) -> Result<Self, mdbx_derive::Error> {
171 #body
172 }
173 }
174 };
175 output
176}
177
178#[proc_macro_derive(BcsObject)]
179pub fn derive_bcs_object(input: TokenStream) -> TokenStream {
180 let input = parse_macro_input!(input as DeriveInput);
181 let ident = input.ident;
182 let output = quote! {
183 impl mdbx_derive::TableObjectDecode for #ident {
184 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
185 Ok(mdbx_derive::bcs::from_bytes(&data_val)?)
186 }
187 }
188
189 impl mdbx_derive::mdbx::TableObject for #ident {
190 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
191 mdbx_derive::bcs::from_bytes(&data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
192 }
193 }
194
195 impl mdbx_derive::TableObjectEncode for #ident {
196 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
197 Ok(mdbx_derive::bcs::to_bytes(&self)?)
198 }
199 }
200 };
201 output.into()
202}
203
204#[proc_macro_derive(ZstdBcsObject)]
205pub fn derive_zstd_bcs_object(input: TokenStream) -> TokenStream {
206 let input = parse_macro_input!(input as DeriveInput);
207 let ident = input.ident;
208 let output = quote! {
209 impl mdbx_derive::TableObjectDecode for #ident {
210 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
211 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
212 mdbx_derive::Error::Zstd(e)
213 })?;
214 Ok(mdbx_derive::bcs::from_bytes(&decompressed)?)
215 }
216 }
217
218 impl mdbx_derive::mdbx::TableObject for #ident {
219 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
220 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
221 mdbx_derive::mdbx::Error::Corrupted
222 })?;
223 mdbx_derive::bcs::from_bytes(&decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
224 }
225 }
226
227 impl mdbx_derive::TableObjectEncode for #ident {
228 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
229 let bs = mdbx_derive::bcs::to_bytes(&self)?;
230 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
231 mdbx_derive::Error::Zstd(e)
232 })?;
233 Ok(compressed)
234 }
235 }
236 };
237 output.into()
238}
239
240#[proc_macro_derive(KeyAsTableObject)]
241pub fn derive_key_table_object(input: TokenStream) -> TokenStream {
242 let input = parse_macro_input!(input as DeriveInput);
243 let ident = input.ident;
244 let output = quote! {
245 impl mdbx_derive::TableObjectDecode for #ident {
246 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
247 <#ident as mdbx_derive::KeyObjectDecode>::key_decode(data_val)
248 }
249 }
250
251 impl mdbx_derive::TableObjectEncode for #ident {
252 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
253 <#ident as mdbx_derive::KeyObjectEncode>::key_encode(self)
254 }
255 }
256 };
257 output.into()
258}
259
260#[proc_macro_derive(ZstdBincodeObject)]
261pub fn derive_zstd_bindcode(input: TokenStream) -> TokenStream {
262 let input = parse_macro_input!(input as DeriveInput);
263 let ident = input.ident;
264 let output = quote! {
265 impl mdbx_derive::TableObjectDecode for #ident {
266 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
267 let config = mdbx_derive::bincode::config::standard();
268 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
269 mdbx_derive::Error::Zstd(e)
270 })?;
271 Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config)?.0)
272 }
273 }
274
275 impl mdbx_derive::mdbx::TableObject for #ident {
276 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
277 let config = mdbx_derive::bincode::config::standard();
278 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
279 mdbx_derive::mdbx::Error::Corrupted
280 })?;
281 Ok(mdbx_derive::bincode::decode_from_slice(&decompressed, config).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)?.0)
282 }
283 }
284
285 impl mdbx_derive::TableObjectEncode for #ident {
286 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
287 let config = mdbx_derive::bincode::config::standard();
288 let bs = mdbx_derive::bincode::encode_to_vec(&self, config)?;
289 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
290 mdbx_derive::Error::Zstd(e)
291 })?;
292 Ok(compressed)
293 }
294 }
295 };
296 output.into()
297}
298
299#[cfg(feature = "json")]
300#[proc_macro_derive(ZstdJSONObject)]
301pub fn derive_zstd_json(input: TokenStream) -> TokenStream {
302 let input = parse_macro_input!(input as DeriveInput);
303 let ident = input.ident;
304 let output = quote! {
305 impl mdbx_derive::TableObjectDecode for #ident {
306 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
307 let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
308 mdbx_derive::Error::Zstd(e)
309 })?;
310 Ok(mdbx_derive::json::from_slice(&mut decompressed)?)
311 }
312 }
313
314 impl mdbx_derive::mdbx::TableObject for #ident {
315 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
316 let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
317 mdbx_derive::mdbx::Error::Corrupted
318 })?;
319 mdbx_derive::json::from_slice(&mut decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
320 }
321 }
322
323 impl mdbx_derive::TableObjectEncode for #ident {
324 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
325 let bs = mdbx_derive::json::to_vec(&self)?;
326 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
327 mdbx_derive::Error::Zstd(e)
328 })?;
329 Ok(compressed)
330 }
331 }
332 };
333 output.into()
334}
335
336struct MacroInput {
338 struct_name: Ident,
339 error_type: Type,
340 tables: Punctuated<Type, Token![,]>,
341}
342
343impl Parse for MacroInput {
344 fn parse(input: ParseStream) -> syn::Result<Self> {
345 let struct_name: Ident = input.parse()?;
346 input.parse::<Token![,]>()?;
347 let error_type: Type = input.parse()?;
348 input.parse::<Token![,]>()?;
349 let tables = input.parse_terminated(Type::parse, Token![,])?;
350 Ok(MacroInput {
351 struct_name,
352 error_type,
353 tables,
354 })
355 }
356}
357
358#[proc_macro]
359pub fn generate_dbi_struct(input: TokenStream) -> TokenStream {
360 let MacroInput {
361 struct_name,
362 error_type,
363 tables,
364 } = syn::parse_macro_input!(input as MacroInput);
365
366 let field_names: Vec<_> = tables
367 .iter()
368 .map(|table_type| {
369 let type_path = if let Type::Path(tp) = table_type {
370 tp
371 } else {
372 panic!("Expected a type path")
373 };
374 let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
375 let field_name_str = type_ident_str.to_snake_case();
376 Ident::new(&field_name_str, proc_macro2::Span::call_site())
377 })
378 .collect();
379
380 let field_statemens: Vec<_> = tables
381 .iter()
382 .map(|table_type| {
383 let type_path = if let Type::Path(tp) = table_type {
384 tp
385 } else {
386 panic!("Expected a type path")
387 };
388 let ty = type_path.path.segments.last().unwrap().ident.clone();
389 let field_name_str = ty.to_string().to_snake_case();
390 let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
391
392 quote! {
393 let flags = if <#ty as mdbx_derive::MDBXTable>::DUPSORT {
394 mdbx_derive::mdbx::DatabaseFlags::DUP_SORT
395 } else {
396 mdbx_derive::mdbx::DatabaseFlags::default()
397 };
398 let #ident = <#ty as mdbx_derive::MDBXTable>::create_table_tx(&tx, flags).await?;
399
400 }
401 })
402 .collect();
403
404 let ro_field_statemens: Vec<_> = tables
405 .iter()
406 .map(|table_type| {
407 let type_path = if let Type::Path(tp) = table_type {
408 tp
409 } else {
410 panic!("Expected a type path")
411 };
412 let ty = type_path.path.segments.last().unwrap().ident.clone();
413 let field_name_str = ty.to_string().to_snake_case();
414 let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
415
416 quote! {
417 let #ident = <#ty as mdbx_derive::MDBXTable>::open_table_tx(&tx).await?;
418
419 }
420 })
421 .collect();
422
423 let fields = tables
424 .iter()
425 .zip(field_names.iter())
426 .map(|(table_type, field_name)| {
427 let type_path = if let Type::Path(tp) = table_type {
428 tp
429 } else {
430 panic!()
431 };
432 let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
433 let doc_string = format!("DBI handle for the `{}` table.", type_ident_str);
434
435 quote! {
436 #[doc = #doc_string]
437 pub #field_name: u32,
438 }
439 });
440
441 let original_type_names: Vec<_> = tables
442 .iter()
443 .map(|table_type| {
444 let type_path = if let Type::Path(tp) = table_type {
445 tp
446 } else {
447 panic!("Expected a type path")
448 };
449 let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
450 Ident::new(&type_ident_str, proc_macro2::Span::call_site())
451 })
452 .collect(); let rw_tables: Vec<_> = tables
455 .iter()
456 .map(|table_type| {
457 let type_path = if let Type::Path(tp) = table_type {
458 tp
459 } else {
460 panic!("Expected a type path")
461 };
462 let ty = type_path.path.segments.last().unwrap().ident.clone();
463 let field_name_str = ty.to_string().to_snake_case();
464 let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
465 let wfname_tx = Ident::new(format!("write_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
466 let rfname_tx = Ident::new(format!("read_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
467 let dfname_tx = Ident::new(format!("del_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
468 let cursor_fname = Ident::new(format!("{}_cursor", &field_name_str).as_str(), proc_macro2::Span::call_site());
469 quote! {
470 pub async fn #wfname_tx
471 (
472 &self,
473 tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
474 key: &<#ty as mdbx_derive::MDBXTable>::Key,
475 value: &<#ty as mdbx_derive::MDBXTable>::Value,
476 flags: mdbx_derive::mdbx::WriteFlags
477 ) -> Result<(), mdbx_derive::Error> {
478 tx.put(
479 self.#ident,
480 &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
481 &<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(value)?,
482 flags
483 ).await?;
484 Ok(())
485 }
486
487 pub async fn #rfname_tx <K: mdbx_derive::mdbx::TransactionKind>
488 (
489 &self,
490 tx: &mdbx_derive::mdbx::TransactionAny<K>,
491 key: &<#ty as mdbx_derive::MDBXTable>::Key
492 ) -> Result<Option< <#ty as mdbx_derive::MDBXTable>::Value >, mdbx_derive::Error> {
493 let v = tx.get::<Vec<u8>>(
494 self.#ident,
495 &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
496 ).await?;
497 if let Some(v) = v {
498 Ok(Some(<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectDecode>::table_decode(&v)?))
499 } else {
500 Ok(None)
501 }
502 }
503
504 pub async fn #dfname_tx
505 (
506 &self,
507 tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
508 key: &<#ty as mdbx_derive::MDBXTable>::Key,
509 value: Option<&<#ty as mdbx_derive::MDBXTable>::Value>
510 ) -> Result<bool, mdbx_derive::Error> {
511 let v = value.map(|v| <<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(v))
512 .transpose()?;
513 Ok(tx.del(
514 self.#ident,
515 &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
516 v.as_ref().map(|t| t.as_slice())
517 ).await?)
518 }
519
520 pub async fn #cursor_fname <K: mdbx_derive::mdbx::TransactionKind>
521 (
522 &self,
523 tx: &mdbx_derive::mdbx::TransactionAny<K>
524 ) -> Result<mdbx_derive::mdbx::CursorAny<K>, mdbx_derive::Error> {
525 Ok(tx.cursor_with_dbi(self.#ident).await?)
526 }
527 }
528 })
529 .collect();
530
531 let output = quote! {
532 #[derive(Debug, Clone, Copy)]
533 pub struct #struct_name {
534 #( #fields )*
535 }
536
537 impl #struct_name {
538 pub async fn new(
539 env: &mdbx_derive::mdbx::EnvironmentAny,
540 ) -> Result<Self, #error_type> {
541 let tx = env.begin_rw_txn().await?;
542
543 #(
544 #field_statemens
545 )*
546
547 tx.commit().await?;
548
549 Ok(Self {
550 #( #field_names, )*
551 })
552 }
553
554 pub async fn new_ro<K: mdbx_derive::mdbx::TransactionKind>(
555 tx: &mdbx_derive::mdbx::TransactionAny<K>
556 ) -> Result<Self, #error_type> {
557
558 #(
559 #ro_field_statemens
560 )*
561
562 Ok(Self {
563 #( #field_names, )*
564 })
565 }
566
567 #(
568 #rw_tables
569 )*
570 }
571
572 impl mdbx_derive::HasMDBXTables for #struct_name {
573 type Error = #error_type;
574 type Tables = mdbx_derive::tuple_list_type!(#( #original_type_names),*);
575 }
576 };
577
578 output.into()
579}