1extern crate either;
2#[macro_use]
3extern crate quote;
4extern crate proc_macro;
5extern crate syn;
6
7use either::Either;
8use proc_macro as pm;
9use proc_macro2 as pm2;
10use syn::parse_macro_input;
11use syn::spanned::Spanned;
12
13mod kw {
14 syn::custom_keyword!(equal);
15 syn::custom_keyword!(equal_with);
16 syn::custom_keyword!(methods);
17 syn::custom_keyword!(model);
18 syn::custom_keyword!(post);
19 syn::custom_keyword!(pre);
20 syn::custom_keyword!(tested);
21 syn::custom_keyword!(type_parameters);
22}
23
24#[allow(clippy::enum_variant_names)]
25enum PassingMode {
26 ByValue,
27 ByRef,
28 ByRefMut,
29}
30
31struct Argument {
32 name: syn::Ident,
33 ty: syn::Type,
34 passing_mode: PassingMode,
35}
36
37struct Method {
38 name: syn::Ident,
39 inputs: Vec<Argument>,
41 process_result: Option<syn::Path>,
42 }
44
45impl syn::parse::Parse for Method {
46 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
47 let method_item: syn::TraitItemMethod = input.parse()?;
48
49 if let Some(ref defaultness) = method_item.default {
50 return Err(syn::Error::new(defaultness.span(), "unexpected `default`"));
51 }
52 if let Some(ref constness) = method_item.sig.constness {
53 return Err(syn::Error::new(constness.span(), "unexpected `const`"));
54 }
55 if let Some(ref asyncness) = method_item.sig.asyncness {
56 return Err(syn::Error::new(asyncness.span(), "unexpected `async`"));
57 }
58 if let Some(ref unsafety) = method_item.sig.unsafety {
59 return Err(syn::Error::new(unsafety.span(), "unexpected `unsafe`"));
60 }
61
62 let (receivers, args) = method_item
63 .sig
64 .inputs
65 .iter()
66 .map(|input| match input {
67 syn::FnArg::Receiver(receiver) => Either::Left(receiver),
68 syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => {
69 let ident = match **pat {
70 syn::Pat::Ident(syn::PatIdent { ref ident, .. }) => ident.clone(),
71 ref pat => {
72 syn::Ident::new("_", pat.span())
75 }
76 };
77 match **ty {
78 syn::Type::Reference(syn::TypeReference {
79 ref mutability,
80 ref elem,
81 ..
82 }) => Either::Right(Argument {
83 name: ident,
84 ty: (**elem).clone(),
85 passing_mode: if mutability.is_some() {
86 PassingMode::ByRefMut
87 } else {
88 PassingMode::ByRef
89 },
90 }),
91 ref ty => Either::Right(Argument {
92 name: ident,
93 ty: ty.clone(),
94 passing_mode: PassingMode::ByValue,
95 }),
96 }
97 }
98 })
99 .partition::<Vec<_>, _>(Either::is_left);
100
101 let receivers: Vec<_> = receivers.into_iter().filter_map(Either::left).collect();
102 let args: Vec<_> = args.into_iter().filter_map(Either::right).collect();
103
104 let receiver = receivers.first();
105 if let Some(receiver) = receiver {
106 if receiver.reference.is_none() {
107 return Err(syn::Error::new(
108 receiver.span(),
109 "unexpected by-value receiver",
110 ));
111 }
112 } else {
113 return Err(syn::Error::new(
114 method_item.span(),
115 "unexpected method with no receiver",
116 ));
117 }
118
119 Ok(Self {
120 name: method_item.sig.ident,
121 process_result: None,
123 inputs: args,
124 })
131 }
132}
133
134struct Specification {
135 model: syn::Path,
136 tested: syn::Path,
137 type_params: Vec<syn::TypeParam>,
138 methods: Vec<Method>,
139 post: Vec<syn::Stmt>,
140 pre: Vec<syn::Stmt>,
141}
142
143impl syn::parse::Parse for Specification {
144 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
145 use syn::{braced, parenthesized, Token};
146
147 let mut model: Option<syn::Path> = None;
148 let mut tested: Option<syn::Path> = None;
149 let mut type_params: Vec<syn::TypeParam> = vec![];
150 let mut methods: Vec<Method> = vec![];
151 let mut post: Vec<syn::Stmt> = vec![];
152 let mut pre: Vec<syn::Stmt> = vec![];
153
154 while !input.is_empty() {
155 let lookahead = input.lookahead1();
156 if lookahead.peek(kw::model) {
157 let _: kw::model = input.parse()?;
158 let _: Token![=] = input.parse()?;
159 model = Some(input.parse()?);
160 } else if lookahead.peek(kw::tested) {
161 let _: kw::tested = input.parse()?;
162 let _: Token![=] = input.parse()?;
163 tested = Some(input.parse()?);
164 } else if lookahead.peek(kw::type_parameters) {
165 let _: kw::type_parameters = input.parse()?;
166 let _: Token![=] = input.parse()?;
167 let generics: syn::Generics = input.parse()?;
168 type_params = generics.type_params().cloned().collect();
169 } else if lookahead.peek(kw::methods) {
170 let outer;
171 let mut inner;
172 let _: kw::methods = input.parse()?;
173 braced!(outer in input);
174
175 while !outer.is_empty() {
176 let lookahead = outer.lookahead1();
177 let process = if lookahead.peek(kw::equal) {
178 let _: kw::equal = outer.parse()?;
179 None
180 } else if lookahead.peek(kw::equal_with) {
181 let _: kw::equal_with = outer.parse()?;
182 let path;
183 parenthesized!(path in outer);
184 Some(path.parse()?)
185 } else {
186 return Err(lookahead.error());
187 };
188
189 braced!(inner in outer);
190 while !inner.is_empty() {
191 let mut method: Method = inner.parse()?;
192 method.process_result = process.clone();
193 methods.push(method);
194 }
195 }
196 } else if lookahead.peek(kw::post) {
197 let inner;
198 let _: kw::post = input.parse()?;
199 braced!(inner in input);
200 while !inner.is_empty() {
201 post.push(inner.parse()?);
202 }
203 } else if lookahead.peek(kw::pre) {
204 let inner;
205 let _: kw::pre = input.parse()?;
206 braced!(inner in input);
207 while !inner.is_empty() {
208 pre.push(inner.parse()?);
209 }
210 } else {
211 return Err(lookahead.error());
212 }
213
214 if input.peek(Token![,]) {
215 let _: Token![,] = input.parse()?;
216 }
217 }
218
219 let model = match model {
220 Some(model) => model,
221 None => return Err(input.error("missing `model`")),
222 };
223
224 let tested = match tested {
225 Some(tested) => tested,
226 None => return Err(input.error("missing `tested`")),
227 };
228
229 Ok(Self {
230 model,
231 tested,
232 type_params,
233 methods,
234 post,
235 pre,
236 })
237 }
238}
239
240impl quote::ToTokens for Method {
241 fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
242 use pm2::{Delimiter, Group, Punct, Spacing};
243 use quote::TokenStreamExt;
244
245 tokens.append(self.name.clone());
246
247 if !self.inputs.is_empty() {
248 let mut fields = pm2::TokenStream::new();
249 for input in &self.inputs {
250 fields.append(input.name.clone());
251 fields.append(Punct::new(':', Spacing::Joint));
252 input.ty.to_tokens(&mut fields);
253 fields.append(Punct::new(',', Spacing::Joint));
254 }
255 tokens.append(Group::new(Delimiter::Brace, fields));
256 }
257 }
258}
259
260struct MethodTest<'s> {
261 method: &'s Method,
262 compare: bool,
263}
264
265impl<'s> quote::ToTokens for MethodTest<'s> {
266 fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
267 let args: Vec<_> = self
268 .method
269 .inputs
270 .iter()
271 .map(|input| {
272 let input_name = &input.name;
273 match input.passing_mode {
274 PassingMode::ByValue => quote! { #input_name.clone() },
275 PassingMode::ByRef => quote! { #input_name },
276 PassingMode::ByRefMut => quote! { &mut *#input_name },
277 }
278 })
279 .collect();
280
281 let method_name = &self.method.name;
282
283 let keys: Vec<_> = self.method.inputs.iter().map(|input| &input.name).collect();
284 let pattern = if keys.is_empty() {
285 quote! { Op::#method_name }
286 } else {
287 quote! { Op::#method_name { #(ref #keys),* } }
288 };
289
290 let process_tested_res = self
291 .method
292 .process_result
293 .as_ref()
294 .map(|p| quote! { #p(tested_res) })
295 .unwrap_or(quote! { tested_res });
296
297 if self.compare {
298 let process_model_res = self
299 .method
300 .process_result
301 .as_ref()
302 .map(|p| quote! { #p(model_res) })
303 .unwrap_or(quote! { model_res });
304 tokens.extend(quote! {
305 #pattern => {
306 let model_res = model.#method_name(#(#args),*);
307 let tested_res = tested.#method_name(#(#args),*);
308 let model_res = #process_model_res;
309 let tested_res = #process_tested_res;
310 assert_eq!(model_res, tested_res);
311 }
312 });
313 } else {
314 tokens.extend(quote! {
315 #pattern => {
316 let _ = tested.#method_name(#(#args),*);
317 }
318 });
319 }
320 }
321}
322
323struct OperationEnum<'s> {
324 spec: &'s Specification,
325}
326
327impl<'s> quote::ToTokens for OperationEnum<'s> {
328 #[allow(clippy::cognitive_complexity)]
329 fn to_tokens(&self, tokens: &mut pm2::TokenStream) {
330 let type_params_with_bounds = &self.spec.type_params;
331 let type_params: Vec<_> = type_params_with_bounds
332 .iter()
333 .map(|tp| tp.ident.clone())
334 .collect();
335
336 let model = &self.spec.model;
337 let tested = &self.spec.tested;
338 let variants = &self.spec.methods;
339
340 let comp_method_tests: Vec<_> = self
341 .spec
342 .methods
343 .iter()
344 .map(|method| MethodTest {
345 method,
346 compare: true,
347 })
348 .collect();
349
350 let method_tests: Vec<_> = self
351 .spec
352 .methods
353 .iter()
354 .map(|method| MethodTest {
355 method,
356 compare: false,
357 })
358 .collect();
359
360 let format_calls: Vec<_> = self
361 .spec
362 .methods
363 .iter()
364 .map(|method| {
365 let args: Vec<_> = method
366 .inputs
367 .iter()
368 .map(|input| match input.passing_mode {
369 PassingMode::ByValue => "{:?}",
370 PassingMode::ByRef => "&{:?}",
371 PassingMode::ByRefMut => "&mut {:?}",
372 })
373 .collect();
374
375 let method_name = &method.name;
376 let format_str = format!("v.{}({});", method_name, args.join(", "));
377 let keys: Vec<_> = method.inputs.iter().map(|input| &input.name).collect();
378 let pattern = if keys.is_empty() {
379 quote! { Op::#method_name }
380 } else {
381 quote! { Op::#method_name { #(#keys),* } }
382 };
383
384 quote! { #pattern =>
385 write!(f, #format_str, #(#keys),*)
386 }
387 })
388 .collect();
389
390 let post = &self.spec.post;
391 let pre = &self.spec.pre;
392
393 tokens.extend(quote! {
394 #[allow(non_camel_case_types)]
395 #[derive(arbitrary::Arbitrary, Clone, Debug, PartialEq)]
396 pub enum Op<#(#type_params_with_bounds),*> {
397 #(#variants),*
398 }
399
400 impl<#(#type_params_with_bounds),*> std::fmt::Display for Op<#(#type_params),*> {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 match self {
403 #(#format_calls),*
404 }
405 }
406 }
407
408 impl<#(#type_params_with_bounds),*> Op<#(#type_params),*> {
409 pub fn execute(self, tested: &mut #tested) {
410 match &self {
411 #(#method_tests),*
412 }
413 }
414
415 pub fn execute_and_compare(self, model: &mut #model, tested: &mut #tested) {
416 #(#pre)*
417 match &self {
418 #(#comp_method_tests),*
419 }
420 #(#post)*
421 }
422 }
423 })
424 }
425}
426
427#[proc_macro]
428pub fn arbitrary_stateful_operations(input: pm::TokenStream) -> pm::TokenStream {
429 let parsed_spec = parse_macro_input!(input as Specification);
430
431 let operation_enum = OperationEnum { spec: &parsed_spec };
432
433 let output = quote! {
434 mod op {
435 use super::*;
436 #operation_enum
437 }
438 };
439
440 output.into()
441}