prusti_specs/
print_counterexample.rs

1use super::rewriter::AstRewriter;
2use itertools::Itertools;
3use proc_macro2::{Span, TokenStream};
4use quote::{quote_spanned, ToTokens};
5use syn::{
6    parse::Parser, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, Expr, ExprLit,
7    Fields, Generics, Ident, Lit, Pat, PatLit, Token,
8};
9
10pub fn rewrite_struct(
11    attrs: TokenStream,
12    item_struct: syn::ItemStruct,
13) -> syn::Result<Vec<syn::Item>> {
14    let res = rewrite_internal_struct(attrs, item_struct);
15    match res {
16        Ok(result) => Ok(result),
17        Err(err) => Err(err.into()),
18    }
19}
20
21pub fn rewrite_enum(attrs: TokenStream, item_enum: syn::ItemEnum) -> syn::Result<Vec<syn::Item>> {
22    let res = rewrite_internal_enum(attrs, item_enum);
23    match res {
24        Ok(result) => Ok(result),
25        Err(err) => Err(err.into()),
26    }
27}
28
29type TypeCounterexampleResult<R> = Result<R, TypeCounterexampleError>;
30
31#[derive(Debug)]
32enum TypeCounterexampleError {
33    ArgumentsDoNotMatch(proc_macro2::Span),
34    WrongFirstArgument(proc_macro2::Span),
35    AtLeastOneArgument(proc_macro2::Span),
36    WrongNumberOfArguemnts(proc_macro2::Span),
37    InvalidName(proc_macro2::Span),
38    InvalidArgument(proc_macro2::Span, String, String),
39    ParsingError(syn::Error),
40}
41
42impl std::convert::From<TypeCounterexampleError> for syn::Error {
43    fn from(err: TypeCounterexampleError) -> Self {
44        match err {
45            TypeCounterexampleError::ArgumentsDoNotMatch(span) => {
46                syn::Error::new(span, "Number of arguments and number of {} do not match")
47            }
48            TypeCounterexampleError::WrongFirstArgument(span) => {
49                syn::Error::new(span, "First argument must be a string literal")
50            }
51            TypeCounterexampleError::AtLeastOneArgument(span) => {
52                syn::Error::new(span, "At least one argument is expected")
53            }
54            TypeCounterexampleError::InvalidName(span) => {
55                syn::Error::new(span, "Invalid argument name")
56            }
57            TypeCounterexampleError::InvalidArgument(span, name, arg) => {
58                syn::Error::new(span, format!("`{name}` does not have a field named {arg}"))
59            }
60            TypeCounterexampleError::WrongNumberOfArguemnts(span) => {
61                syn::Error::new(span, "Number of arguments are incorrect")
62            }
63            TypeCounterexampleError::ParsingError(parse_err) => parse_err,
64        }
65    }
66}
67
68fn rewrite_internal_struct(
69    attr: TokenStream,
70    item_struct: syn::ItemStruct,
71) -> TypeCounterexampleResult<Vec<syn::Item>> {
72    let parser = Punctuated::<Pat, Token![,]>::parse_terminated;
73    let attrs = match parser.parse(attr.clone().into()) {
74        Ok(result) => result,
75        Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
76    };
77    let len = attrs.len();
78
79    let (first_arg, args) = process_attr(&attrs, len)?;
80    let mut rewriter = AstRewriter::new();
81    let spec_id = rewriter.generate_spec_id();
82    let spec_id_str = spec_id.to_string();
83    let item_span = item_struct.span();
84    let item_name = syn::Ident::new(
85        &format!(
86            "prusti_print_counterexample_item_{}_{}",
87            item_struct.ident, spec_id
88        ),
89        item_span,
90    );
91    let mut args2: Punctuated<Pat, Token![,]> = attrs
92        .into_iter()
93        .skip(1)
94        .unique()
95        .collect::<Punctuated<Pat, Token![,]>>();
96    //add trailing punctuation
97    if !args2.empty_or_trailing() {
98        args2.push_punct(<syn::Token![,]>::default());
99    }
100
101    // clippy false positive (https://github.com/rust-lang/rust-clippy/issues/10577)
102    #[allow(clippy::redundant_clone)]
103    let typ = item_struct.ident.clone();
104
105    let spec_item = match item_struct.fields {
106        Fields::Named(_) => {
107            let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
108                #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
109                #[prusti::spec_only]
110                #[prusti::counterexample_print]
111                #[prusti::spec_id = #spec_id_str]
112                fn #item_name(self){
113                    if let #typ{#args2 ..} = self{
114                        #first_arg
115                        #args
116                    }
117                }
118            };
119            spec_item
120        }
121        Fields::Unnamed(ref fields_unnamed) => {
122            //check if all args are correct
123            check_validity_of_args(
124                args2,
125                fields_unnamed.unnamed.len() as u32,
126                &item_struct.ident.to_string(),
127            )?;
128
129            let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
130                #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
131                #[prusti::spec_only]
132                #[prusti::counterexample_print]
133                #[prusti::spec_id = #spec_id_str]
134                fn #item_name(self){
135                    if let #typ{..} = self{
136                        #first_arg
137                        #args
138                    }
139                }
140            };
141            spec_item
142        }
143        Fields::Unit => {
144            if len == 1 {
145                let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
146                    #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
147                    #[prusti::spec_only]
148                    #[prusti::counterexample_print]
149                    #[prusti::spec_id = #spec_id_str]
150                    fn #item_name(self){
151                        if let #typ{..} = self{
152                            #first_arg
153                        }
154                    }
155                };
156                spec_item
157            } else {
158                return Err(TypeCounterexampleError::WrongNumberOfArguemnts(attr.span()));
159            }
160        }
161    };
162
163    let item_impl = generate_generics(
164        item_struct.span(),
165        item_struct.ident.clone(),
166        &item_struct.generics,
167        spec_item.into_token_stream(),
168    );
169    Ok(vec![syn::Item::Impl(item_impl)])
170}
171
172fn rewrite_internal_enum(
173    attr: TokenStream,
174    item_enum: syn::ItemEnum,
175) -> TypeCounterexampleResult<Vec<syn::Item>> {
176    let parser = Punctuated::<Pat, Token![,]>::parse_terminated;
177    let attrs = match parser.parse(attr.clone().into()) {
178        Ok(result) => result,
179        Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
180    };
181    let item_span = item_enum.span();
182    let len = attrs.len();
183    if len != 0 {
184        return Err(TypeCounterexampleError::WrongNumberOfArguemnts(item_span));
185    }
186    let mut spec_items: Vec<syn::ItemFn> = vec![];
187    let enum_name = item_enum.ident.clone();
188    let mut rewriter = AstRewriter::new();
189    let spec_id = rewriter.generate_spec_id();
190    let spec_id_str = spec_id.to_string(); //Does this have to be unique?
191
192    for variant in &item_enum.variants {
193        if let Some(custom_print) = variant.attrs.iter().find(|attr| {
194            attr.path.get_ident().map(|x| x.to_string()) == Some("print_counterexample".to_string())
195        }) {
196            let variant_name = variant.ident.clone();
197            let item_span = variant.ident.span();
198            let item_name = syn::Ident::new(
199                &format!(
200                    "prusti_print_counterexample_variant_{}_{}",
201                    variant.ident, spec_id
202                ),
203                item_span,
204            );
205            let variant_name_str = variant_name.to_string();
206            let parser = Punctuated::<Pat, Token![,]>::parse_terminated; //parse_separated_nonempty;
207            let attrs = match custom_print.parse_args_with(parser) {
208                Ok(result) => result,
209                Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
210            };
211
212            let len = attrs.len();
213            let (first_arg, args) = process_attr(&attrs, len)?;
214            match &variant.fields {
215                Fields::Named(_) => {
216                    let mut args2: Punctuated<Pat, Token![,]> = attrs
217                        .into_iter()
218                        .skip(1)
219                        .unique()
220                        .collect::<Punctuated<Pat, Token![,]>>(
221                    );
222                    if !args2.empty_or_trailing() {
223                        args2.push_punct(<syn::Token![,]>::default());
224                    }
225                    let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
226                        #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
227                        #[prusti::spec_only]
228                        #[prusti::counterexample_print  = #variant_name_str]
229                        #[prusti::spec_id = #spec_id_str]
230                        fn #item_name(self) {
231                            if let #enum_name::#variant_name{#args2 ..} = self{
232                                #first_arg
233                                #args
234                            }
235                        }
236                    };
237                    spec_items.push(spec_item);
238                }
239                Fields::Unnamed(fields_unnamed) => {
240                    let args2: Punctuated<Pat, Token![,]> = attrs
241                        .into_iter()
242                        .skip(1)
243                        .unique()
244                        .collect::<Punctuated<Pat, Token![,]>>(
245                    );
246
247                    //check if all args are correct
248                    check_validity_of_args(
249                        args2,
250                        fields_unnamed.unnamed.len() as u32,
251                        &item_enum.ident.to_string(),
252                    )?;
253                    let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
254                        #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
255                        #[prusti::spec_only]
256                        #[prusti::counterexample_print = #variant_name_str]
257                        #[prusti::spec_id = #spec_id_str]
258                        fn #item_name(self) {
259                            if let #enum_name::#variant_name(..) = self{
260                                #first_arg
261                                #args
262                            }
263                        }
264                    };
265                    spec_items.push(spec_item);
266                }
267                Fields::Unit => {
268                    if len == 1 {
269                        let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
270                            #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
271                            #[prusti::spec_only]
272                            #[prusti::counterexample_print = #variant_name_str]
273                            #[prusti::spec_id = #spec_id_str]
274                            fn #item_name(self) {
275                                if let #enum_name::#variant_name = self{
276                                    #first_arg
277                                }
278                            }
279                        };
280                        spec_items.push(spec_item);
281                    } else {
282                        return Err(TypeCounterexampleError::WrongNumberOfArguemnts(attr.span()));
283                    }
284                }
285            }
286        }
287    }
288    let mut spec_item_as_tokens = TokenStream::new();
289    for x in spec_items {
290        x.to_tokens(&mut spec_item_as_tokens);
291    }
292
293    let item_impl = generate_generics(
294        item_enum.span(),
295        item_enum.ident.clone(),
296        &item_enum.generics,
297        spec_item_as_tokens.into_token_stream(),
298    );
299    let mut item_enum_new = item_enum;
300    for variant in &mut item_enum_new.variants {
301        //remove all macros inside the enum
302        variant.attrs.retain(|attr| {
303            attr.path.get_ident().map(|x| x.to_string()) != Some("print_counterexample".to_string())
304        });
305    }
306    Ok(vec![
307        syn::Item::Enum(item_enum_new),
308        syn::Item::Impl(item_impl),
309    ])
310}
311
312fn process_attr(
313    attrs: &Punctuated<Pat, Token![,]>,
314    len: usize,
315) -> TypeCounterexampleResult<(TokenStream, TokenStream)> {
316    let mut attrs_iter = attrs.iter();
317    let callsite_span = Span::call_site();
318    //first arg
319    let first_as_token = if let Some(text) = attrs_iter.next() {
320        let span = text.span();
321        match text {
322            Pat::Lit(PatLit {
323                attrs: _,
324                expr:
325                    box Expr::Lit(ExprLit {
326                        attrs: _,
327                        lit: Lit::Str(lit_str),
328                    }),
329            }) => {
330                let value = lit_str.value();
331                let count = value.matches("{}").count();
332                if count != len - 1 {
333                    return Err(TypeCounterexampleError::ArgumentsDoNotMatch(span));
334                }
335                quote_spanned! {callsite_span=> #value;}
336            }
337            _ => return Err(TypeCounterexampleError::WrongFirstArgument(span)),
338        }
339    } else {
340        return Err(TypeCounterexampleError::AtLeastOneArgument(attrs.span()));
341    };
342    //other args
343    let args_as_token = attrs_iter
344        .map(|pat| match pat {
345            Pat::Ident(pat_ident) => {
346                quote_spanned! {callsite_span=> #pat_ident; }
347            }
348            Pat::Lit(PatLit {
349                attrs: _,
350                expr:
351                    box Expr::Lit(ExprLit {
352                        attrs: _,
353                        lit: Lit::Int(lit_int),
354                    }),
355            }) => {
356                quote_spanned! {callsite_span=> #lit_int; }
357            }
358            _ => {
359                let err: syn::Error = TypeCounterexampleError::InvalidName(callsite_span).into();
360                err.to_compile_error()
361            }
362        })
363        .collect::<TokenStream>();
364    Ok((first_as_token, args_as_token))
365}
366fn check_validity_of_args(
367    args: Punctuated<Pat, Token![,]>,
368    len: u32,
369    name: &String,
370) -> TypeCounterexampleResult<()> {
371    for arg in &args {
372        if let Pat::Lit(PatLit {
373            attrs: _,
374            expr:
375                box Expr::Lit(ExprLit {
376                    attrs: _,
377                    lit: Lit::Int(lit_int),
378                }),
379        }) = arg
380        {
381            let value: u32 = match lit_int.base10_parse() {
382                Ok(result) => result,
383                Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
384            };
385            if value >= len {
386                return Err(TypeCounterexampleError::InvalidArgument(
387                    arg.span(),
388                    name.to_string(),
389                    value.to_string(),
390                ));
391            }
392        } else {
393            return Err(TypeCounterexampleError::InvalidName(arg.span()));
394        }
395    }
396    Ok(())
397}
398
399fn generate_generics(
400    item_span: Span,
401    typ: Ident,
402    generics: &Generics,
403    spec_item: TokenStream,
404) -> syn::ItemImpl {
405    let generics_idents = generics
406        .params
407        .iter()
408        .filter_map(|generic_param| match generic_param {
409            syn::GenericParam::Type(type_param) => Some(type_param.ident.clone()),
410            _ => None,
411        })
412        .collect::<syn::punctuated::Punctuated<_, syn::Token![,]>>();
413    let item_impl: syn::ItemImpl = parse_quote_spanned! {item_span=>
414        impl #generics #typ <#generics_idents> {
415            #spec_item
416        }
417    };
418    item_impl
419}