1use super::rewriter::AstRewriter;
2use itertools::Itertools;
3use proc_macro2::{Span, TokenStream};
4use quote::{quote_spanned, ToTokens};
5use syn::{
6 parse::Parser, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, Expr, ExprLit,
7 Fields, Generics, Ident, Lit, Pat, PatLit, Token,
8};
9
10pub fn rewrite_struct(
11 attrs: TokenStream,
12 item_struct: syn::ItemStruct,
13) -> syn::Result<Vec<syn::Item>> {
14 let res = rewrite_internal_struct(attrs, item_struct);
15 match res {
16 Ok(result) => Ok(result),
17 Err(err) => Err(err.into()),
18 }
19}
20
21pub fn rewrite_enum(attrs: TokenStream, item_enum: syn::ItemEnum) -> syn::Result<Vec<syn::Item>> {
22 let res = rewrite_internal_enum(attrs, item_enum);
23 match res {
24 Ok(result) => Ok(result),
25 Err(err) => Err(err.into()),
26 }
27}
28
29type TypeCounterexampleResult<R> = Result<R, TypeCounterexampleError>;
30
31#[derive(Debug)]
32enum TypeCounterexampleError {
33 ArgumentsDoNotMatch(proc_macro2::Span),
34 WrongFirstArgument(proc_macro2::Span),
35 AtLeastOneArgument(proc_macro2::Span),
36 WrongNumberOfArguemnts(proc_macro2::Span),
37 InvalidName(proc_macro2::Span),
38 InvalidArgument(proc_macro2::Span, String, String),
39 ParsingError(syn::Error),
40}
41
42impl std::convert::From<TypeCounterexampleError> for syn::Error {
43 fn from(err: TypeCounterexampleError) -> Self {
44 match err {
45 TypeCounterexampleError::ArgumentsDoNotMatch(span) => {
46 syn::Error::new(span, "Number of arguments and number of {} do not match")
47 }
48 TypeCounterexampleError::WrongFirstArgument(span) => {
49 syn::Error::new(span, "First argument must be a string literal")
50 }
51 TypeCounterexampleError::AtLeastOneArgument(span) => {
52 syn::Error::new(span, "At least one argument is expected")
53 }
54 TypeCounterexampleError::InvalidName(span) => {
55 syn::Error::new(span, "Invalid argument name")
56 }
57 TypeCounterexampleError::InvalidArgument(span, name, arg) => {
58 syn::Error::new(span, format!("`{name}` does not have a field named {arg}"))
59 }
60 TypeCounterexampleError::WrongNumberOfArguemnts(span) => {
61 syn::Error::new(span, "Number of arguments are incorrect")
62 }
63 TypeCounterexampleError::ParsingError(parse_err) => parse_err,
64 }
65 }
66}
67
68fn rewrite_internal_struct(
69 attr: TokenStream,
70 item_struct: syn::ItemStruct,
71) -> TypeCounterexampleResult<Vec<syn::Item>> {
72 let parser = Punctuated::<Pat, Token![,]>::parse_terminated;
73 let attrs = match parser.parse(attr.clone().into()) {
74 Ok(result) => result,
75 Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
76 };
77 let len = attrs.len();
78
79 let (first_arg, args) = process_attr(&attrs, len)?;
80 let mut rewriter = AstRewriter::new();
81 let spec_id = rewriter.generate_spec_id();
82 let spec_id_str = spec_id.to_string();
83 let item_span = item_struct.span();
84 let item_name = syn::Ident::new(
85 &format!(
86 "prusti_print_counterexample_item_{}_{}",
87 item_struct.ident, spec_id
88 ),
89 item_span,
90 );
91 let mut args2: Punctuated<Pat, Token![,]> = attrs
92 .into_iter()
93 .skip(1)
94 .unique()
95 .collect::<Punctuated<Pat, Token![,]>>();
96 if !args2.empty_or_trailing() {
98 args2.push_punct(<syn::Token![,]>::default());
99 }
100
101 #[allow(clippy::redundant_clone)]
103 let typ = item_struct.ident.clone();
104
105 let spec_item = match item_struct.fields {
106 Fields::Named(_) => {
107 let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
108 #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
109 #[prusti::spec_only]
110 #[prusti::counterexample_print]
111 #[prusti::spec_id = #spec_id_str]
112 fn #item_name(self){
113 if let #typ{#args2 ..} = self{
114 #first_arg
115 #args
116 }
117 }
118 };
119 spec_item
120 }
121 Fields::Unnamed(ref fields_unnamed) => {
122 check_validity_of_args(
124 args2,
125 fields_unnamed.unnamed.len() as u32,
126 &item_struct.ident.to_string(),
127 )?;
128
129 let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
130 #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
131 #[prusti::spec_only]
132 #[prusti::counterexample_print]
133 #[prusti::spec_id = #spec_id_str]
134 fn #item_name(self){
135 if let #typ{..} = self{
136 #first_arg
137 #args
138 }
139 }
140 };
141 spec_item
142 }
143 Fields::Unit => {
144 if len == 1 {
145 let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
146 #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
147 #[prusti::spec_only]
148 #[prusti::counterexample_print]
149 #[prusti::spec_id = #spec_id_str]
150 fn #item_name(self){
151 if let #typ{..} = self{
152 #first_arg
153 }
154 }
155 };
156 spec_item
157 } else {
158 return Err(TypeCounterexampleError::WrongNumberOfArguemnts(attr.span()));
159 }
160 }
161 };
162
163 let item_impl = generate_generics(
164 item_struct.span(),
165 item_struct.ident.clone(),
166 &item_struct.generics,
167 spec_item.into_token_stream(),
168 );
169 Ok(vec![syn::Item::Impl(item_impl)])
170}
171
172fn rewrite_internal_enum(
173 attr: TokenStream,
174 item_enum: syn::ItemEnum,
175) -> TypeCounterexampleResult<Vec<syn::Item>> {
176 let parser = Punctuated::<Pat, Token![,]>::parse_terminated;
177 let attrs = match parser.parse(attr.clone().into()) {
178 Ok(result) => result,
179 Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
180 };
181 let item_span = item_enum.span();
182 let len = attrs.len();
183 if len != 0 {
184 return Err(TypeCounterexampleError::WrongNumberOfArguemnts(item_span));
185 }
186 let mut spec_items: Vec<syn::ItemFn> = vec![];
187 let enum_name = item_enum.ident.clone();
188 let mut rewriter = AstRewriter::new();
189 let spec_id = rewriter.generate_spec_id();
190 let spec_id_str = spec_id.to_string(); for variant in &item_enum.variants {
193 if let Some(custom_print) = variant.attrs.iter().find(|attr| {
194 attr.path.get_ident().map(|x| x.to_string()) == Some("print_counterexample".to_string())
195 }) {
196 let variant_name = variant.ident.clone();
197 let item_span = variant.ident.span();
198 let item_name = syn::Ident::new(
199 &format!(
200 "prusti_print_counterexample_variant_{}_{}",
201 variant.ident, spec_id
202 ),
203 item_span,
204 );
205 let variant_name_str = variant_name.to_string();
206 let parser = Punctuated::<Pat, Token![,]>::parse_terminated; let attrs = match custom_print.parse_args_with(parser) {
208 Ok(result) => result,
209 Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
210 };
211
212 let len = attrs.len();
213 let (first_arg, args) = process_attr(&attrs, len)?;
214 match &variant.fields {
215 Fields::Named(_) => {
216 let mut args2: Punctuated<Pat, Token![,]> = attrs
217 .into_iter()
218 .skip(1)
219 .unique()
220 .collect::<Punctuated<Pat, Token![,]>>(
221 );
222 if !args2.empty_or_trailing() {
223 args2.push_punct(<syn::Token![,]>::default());
224 }
225 let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
226 #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
227 #[prusti::spec_only]
228 #[prusti::counterexample_print = #variant_name_str]
229 #[prusti::spec_id = #spec_id_str]
230 fn #item_name(self) {
231 if let #enum_name::#variant_name{#args2 ..} = self{
232 #first_arg
233 #args
234 }
235 }
236 };
237 spec_items.push(spec_item);
238 }
239 Fields::Unnamed(fields_unnamed) => {
240 let args2: Punctuated<Pat, Token![,]> = attrs
241 .into_iter()
242 .skip(1)
243 .unique()
244 .collect::<Punctuated<Pat, Token![,]>>(
245 );
246
247 check_validity_of_args(
249 args2,
250 fields_unnamed.unnamed.len() as u32,
251 &item_enum.ident.to_string(),
252 )?;
253 let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
254 #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
255 #[prusti::spec_only]
256 #[prusti::counterexample_print = #variant_name_str]
257 #[prusti::spec_id = #spec_id_str]
258 fn #item_name(self) {
259 if let #enum_name::#variant_name(..) = self{
260 #first_arg
261 #args
262 }
263 }
264 };
265 spec_items.push(spec_item);
266 }
267 Fields::Unit => {
268 if len == 1 {
269 let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=>
270 #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case, irrefutable_let_patterns)]
271 #[prusti::spec_only]
272 #[prusti::counterexample_print = #variant_name_str]
273 #[prusti::spec_id = #spec_id_str]
274 fn #item_name(self) {
275 if let #enum_name::#variant_name = self{
276 #first_arg
277 }
278 }
279 };
280 spec_items.push(spec_item);
281 } else {
282 return Err(TypeCounterexampleError::WrongNumberOfArguemnts(attr.span()));
283 }
284 }
285 }
286 }
287 }
288 let mut spec_item_as_tokens = TokenStream::new();
289 for x in spec_items {
290 x.to_tokens(&mut spec_item_as_tokens);
291 }
292
293 let item_impl = generate_generics(
294 item_enum.span(),
295 item_enum.ident.clone(),
296 &item_enum.generics,
297 spec_item_as_tokens.into_token_stream(),
298 );
299 let mut item_enum_new = item_enum;
300 for variant in &mut item_enum_new.variants {
301 variant.attrs.retain(|attr| {
303 attr.path.get_ident().map(|x| x.to_string()) != Some("print_counterexample".to_string())
304 });
305 }
306 Ok(vec![
307 syn::Item::Enum(item_enum_new),
308 syn::Item::Impl(item_impl),
309 ])
310}
311
312fn process_attr(
313 attrs: &Punctuated<Pat, Token![,]>,
314 len: usize,
315) -> TypeCounterexampleResult<(TokenStream, TokenStream)> {
316 let mut attrs_iter = attrs.iter();
317 let callsite_span = Span::call_site();
318 let first_as_token = if let Some(text) = attrs_iter.next() {
320 let span = text.span();
321 match text {
322 Pat::Lit(PatLit {
323 attrs: _,
324 expr:
325 box Expr::Lit(ExprLit {
326 attrs: _,
327 lit: Lit::Str(lit_str),
328 }),
329 }) => {
330 let value = lit_str.value();
331 let count = value.matches("{}").count();
332 if count != len - 1 {
333 return Err(TypeCounterexampleError::ArgumentsDoNotMatch(span));
334 }
335 quote_spanned! {callsite_span=> #value;}
336 }
337 _ => return Err(TypeCounterexampleError::WrongFirstArgument(span)),
338 }
339 } else {
340 return Err(TypeCounterexampleError::AtLeastOneArgument(attrs.span()));
341 };
342 let args_as_token = attrs_iter
344 .map(|pat| match pat {
345 Pat::Ident(pat_ident) => {
346 quote_spanned! {callsite_span=> #pat_ident; }
347 }
348 Pat::Lit(PatLit {
349 attrs: _,
350 expr:
351 box Expr::Lit(ExprLit {
352 attrs: _,
353 lit: Lit::Int(lit_int),
354 }),
355 }) => {
356 quote_spanned! {callsite_span=> #lit_int; }
357 }
358 _ => {
359 let err: syn::Error = TypeCounterexampleError::InvalidName(callsite_span).into();
360 err.to_compile_error()
361 }
362 })
363 .collect::<TokenStream>();
364 Ok((first_as_token, args_as_token))
365}
366fn check_validity_of_args(
367 args: Punctuated<Pat, Token![,]>,
368 len: u32,
369 name: &String,
370) -> TypeCounterexampleResult<()> {
371 for arg in &args {
372 if let Pat::Lit(PatLit {
373 attrs: _,
374 expr:
375 box Expr::Lit(ExprLit {
376 attrs: _,
377 lit: Lit::Int(lit_int),
378 }),
379 }) = arg
380 {
381 let value: u32 = match lit_int.base10_parse() {
382 Ok(result) => result,
383 Err(err) => return Err(TypeCounterexampleError::ParsingError(err)),
384 };
385 if value >= len {
386 return Err(TypeCounterexampleError::InvalidArgument(
387 arg.span(),
388 name.to_string(),
389 value.to_string(),
390 ));
391 }
392 } else {
393 return Err(TypeCounterexampleError::InvalidName(arg.span()));
394 }
395 }
396 Ok(())
397}
398
399fn generate_generics(
400 item_span: Span,
401 typ: Ident,
402 generics: &Generics,
403 spec_item: TokenStream,
404) -> syn::ItemImpl {
405 let generics_idents = generics
406 .params
407 .iter()
408 .filter_map(|generic_param| match generic_param {
409 syn::GenericParam::Type(type_param) => Some(type_param.ident.clone()),
410 _ => None,
411 })
412 .collect::<syn::punctuated::Punctuated<_, syn::Token![,]>>();
413 let item_impl: syn::ItemImpl = parse_quote_spanned! {item_span=>
414 impl #generics #typ <#generics_idents> {
415 #spec_item
416 }
417 };
418 item_impl
419}