1use std::collections::HashMap;
2
3use convert_case::{Case, Casing};
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::{quote, quote_spanned};
6use syn::parse::{Parse, ParseStream, Parser, Result};
7use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields, Type};
8
9struct Args(Type);
10
11impl Parse for Args {
12 fn parse(input: ParseStream) -> Result<Self> {
13 Ok(Args(input.parse()?))
14 }
15}
16
17#[proc_macro_attribute]
18pub fn crdt(
19 args: proc_macro::TokenStream,
20 input: proc_macro::TokenStream,
21) -> proc_macro::TokenStream {
22 let mut ast = parse_macro_input!(input as DeriveInput);
23 let args = parse_macro_input!(args as Args);
24
25 let v_clock_type = args.0;
26
27 if let syn::Data::Struct(ref mut struct_data) = ast.data {
29 if let syn::Fields::Named(fields) = &mut struct_data.fields {
30 fields.named.push(
31 syn::Field::parse_named
32 .parse2(quote! { v_clock: crdts::VClock<#v_clock_type> })
33 .unwrap(),
34 );
35 } else {
36 panic!("`crdt` can only be used on `struct`s that have named fields");
37 }
38 } else {
39 panic!("`crdt` can only be used on `struct`s");
40 }
41
42 let gen = quote! {
44 #[derive(crdts_macro::CRDT, Default, std::fmt::Debug, Clone, PartialEq, Eq, crdts_macro::serde::Serialize, crdts_macro::serde::Deserialize)]
45 #[serde(crate = "crdts_macro::serde")]
46 #ast
47 };
48
49 gen.into()
50}
51
52#[proc_macro_derive(CRDT)]
53pub fn crdt_macro_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54 let input = syn::parse(input).unwrap();
55 let expanded = impl_crdt_macro(input);
56 proc_macro::TokenStream::from(expanded)
57}
58
59fn impl_crdt_macro(input: syn::DeriveInput) -> TokenStream {
60 let name = &input.ident;
61 let data = &input.data;
62
63 let fields = list_fields(data);
64
65 let m_error_name = Ident::new(&(name.to_string() + "CmRDTError"), Span::call_site());
66 let m_error_enum = build_m_error(&fields);
67
68 let v_error_name = Ident::new(&(name.to_string() + "CvRDTError"), Span::call_site());
69 let v_error_enum = build_v_error(&fields);
70
71 let op_name = Ident::new(&(name.to_string() + "CrdtOp"), Span::call_site());
72 let op_param = build_op(&fields);
73
74 let impl_apply = impl_apply(&fields);
75 let impl_validate = impl_validate(&fields);
76
77 let impl_merge = impl_merge(&fields);
78 let impl_validate_merge = impl_validate_merge(&fields);
79
80 quote! {
81 #[derive(std::fmt::Debug, PartialEq, Eq)]
82 pub enum #m_error_name {
83 NoneOp,
84 #m_error_enum
85 }
86
87 impl std::fmt::Display for #m_error_name {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 std::fmt::Debug::fmt(&self, f)
90 }
91 }
92
93 impl std::error::Error for #m_error_name {}
94
95 #[allow(clippy::type_complexity)]
96 #[derive(std::fmt::Debug, Clone, PartialEq, Eq, crdts_macro::serde::Serialize, crdts_macro::serde::Deserialize)]
97 #[serde(crate = "crdts_macro::serde")]
98 pub struct #op_name {
99 #op_param
100 }
101
102 impl crdts::CmRDT for #name {
103 type Op = #op_name;
104 type Validation = #m_error_name;
105
106 fn apply(&mut self, op: Self::Op) {
107 #impl_apply
108 }
109
110 fn validate_op(&self, op: &Self::Op) -> Result<(), Self::Validation> {
111 #impl_validate
112 }
113 }
114
115 #[derive(std::fmt::Debug, PartialEq, Eq)]
116 pub enum #v_error_name {
117 #v_error_enum
118 }
119
120 impl std::fmt::Display for #v_error_name {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 std::fmt::Debug::fmt(&self, f)
123 }
124 }
125
126 impl std::error::Error for #v_error_name {}
127
128 impl crdts::CvRDT for #name {
129 type Validation = #v_error_name;
130
131 fn validate_merge(&self, other: &Self) -> Result<(), Self::Validation> {
132 #impl_validate_merge
133 Ok(())
134 }
135
136 fn merge(&mut self, other: Self) {
137 #impl_merge
138 }
139 }
140 }
141}
142
143fn list_fields(data: &Data) -> HashMap<String, Type> {
144 if let Data::Struct(DataStruct {
145 fields: Fields::Named(fields),
146 ..
147 }) = data
148 {
149 fields
150 .named
151 .iter()
152 .map(|f| (f.ident.as_ref().unwrap().to_string(), f.ty.clone()))
153 .collect()
154 } else {
155 HashMap::new()
156 }
157}
158
159fn build_m_error(fields: &HashMap<String, Type>) -> TokenStream {
160 fields
161 .iter()
162 .map(|(field_name, field_type)| {
163 let pascal_name = field_name.to_case(Case::Pascal);
164 let name = Ident::new(&pascal_name, Span::call_site());
165 quote_spanned! { Span::call_site() =>
166 #name(<#field_type as crdts::CmRDT>::Validation),
167 }
168 })
169 .collect::<TokenStream>()
170}
171
172fn build_v_error(fields: &HashMap<String, Type>) -> TokenStream {
173 fields
174 .iter()
175 .map(|(name, ty)| {
176 let pascal_name = name.to_case(Case::Pascal);
177 let name = Ident::new(&pascal_name, Span::call_site());
178 quote_spanned! { Span::call_site() =>
179 #name(<#ty as crdts::CvRDT>::Validation),
180 }
181 })
182 .collect::<TokenStream>()
183}
184
185fn build_op(fields: &HashMap<String, Type>) -> TokenStream {
186 let mut tokens = TokenStream::new();
187 for (name, ty) in fields {
188 let (name, is_vclock) = if name == "v_clock" {
189 (Ident::new("dot", Span::call_site()), true)
190 } else {
191 (
192 Ident::new(&format!("{}_op", name), Span::call_site()),
193 false,
194 )
195 };
196 let op_type = if is_vclock {
197 quote! {<#ty as crdts::CmRDT>::Op}
198 } else {
199 quote! {Option<<#ty as crdts::CmRDT>::Op>}
200 };
201 tokens.extend(quote_spanned! {Span::call_site() =>
202 pub #name: #op_type,
203 });
204 }
205 tokens
206}
207
208fn impl_apply(fields: &HashMap<String, Type>) -> TokenStream {
209 let op_params = op_params(fields);
210 let nones = count_none(fields);
211
212 let apply = fields.keys().filter(|f| *f != "v_clock").map(|f| {
213 let field = Ident::new(f, Span::call_site());
214 let op = Ident::new(&(f.to_owned() + "_op"), Span::call_site());
215
216 quote_spanned! { Span::call_site() =>
217 if let Some(#op) = #op {
218 self.#field.apply(#op);
219 }
220 }
221 });
222
223 quote! {
224 let Self::Op { dot, #op_params } = op;
225 if self.v_clock.get(&dot.actor) >= dot.counter {
226 return;
227 }
228 match (#op_params) {
229 (#nones) => return,
230 (#op_params) => { #(#apply)* }
231 }
232 self.v_clock.apply(dot);
233 }
234}
235
236fn impl_validate(fields: &HashMap<String, Type>) -> TokenStream {
237 let op_params = op_params(fields);
238 let nones = count_none(fields);
239
240 let validate = fields.keys().filter(|f| f != &"v_clock").map(|f| {
241 let pascal_name = f.to_case(Case::Pascal);
242 let error_name = Ident::new(&pascal_name, Span::call_site());
243 let field = Ident::new(f, Span::call_site());
244 let op = Ident::new(&(f.to_owned() + "_op"), Span::call_site());
245 quote_spanned! { Span::call_site() =>
246 if let Some(#op) = #op {
247 self.#field.validate_op(#op).map_err(Self::Validation::#error_name)?;
248 }
249 }
250 });
251
252 quote! {
253 let Self::Op {
254 dot,
255 #op_params
256 } = op;
257 self.v_clock.validate_op(dot).map_err(Self::Validation::VClock)?;
258 match (#op_params) {
259 (#nones) => return Err(Self::Validation::NoneOp),
260 (#op_params) => {
261 #(#validate)*
262 return Ok(());
263 }
264 }
265 }
266}
267
268fn impl_merge(fields: &HashMap<String, Type>) -> TokenStream {
269 fields
270 .keys()
271 .map(|f| {
272 let field = Ident::new(f, Span::call_site());
273 quote_spanned! {
274 Span::call_site() => self.#field.merge(other.#field);
275 }
276 })
277 .collect()
278}
279
280fn impl_validate_merge(fields: &HashMap<String, Type>) -> TokenStream {
281 fields
282 .keys()
283 .map(|field| {
284 let error_name = Ident::new(&field.to_case(Case::Pascal), Span::call_site());
285 let field = Ident::new(field, Span::call_site());
286 quote! {
287 self.#field.validate_merge(&other.#field)
288 .map_err(Self::Validation::#error_name)?;
289 }
290 })
291 .collect()
292}
293
294fn count_none(fields: &HashMap<String, Type>) -> TokenStream {
295 fields
296 .keys()
297 .filter(|&f| f != "v_clock")
298 .map(|_| quote!(None,))
299 .collect::<Vec<_>>()
300 .into_iter()
301 .collect::<TokenStream>()
302}
303
304fn op_params(fields: &HashMap<String, Type>) -> TokenStream {
305 fields
306 .keys()
307 .filter(|f| *f != "v_clock")
308 .map(|f| format!("{}_op", f))
309 .map(|i| Ident::new(&i, Span::call_site()))
310 .map(|i| quote!(#i,))
311 .collect()
312}