crdts_macro_derive/
lib.rs

1use std::collections::HashMap;
2
3use convert_case::{Case, Casing};
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::{quote, quote_spanned};
6use syn::parse::{Parse, ParseStream, Parser, Result};
7use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields, Type};
8
9struct Args(Type);
10
11impl Parse for Args {
12    fn parse(input: ParseStream) -> Result<Self> {
13        Ok(Args(input.parse()?))
14    }
15}
16
17#[proc_macro_attribute]
18pub fn crdt(
19    args: proc_macro::TokenStream,
20    input: proc_macro::TokenStream,
21) -> proc_macro::TokenStream {
22    let mut ast = parse_macro_input!(input as DeriveInput);
23    let args = parse_macro_input!(args as Args);
24
25    let v_clock_type = args.0;
26
27    // If the struct has named fields, add a v_clock field to it
28    if let syn::Data::Struct(ref mut struct_data) = ast.data {
29        if let syn::Fields::Named(fields) = &mut struct_data.fields {
30            fields.named.push(
31                syn::Field::parse_named
32                    .parse2(quote! { v_clock: crdts::VClock<#v_clock_type> })
33                    .unwrap(),
34            );
35        } else {
36            panic!("`crdt` can only be used on `struct`s that have named fields");
37        }
38    } else {
39        panic!("`crdt` can only be used on `struct`s");
40    }
41
42    // add `CRDT` derive for the struct
43    let gen = quote! {
44        #[derive(crdts_macro::CRDT, Default, std::fmt::Debug, Clone, PartialEq, Eq, crdts_macro::serde::Serialize, crdts_macro::serde::Deserialize)]
45        #[serde(crate = "crdts_macro::serde")]
46        #ast
47    };
48
49    gen.into()
50}
51
52#[proc_macro_derive(CRDT)]
53pub fn crdt_macro_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54    let input = syn::parse(input).unwrap();
55    let expanded = impl_crdt_macro(input);
56    proc_macro::TokenStream::from(expanded)
57}
58
59fn impl_crdt_macro(input: syn::DeriveInput) -> TokenStream {
60    let name = &input.ident;
61    let data = &input.data;
62
63    let fields = list_fields(data);
64
65    let m_error_name = Ident::new(&(name.to_string() + "CmRDTError"), Span::call_site());
66    let m_error_enum = build_m_error(&fields);
67
68    let v_error_name = Ident::new(&(name.to_string() + "CvRDTError"), Span::call_site());
69    let v_error_enum = build_v_error(&fields);
70
71    let op_name = Ident::new(&(name.to_string() + "CrdtOp"), Span::call_site());
72    let op_param = build_op(&fields);
73
74    let impl_apply = impl_apply(&fields);
75    let impl_validate = impl_validate(&fields);
76
77    let impl_merge = impl_merge(&fields);
78    let impl_validate_merge = impl_validate_merge(&fields);
79
80    quote! {
81        #[derive(std::fmt::Debug, PartialEq, Eq)]
82        pub enum #m_error_name {
83            NoneOp,
84            #m_error_enum
85        }
86
87        impl std::fmt::Display for #m_error_name {
88            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89                std::fmt::Debug::fmt(&self, f)
90            }
91        }
92
93        impl std::error::Error for #m_error_name {}
94
95        #[allow(clippy::type_complexity)]
96        #[derive(std::fmt::Debug, Clone, PartialEq, Eq, crdts_macro::serde::Serialize, crdts_macro::serde::Deserialize)]
97        #[serde(crate = "crdts_macro::serde")]
98        pub struct #op_name {
99            #op_param
100        }
101
102        impl crdts::CmRDT for #name {
103            type Op = #op_name;
104            type Validation = #m_error_name;
105
106            fn apply(&mut self, op: Self::Op) {
107                #impl_apply
108            }
109
110            fn validate_op(&self, op: &Self::Op) -> Result<(), Self::Validation> {
111                #impl_validate
112            }
113        }
114
115        #[derive(std::fmt::Debug, PartialEq, Eq)]
116        pub enum #v_error_name {
117            #v_error_enum
118        }
119
120        impl std::fmt::Display for #v_error_name {
121            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122                std::fmt::Debug::fmt(&self, f)
123            }
124        }
125
126        impl std::error::Error for #v_error_name {}
127
128        impl crdts::CvRDT for #name {
129            type Validation = #v_error_name;
130
131            fn validate_merge(&self, other: &Self) -> Result<(), Self::Validation> {
132                #impl_validate_merge
133                Ok(())
134            }
135
136            fn merge(&mut self, other: Self) {
137                #impl_merge
138            }
139        }
140    }
141}
142
143fn list_fields(data: &Data) -> HashMap<String, Type> {
144    if let Data::Struct(DataStruct {
145        fields: Fields::Named(fields),
146        ..
147    }) = data
148    {
149        fields
150            .named
151            .iter()
152            .map(|f| (f.ident.as_ref().unwrap().to_string(), f.ty.clone()))
153            .collect()
154    } else {
155        HashMap::new()
156    }
157}
158
159fn build_m_error(fields: &HashMap<String, Type>) -> TokenStream {
160    fields
161        .iter()
162        .map(|(field_name, field_type)| {
163            let pascal_name = field_name.to_case(Case::Pascal);
164            let name = Ident::new(&pascal_name, Span::call_site());
165            quote_spanned! { Span::call_site() =>
166                #name(<#field_type as crdts::CmRDT>::Validation),
167            }
168        })
169        .collect::<TokenStream>()
170}
171
172fn build_v_error(fields: &HashMap<String, Type>) -> TokenStream {
173    fields
174        .iter()
175        .map(|(name, ty)| {
176            let pascal_name = name.to_case(Case::Pascal);
177            let name = Ident::new(&pascal_name, Span::call_site());
178            quote_spanned! { Span::call_site() =>
179                #name(<#ty as crdts::CvRDT>::Validation),
180            }
181        })
182        .collect::<TokenStream>()
183}
184
185fn build_op(fields: &HashMap<String, Type>) -> TokenStream {
186    let mut tokens = TokenStream::new();
187    for (name, ty) in fields {
188        let (name, is_vclock) = if name == "v_clock" {
189            (Ident::new("dot", Span::call_site()), true)
190        } else {
191            (
192                Ident::new(&format!("{}_op", name), Span::call_site()),
193                false,
194            )
195        };
196        let op_type = if is_vclock {
197            quote! {<#ty as crdts::CmRDT>::Op}
198        } else {
199            quote! {Option<<#ty as crdts::CmRDT>::Op>}
200        };
201        tokens.extend(quote_spanned! {Span::call_site() =>
202            pub #name: #op_type,
203        });
204    }
205    tokens
206}
207
208fn impl_apply(fields: &HashMap<String, Type>) -> TokenStream {
209    let op_params = op_params(fields);
210    let nones = count_none(fields);
211
212    let apply = fields.keys().filter(|f| *f != "v_clock").map(|f| {
213        let field = Ident::new(f, Span::call_site());
214        let op = Ident::new(&(f.to_owned() + "_op"), Span::call_site());
215
216        quote_spanned! { Span::call_site() =>
217            if let Some(#op) = #op {
218                self.#field.apply(#op);
219            }
220        }
221    });
222
223    quote! {
224        let Self::Op { dot, #op_params } = op;
225        if self.v_clock.get(&dot.actor) >= dot.counter {
226            return;
227        }
228        match (#op_params) {
229            (#nones) => return,
230            (#op_params) => { #(#apply)* }
231        }
232        self.v_clock.apply(dot);
233    }
234}
235
236fn impl_validate(fields: &HashMap<String, Type>) -> TokenStream {
237    let op_params = op_params(fields);
238    let nones = count_none(fields);
239
240    let validate = fields.keys().filter(|f| f != &"v_clock").map(|f| {
241        let pascal_name = f.to_case(Case::Pascal);
242        let error_name = Ident::new(&pascal_name, Span::call_site());
243        let field = Ident::new(f, Span::call_site());
244        let op = Ident::new(&(f.to_owned() + "_op"), Span::call_site());
245        quote_spanned! { Span::call_site() =>
246            if let Some(#op) = #op {
247                self.#field.validate_op(#op).map_err(Self::Validation::#error_name)?;
248            }
249        }
250    });
251
252    quote! {
253        let Self::Op {
254            dot,
255            #op_params
256        } = op;
257        self.v_clock.validate_op(dot).map_err(Self::Validation::VClock)?;
258        match (#op_params) {
259            (#nones) => return Err(Self::Validation::NoneOp),
260            (#op_params) => {
261                #(#validate)*
262                return Ok(());
263            }
264        }
265    }
266}
267
268fn impl_merge(fields: &HashMap<String, Type>) -> TokenStream {
269    fields
270        .keys()
271        .map(|f| {
272            let field = Ident::new(f, Span::call_site());
273            quote_spanned! {
274                Span::call_site() => self.#field.merge(other.#field);
275            }
276        })
277        .collect()
278}
279
280fn impl_validate_merge(fields: &HashMap<String, Type>) -> TokenStream {
281    fields
282        .keys()
283        .map(|field| {
284            let error_name = Ident::new(&field.to_case(Case::Pascal), Span::call_site());
285            let field = Ident::new(field, Span::call_site());
286            quote! {
287                self.#field.validate_merge(&other.#field)
288                    .map_err(Self::Validation::#error_name)?;
289            }
290        })
291        .collect()
292}
293
294fn count_none(fields: &HashMap<String, Type>) -> TokenStream {
295    fields
296        .keys()
297        .filter(|&f| f != "v_clock")
298        .map(|_| quote!(None,))
299        .collect::<Vec<_>>()
300        .into_iter()
301        .collect::<TokenStream>()
302}
303
304fn op_params(fields: &HashMap<String, Type>) -> TokenStream {
305    fields
306        .keys()
307        .filter(|f| *f != "v_clock")
308        .map(|f| format!("{}_op", f))
309        .map(|i| Ident::new(&i, Span::call_site()))
310        .map(|i| quote!(#i,))
311        .collect()
312}