Skip to main content

etchdb_derive/
lib.rs

1//! Derive macros for etchdb.
2//!
3//! Generates `Replayable` and `Transactable` implementations from annotated
4//! structs, eliminating ~60 lines of boilerplate per state type.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, DeriveInput, Fields, PathSegment};
10
11/// Parsed info about one `#[etch(collection = N)]` field.
12struct EtchField {
13    ident: syn::Ident,
14    collection_id: u8,
15    map_kind: MapKind,
16    key_ty: syn::Type,
17    value_ty: syn::Type,
18}
19
20#[derive(Clone, Copy, PartialEq)]
21enum MapKind {
22    BTreeMap,
23    HashMap,
24}
25
26fn parse_etch_fields(input: &DeriveInput) -> syn::Result<Vec<EtchField>> {
27    let data = match &input.data {
28        syn::Data::Struct(s) => s,
29        _ => {
30            return Err(syn::Error::new_spanned(
31                input,
32                "etch derives only work on structs",
33            ))
34        }
35    };
36    let fields = match &data.fields {
37        Fields::Named(f) => &f.named,
38        _ => {
39            return Err(syn::Error::new_spanned(
40                input,
41                "etch derives require named fields",
42            ))
43        }
44    };
45
46    let mut result = Vec::new();
47
48    for field in fields {
49        let mut collection_id: Option<u8> = None;
50
51        for attr in &field.attrs {
52            if !attr.path().is_ident("etch") {
53                continue;
54            }
55            attr.parse_nested_meta(|meta| {
56                if meta.path.is_ident("collection") {
57                    let value = meta.value()?;
58                    let lit: syn::LitInt = value.parse()?;
59                    collection_id = Some(lit.base10_parse()?);
60                    Ok(())
61                } else {
62                    Err(meta.error("expected `collection = N`"))
63                }
64            })?;
65        }
66
67        let Some(id) = collection_id else {
68            continue;
69        };
70
71        let ident = field.ident.clone().unwrap();
72        let (map_kind, key_ty, value_ty) = parse_map_type(&field.ty).ok_or_else(|| {
73            syn::Error::new(
74                field.ty.span(),
75                "expected BTreeMap<K, V> or HashMap<K, V>",
76            )
77        })?;
78
79        result.push(EtchField {
80            ident,
81            collection_id: id,
82            map_kind,
83            key_ty,
84            value_ty,
85        });
86    }
87
88    if result.is_empty() {
89        return Err(syn::Error::new_spanned(
90            input,
91            "no fields annotated with #[etch(collection = N)]",
92        ));
93    }
94
95    // Check for duplicate collection IDs.
96    let mut seen = std::collections::HashSet::new();
97    for f in &result {
98        if !seen.insert(f.collection_id) {
99            return Err(syn::Error::new_spanned(
100                &f.ident,
101                format!("duplicate collection id {}", f.collection_id),
102            ));
103        }
104    }
105
106    Ok(result)
107}
108
109/// Extract (MapKind, K, V) from `BTreeMap<K, V>` or `HashMap<K, V>`.
110fn parse_map_type(ty: &syn::Type) -> Option<(MapKind, syn::Type, syn::Type)> {
111    let path = match ty {
112        syn::Type::Path(p) => &p.path,
113        _ => return None,
114    };
115    let seg: &PathSegment = path.segments.last()?;
116    let kind = match seg.ident.to_string().as_str() {
117        "BTreeMap" => MapKind::BTreeMap,
118        "HashMap" => MapKind::HashMap,
119        _ => return None,
120    };
121    let args = match &seg.arguments {
122        syn::PathArguments::AngleBracketed(a) => a,
123        _ => return None,
124    };
125    let mut types = args.args.iter().filter_map(|a| match a {
126        syn::GenericArgument::Type(t) => Some(t.clone()),
127        _ => None,
128    });
129    let key = types.next()?;
130    let val = types.next()?;
131    Some((kind, key, val))
132}
133
134/// Derive `Replayable` for a struct with `#[etch(collection = N)]` fields.
135///
136/// Generates an `apply` method that routes ops to the correct field based
137/// on the collection id, using `apply_op` for BTreeMap fields and
138/// `apply_op_hash` for HashMap fields.
139#[proc_macro_derive(Replayable, attributes(etch))]
140pub fn derive_replayable(input: TokenStream) -> TokenStream {
141    let input = parse_macro_input!(input as DeriveInput);
142    match derive_replayable_inner(&input) {
143        Ok(ts) => ts.into(),
144        Err(e) => e.to_compile_error().into(),
145    }
146}
147
148fn derive_replayable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
149    let fields = parse_etch_fields(input)?;
150    let name = &input.ident;
151    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
152
153    let arms: Vec<_> = fields
154        .iter()
155        .map(|f| {
156            let id = f.collection_id;
157            let field = &f.ident;
158            let key_ty = &f.key_ty;
159            let apply_fn = match f.map_kind {
160                MapKind::BTreeMap => quote! { etchdb::apply_op_with },
161                MapKind::HashMap => quote! { etchdb::apply_op_hash_with },
162            };
163            quote! {
164                #id => #apply_fn(&mut self.#field, op, |bytes| {
165                    <#key_ty as etchdb::EtchKey>::from_bytes(bytes)
166                })?,
167            }
168        })
169        .collect();
170
171    Ok(quote! {
172        impl #impl_generics etchdb::Replayable for #name #ty_generics #where_clause {
173            fn apply(&mut self, ops: &[etchdb::Op]) -> etchdb::Result<()> {
174                for op in ops {
175                    match op.collection() {
176                        #(#arms)*
177                        _ => {}
178                    }
179                }
180                Ok(())
181            }
182        }
183    })
184}
185
186/// Derive `Transactable` for a struct with `#[etch(collection = N)]` fields.
187///
188/// Generates:
189/// - A transaction struct (`{Name}Tx`) with `Collection` fields
190/// - An overlay struct (`{Name}Overlay`) with `Overlay` fields
191/// - The full `Transactable` trait implementation
192#[proc_macro_derive(Transactable, attributes(etch))]
193pub fn derive_transactable(input: TokenStream) -> TokenStream {
194    let input = parse_macro_input!(input as DeriveInput);
195    match derive_transactable_inner(&input) {
196        Ok(ts) => ts.into(),
197        Err(e) => e.to_compile_error().into(),
198    }
199}
200
201fn derive_transactable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
202    let fields = parse_etch_fields(input)?;
203    let name = &input.ident;
204    let tx_name = format_ident!("{}Tx", name);
205    let overlay_name = format_ident!("{}Overlay", name);
206    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
207
208    // Overlay struct fields.
209    let overlay_fields: Vec<_> = fields
210        .iter()
211        .map(|f| {
212            let ident = &f.ident;
213            let k = &f.key_ty;
214            let v = &f.value_ty;
215            quote! { pub #ident: etchdb::Overlay<#k, #v> }
216        })
217        .collect();
218
219    // Tx struct fields: one Collection per annotated field.
220    let tx_fields: Vec<_> = fields
221        .iter()
222        .map(|f| {
223            let ident = &f.ident;
224            let k = &f.key_ty;
225            let v = &f.value_ty;
226            let m = map_type_tokens(f);
227            quote! { pub #ident: etchdb::Collection<'a, #k, #v, #m> }
228        })
229        .collect();
230
231    // begin_tx: construct Collection for each field.
232    let begin_fields: Vec<_> = fields
233        .iter()
234        .map(|f| {
235            let ident = &f.ident;
236            let id = f.collection_id;
237            quote! { #ident: etchdb::Collection::new(&self.#ident, #id) }
238        })
239        .collect();
240
241    // finish_tx: destructure each Collection into ops + overlay.
242    let finish_lets: Vec<_> = fields
243        .iter()
244        .map(|f| {
245            let ident = &f.ident;
246            let ops_name = format_ident!("{}_ops", ident);
247            let ov_name = format_ident!("{}_ov", ident);
248            quote! {
249                let (#ops_name, #ov_name) = tx.#ident.into_parts();
250                ops.extend(#ops_name);
251            }
252        })
253        .collect();
254
255    let finish_overlay_fields: Vec<_> = fields
256        .iter()
257        .map(|f| {
258            let ident = &f.ident;
259            let ov_name = format_ident!("{}_ov", ident);
260            quote! { #ident: #ov_name }
261        })
262        .collect();
263
264    // apply_overlay: merge each overlay into committed state.
265    let apply_stmts: Vec<_> = fields
266        .iter()
267        .map(|f| {
268            let ident = &f.ident;
269            let merge_fn = match f.map_kind {
270                MapKind::BTreeMap => quote! { etchdb::apply_overlay_btree },
271                MapKind::HashMap => quote! { etchdb::apply_overlay_hash },
272            };
273            quote! { #merge_fn(&mut self.#ident, overlay.#ident); }
274        })
275        .collect();
276
277    Ok(quote! {
278        pub struct #overlay_name {
279            #(#overlay_fields,)*
280        }
281
282        pub struct #tx_name<'a> {
283            #(#tx_fields,)*
284        }
285
286        impl #impl_generics etchdb::Transactable for #name #ty_generics #where_clause {
287            type Tx<'a> = #tx_name<'a>;
288            type Overlay = #overlay_name;
289
290            fn begin_tx(&self) -> #tx_name<'_> {
291                #tx_name {
292                    #(#begin_fields,)*
293                }
294            }
295
296            fn finish_tx(tx: #tx_name<'_>) -> (::std::vec::Vec<etchdb::Op>, #overlay_name) {
297                let mut ops = ::std::vec::Vec::new();
298                #(#finish_lets)*
299                (ops, #overlay_name {
300                    #(#finish_overlay_fields,)*
301                })
302            }
303
304            fn apply_overlay(&mut self, overlay: #overlay_name) {
305                #(#apply_stmts)*
306            }
307        }
308    })
309}
310
311fn map_type_tokens(f: &EtchField) -> proc_macro2::TokenStream {
312    let k = &f.key_ty;
313    let v = &f.value_ty;
314    match f.map_kind {
315        MapKind::BTreeMap => quote! { std::collections::BTreeMap<#k, #v> },
316        MapKind::HashMap => quote! { std::collections::HashMap<#k, #v> },
317    }
318}