Skip to main content

etchdb_derive/
lib.rs

1//! Derive macros for etchdb.
2//!
3//! Generates `Replayable` and `Transactable` implementations from annotated
4//! structs, eliminating ~60 lines of boilerplate per state type.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::spanned::Spanned;
9use syn::{DeriveInput, Fields, PathSegment, parse_macro_input};
10
11/// Parsed info about one `#[etch(collection = N)]` field.
12struct EtchField {
13    ident: syn::Ident,
14    collection_id: u8,
15    map_kind: MapKind,
16    key_ty: syn::Type,
17    value_ty: syn::Type,
18}
19
20#[derive(Clone, Copy, PartialEq)]
21enum MapKind {
22    BTreeMap,
23    HashMap,
24}
25
26fn parse_etch_fields(input: &DeriveInput) -> syn::Result<Vec<EtchField>> {
27    let data = match &input.data {
28        syn::Data::Struct(s) => s,
29        _ => {
30            return Err(syn::Error::new_spanned(
31                input,
32                "etch derives only work on structs",
33            ));
34        }
35    };
36    let fields = match &data.fields {
37        Fields::Named(f) => &f.named,
38        _ => {
39            return Err(syn::Error::new_spanned(
40                input,
41                "etch derives require named fields",
42            ));
43        }
44    };
45
46    let mut result = Vec::new();
47
48    for field in fields {
49        let mut collection_id: Option<u8> = None;
50
51        for attr in &field.attrs {
52            if !attr.path().is_ident("etch") {
53                continue;
54            }
55            attr.parse_nested_meta(|meta| {
56                if meta.path.is_ident("collection") {
57                    let value = meta.value()?;
58                    let lit: syn::LitInt = value.parse()?;
59                    collection_id = Some(lit.base10_parse()?);
60                    Ok(())
61                } else {
62                    Err(meta.error("expected `collection = N`"))
63                }
64            })?;
65        }
66
67        let Some(id) = collection_id else {
68            continue;
69        };
70
71        let ident = field.ident.clone().unwrap();
72        let (map_kind, key_ty, value_ty) = parse_map_type(&field.ty).ok_or_else(|| {
73            syn::Error::new(field.ty.span(), "expected BTreeMap<K, V> or HashMap<K, V>")
74        })?;
75
76        result.push(EtchField {
77            ident,
78            collection_id: id,
79            map_kind,
80            key_ty,
81            value_ty,
82        });
83    }
84
85    if result.is_empty() {
86        return Err(syn::Error::new_spanned(
87            input,
88            "no fields annotated with #[etch(collection = N)]",
89        ));
90    }
91
92    // Check for duplicate collection IDs.
93    let mut seen = std::collections::HashSet::new();
94    for f in &result {
95        if !seen.insert(f.collection_id) {
96            return Err(syn::Error::new_spanned(
97                &f.ident,
98                format!("duplicate collection id {}", f.collection_id),
99            ));
100        }
101    }
102
103    Ok(result)
104}
105
106/// Extract (MapKind, K, V) from `BTreeMap<K, V>` or `HashMap<K, V>`.
107fn parse_map_type(ty: &syn::Type) -> Option<(MapKind, syn::Type, syn::Type)> {
108    let path = match ty {
109        syn::Type::Path(p) => &p.path,
110        _ => return None,
111    };
112    let seg: &PathSegment = path.segments.last()?;
113    let kind = match seg.ident.to_string().as_str() {
114        "BTreeMap" => MapKind::BTreeMap,
115        "HashMap" => MapKind::HashMap,
116        _ => return None,
117    };
118    let args = match &seg.arguments {
119        syn::PathArguments::AngleBracketed(a) => a,
120        _ => return None,
121    };
122    let mut types = args.args.iter().filter_map(|a| match a {
123        syn::GenericArgument::Type(t) => Some(t.clone()),
124        _ => None,
125    });
126    let key = types.next()?;
127    let val = types.next()?;
128    Some((kind, key, val))
129}
130
131/// Derive `Replayable` for a struct with `#[etch(collection = N)]` fields.
132///
133/// Generates an `apply` method that routes ops to the correct field based
134/// on the collection id, using `apply_op` for BTreeMap fields and
135/// `apply_op_hash` for HashMap fields.
136#[proc_macro_derive(Replayable, attributes(etch))]
137pub fn derive_replayable(input: TokenStream) -> TokenStream {
138    let input = parse_macro_input!(input as DeriveInput);
139    match derive_replayable_inner(&input) {
140        Ok(ts) => ts.into(),
141        Err(e) => e.to_compile_error().into(),
142    }
143}
144
145fn derive_replayable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
146    let fields = parse_etch_fields(input)?;
147    let name = &input.ident;
148    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
149
150    let arms: Vec<_> = fields
151        .iter()
152        .map(|f| {
153            let id = f.collection_id;
154            let field = &f.ident;
155            let key_ty = &f.key_ty;
156            let apply_fn = match f.map_kind {
157                MapKind::BTreeMap => quote! { etchdb::apply_op_with },
158                MapKind::HashMap => quote! { etchdb::apply_op_hash_with },
159            };
160            quote! {
161                #id => {
162                    if let Err(e) = #apply_fn(&mut self.#field, op, |bytes| {
163                        <#key_ty as etchdb::EtchKey>::from_bytes(bytes)
164                    }) {
165                        eprintln!("etchdb: skipped op on collection {}: {}", #id, e);
166                    }
167                }
168            }
169        })
170        .collect();
171
172    Ok(quote! {
173        impl #impl_generics etchdb::Replayable for #name #ty_generics #where_clause {
174            fn apply(&mut self, ops: &[etchdb::Op]) -> etchdb::Result<()> {
175                for op in ops {
176                    match op.collection() {
177                        #(#arms)*
178                        _ => {}
179                    }
180                }
181                Ok(())
182            }
183        }
184    })
185}
186
187/// Derive `Transactable` for a struct with `#[etch(collection = N)]` fields.
188///
189/// Generates:
190/// - A transaction struct (`{Name}Tx`) with `Collection` fields
191/// - An overlay struct (`{Name}Overlay`) with `Overlay` fields
192/// - The full `Transactable` trait implementation
193#[proc_macro_derive(Transactable, attributes(etch))]
194pub fn derive_transactable(input: TokenStream) -> TokenStream {
195    let input = parse_macro_input!(input as DeriveInput);
196    match derive_transactable_inner(&input) {
197        Ok(ts) => ts.into(),
198        Err(e) => e.to_compile_error().into(),
199    }
200}
201
202fn derive_transactable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
203    let fields = parse_etch_fields(input)?;
204    let name = &input.ident;
205    let tx_name = format_ident!("{}Tx", name);
206    let overlay_name = format_ident!("{}Overlay", name);
207    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
208
209    // Overlay struct fields.
210    let overlay_fields: Vec<_> = fields
211        .iter()
212        .map(|f| {
213            let ident = &f.ident;
214            let k = &f.key_ty;
215            let v = &f.value_ty;
216            quote! { pub #ident: etchdb::Overlay<#k, #v> }
217        })
218        .collect();
219
220    // Tx struct fields: one Collection per annotated field.
221    let tx_fields: Vec<_> = fields
222        .iter()
223        .map(|f| {
224            let ident = &f.ident;
225            let k = &f.key_ty;
226            let v = &f.value_ty;
227            let m = map_type_tokens(f);
228            quote! { pub #ident: etchdb::Collection<'a, #k, #v, #m> }
229        })
230        .collect();
231
232    // begin_tx: construct Collection for each field.
233    let begin_fields: Vec<_> = fields
234        .iter()
235        .map(|f| {
236            let ident = &f.ident;
237            let id = f.collection_id;
238            quote! { #ident: etchdb::Collection::new(&self.#ident, #id) }
239        })
240        .collect();
241
242    // finish_tx: destructure each Collection into ops + overlay.
243    let finish_lets: Vec<_> = fields
244        .iter()
245        .map(|f| {
246            let ident = &f.ident;
247            let ops_name = format_ident!("{}_ops", ident);
248            let ov_name = format_ident!("{}_ov", ident);
249            quote! {
250                let (#ops_name, #ov_name) = tx.#ident.into_parts();
251                ops.extend(#ops_name);
252            }
253        })
254        .collect();
255
256    let finish_overlay_fields: Vec<_> = fields
257        .iter()
258        .map(|f| {
259            let ident = &f.ident;
260            let ov_name = format_ident!("{}_ov", ident);
261            quote! { #ident: #ov_name }
262        })
263        .collect();
264
265    // apply_overlay: merge each overlay into committed state.
266    let apply_stmts: Vec<_> = fields
267        .iter()
268        .map(|f| {
269            let ident = &f.ident;
270            let merge_fn = match f.map_kind {
271                MapKind::BTreeMap => quote! { etchdb::apply_overlay_btree },
272                MapKind::HashMap => quote! { etchdb::apply_overlay_hash },
273            };
274            quote! { #merge_fn(&mut self.#ident, overlay.#ident); }
275        })
276        .collect();
277
278    Ok(quote! {
279        pub struct #overlay_name {
280            #(#overlay_fields,)*
281        }
282
283        pub struct #tx_name<'a> {
284            #(#tx_fields,)*
285        }
286
287        impl #impl_generics etchdb::Transactable for #name #ty_generics #where_clause {
288            type Tx<'a> = #tx_name<'a>;
289            type Overlay = #overlay_name;
290
291            fn begin_tx(&self) -> #tx_name<'_> {
292                #tx_name {
293                    #(#begin_fields,)*
294                }
295            }
296
297            fn finish_tx(tx: #tx_name<'_>) -> (::std::vec::Vec<etchdb::Op>, #overlay_name) {
298                let mut ops = ::std::vec::Vec::new();
299                #(#finish_lets)*
300                (ops, #overlay_name {
301                    #(#finish_overlay_fields,)*
302                })
303            }
304
305            fn apply_overlay(&mut self, overlay: #overlay_name) {
306                #(#apply_stmts)*
307            }
308        }
309    })
310}
311
312fn map_type_tokens(f: &EtchField) -> proc_macro2::TokenStream {
313    let k = &f.key_ty;
314    let v = &f.value_ty;
315    match f.map_kind {
316        MapKind::BTreeMap => quote! { std::collections::BTreeMap<#k, #v> },
317        MapKind::HashMap => quote! { std::collections::HashMap<#k, #v> },
318    }
319}