1use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, DeriveInput, Fields, PathSegment};
10
11struct EtchField {
13 ident: syn::Ident,
14 collection_id: u8,
15 map_kind: MapKind,
16 key_ty: syn::Type,
17 value_ty: syn::Type,
18}
19
20#[derive(Clone, Copy, PartialEq)]
21enum MapKind {
22 BTreeMap,
23 HashMap,
24}
25
26fn parse_etch_fields(input: &DeriveInput) -> syn::Result<Vec<EtchField>> {
27 let data = match &input.data {
28 syn::Data::Struct(s) => s,
29 _ => {
30 return Err(syn::Error::new_spanned(
31 input,
32 "etch derives only work on structs",
33 ))
34 }
35 };
36 let fields = match &data.fields {
37 Fields::Named(f) => &f.named,
38 _ => {
39 return Err(syn::Error::new_spanned(
40 input,
41 "etch derives require named fields",
42 ))
43 }
44 };
45
46 let mut result = Vec::new();
47
48 for field in fields {
49 let mut collection_id: Option<u8> = None;
50
51 for attr in &field.attrs {
52 if !attr.path().is_ident("etch") {
53 continue;
54 }
55 attr.parse_nested_meta(|meta| {
56 if meta.path.is_ident("collection") {
57 let value = meta.value()?;
58 let lit: syn::LitInt = value.parse()?;
59 collection_id = Some(lit.base10_parse()?);
60 Ok(())
61 } else {
62 Err(meta.error("expected `collection = N`"))
63 }
64 })?;
65 }
66
67 let Some(id) = collection_id else {
68 continue;
69 };
70
71 let ident = field.ident.clone().unwrap();
72 let (map_kind, key_ty, value_ty) = parse_map_type(&field.ty).ok_or_else(|| {
73 syn::Error::new(
74 field.ty.span(),
75 "expected BTreeMap<K, V> or HashMap<K, V>",
76 )
77 })?;
78
79 result.push(EtchField {
80 ident,
81 collection_id: id,
82 map_kind,
83 key_ty,
84 value_ty,
85 });
86 }
87
88 if result.is_empty() {
89 return Err(syn::Error::new_spanned(
90 input,
91 "no fields annotated with #[etch(collection = N)]",
92 ));
93 }
94
95 let mut seen = std::collections::HashSet::new();
97 for f in &result {
98 if !seen.insert(f.collection_id) {
99 return Err(syn::Error::new_spanned(
100 &f.ident,
101 format!("duplicate collection id {}", f.collection_id),
102 ));
103 }
104 }
105
106 Ok(result)
107}
108
109fn parse_map_type(ty: &syn::Type) -> Option<(MapKind, syn::Type, syn::Type)> {
111 let path = match ty {
112 syn::Type::Path(p) => &p.path,
113 _ => return None,
114 };
115 let seg: &PathSegment = path.segments.last()?;
116 let kind = match seg.ident.to_string().as_str() {
117 "BTreeMap" => MapKind::BTreeMap,
118 "HashMap" => MapKind::HashMap,
119 _ => return None,
120 };
121 let args = match &seg.arguments {
122 syn::PathArguments::AngleBracketed(a) => a,
123 _ => return None,
124 };
125 let mut types = args.args.iter().filter_map(|a| match a {
126 syn::GenericArgument::Type(t) => Some(t.clone()),
127 _ => None,
128 });
129 let key = types.next()?;
130 let val = types.next()?;
131 Some((kind, key, val))
132}
133
134#[proc_macro_derive(Replayable, attributes(etch))]
140pub fn derive_replayable(input: TokenStream) -> TokenStream {
141 let input = parse_macro_input!(input as DeriveInput);
142 match derive_replayable_inner(&input) {
143 Ok(ts) => ts.into(),
144 Err(e) => e.to_compile_error().into(),
145 }
146}
147
148fn derive_replayable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
149 let fields = parse_etch_fields(input)?;
150 let name = &input.ident;
151 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
152
153 let arms: Vec<_> = fields
154 .iter()
155 .map(|f| {
156 let id = f.collection_id;
157 let field = &f.ident;
158 let key_ty = &f.key_ty;
159 let apply_fn = match f.map_kind {
160 MapKind::BTreeMap => quote! { etchdb::apply_op_with },
161 MapKind::HashMap => quote! { etchdb::apply_op_hash_with },
162 };
163 quote! {
164 #id => #apply_fn(&mut self.#field, op, |bytes| {
165 <#key_ty as etchdb::EtchKey>::from_bytes(bytes)
166 })?,
167 }
168 })
169 .collect();
170
171 Ok(quote! {
172 impl #impl_generics etchdb::Replayable for #name #ty_generics #where_clause {
173 fn apply(&mut self, ops: &[etchdb::Op]) -> etchdb::Result<()> {
174 for op in ops {
175 match op.collection() {
176 #(#arms)*
177 _ => {}
178 }
179 }
180 Ok(())
181 }
182 }
183 })
184}
185
186#[proc_macro_derive(Transactable, attributes(etch))]
193pub fn derive_transactable(input: TokenStream) -> TokenStream {
194 let input = parse_macro_input!(input as DeriveInput);
195 match derive_transactable_inner(&input) {
196 Ok(ts) => ts.into(),
197 Err(e) => e.to_compile_error().into(),
198 }
199}
200
201fn derive_transactable_inner(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
202 let fields = parse_etch_fields(input)?;
203 let name = &input.ident;
204 let tx_name = format_ident!("{}Tx", name);
205 let overlay_name = format_ident!("{}Overlay", name);
206 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
207
208 let overlay_fields: Vec<_> = fields
210 .iter()
211 .map(|f| {
212 let ident = &f.ident;
213 let k = &f.key_ty;
214 let v = &f.value_ty;
215 quote! { pub #ident: etchdb::Overlay<#k, #v> }
216 })
217 .collect();
218
219 let tx_fields: Vec<_> = fields
221 .iter()
222 .map(|f| {
223 let ident = &f.ident;
224 let k = &f.key_ty;
225 let v = &f.value_ty;
226 let m = map_type_tokens(f);
227 quote! { pub #ident: etchdb::Collection<'a, #k, #v, #m> }
228 })
229 .collect();
230
231 let begin_fields: Vec<_> = fields
233 .iter()
234 .map(|f| {
235 let ident = &f.ident;
236 let id = f.collection_id;
237 quote! { #ident: etchdb::Collection::new(&self.#ident, #id) }
238 })
239 .collect();
240
241 let finish_lets: Vec<_> = fields
243 .iter()
244 .map(|f| {
245 let ident = &f.ident;
246 let ops_name = format_ident!("{}_ops", ident);
247 let ov_name = format_ident!("{}_ov", ident);
248 quote! {
249 let (#ops_name, #ov_name) = tx.#ident.into_parts();
250 ops.extend(#ops_name);
251 }
252 })
253 .collect();
254
255 let finish_overlay_fields: Vec<_> = fields
256 .iter()
257 .map(|f| {
258 let ident = &f.ident;
259 let ov_name = format_ident!("{}_ov", ident);
260 quote! { #ident: #ov_name }
261 })
262 .collect();
263
264 let apply_stmts: Vec<_> = fields
266 .iter()
267 .map(|f| {
268 let ident = &f.ident;
269 let merge_fn = match f.map_kind {
270 MapKind::BTreeMap => quote! { etchdb::apply_overlay_btree },
271 MapKind::HashMap => quote! { etchdb::apply_overlay_hash },
272 };
273 quote! { #merge_fn(&mut self.#ident, overlay.#ident); }
274 })
275 .collect();
276
277 Ok(quote! {
278 pub struct #overlay_name {
279 #(#overlay_fields,)*
280 }
281
282 pub struct #tx_name<'a> {
283 #(#tx_fields,)*
284 }
285
286 impl #impl_generics etchdb::Transactable for #name #ty_generics #where_clause {
287 type Tx<'a> = #tx_name<'a>;
288 type Overlay = #overlay_name;
289
290 fn begin_tx(&self) -> #tx_name<'_> {
291 #tx_name {
292 #(#begin_fields,)*
293 }
294 }
295
296 fn finish_tx(tx: #tx_name<'_>) -> (::std::vec::Vec<etchdb::Op>, #overlay_name) {
297 let mut ops = ::std::vec::Vec::new();
298 #(#finish_lets)*
299 (ops, #overlay_name {
300 #(#finish_overlay_fields,)*
301 })
302 }
303
304 fn apply_overlay(&mut self, overlay: #overlay_name) {
305 #(#apply_stmts)*
306 }
307 }
308 })
309}
310
311fn map_type_tokens(f: &EtchField) -> proc_macro2::TokenStream {
312 let k = &f.key_ty;
313 let v = &f.value_ty;
314 match f.map_kind {
315 MapKind::BTreeMap => quote! { std::collections::BTreeMap<#k, #v> },
316 MapKind::HashMap => quote! { std::collections::HashMap<#k, #v> },
317 }
318}