Skip to main content

arbitrary_model_tests/
lib.rs

1extern crate either;
2#[macro_use]
3extern crate quote;
4extern crate proc_macro;
5extern crate syn;
6
7use either::Either;
8use proc_macro as pm;
9use proc_macro2 as pm2;
10use syn::parse_macro_input;
11use syn::spanned::Spanned;
12
13mod kw {
14    syn::custom_keyword!(equal);
15    syn::custom_keyword!(equal_with);
16    syn::custom_keyword!(methods);
17    syn::custom_keyword!(model);
18    syn::custom_keyword!(post);
19    syn::custom_keyword!(pre);
20    syn::custom_keyword!(tested);
21    syn::custom_keyword!(type_parameters);
22}
23
24#[allow(clippy::enum_variant_names)]
25enum PassingMode {
26    ByValue,
27    ByRef,
28    ByRefMut,
29}
30
31struct Argument {
32    name: syn::Ident,
33    ty: syn::Type,
34    passing_mode: PassingMode,
35}
36
37struct Method {
38    name: syn::Ident,
39    // self_mut: bool,
40    inputs: Vec<Argument>,
41    process_result: Option<syn::Path>,
42    // output: syn::Type
43}
44
45impl syn::parse::Parse for Method {
46    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
47        let method_item: syn::TraitItemMethod = input.parse()?;
48
49        if let Some(ref defaultness) = method_item.default {
50            return Err(syn::Error::new(defaultness.span(), "unexpected `default`"));
51        }
52        if let Some(ref constness) = method_item.sig.constness {
53            return Err(syn::Error::new(constness.span(), "unexpected `const`"));
54        }
55        if let Some(ref asyncness) = method_item.sig.asyncness {
56            return Err(syn::Error::new(asyncness.span(), "unexpected `async`"));
57        }
58        if let Some(ref unsafety) = method_item.sig.unsafety {
59            return Err(syn::Error::new(unsafety.span(), "unexpected `unsafe`"));
60        }
61
62        let (receivers, args) = method_item
63            .sig
64            .inputs
65            .iter()
66            .map(|input| match input {
67                syn::FnArg::Receiver(receiver) => Either::Left(receiver),
68                syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => {
69                    let ident = match **pat {
70                        syn::Pat::Ident(syn::PatIdent { ref ident, .. }) => ident.clone(),
71                        ref pat => {
72                            //error_stream.extend(
73                            //    syn::Error::new(pat.span(), "unexpected `unsafe`").to_compile_error());
74                            syn::Ident::new("_", pat.span())
75                        }
76                    };
77                    match **ty {
78                        syn::Type::Reference(syn::TypeReference {
79                            ref mutability,
80                            ref elem,
81                            ..
82                        }) => Either::Right(Argument {
83                            name: ident,
84                            ty: (**elem).clone(),
85                            passing_mode: if mutability.is_some() {
86                                PassingMode::ByRefMut
87                            } else {
88                                PassingMode::ByRef
89                            },
90                        }),
91                        ref ty => Either::Right(Argument {
92                            name: ident,
93                            ty: ty.clone(),
94                            passing_mode: PassingMode::ByValue,
95                        }),
96                    }
97                }
98            })
99            .partition::<Vec<_>, _>(Either::is_left);
100
101        let receivers: Vec<_> = receivers.into_iter().filter_map(Either::left).collect();
102        let args: Vec<_> = args.into_iter().filter_map(Either::right).collect();
103
104        let receiver = receivers.first();
105        if let Some(receiver) = receiver {
106            if receiver.reference.is_none() {
107                return Err(syn::Error::new(
108                    receiver.span(),
109                    "unexpected by-value receiver",
110                ));
111            }
112        } else {
113            return Err(syn::Error::new(
114                method_item.span(),
115                "unexpected method with no receiver",
116            ));
117        }
118
119        Ok(Self {
120            name: method_item.sig.ident,
121            // self_mut: receiver.map_or(false, |r| r.mutability.is_some()),
122            process_result: None,
123            inputs: args,
124            /*output: match method_item.sig.output {
125                syn::ReturnType::Default =>
126                    syn::parse_str("()").unwrap(),
127                syn::ReturnType::Type(_, typ) =>
128                    (*typ).clone()
129            }*/
130        })
131    }
132}
133
134struct Specification {
135    model: syn::Path,
136    tested: syn::Path,
137    type_params: Vec<syn::TypeParam>,
138    methods: Vec<Method>,
139    post: Vec<syn::Stmt>,
140    pre: Vec<syn::Stmt>,
141}
142
143impl syn::parse::Parse for Specification {
144    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
145        use syn::{braced, parenthesized, Token};
146
147        let mut model: Option<syn::Path> = None;
148        let mut tested: Option<syn::Path> = None;
149        let mut type_params: Vec<syn::TypeParam> = vec![];
150        let mut methods: Vec<Method> = vec![];
151        let mut post: Vec<syn::Stmt> = vec![];
152        let mut pre: Vec<syn::Stmt> = vec![];
153
154        while !input.is_empty() {
155            let lookahead = input.lookahead1();
156            if lookahead.peek(kw::model) {
157                let _: kw::model = input.parse()?;
158                let _: Token![=] = input.parse()?;
159                model = Some(input.parse()?);
160            } else if lookahead.peek(kw::tested) {
161                let _: kw::tested = input.parse()?;
162                let _: Token![=] = input.parse()?;
163                tested = Some(input.parse()?);
164            } else if lookahead.peek(kw::type_parameters) {
165                let _: kw::type_parameters = input.parse()?;
166                let _: Token![=] = input.parse()?;
167                let generics: syn::Generics = input.parse()?;
168                type_params = generics.type_params().cloned().collect();
169            } else if lookahead.peek(kw::methods) {
170                let outer;
171                let mut inner;
172                let _: kw::methods = input.parse()?;
173                braced!(outer in input);
174
175                while !outer.is_empty() {
176                    let lookahead = outer.lookahead1();
177                    let process = if lookahead.peek(kw::equal) {
178                        let _: kw::equal = outer.parse()?;
179                        None
180                    } else if lookahead.peek(kw::equal_with) {
181                        let _: kw::equal_with = outer.parse()?;
182                        let path;
183                        parenthesized!(path in outer);
184                        Some(path.parse()?)
185                    } else {
186                        return Err(lookahead.error());
187                    };
188
189                    braced!(inner in outer);
190                    while !inner.is_empty() {
191                        let mut method: Method = inner.parse()?;
192                        method.process_result = process.clone();
193                        methods.push(method);
194                    }
195                }
196            } else if lookahead.peek(kw::post) {
197                let inner;
198                let _: kw::post = input.parse()?;
199                braced!(inner in input);
200                while !inner.is_empty() {
201                    post.push(inner.parse()?);
202                }
203            } else if lookahead.peek(kw::pre) {
204                let inner;
205                let _: kw::pre = input.parse()?;
206                braced!(inner in input);
207                while !inner.is_empty() {
208                    pre.push(inner.parse()?);
209                }
210            } else {
211                return Err(lookahead.error());
212            }
213
214            if input.peek(Token![,]) {
215                let _: Token![,] = input.parse()?;
216            }
217        }
218
219        let model = match model {
220            Some(model) => model,
221            None => return Err(input.error("missing `model`")),
222        };
223
224        let tested = match tested {
225            Some(tested) => tested,
226            None => return Err(input.error("missing `tested`")),
227        };
228
229        Ok(Self {
230            model,
231            tested,
232            type_params,
233            methods,
234            post,
235            pre,
236        })
237    }
238}
239
240impl quote::ToTokens for Method {
241    fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
242        use pm2::{Delimiter, Group, Punct, Spacing};
243        use quote::TokenStreamExt;
244
245        tokens.append(self.name.clone());
246
247        if !self.inputs.is_empty() {
248            let mut fields = pm2::TokenStream::new();
249            for input in &self.inputs {
250                fields.append(input.name.clone());
251                fields.append(Punct::new(':', Spacing::Joint));
252                input.ty.to_tokens(&mut fields);
253                fields.append(Punct::new(',', Spacing::Joint));
254            }
255            tokens.append(Group::new(Delimiter::Brace, fields));
256        }
257    }
258}
259
260struct MethodTest<'s> {
261    method: &'s Method,
262    compare: bool,
263}
264
265impl<'s> quote::ToTokens for MethodTest<'s> {
266    fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
267        let args: Vec<_> = self
268            .method
269            .inputs
270            .iter()
271            .map(|input| {
272                let input_name = &input.name;
273                match input.passing_mode {
274                    PassingMode::ByValue => quote! { #input_name.clone() },
275                    PassingMode::ByRef => quote! { #input_name },
276                    PassingMode::ByRefMut => quote! { &mut *#input_name },
277                }
278            })
279            .collect();
280
281        let method_name = &self.method.name;
282
283        let keys: Vec<_> = self.method.inputs.iter().map(|input| &input.name).collect();
284        let pattern = if keys.is_empty() {
285            quote! { Op::#method_name }
286        } else {
287            quote! { Op::#method_name { #(ref #keys),* } }
288        };
289
290        let process_tested_res = self
291            .method
292            .process_result
293            .as_ref()
294            .map(|p| quote! { #p(tested_res) })
295            .unwrap_or(quote! { tested_res });
296
297        if self.compare {
298            let process_model_res = self
299                .method
300                .process_result
301                .as_ref()
302                .map(|p| quote! { #p(model_res) })
303                .unwrap_or(quote! { model_res });
304            tokens.extend(quote! {
305                #pattern => {
306                    let model_res = model.#method_name(#(#args),*);
307                    let tested_res = tested.#method_name(#(#args),*);
308                    let model_res = #process_model_res;
309                    let tested_res = #process_tested_res;
310                    assert_eq!(model_res, tested_res);
311                }
312            });
313        } else {
314            tokens.extend(quote! {
315                #pattern => {
316                    let _ = tested.#method_name(#(#args),*);
317                }
318            });
319        }
320    }
321}
322
323struct OperationEnum<'s> {
324    spec: &'s Specification,
325}
326
327impl<'s> quote::ToTokens for OperationEnum<'s> {
328    #[allow(clippy::cognitive_complexity)]
329    fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
330        let type_params_with_bounds = &self.spec.type_params;
331        let type_params: Vec<_> = type_params_with_bounds
332            .iter()
333            .map(|tp| tp.ident.clone())
334            .collect();
335
336        let model = &self.spec.model;
337        let tested = &self.spec.tested;
338        let variants = &self.spec.methods;
339
340        let comp_method_tests: Vec<_> = self
341            .spec
342            .methods
343            .iter()
344            .map(|method| MethodTest {
345                method,
346                compare: true,
347            })
348            .collect();
349
350        let method_tests: Vec<_> = self
351            .spec
352            .methods
353            .iter()
354            .map(|method| MethodTest {
355                method,
356                compare: false,
357            })
358            .collect();
359
360        let format_calls: Vec<_> = self
361            .spec
362            .methods
363            .iter()
364            .map(|method| {
365                let args: Vec<_> = method
366                    .inputs
367                    .iter()
368                    .map(|input| match input.passing_mode {
369                        PassingMode::ByValue => "{:?}",
370                        PassingMode::ByRef => "&{:?}",
371                        PassingMode::ByRefMut => "&mut {:?}",
372                    })
373                    .collect();
374
375                let method_name = &method.name;
376                let format_str = format!("v.{}({});", method_name, args.join(", "));
377                let keys: Vec<_> = method.inputs.iter().map(|input| &input.name).collect();
378                let pattern = if keys.is_empty() {
379                    quote! { Op::#method_name }
380                } else {
381                    quote! { Op::#method_name { #(#keys),* } }
382                };
383
384                quote! { #pattern =>
385                    write!(f, #format_str, #(#keys),*)
386                }
387            })
388            .collect();
389
390        let post = &self.spec.post;
391        let pre = &self.spec.pre;
392
393        tokens.extend(quote! {
394            #[allow(non_camel_case_types)]
395            #[derive(arbitrary::Arbitrary, Clone, Debug, PartialEq)]
396            pub enum Op<#(#type_params_with_bounds),*> {
397                #(#variants),*
398            }
399
400            impl<#(#type_params_with_bounds),*> std::fmt::Display for Op<#(#type_params),*> {
401                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402                    match self {
403                        #(#format_calls),*
404                    }
405                }
406            }
407
408            impl<#(#type_params_with_bounds),*> Op<#(#type_params),*> {
409                pub fn execute(self, tested: &mut #tested) {
410                    match &self {
411                        #(#method_tests),*
412                    }
413                }
414
415                pub fn execute_and_compare(self, model: &mut #model, tested: &mut #tested) {
416                    #(#pre)*
417                    match &self {
418                        #(#comp_method_tests),*
419                    }
420                    #(#post)*
421                }
422            }
423        })
424    }
425}
426
427#[proc_macro]
428pub fn arbitrary_stateful_operations(input: pm::TokenStream) -> pm::TokenStream {
429    let parsed_spec = parse_macro_input!(input as Specification);
430
431    let operation_enum = OperationEnum { spec: &parsed_spec };
432
433    let output = quote! {
434        mod op {
435            use super::*;
436            #operation_enum
437        }
438    };
439
440    output.into()
441}