1use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::spanned::Spanned;
9use syn::{DeriveInput, Fields, PathSegment, parse_macro_input};
10
11struct EtchField {
13 ident: syn::Ident,
14 collection_id: u8,
15 map_kind: MapKind,
16 key_ty: syn::Type,
17 value_ty: syn::Type,
18}
19
20#[derive(Clone, Copy, PartialEq)]
21enum MapKind {
22 BTreeMap,
23 HashMap,
24}
25
26fn parse_etch_fields(input: &DeriveInput) -> syn::Result<Vec<EtchField>> {
27 let data = match &input.data {
28 syn::Data::Struct(s) => s,
29 _ => {
30 return Err(syn::Error::new_spanned(
31 input,
32 "etch derives only work on structs",
33 ));
34 }
35 };
36 let fields = match &data.fields {
37 Fields::Named(f) => &f.named,
38 _ => {
39 return Err(syn::Error::new_spanned(
40 input,
41 "etch derives require named fields",
42 ));
43 }
44 };
45
46 let mut result = Vec::new();
47
48 for field in fields {
49 let mut collection_id: Option<u8> = None;
50
51 for attr in &field.attrs {
52 if !attr.path().is_ident("etch") {
53 continue;
54 }
55 attr.parse_nested_meta(|meta| {
56 if meta.path.is_ident("collection") {
57 let value = meta.value()?;
58 let lit: syn::LitInt = value.parse()?;
59 collection_id = Some(lit.base10_parse()?);
60 Ok(())
61 } else {
62 Err(meta.error("expected `collection = N`"))
63 }
64 })?;
65 }
66
67 let Some(id) = collection_id else {
68 continue;
69 };
70
71 let ident = field.ident.clone().unwrap();
72 let (map_kind, key_ty, value_ty) = parse_map_type(&field.ty).ok_or_else(|| {
73 syn::Error::new(field.ty.span(), "expected BTreeMap<K, V> or HashMap<K, V>")
74 })?;
75
76 result.push(EtchField {
77 ident,
78 collection_id: id,
79 map_kind,
80 key_ty,
81 value_ty,
82 });
83 }
84
85 if result.is_empty() {
86 return Err(syn::Error::new_spanned(
87 input,
88 "no fields annotated with #[etch(collection = N)]",
89 ));
90 }
91
92 let mut seen = std::collections::HashSet::new();
94 for f in &result {
95 if !seen.insert(f.collection_id) {
96 return Err(syn::Error::new_spanned(
97 &f.ident,
98 format!("duplicate collection id {}", f.collection_id),
99 ));
100 }
101 }
102
103 Ok(result)
104}
105
106fn parse_map_type(ty: &syn::Type) -> Option<(MapKind, syn::Type, syn::Type)> {
108 let path = match ty {
109 syn::Type::Path(p) => &p.path,
110 _ => return None,
111 };
112 let seg: &PathSegment = path.segments.last()?;
113 let kind = match seg.ident.to_string().as_str() {
114 "BTreeMap" => MapKind::BTreeMap,
115 "HashMap" => MapKind::HashMap,
116 _ => return None,
117 };
118 let args = match &seg.arguments {
119 syn::PathArguments::AngleBracketed(a) => a,
120 _ => return None,
121 };
122 let mut types = args.args.iter().filter_map(|a| match a {
123 syn::GenericArgument::Type(t) => Some(t.clone()),
124 _ => None,
125 });
126 let key = types.next()?;
127 let val = types.next()?;
128 Some((kind, key, val))
129}
130
131#[proc_macro_derive(Replayable, attributes(etch))]
137pub fn derive_replayable(input: TokenStream) -> TokenStream {
138 let input = parse_macro_input!(input as DeriveInput);
139 match derive_replayable_inner(&input) {
140 Ok(ts) => ts.into(),
141 Err(e) => e.to_compile_error().into(),
142 }
143}
144
145fn derive_replayable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
146 let fields = parse_etch_fields(input)?;
147 let name = &input.ident;
148 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
149
150 let arms: Vec<_> = fields
151 .iter()
152 .map(|f| {
153 let id = f.collection_id;
154 let field = &f.ident;
155 let key_ty = &f.key_ty;
156 let apply_fn = match f.map_kind {
157 MapKind::BTreeMap => quote! { etchdb::apply_op_with },
158 MapKind::HashMap => quote! { etchdb::apply_op_hash_with },
159 };
160 quote! {
161 #id => {
162 if let Err(e) = #apply_fn(&mut self.#field, op, |bytes| {
163 <#key_ty as etchdb::EtchKey>::from_bytes(bytes)
164 }) {
165 eprintln!("etchdb: skipped op on collection {}: {}", #id, e);
166 }
167 }
168 }
169 })
170 .collect();
171
172 Ok(quote! {
173 impl #impl_generics etchdb::Replayable for #name #ty_generics #where_clause {
174 fn apply(&mut self, ops: &[etchdb::Op]) -> etchdb::Result<()> {
175 for op in ops {
176 match op.collection() {
177 #(#arms)*
178 _ => {}
179 }
180 }
181 Ok(())
182 }
183 }
184 })
185}
186
187#[proc_macro_derive(Transactable, attributes(etch))]
194pub fn derive_transactable(input: TokenStream) -> TokenStream {
195 let input = parse_macro_input!(input as DeriveInput);
196 match derive_transactable_inner(&input) {
197 Ok(ts) => ts.into(),
198 Err(e) => e.to_compile_error().into(),
199 }
200}
201
202fn derive_transactable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
203 let fields = parse_etch_fields(input)?;
204 let name = &input.ident;
205 let tx_name = format_ident!("{}Tx", name);
206 let overlay_name = format_ident!("{}Overlay", name);
207 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
208
209 let overlay_fields: Vec<_> = fields
211 .iter()
212 .map(|f| {
213 let ident = &f.ident;
214 let k = &f.key_ty;
215 let v = &f.value_ty;
216 quote! { pub #ident: etchdb::Overlay<#k, #v> }
217 })
218 .collect();
219
220 let tx_fields: Vec<_> = fields
222 .iter()
223 .map(|f| {
224 let ident = &f.ident;
225 let k = &f.key_ty;
226 let v = &f.value_ty;
227 let m = map_type_tokens(f);
228 quote! { pub #ident: etchdb::Collection<'a, #k, #v, #m> }
229 })
230 .collect();
231
232 let begin_fields: Vec<_> = fields
234 .iter()
235 .map(|f| {
236 let ident = &f.ident;
237 let id = f.collection_id;
238 quote! { #ident: etchdb::Collection::new(&self.#ident, #id) }
239 })
240 .collect();
241
242 let finish_lets: Vec<_> = fields
244 .iter()
245 .map(|f| {
246 let ident = &f.ident;
247 let ops_name = format_ident!("{}_ops", ident);
248 let ov_name = format_ident!("{}_ov", ident);
249 quote! {
250 let (#ops_name, #ov_name) = tx.#ident.into_parts();
251 ops.extend(#ops_name);
252 }
253 })
254 .collect();
255
256 let finish_overlay_fields: Vec<_> = fields
257 .iter()
258 .map(|f| {
259 let ident = &f.ident;
260 let ov_name = format_ident!("{}_ov", ident);
261 quote! { #ident: #ov_name }
262 })
263 .collect();
264
265 let apply_stmts: Vec<_> = fields
267 .iter()
268 .map(|f| {
269 let ident = &f.ident;
270 let merge_fn = match f.map_kind {
271 MapKind::BTreeMap => quote! { etchdb::apply_overlay_btree },
272 MapKind::HashMap => quote! { etchdb::apply_overlay_hash },
273 };
274 quote! { #merge_fn(&mut self.#ident, overlay.#ident); }
275 })
276 .collect();
277
278 Ok(quote! {
279 pub struct #overlay_name {
280 #(#overlay_fields,)*
281 }
282
283 pub struct #tx_name<'a> {
284 #(#tx_fields,)*
285 }
286
287 impl #impl_generics etchdb::Transactable for #name #ty_generics #where_clause {
288 type Tx<'a> = #tx_name<'a>;
289 type Overlay = #overlay_name;
290
291 fn begin_tx(&self) -> #tx_name<'_> {
292 #tx_name {
293 #(#begin_fields,)*
294 }
295 }
296
297 fn finish_tx(tx: #tx_name<'_>) -> (::std::vec::Vec<etchdb::Op>, #overlay_name) {
298 let mut ops = ::std::vec::Vec::new();
299 #(#finish_lets)*
300 (ops, #overlay_name {
301 #(#finish_overlay_fields,)*
302 })
303 }
304
305 fn apply_overlay(&mut self, overlay: #overlay_name) {
306 #(#apply_stmts)*
307 }
308 }
309 })
310}
311
312fn map_type_tokens(f: &EtchField) -> proc_macro2::TokenStream {
313 let k = &f.key_ty;
314 let v = &f.value_ty;
315 match f.map_kind {
316 MapKind::BTreeMap => quote! { std::collections::BTreeMap<#k, #v> },
317 MapKind::HashMap => quote! { std::collections::HashMap<#k, #v> },
318 }
319}