1#[cfg(feature = "mdbx")]
2use heck::ToSnakeCase;
3use itertools::Itertools;
4use proc_macro::TokenStream;
5use quote::{quote, quote_spanned};
6use syn::{
7 Data, DeriveInput, Fields, Index,
8 parse_macro_input,
9 spanned::Spanned,
10};
11#[cfg(feature = "mdbx")]
12use syn::{
13 Ident, Token, Type,
14 parse::{Parse, ParseStream},
15 punctuated::Punctuated,
16};
17
18#[proc_macro_derive(KeyObject)]
19pub fn derive(input: TokenStream) -> TokenStream {
20 let input = parse_macro_input!(input as DeriveInput);
21 let decode = decode_impl(&input);
22 let ident = input.ident;
24 let ts = match &input.data {
25 Data::Struct(st) => match &st.fields {
26 Fields::Named(fields) => {
27 let recur = fields.named.iter().map(|t| {
28 let name = &t.ident;
29 quote_spanned! {t.span()=>
30 self.#name.key_encode()?.into_iter()
31 }
32 });
33 quote! {
34 [#(#recur),*].into_iter().flatten().collect()
35 }
36 }
37 Fields::Unnamed(fields) => {
38 let recur = fields.unnamed.iter().enumerate().map(|(idx, t)| {
39 let index = Index::from(idx);
40 quote_spanned! {t.span()=>
41 self.#index.key_encode()?.into_iter()
42 }
43 });
44 quote! {
45 [#(#recur),*].into_iter().flatten().collect()
46 }
47 }
48 _ => quote! {
49 compile_error!("Not supported")
50 },
51 },
52 _ => quote! {
53 compile_error!("Not supported struct")
54 },
55 };
56 #[cfg(feature = "mdbx")]
57 let table_object_impl = quote! {
58 impl mdbx_derive::mdbx::TableObject for #ident {
59 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
60 <Self as mdbx_derive::KeyObjectDecode>::key_decode(data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
61 }
62 }
63 };
64 #[cfg(not(feature = "mdbx"))]
65 let table_object_impl = quote! {};
66
67 let output = quote! {
68 impl mdbx_derive::KeyObjectEncode for #ident {
69 fn key_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
70 Ok(#ts)
71 }
72 }
73
74 #table_object_impl
75
76 #decode
77 };
78 output.into()
79}
80
81fn decode_impl(input: &DeriveInput) -> proc_macro2::TokenStream {
82 let ident = &input.ident;
83 let body = match &input.data {
84 Data::Struct(st) => {
85 let mut named = false;
86 let fs = match &st.fields {
87 Fields::Named(fields) => {
88 named = true;
89 Some(fields.named.iter())
90 }
91 Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
92 _ => None,
93 };
94
95 if let Some(fs) = fs {
96 let ranges = fs
97 .clone()
98 .scan(quote! {0}, |acc, x| {
99 let ty = &x.ty;
100 let ret = Some(quote_spanned! {x.span()=>
101 (#acc)..(#acc + <#ty>::KEYSIZE)
102 });
103
104 *acc = quote! { #acc + <#ty>::KEYSIZE };
105 ret
106 })
107 .collect_vec();
108 let recur = fs.clone().map(|t| {
109 let ty = &t.ty;
110 quote_spanned! {t.span()=>
111 <#ty>::KEYSIZE
112 }
113 });
114 let tyts = quote! {
115 0 #(+ #recur)*
116 };
117
118 if named {
119 let names = fs.clone().map(|t| {
120 let name = &t.ident;
121 quote_spanned! {t.span()=>
122 #name
123 }
124 });
125 let recur = fs.clone().zip(ranges).map(|(t, idx)| {
126 let name = &t.ident;
127 let ty = &t.ty;
128 quote_spanned! {t.span()=>
129 let #name = <#ty>::key_decode(bs[#idx].try_into().unwrap())?;
130 }
131 });
132 quote! {
133 let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
134 #(#recur)*
135 Ok(Self {
136 #(#names),*
137 })
138 }
139 } else {
140 let recur = fs.zip(ranges).map(|(t, idx)| {
141 let ty = &t.ty;
142 quote_spanned! {t.span()=>
143 <#ty>::key_decode(bs[#idx].try_into().unwrap())?
144 }
145 });
146
147 quote! {
148 let bs: [u8; #tyts] = val.try_into().map_err(|_| mdbx_derive::Error::IncorrectSchema(val.to_vec()))?;
149 Ok(Self(#(#recur),*))
150 }
151 }
152 } else {
153 quote! {
154 compile_error("Not supported field")
155 }
156 }
157 }
158 _ => quote! {
159 compile_error!("Not supported struct")
160 },
161 };
162
163 let key_sz = match &input.data {
164 Data::Struct(st) => {
165 let ks = st.fields.iter().map(|f| {
166 let ty = &f.ty;
167 quote_spanned! {f.span()=>
168 <#ty>::KEYSIZE
169 }
170 });
171
172 quote! {
173 0 #(+ #ks)*
174 }
175 }
176 _ => quote! { 0 },
177 };
178
179 let output = quote! {
180 impl mdbx_derive::KeyObjectDecode for #ident {
181 const KEYSIZE: usize = #key_sz ;
182 fn key_decode(val: &[u8]) -> Result<Self, mdbx_derive::Error> {
183 #body
184 }
185 }
186 };
187 output
188}
189
190#[proc_macro_derive(BcsObject)]
191pub fn derive_bcs_object(input: TokenStream) -> TokenStream {
192 let input = parse_macro_input!(input as DeriveInput);
193 let ident = input.ident;
194
195 #[cfg(feature = "mdbx")]
196 let table_object_impl = quote! {
197 impl mdbx_derive::mdbx::TableObject for #ident {
198 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
199 mdbx_derive::bcs::from_bytes(&data_val).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
200 }
201 }
202 };
203 #[cfg(not(feature = "mdbx"))]
204 let table_object_impl = quote! {};
205
206 let output = quote! {
207 impl mdbx_derive::TableObjectDecode for #ident {
208 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
209 Ok(mdbx_derive::bcs::from_bytes(&data_val)?)
210 }
211 }
212
213 #table_object_impl
214
215 impl mdbx_derive::TableObjectEncode for #ident {
216 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
217 Ok(mdbx_derive::bcs::to_bytes(&self)?)
218 }
219 }
220 };
221 output.into()
222}
223
224#[proc_macro_derive(ZstdBcsObject)]
225pub fn derive_zstd_bcs_object(input: TokenStream) -> TokenStream {
226 let input = parse_macro_input!(input as DeriveInput);
227 let ident = input.ident;
228
229 #[cfg(feature = "mdbx")]
230 let table_object_impl = quote! {
231 impl mdbx_derive::mdbx::TableObject for #ident {
232 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
233 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
234 mdbx_derive::mdbx::Error::Corrupted
235 })?;
236 mdbx_derive::bcs::from_bytes(&decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
237 }
238 }
239 };
240 #[cfg(not(feature = "mdbx"))]
241 let table_object_impl = quote! {};
242
243 let output = quote! {
244 impl mdbx_derive::TableObjectDecode for #ident {
245 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
246 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
247 mdbx_derive::Error::Zstd(e)
248 })?;
249 Ok(mdbx_derive::bcs::from_bytes(&decompressed)?)
250 }
251 }
252
253 #table_object_impl
254
255 impl mdbx_derive::TableObjectEncode for #ident {
256 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
257 let bs = mdbx_derive::bcs::to_bytes(&self)?;
258 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
259 mdbx_derive::Error::Zstd(e)
260 })?;
261 Ok(compressed)
262 }
263 }
264 };
265 output.into()
266}
267
268#[proc_macro_derive(KeyAsTableObject)]
269pub fn derive_key_table_object(input: TokenStream) -> TokenStream {
270 let input = parse_macro_input!(input as DeriveInput);
271 let ident = input.ident;
272 let output = quote! {
273 impl mdbx_derive::TableObjectDecode for #ident {
274 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
275 <#ident as mdbx_derive::KeyObjectDecode>::key_decode(data_val)
276 }
277 }
278
279 impl mdbx_derive::TableObjectEncode for #ident {
280 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
281 <#ident as mdbx_derive::KeyObjectEncode>::key_encode(self)
282 }
283 }
284 };
285 output.into()
286}
287
288#[proc_macro_derive(ZstdPostcardObject)]
289pub fn derive_zstd_postcard(input: TokenStream) -> TokenStream {
290 let input = parse_macro_input!(input as DeriveInput);
291 let ident = input.ident;
292
293 #[cfg(feature = "mdbx")]
294 let table_object_impl = quote! {
295 impl mdbx_derive::mdbx::TableObject for #ident {
296 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
297 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
298 mdbx_derive::mdbx::Error::Corrupted
299 })?;
300 mdbx_derive::postcard::from_bytes(&decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
301 }
302 }
303 };
304 #[cfg(not(feature = "mdbx"))]
305 let table_object_impl = quote! {};
306
307 let output = quote! {
308 impl mdbx_derive::TableObjectDecode for #ident {
309 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
310 let decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
311 mdbx_derive::Error::Zstd(e)
312 })?;
313 Ok(mdbx_derive::postcard::from_bytes(&decompressed)?)
314 }
315 }
316
317 #table_object_impl
318
319 impl mdbx_derive::TableObjectEncode for #ident {
320 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
321 let bs = mdbx_derive::postcard::to_allocvec(&self)?;
322 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
323 mdbx_derive::Error::Zstd(e)
324 })?;
325 Ok(compressed)
326 }
327 }
328 };
329 output.into()
330}
331
332#[cfg(feature = "json")]
333#[proc_macro_derive(ZstdJSONObject)]
334pub fn derive_zstd_json(input: TokenStream) -> TokenStream {
335 let input = parse_macro_input!(input as DeriveInput);
336 let ident = input.ident;
337
338 #[cfg(feature = "mdbx")]
339 let table_object_impl = quote! {
340 impl mdbx_derive::mdbx::TableObject for #ident {
341 fn decode(data_val: &[u8]) -> Result<Self, mdbx_derive::mdbx::Error> {
342 let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|_| {
343 mdbx_derive::mdbx::Error::Corrupted
344 })?;
345 mdbx_derive::json::from_slice(&mut decompressed).map_err(|_| mdbx_derive::mdbx::Error::Corrupted)
346 }
347 }
348 };
349 #[cfg(not(feature = "mdbx"))]
350 let table_object_impl = quote! {};
351
352 let output = quote! {
353 impl mdbx_derive::TableObjectDecode for #ident {
354 fn table_decode(data_val: &[u8]) -> Result<Self, mdbx_derive::Error> {
355 let mut decompressed = mdbx_derive::zstd::decode_all(data_val).map_err(|e| {
356 mdbx_derive::Error::Zstd(e)
357 })?;
358 Ok(mdbx_derive::json::from_slice(&mut decompressed)?)
359 }
360 }
361
362 #table_object_impl
363
364 impl mdbx_derive::TableObjectEncode for #ident {
365 fn table_encode(&self) -> Result<Vec<u8>, mdbx_derive::Error> {
366 let bs = mdbx_derive::json::to_vec(&self)?;
367 let compressed = mdbx_derive::zstd::encode_all(std::io::Cursor::new(bs), 1).map_err(|e| {
368 mdbx_derive::Error::Zstd(e)
369 })?;
370 Ok(compressed)
371 }
372 }
373 };
374 output.into()
375}
376
377#[cfg(feature = "mdbx")]
378struct MacroInput {
379 struct_name: Ident,
380 error_type: Type,
381 tables: Punctuated<Type, Token![,]>,
382}
383
384#[cfg(feature = "mdbx")]
385impl Parse for MacroInput {
386 fn parse(input: ParseStream) -> syn::Result<Self> {
387 let struct_name: Ident = input.parse()?;
388 input.parse::<Token![,]>()?;
389 let error_type: Type = input.parse()?;
390 input.parse::<Token![,]>()?;
391 let tables = input.parse_terminated(Type::parse, Token![,])?;
392 Ok(MacroInput {
393 struct_name,
394 error_type,
395 tables,
396 })
397 }
398}
399
400#[cfg(feature = "mdbx")]
401#[proc_macro]
402pub fn generate_dbi_struct(input: TokenStream) -> TokenStream {
403 let MacroInput {
404 struct_name,
405 error_type,
406 tables,
407 } = syn::parse_macro_input!(input as MacroInput);
408
409 let field_names: Vec<_> = tables
410 .iter()
411 .map(|table_type| {
412 let type_path = if let Type::Path(tp) = table_type {
413 tp
414 } else {
415 panic!("Expected a type path")
416 };
417 let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
418 let field_name_str = type_ident_str.to_snake_case();
419 Ident::new(&field_name_str, proc_macro2::Span::call_site())
420 })
421 .collect();
422
423 let field_statemens: Vec<_> = tables
424 .iter()
425 .map(|table_type| {
426 let type_path = if let Type::Path(tp) = table_type {
427 tp
428 } else {
429 panic!("Expected a type path")
430 };
431 let ty = type_path.path.segments.last().unwrap().ident.clone();
432 let field_name_str = ty.to_string().to_snake_case();
433 let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
434
435 quote! {
436 let flags = if <#ty as mdbx_derive::MDBXTable>::DUPSORT {
437 mdbx_derive::mdbx::DatabaseFlags::DUP_SORT
438 } else {
439 mdbx_derive::mdbx::DatabaseFlags::default()
440 };
441 let #ident = <#ty as mdbx_derive::MDBXTable>::create_table_tx(&tx, flags).await?;
442
443 }
444 })
445 .collect();
446
447 let ro_field_statemens: Vec<_> = tables
448 .iter()
449 .map(|table_type| {
450 let type_path = if let Type::Path(tp) = table_type {
451 tp
452 } else {
453 panic!("Expected a type path")
454 };
455 let ty = type_path.path.segments.last().unwrap().ident.clone();
456 let field_name_str = ty.to_string().to_snake_case();
457 let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
458
459 quote! {
460 let #ident = <#ty as mdbx_derive::MDBXTable>::open_table_tx(&tx).await?;
461
462 }
463 })
464 .collect();
465
466 let fields = tables
467 .iter()
468 .zip(field_names.iter())
469 .map(|(table_type, field_name)| {
470 let type_path = if let Type::Path(tp) = table_type {
471 tp
472 } else {
473 panic!()
474 };
475 let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
476 let doc_string = format!("DBI handle for the `{}` table.", type_ident_str);
477
478 quote! {
479 #[doc = #doc_string]
480 pub #field_name: u32,
481 }
482 });
483
484 let original_type_names: Vec<_> = tables
485 .iter()
486 .map(|table_type| {
487 let type_path = if let Type::Path(tp) = table_type {
488 tp
489 } else {
490 panic!("Expected a type path")
491 };
492 let type_ident_str = type_path.path.segments.last().unwrap().ident.to_string();
493 Ident::new(&type_ident_str, proc_macro2::Span::call_site())
494 })
495 .collect(); let rw_tables: Vec<_> = tables
498 .iter()
499 .map(|table_type| {
500 let type_path = if let Type::Path(tp) = table_type {
501 tp
502 } else {
503 panic!("Expected a type path")
504 };
505 let ty = type_path.path.segments.last().unwrap().ident.clone();
506 let field_name_str = ty.to_string().to_snake_case();
507 let ident = Ident::new(&field_name_str, proc_macro2::Span::call_site());
508 let wfname_tx = Ident::new(format!("write_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
509 let rfname_tx = Ident::new(format!("read_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
510 let dfname_tx = Ident::new(format!("del_{}_tx", &field_name_str).as_str(), proc_macro2::Span::call_site());
511 let cursor_fname = Ident::new(format!("{}_cursor", &field_name_str).as_str(), proc_macro2::Span::call_site());
512 quote! {
513 pub async fn #wfname_tx
514 (
515 &self,
516 tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
517 key: &<#ty as mdbx_derive::MDBXTable>::Key,
518 value: &<#ty as mdbx_derive::MDBXTable>::Value,
519 flags: mdbx_derive::mdbx::WriteFlags
520 ) -> Result<(), mdbx_derive::Error> {
521 tx.put(
522 self.#ident,
523 &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
524 &<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(value)?,
525 flags
526 ).await?;
527 Ok(())
528 }
529
530 pub async fn #rfname_tx <K: mdbx_derive::mdbx::TransactionKind>
531 (
532 &self,
533 tx: &mdbx_derive::mdbx::TransactionAny<K>,
534 key: &<#ty as mdbx_derive::MDBXTable>::Key
535 ) -> Result<Option< <#ty as mdbx_derive::MDBXTable>::Value >, mdbx_derive::Error> {
536 let v = tx.get::<Vec<u8>>(
537 self.#ident,
538 &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
539 ).await?;
540 if let Some(v) = v {
541 Ok(Some(<<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectDecode>::table_decode(&v)?))
542 } else {
543 Ok(None)
544 }
545 }
546
547 pub async fn #dfname_tx
548 (
549 &self,
550 tx: &mdbx_derive::mdbx::TransactionAny<mdbx_derive::mdbx::RW>,
551 key: &<#ty as mdbx_derive::MDBXTable>::Key,
552 value: Option<&<#ty as mdbx_derive::MDBXTable>::Value>
553 ) -> Result<bool, mdbx_derive::Error> {
554 let v = value.map(|v| <<#ty as mdbx_derive::MDBXTable>::Value as mdbx_derive::TableObjectEncode>::table_encode(v))
555 .transpose()?;
556 Ok(tx.del(
557 self.#ident,
558 &<<#ty as mdbx_derive::MDBXTable>::Key as mdbx_derive::KeyObjectEncode>::key_encode(key)?,
559 v.as_ref().map(|t| t.as_slice())
560 ).await?)
561 }
562
563 pub async fn #cursor_fname <K: mdbx_derive::mdbx::TransactionKind>
564 (
565 &self,
566 tx: &mdbx_derive::mdbx::TransactionAny<K>
567 ) -> Result<mdbx_derive::mdbx::CursorAny<K>, mdbx_derive::Error> {
568 Ok(tx.cursor_with_dbi(self.#ident).await?)
569 }
570 }
571 })
572 .collect();
573
574 let output = quote! {
575 #[derive(Debug, Clone, Copy)]
576 pub struct #struct_name {
577 #( #fields )*
578 }
579
580 impl #struct_name {
581 pub async fn new(
582 env: &mdbx_derive::mdbx::EnvironmentAny,
583 ) -> Result<Self, #error_type> {
584 let tx = env.begin_rw_txn().await?;
585
586 #(
587 #field_statemens
588 )*
589
590 tx.commit().await?;
591
592 Ok(Self {
593 #( #field_names, )*
594 })
595 }
596
597 pub async fn new_ro<K: mdbx_derive::mdbx::TransactionKind>(
598 tx: &mdbx_derive::mdbx::TransactionAny<K>
599 ) -> Result<Self, #error_type> {
600
601 #(
602 #ro_field_statemens
603 )*
604
605 Ok(Self {
606 #( #field_names, )*
607 })
608 }
609
610 #(
611 #rw_tables
612 )*
613 }
614
615 impl mdbx_derive::HasMDBXTables for #struct_name {
616 type Error = #error_type;
617 type Tables = mdbx_derive::tuple_list_type!(#( #original_type_names),*);
618 }
619 };
620
621 output.into()
622}