1use proc_macro2::{Delimiter, Span, TokenStream, TokenTree};
3use proc_macro2::{Punct, Spacing::*};
4use quote::{quote, quote_spanned, ToTokens};
5use std::collections::VecDeque;
6use syn::{
7 parse::{Parse, ParseStream},
8 spanned::Spanned,
9};
10
11#[derive(Debug, Clone)]
13pub struct Arg {
14 pub name: syn::Ident,
15 pub typ: syn::Type,
16}
17
18pub fn parse_prusti(tokens: TokenStream) -> syn::Result<TokenStream> {
19 let parsed = PrustiTokenStream::new(tokens).parse()?;
20 syn::parse2::<syn::Expr>(parsed.clone())?;
23 Ok(parsed)
24}
25pub fn parse_prusti_pledge(tokens: TokenStream) -> syn::Result<TokenStream> {
26 let (reference, rhs) = PrustiTokenStream::new(tokens).parse_pledge()?;
30 if let Some(reference) = reference {
31 if reference.to_string() != "result" {
32 return err(
33 reference.span(),
34 "reference of after_expiry must be \"result\"",
35 );
36 }
37 }
38 syn::parse2::<syn::Expr>(rhs.clone())?;
39 Ok(rhs)
40}
41
42pub fn parse_prusti_assert_pledge(tokens: TokenStream) -> syn::Result<(TokenStream, TokenStream)> {
43 let (reference, lhs, rhs) = PrustiTokenStream::new(tokens).parse_assert_pledge()?;
47 if let Some(reference) = reference {
48 if reference.to_string() != "result" {
49 return err(
50 reference.span(),
51 "reference of assert_on_expiry must be \"result\"",
52 );
53 }
54 }
55 syn::parse2::<syn::Expr>(lhs.clone())?;
56 syn::parse2::<syn::Expr>(rhs.clone())?;
57 Ok((lhs, rhs))
58}
59
60pub fn parse_type_cond_spec(tokens: TokenStream) -> syn::Result<TypeCondSpecRefinement> {
61 syn::parse2(tokens)
62}
63
64fn error(span: Span, msg: &str) -> syn::Error {
83 syn::Error::new(span, msg)
84}
85
86fn err<T>(span: Span, msg: &str) -> syn::Result<T> {
88 Err(error(span, msg))
89}
90
91#[derive(Debug, Clone)]
92struct PrustiTokenStream {
93 tokens: VecDeque<PrustiToken>,
94 source_span: Span,
95 }
97
98impl PrustiTokenStream {
99 fn new(source: TokenStream) -> Self {
101 let source_span = source.span();
102 let source = source.into_iter().collect::<Vec<_>>();
103
104 let mut pos = 0;
105 let mut tokens = VecDeque::new();
106
107 while pos < source.len() {
110 pos += 1;
112 tokens.push_back(match (&source[pos - 1], source.get(pos), source.get(pos + 1), source.get(pos + 2)) {
113 (
114 TokenTree::Punct(p1),
115 Some(TokenTree::Punct(p2)),
116 Some(TokenTree::Punct(p3)),
117 Some(TokenTree::Punct(p4)),
118 ) if let Some(op) = PrustiToken::parse_op4(p1, p2, p3, p4) => {
119 pos += 3;
122 op
123 }
124 (
125 TokenTree::Punct(p1),
126 Some(TokenTree::Punct(p2)),
127 Some(TokenTree::Punct(p3)),
128 _
129 ) if let Some(op) = PrustiToken::parse_op3(p1, p2, p3) => {
130 pos += 2;
133 op
134 }
135 (
136 TokenTree::Punct(p1),
137 Some(TokenTree::Punct(p2)),
138 _,
139 _,
140 ) if let Some(op) = PrustiToken::parse_op2(p1, p2) => {
141 pos += 1;
144 op
145 }
146 (TokenTree::Ident(ident), _, _, _) if ident == "outer" =>
147 PrustiToken::Outer(ident.span()),
148 (TokenTree::Ident(ident), _, _, _) if ident == "forall" =>
149 PrustiToken::Quantifier(ident.span(), Quantifier::Forall),
150 (TokenTree::Ident(ident), _, _, _) if ident == "exists" =>
151 PrustiToken::Quantifier(ident.span(), Quantifier::Exists),
152 (TokenTree::Punct(punct), _, _, _)
153 if punct.as_char() == ',' && punct.spacing() == Alone =>
154 PrustiToken::BinOp(punct.span(), PrustiBinaryOp::Rust(RustOp::Comma)),
155 (TokenTree::Punct(punct), _, _, _)
156 if punct.as_char() == ';' && punct.spacing() == Alone =>
157 PrustiToken::BinOp(punct.span(), PrustiBinaryOp::Rust(RustOp::Semicolon)),
158 (TokenTree::Punct(punct), _, _, _)
159 if punct.as_char() == '=' && punct.spacing() == Alone =>
160 PrustiToken::BinOp(punct.span(), PrustiBinaryOp::Rust(RustOp::Assign)),
161 (token @ TokenTree::Punct(punct), _, _, _) if punct.spacing() == Joint => {
162 tokens.push_back(PrustiToken::Token(token.clone()));
165 while let Some(token @ TokenTree::Punct(p)) = source.get(pos) {
166 pos += 1;
167 tokens.push_back(PrustiToken::Token(token.clone()));
168 if p.spacing() != Joint {
169 break;
170 }
171 }
172 continue;
173 }
174 (TokenTree::Group(group), _, _, _) => PrustiToken::Group(
175 group.span(),
176 group.delimiter(),
177 Box::new(Self::new(group.stream())),
178 ),
179 (token, _, _, _) => PrustiToken::Token(token.clone()),
180 });
181 }
182 Self {
183 tokens,
184 source_span,
185 }
186 }
187
188 fn is_empty(&self) -> bool {
189 self.tokens.is_empty()
190 }
191
192 fn parse_rest<T, F>(mut self, f: F) -> syn::Result<T>
193 where
194 F: FnOnce(&mut Self) -> syn::Result<T>,
195 {
196 let result = f(&mut self)?;
197 if !self.is_empty() {
198 let start = self.tokens.front().expect("unreachable").span();
199 let end = self.tokens.back().expect("unreachable").span();
200 let span = join_spans(start, end);
201 return err(span, "unexpected extra tokens");
202 }
203 Ok(result)
204 }
205
206 fn parse(mut self) -> syn::Result<TokenStream> {
209 self.expr_bp(0)
210 }
211
212 fn parse_rust_only(self) -> syn::Result<TokenStream> {
215 Ok(TokenStream::from_iter(
216 self.tokens
217 .into_iter()
218 .map(|token| match token {
219 PrustiToken::Group(_, _, box stream) => stream.parse_rust_only(),
220 PrustiToken::Token(tree) => Ok(tree.to_token_stream()),
221 PrustiToken::BinOp(span, PrustiBinaryOp::Rust(op)) => Ok(op.to_tokens(span)),
222 _ => err(token.span(), "unexpected Prusti syntax"),
223 })
224 .collect::<Result<Vec<_>, _>>()?,
225 ))
226 }
227
228 fn parse_pledge(self) -> syn::Result<(Option<TokenStream>, TokenStream)> {
231 let mut pledge_ops = self.split(PrustiBinaryOp::Rust(RustOp::Arrow), false);
232 if pledge_ops.len() == 1 {
233 Ok((None, pledge_ops[0].expr_bp(0)?))
234 } else if pledge_ops.len() == 2 {
235 Ok((Some(pledge_ops[0].expr_bp(0)?), pledge_ops[1].expr_bp(0)?))
236 } else {
237 err(Span::call_site(), "too many arrows in after_expiry")
238 }
239 }
240
241 fn parse_assert_pledge(self) -> syn::Result<(Option<TokenStream>, TokenStream, TokenStream)> {
244 let mut pledge_ops = self.split(PrustiBinaryOp::Rust(RustOp::Arrow), false);
245 let (reference, body) = match (pledge_ops.pop(), pledge_ops.pop(), pledge_ops.pop()) {
246 (Some(body), None, _) => (None, body),
247 (Some(body), Some(mut reference), None) => (Some(reference.expr_bp(0)?), body),
248 _ => return err(Span::call_site(), "too many arrows in assert_on_expiry"),
249 };
250 let mut body_parts = body.split(PrustiBinaryOp::Rust(RustOp::Comma), false);
251 if body_parts.len() == 2 {
252 Ok((
253 reference,
254 body_parts[0].expr_bp(0)?,
255 body_parts[1].expr_bp(0)?,
256 ))
257 } else {
258 err(Span::call_site(), "missing assertion")
259 }
260 }
261
262 fn expr_bp(&mut self, min_bp: u8) -> syn::Result<TokenStream> {
267 let mut lhs = match self.tokens.pop_front() {
268 Some(PrustiToken::Group(span, delimiter, box stream)) => {
269 let mut group = proc_macro2::Group::new(delimiter, stream.parse()?);
270 group.set_span(span);
271 TokenTree::Group(group).to_token_stream()
272 }
273 Some(PrustiToken::Outer(span)) => {
274 let _stream = self
275 .pop_group(Delimiter::Parenthesis)
276 .ok_or_else(|| error(span, "expected parenthesized expression after outer"))?;
277 todo!()
278 }
279 Some(PrustiToken::Quantifier(span, kind)) => {
280 let mut stream = self.pop_group(Delimiter::Parenthesis).ok_or_else(|| {
281 error(span, "expected parenthesized expression after quantifier")
282 })?;
283 let args = stream
284 .pop_closure_args()
285 .ok_or_else(|| error(span, "expected quantifier body"))?;
286
287 {
288 let cl_args = args.clone().parse_rust_only()?;
292 let check_cl = quote! { | #cl_args | 0 };
293 let parsed_cl = syn::parse2::<syn::ExprClosure>(check_cl)?;
294 for pat in parsed_cl.inputs {
295 match pat {
296 syn::Pat::Type(_) => {}
297 _ => {
298 return err(
299 pat.span(),
300 "quantifier arguments must have explicit types",
301 )
302 }
303 }
304 }
305 };
306
307 let triggers = stream.extract_triggers()?;
308 if args.is_empty() {
309 return err(span, "a quantifier must have at least one argument");
310 }
311 let args = args.parse()?;
312 let body = stream.parse()?;
313 kind.translate(span, triggers, args, body)
314 }
315
316 Some(PrustiToken::SpecEnt(span, _)) | Some(PrustiToken::CallDesc(span, _)) => {
317 return err(span, "unexpected operator")
318 }
319
320 Some(PrustiToken::BinOp(span, PrustiBinaryOp::Rust(op))) => op.to_tokens(span),
322
323 Some(PrustiToken::BinOp(span, _)) => return err(span, "unexpected binary operator"),
324 Some(PrustiToken::Token(token)) => token.to_token_stream(),
325 None => return Ok(TokenStream::new()),
326 };
327 loop {
328 let (span, op) = match self.tokens.front() {
329 Some(PrustiToken::Group(span, delimiter, box stream)) => {
334 let mut group = proc_macro2::Group::new(*delimiter, stream.clone().parse()?);
335 group.set_span(*span);
336 lhs.extend(TokenTree::Group(group).to_token_stream());
337 self.tokens.pop_front();
338 continue;
339 }
340 Some(PrustiToken::Token(token)) => {
341 lhs.extend(token.to_token_stream());
342 self.tokens.pop_front();
343 continue;
344 }
345
346 Some(PrustiToken::SpecEnt(span, once)) => {
347 let span = *span;
348 let once = *once;
349 self.tokens.pop_front();
350 let args = self
351 .pop_closure_args()
352 .ok_or_else(|| error(span, "expected closure arguments"))?;
353 let nested_closure_specs = self.pop_group_of_nested_specs(span)?;
354 lhs = translate_spec_ent(
355 span,
356 once,
357 lhs,
358 args.split(PrustiBinaryOp::Rust(RustOp::Comma), true)
359 .into_iter()
360 .map(|stream| stream.parse())
361 .collect::<Result<Vec<_>, _>>()?,
362 nested_closure_specs,
363 );
364 continue;
365 }
366
367 Some(PrustiToken::CallDesc(..)) => todo!("call desc"),
368
369 Some(PrustiToken::BinOp(span, op)) => (*span, *op),
370 Some(PrustiToken::Outer(span)) => return err(*span, "unexpected outer"),
371 Some(PrustiToken::Quantifier(span, _)) => {
372 return err(*span, "unexpected quantifier")
373 }
374
375 None => break,
376 };
377 let (l_bp, r_bp) = op.binding_power();
378 if l_bp < min_bp {
379 break;
380 }
381 self.tokens.pop_front();
382 let rhs = self.expr_bp(r_bp)?;
383
384 if !matches!(op, PrustiBinaryOp::Rust(_)) && rhs.is_empty() {
391 return err(span, "expected expression");
392 }
393 lhs = op.translate(span, lhs, rhs);
394 }
395 Ok(lhs)
396 }
397
398 fn pop_group(&mut self, delimiter: Delimiter) -> Option<Self> {
399 match self.tokens.pop_front() {
400 Some(PrustiToken::Group(_, del, box stream)) if del == delimiter => Some(stream),
401 _ => None,
402 }
403 }
404
405 fn pop_closure_args(&mut self) -> Option<Self> {
406 let mut tokens = VecDeque::new();
407
408 if matches!(
410 self.tokens.front(),
411 Some(PrustiToken::BinOp(_, PrustiBinaryOp::Or))
412 ) {
413 return Some(Self {
414 tokens,
415 source_span: self.source_span,
416 });
417 }
418
419 if !self.tokens.pop_front()?.is_closure_brace() {
420 return None;
421 }
422 loop {
423 let token = self.tokens.pop_front()?;
424 if token.is_closure_brace() {
425 break;
426 }
427 tokens.push_back(token);
428 }
429
430 Some(Self {
431 tokens,
432 source_span: self.source_span,
433 })
434 }
435
436 fn pop_parenthesized_group(&mut self) -> syn::Result<Self> {
437 match self.tokens.pop_front() {
438 Some(PrustiToken::Group(_span, Delimiter::Parenthesis, box group)) => {
439 Ok(group) }
441 _ => Err(error(self.source_span, "expected parenthesized group")),
442 }
443 }
444
445 fn pop_single_nested_spec(&mut self) -> syn::Result<NestedSpec<Self>> {
446 let first = self
447 .tokens
448 .pop_front()
449 .ok_or_else(|| error(self.source_span, "expected nested spec"))?;
450 if let PrustiToken::Token(TokenTree::Ident(spec_type)) = first {
451 match spec_type.to_string().as_ref() {
452 "requires" => Ok(NestedSpec::Requires(self.pop_parenthesized_group()?)),
453 "ensures" => Ok(NestedSpec::Ensures(self.pop_parenthesized_group()?)),
454 "pure" => Ok(NestedSpec::Pure),
455 other => err(
456 self.source_span,
457 format!("unexpected nested spec type: {other}").as_ref(),
458 ),
459 }
460 } else {
461 err(self.source_span, "expected identifier")
462 }
463 }
464
465 fn pop_group_of_nested_specs(
466 &mut self,
467 span: Span,
468 ) -> syn::Result<Vec<NestedSpec<TokenStream>>> {
469 let group_of_specs = self
470 .pop_group(Delimiter::Bracket)
471 .ok_or_else(|| error(span, "expected nested specification in brackets"))?;
472 let parsed = group_of_specs
473 .split(PrustiBinaryOp::Rust(RustOp::Comma), true)
474 .into_iter()
475 .map(|stream| stream.parse_rest(|stream| stream.pop_single_nested_spec()))
476 .map(|stream| stream.and_then(|s| s.parse()))
477 .collect::<syn::Result<Vec<NestedSpec<TokenStream>>>>()?;
478 Ok(parsed)
479 }
480
481 fn split(self, split_on: PrustiBinaryOp, allow_trailing: bool) -> Vec<Self> {
482 if self.tokens.is_empty() {
483 return vec![];
484 }
485 let mut res = self
486 .tokens
487 .into_iter()
488 .collect::<Vec<_>>()
489 .split(|token| matches!(token, PrustiToken::BinOp(_, t) if *t == split_on))
490 .map(|group| Self {
491 tokens: group.iter().cloned().collect(),
492 source_span: self.source_span,
493 })
494 .collect::<Vec<_>>();
495 if allow_trailing && res.len() > 1 && res[res.len() - 1].tokens.is_empty() {
496 res.pop();
497 }
498 res
499 }
500
501 fn extract_triggers(&mut self) -> syn::Result<Vec<Vec<TokenStream>>> {
502 let len = self.tokens.len();
503 if len < 4 {
504 return Ok(vec![]);
505 }
506 match [
507 &self.tokens[len - 4],
508 &self.tokens[len - 3],
509 &self.tokens[len - 2],
510 &self.tokens[len - 1],
511 ] {
512 [PrustiToken::BinOp(_, PrustiBinaryOp::Rust(RustOp::Comma)), PrustiToken::Token(TokenTree::Ident(ident)), PrustiToken::BinOp(_, PrustiBinaryOp::Rust(RustOp::Assign)), PrustiToken::Group(triggers_span, Delimiter::Bracket, box triggers)]
513 if ident == "triggers" =>
514 {
515 let triggers = triggers
516 .clone()
517 .split(PrustiBinaryOp::Rust(RustOp::Comma), true)
518 .into_iter()
519 .map(|mut stream| {
520 stream
521 .pop_group(Delimiter::Parenthesis)
522 .ok_or_else(|| {
523 error(*triggers_span, "trigger sets must be tuples of expressions")
524 })?
525 .split(PrustiBinaryOp::Rust(RustOp::Comma), true)
526 .into_iter()
527 .map(|stream| stream.parse())
528 .collect::<Result<Vec<_>, _>>()
529 })
530 .collect::<Result<Vec<_>, _>>();
531 self.tokens.truncate(len - 4);
532 triggers
533 }
534 _ => Ok(vec![]),
535 }
536 }
537}
538
539#[derive(Debug)]
540pub struct TypeCondSpecRefinement {
541 pub trait_bounds: Vec<syn::PredicateType>,
542 pub specs: Vec<NestedSpec<TokenStream>>,
543}
544
545impl Parse for TypeCondSpecRefinement {
546 fn parse(input: ParseStream) -> syn::Result<Self> {
547 input
548 .parse::<syn::Token![where]>()
549 .map_err(with_type_cond_spec_example)?;
550 Ok(TypeCondSpecRefinement {
551 trait_bounds: parse_trait_bounds(input)?,
552 specs: PrustiTokenStream::new(input.parse().unwrap())
553 .parse_rest(|pts| pts.pop_group_of_nested_specs(input.span()))?,
554 })
555 }
556}
557
558fn parse_trait_bounds(input: ParseStream) -> syn::Result<Vec<syn::PredicateType>> {
559 let mut bounds: Vec<syn::PredicateType> = Vec::new();
560 loop {
561 let predicate = input
562 .parse::<syn::WherePredicate>()
563 .map_err(with_type_cond_spec_example)?;
564 bounds.push(validate_predicate(predicate)?);
565 input
566 .parse::<syn::token::Comma>()
567 .map_err(with_type_cond_spec_example)?;
568 if input.peek(syn::token::Bracket) || input.is_empty() {
569 break;
572 }
573 }
574 Ok(bounds)
575}
576
577fn validate_predicate(predicate: syn::WherePredicate) -> syn::Result<syn::PredicateType> {
578 use syn::WherePredicate::*;
579
580 match predicate {
581 Type(type_bound) => {
582 validate_trait_bounds(&type_bound)?;
583 Ok(type_bound)
584 }
585 Lifetime(lifetime_bound) => disallowed_lifetime_error(lifetime_bound.span()),
586 Eq(eq_bound) => err(
587 eq_bound.span(),
588 "equality constraints are not allowed in type-conditional spec refinements",
589 ),
590 }
591}
592
593fn disallowed_lifetime_error<T>(span: Span) -> syn::Result<T> {
594 err(
595 span,
596 "lifetimes are not allowed in type-conditional spec refinement trait bounds",
597 )
598}
599
600fn validate_trait_bounds(trait_bounds: &syn::PredicateType) -> syn::Result<()> {
601 if let Some(lifetimes) = &trait_bounds.lifetimes {
602 return disallowed_lifetime_error(lifetimes.span());
603 }
604 for bound in &trait_bounds.bounds {
605 match bound {
606 syn::TypeParamBound::Lifetime(lt) => {
607 return disallowed_lifetime_error(lt.span());
608 }
609 syn::TypeParamBound::Trait(trait_bound) => {
610 if let Some(lt) = &trait_bound.lifetimes {
611 return disallowed_lifetime_error(lt.span());
612 }
613 }
614 }
615 }
616
617 Ok(())
618}
619
620fn with_type_cond_spec_example(mut err: syn::Error) -> syn::Error {
621 err.combine(error(err.span(), "expected where constraint and specifications in brackets, e.g.: `refine_spec(where T: A + B, U: C, [requires(...), ...])`"));
622 err
623}
624
625#[derive(Debug)]
627pub enum NestedSpec<T> {
628 Requires(T),
629 Ensures(T),
630 Pure,
631}
632
633impl NestedSpec<PrustiTokenStream> {
634 fn parse(self) -> syn::Result<NestedSpec<TokenStream>> {
635 Ok(match self {
636 NestedSpec::Requires(stream) => NestedSpec::Requires(stream.parse()?),
637 NestedSpec::Ensures(stream) => NestedSpec::Ensures(stream.parse()?),
638 NestedSpec::Pure => NestedSpec::Pure,
639 })
640 }
641}
642
643#[derive(Debug, Clone)]
644enum PrustiToken {
645 Group(Span, Delimiter, Box<PrustiTokenStream>),
646 Token(TokenTree),
647 BinOp(Span, PrustiBinaryOp),
648 Outer(Span),
650 Quantifier(Span, Quantifier),
651 SpecEnt(Span, bool),
652 CallDesc(Span, bool),
653}
654
655fn translate_spec_ent(
656 span: Span,
657 once: bool,
658 cl_expr: TokenStream,
659 cl_args: Vec<TokenStream>,
660 contract: Vec<NestedSpec<TokenStream>>,
661) -> TokenStream {
662 let once = if once {
663 quote_spanned! { span => true }
664 } else {
665 quote_spanned! { span => false }
666 };
667
668 let arg_count = cl_args.len();
670 let generics_args = (0..arg_count)
671 .map(|i| TokenTree::Ident(proc_macro2::Ident::new(&format!("GA{i}"), span)))
672 .collect::<Vec<_>>();
673 let generic_res = TokenTree::Ident(proc_macro2::Ident::new("GR", span));
674
675 let extract_args = (0..arg_count)
676 .map(|i| TokenTree::Ident(proc_macro2::Ident::new(&format!("__extract_arg{i}"), span)))
677 .collect::<Vec<_>>();
678 let extract_args_decl = extract_args
679 .iter()
680 .zip(generics_args.iter())
681 .map(|(ident, arg_type)| {
682 quote_spanned! { span =>
683 #[prusti::spec_only]
684 fn #ident<
685 #(#generics_args),* ,
686 #generic_res,
687 F: FnOnce( #(#generics_args),* ) -> #generic_res
688 >(_f: &F) -> #arg_type { unreachable!() }
689 }
690 })
691 .collect::<Vec<_>>();
692
693 let preconds = contract
694 .iter()
695 .filter_map(|spec| match spec {
696 NestedSpec::Requires(stream) => Some(stream.clone()),
697 _ => None,
698 })
699 .collect::<Vec<_>>();
700 let postconds = contract
701 .into_iter()
702 .filter_map(|spec| match spec {
703 NestedSpec::Ensures(stream) => Some(stream),
704 _ => None,
705 })
706 .collect::<Vec<_>>();
707
708 quote_spanned! { span => {
711 let __cl_ref = & #cl_expr;
712 #(#extract_args_decl)*
713 #[prusti::spec_only]
714 fn __extract_res<
715 #(#generics_args),* ,
716 #generic_res,
717 F: FnOnce( #(#generics_args),* ) -> #generic_res
718 >(_f: &F) -> #generic_res { unreachable!() }
719 #( let #cl_args = #extract_args(__cl_ref); )*
720 let result = __extract_res(__cl_ref);
721 specification_entailment(
722 #once,
723 __cl_ref,
724 ( #( #[prusti::spec_only] || -> bool { #preconds }, )* ),
725 ( #( #[prusti::spec_only] || -> bool { #postconds }, )* ),
726 )
727 } }
728}
729
730#[derive(Debug, Clone)]
731enum Quantifier {
732 Forall,
733 Exists,
734}
735
736impl Quantifier {
737 fn translate(
738 &self,
739 span: Span,
740 triggers: Vec<Vec<TokenStream>>,
741 args: TokenStream,
742 body: TokenStream,
743 ) -> TokenStream {
744 let full_span = join_spans(span, body.span());
745 let trigger_sets = triggers
746 .into_iter()
747 .map(|set| {
748 let triggers = TokenStream::from_iter(set.into_iter().map(|trigger| {
749 quote_spanned! { trigger.span() =>
750 #[prusti::spec_only] | #args | ( #trigger ), }
751 }));
752 quote_spanned! { full_span => ( #triggers ) }
753 })
754 .collect::<Vec<_>>();
755 let body = quote_spanned! { body.span() => #body };
756 match self {
757 Self::Forall => quote_spanned! { full_span => ::prusti_contracts::forall(
758 ( #( #trigger_sets, )* ),
759 #[prusti::spec_only] | #args | -> bool { #body }
760 ) },
761 Self::Exists => quote_spanned! { full_span => ::prusti_contracts::exists(
762 ( #( #trigger_sets, )* ),
763 #[prusti::spec_only] | #args | -> bool { #body }
764 ) },
765 }
766 }
767}
768
769fn operator2(op: &str, p1: &Punct, p2: &Punct) -> bool {
777 let chars = op.chars().collect::<Vec<_>>();
778 [p1.as_char(), p2.as_char()] == chars[0..2] && p1.spacing() == Joint && p2.spacing() == Alone
779}
780
781fn operator3(op: &str, p1: &Punct, p2: &Punct, p3: &Punct) -> bool {
782 let chars = op.chars().collect::<Vec<_>>();
783 [p1.as_char(), p2.as_char(), p3.as_char()] == chars[0..3]
784 && p1.spacing() == Joint
785 && p2.spacing() == Joint
786 && p3.spacing() == Alone
787}
788
789fn operator4(op: &str, p1: &Punct, p2: &Punct, p3: &Punct, p4: &Punct) -> bool {
790 let chars = op.chars().collect::<Vec<_>>();
791 [p1.as_char(), p2.as_char(), p3.as_char(), p4.as_char()] == chars[0..4]
792 && p1.spacing() == Joint
793 && p2.spacing() == Joint
794 && p3.spacing() == Joint
795 && p4.spacing() == Alone
796}
797
798impl PrustiToken {
799 fn span(&self) -> Span {
800 match self {
801 Self::Group(span, _, _)
802 | Self::BinOp(span, _)
803 | Self::Outer(span)
804 | Self::Quantifier(span, _)
805 | Self::SpecEnt(span, _)
806 | Self::CallDesc(span, _) => *span,
807 Self::Token(tree) => tree.span(),
808 }
809 }
810
811 fn is_closure_brace(&self) -> bool {
812 matches!(self, Self::Token(TokenTree::Punct(p))
813 if p.as_char() == '|' && p.spacing() == proc_macro2::Spacing::Alone)
814 }
815
816 fn parse_op2(p1: &Punct, p2: &Punct) -> Option<Self> {
817 let span = join_spans(p1.span(), p2.span());
818 Some(Self::BinOp(
819 span,
820 if operator2("&&", p1, p2) {
821 PrustiBinaryOp::And
822 } else if operator2("||", p1, p2) {
823 PrustiBinaryOp::Or
824 } else if operator2("->", p1, p2) {
825 PrustiBinaryOp::Implies
826 } else if operator2("..", p1, p2) {
827 PrustiBinaryOp::Rust(RustOp::Range)
828 } else if operator2("+=", p1, p2) {
829 PrustiBinaryOp::Rust(RustOp::AddAssign)
830 } else if operator2("-=", p1, p2) {
831 PrustiBinaryOp::Rust(RustOp::SubtractAssign)
832 } else if operator2("*=", p1, p2) {
833 PrustiBinaryOp::Rust(RustOp::MultiplyAssign)
834 } else if operator2("/=", p1, p2) {
835 PrustiBinaryOp::Rust(RustOp::DivideAssign)
836 } else if operator2("%=", p1, p2) {
837 PrustiBinaryOp::Rust(RustOp::ModuloAssign)
838 } else if operator2("&=", p1, p2) {
839 PrustiBinaryOp::Rust(RustOp::BitAndAssign)
840 } else if operator2("^=", p1, p2) {
843 PrustiBinaryOp::Rust(RustOp::BitXorAssign)
844 } else if operator2("=>", p1, p2) {
845 PrustiBinaryOp::Rust(RustOp::Arrow)
846 } else if operator2("|=", p1, p2) {
847 return Some(Self::SpecEnt(span, false));
848 } else if operator2("~>", p1, p2) {
849 return Some(Self::CallDesc(span, false));
850 } else {
851 return None;
852 },
853 ))
854 }
855
856 fn parse_op3(p1: &Punct, p2: &Punct, p3: &Punct) -> Option<Self> {
857 let span = join_spans(join_spans(p1.span(), p2.span()), p3.span());
858 Some(Self::BinOp(
859 span,
860 if operator3("==>", p1, p2, p3) {
861 PrustiBinaryOp::Implies
862 } else if operator3("<==", p1, p2, p3) {
863 PrustiBinaryOp::ImpliesReverse
864 } else if operator3("===", p1, p2, p3) {
865 PrustiBinaryOp::SnapEq
866 } else if operator3("!==", p1, p2, p3) {
867 PrustiBinaryOp::SnapNe
868 } else if operator3("..=", p1, p2, p3) {
869 PrustiBinaryOp::Rust(RustOp::RangeInclusive)
870 } else if operator3("<<=", p1, p2, p3) {
871 PrustiBinaryOp::Rust(RustOp::LeftShiftAssign)
872 } else if operator3(">>=", p1, p2, p3) {
873 PrustiBinaryOp::Rust(RustOp::RightShiftAssign)
874 } else if operator3("|=!", p1, p2, p3) {
875 return Some(Self::SpecEnt(span, true));
876 } else if operator3("~>!", p1, p2, p3) {
877 return Some(Self::CallDesc(span, true));
878 } else {
879 return None;
880 },
881 ))
882 }
883
884 fn parse_op4(p1: &Punct, p2: &Punct, p3: &Punct, p4: &Punct) -> Option<Self> {
885 let span = join_spans(
886 join_spans(join_spans(p1.span(), p2.span()), p3.span()),
887 p4.span(),
888 );
889 Some(Self::BinOp(
890 span,
891 if operator4("<==>", p1, p2, p3, p4) {
892 PrustiBinaryOp::Iff
893 } else {
894 return None;
895 },
896 ))
897 }
898}
899
900#[derive(Debug, Clone, Copy, PartialEq, Eq)]
901enum PrustiBinaryOp {
902 Rust(RustOp),
903 Iff,
904 Implies,
905 ImpliesReverse,
906 Or,
907 And,
908 SnapEq,
909 SnapNe,
910}
911
912impl PrustiBinaryOp {
913 fn binding_power(&self) -> (u8, u8) {
927 match self {
929 Self::Rust(_) => (0, 0),
930 Self::Iff => (4, 3),
931 Self::Implies => (6, 5),
932 Self::ImpliesReverse => (5, 6),
933 Self::Or => (7, 8),
934 Self::And => (9, 10),
935 Self::SnapEq => (11, 12),
936 Self::SnapNe => (11, 12),
937 }
938 }
939
940 fn translate(&self, span: Span, raw_lhs: TokenStream, raw_rhs: TokenStream) -> TokenStream {
941 let lhs = quote_spanned! { raw_lhs.span() => (#raw_lhs) };
943 let rhs = quote_spanned! { raw_rhs.span() => (#raw_rhs) };
944 match self {
945 Self::Rust(op) => op.translate(span, raw_lhs, raw_rhs),
946 Self::Iff => {
947 let joined_span = join_spans(lhs.span(), rhs.span());
948 quote_spanned! { joined_span => #lhs == #rhs }
949 }
950 Self::Implies => {
954 let joined_span = join_spans(lhs.span(), rhs.span());
955 let not_lhs = quote_spanned! { lhs.span() => !#lhs };
957 quote_spanned! { joined_span => #not_lhs || #rhs }
958 }
959 Self::ImpliesReverse => {
960 let joined_span = join_spans(lhs.span(), rhs.span());
961 let not_rhs = quote_spanned! { rhs.span() => !#rhs };
963 quote_spanned! { joined_span => #not_rhs || #lhs }
964 }
965 Self::Or => quote_spanned! { span => #lhs || #rhs },
966 Self::And => quote_spanned! { span => #lhs && #rhs },
967 Self::SnapEq => {
968 let joined_span = join_spans(lhs.span(), rhs.span());
969 quote_spanned! { joined_span => snapshot_equality(&#lhs, &#rhs) }
970 }
971 Self::SnapNe => {
972 let joined_span = join_spans(lhs.span(), rhs.span());
973 quote_spanned! { joined_span => !snapshot_equality(&#lhs, &#rhs) }
974 }
975 }
976 }
977}
978
979#[derive(Debug, Clone, Copy, PartialEq, Eq)]
980enum RustOp {
981 RangeInclusive,
982 LeftShiftAssign,
983 RightShiftAssign,
984 Range,
985 AddAssign,
986 SubtractAssign,
987 MultiplyAssign,
988 DivideAssign,
989 ModuloAssign,
990 BitAndAssign,
991 BitXorAssign,
994 Arrow,
995 Comma,
996 Semicolon,
997 Assign,
998}
999
1000impl RustOp {
1001 fn translate(&self, span: Span, lhs: TokenStream, rhs: TokenStream) -> TokenStream {
1002 let op = self.to_tokens(span);
1003 quote! { #lhs #op #rhs }
1004 }
1005
1006 fn to_tokens(self, span: Span) -> TokenStream {
1007 match self {
1008 Self::RangeInclusive => quote_spanned! { span => ..= },
1009 Self::LeftShiftAssign => quote_spanned! { span => <<= },
1010 Self::RightShiftAssign => quote_spanned! { span => >>= },
1011 Self::Range => quote_spanned! { span => .. },
1012 Self::AddAssign => quote_spanned! { span => += },
1013 Self::SubtractAssign => quote_spanned! { span => -= },
1014 Self::MultiplyAssign => quote_spanned! { span => *= },
1015 Self::DivideAssign => quote_spanned! { span => /= },
1016 Self::ModuloAssign => quote_spanned! { span => %= },
1017 Self::BitAndAssign => quote_spanned! { span => &= },
1018 Self::BitXorAssign => quote_spanned! { span => ^= },
1020 Self::Arrow => quote_spanned! { span => => },
1021 Self::Comma => quote_spanned! { span => , },
1022 Self::Semicolon => quote_spanned! { span => ; },
1023 Self::Assign => quote_spanned! { span => = },
1024 }
1025 }
1026}
1027
1028fn join_spans(s1: Span, s2: Span) -> Span {
1029 if cfg!(test) {
1031 s1.join(s2).unwrap_or(s1)
1033 } else {
1034 s1.unwrap()
1036 .join(s2.unwrap())
1037 .expect("Failed to join spans!")
1038 .into()
1039 }
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044 use super::*;
1045
1046 macro_rules! assert_error {
1047 ( $result:expr, $expected:expr ) => {{
1048 let _res = $result;
1049 assert!(_res.is_err());
1050 let _err = _res.unwrap_err();
1051 assert_eq!(_err.to_string(), $expected);
1052 }};
1053 }
1054
1055 #[test]
1056 fn test_preparser() {
1057 assert_eq!(
1058 parse_prusti("a ==> b".parse().unwrap())
1059 .unwrap()
1060 .to_string(),
1061 "! (a) || (b)",
1062 );
1063 assert_eq!(
1064 parse_prusti("a === b + c".parse().unwrap())
1065 .unwrap()
1066 .to_string(),
1067 "snapshot_equality (& (a) , & (b + c))",
1068 );
1069 assert_eq!(
1070 parse_prusti("a !== b + c".parse().unwrap())
1071 .unwrap()
1072 .to_string(),
1073 "! snapshot_equality (& (a) , & (b + c))",
1074 );
1075 assert_eq!(
1076 parse_prusti("a ==> b ==> c".parse().unwrap())
1077 .unwrap()
1078 .to_string(),
1079 "! (a) || (! (b) || (c))",
1080 );
1081 assert_eq!(
1082 parse_prusti("(a ==> b && c) ==> d || e".parse().unwrap())
1083 .unwrap()
1084 .to_string(),
1085 "! ((! (a) || ((b) && (c)))) || ((d) || (e))",
1086 );
1087 assert_eq!(
1088 parse_prusti("forall(|x: i32| a ==> b)".parse().unwrap())
1089 .unwrap()
1090 .to_string(),
1091 ":: prusti_contracts :: forall (() , # [prusti :: spec_only] | x : i32 | -> bool { ! (a) || (b) })",
1092 );
1093 assert_eq!(
1094 parse_prusti("exists(|x: i32| a === b)".parse().unwrap()).unwrap().to_string(),
1095 ":: prusti_contracts :: exists (() , # [prusti :: spec_only] | x : i32 | -> bool { snapshot_equality (& (a) , & (b)) })",
1096 );
1097 assert_eq!(
1098 parse_prusti("forall(|x: i32| a ==> b, triggers = [(c,), (d, e)])".parse().unwrap()).unwrap().to_string(),
1099 ":: prusti_contracts :: forall (((# [prusti :: spec_only] | x : i32 | (c) ,) , (# [prusti :: spec_only] | x : i32 | (d) , # [prusti :: spec_only] | x : i32 | (e) ,) ,) , # [prusti :: spec_only] | x : i32 | -> bool { ! (a) || (b) })",
1100 );
1101 assert_eq!(
1102 parse_prusti("assert!(a === b ==> b)".parse().unwrap())
1103 .unwrap()
1104 .to_string(),
1105 "assert ! (! (snapshot_equality (& (a) , & (b))) || (b))",
1106 );
1107 }
1108
1109 mod type_cond_specs {
1110 use std::assert_matches::assert_matches;
1111
1112 use super::*;
1113
1114 #[test]
1115 fn invalid_args() {
1116 let err_invalid_bounds = "expected one of: `for`, parentheses, `fn`, `unsafe`, `extern`, identifier, `::`, `<`, square brackets, `*`, `&`, `!`, `impl`, `_`, lifetime";
1117 assert_error!(
1118 parse_type_cond_spec(quote! { [requires(false)] }),
1119 "expected `where`"
1120 );
1121 assert_error!(
1122 parse_type_cond_spec(quote! { where [requires(false)] }),
1123 err_invalid_bounds
1124 );
1125 assert_error!(
1126 parse_type_cond_spec(quote! { [requires(false)], T: A }),
1127 "expected `where`"
1128 );
1129 assert_error!(
1130 parse_type_cond_spec(quote! { where [requires(false)], T: A }),
1131 err_invalid_bounds
1132 );
1133 assert_error!(
1134 parse_type_cond_spec(quote! {}),
1135 format!("unexpected end of input, {}", "expected `where`")
1136 );
1137 assert_error!(parse_type_cond_spec(quote! { T: A }), "expected `where`");
1138 assert_error!(parse_type_cond_spec(quote! { where T: A }), "expected `,`");
1139 assert_error!(
1140 parse_type_cond_spec(quote! { where T: A, }),
1141 "expected nested specification in brackets"
1142 );
1143 assert_error!(
1144 parse_type_cond_spec(quote! { where T: A, {} }),
1145 err_invalid_bounds
1146 );
1147 assert_error!(
1148 parse_type_cond_spec(quote! { where T: A [requires(false)] }),
1149 "expected `,`"
1150 );
1151 assert_error!(
1152 parse_type_cond_spec(quote! { where T: A, [requires(false)], "nope" }),
1153 "unexpected extra tokens"
1154 );
1155 }
1156
1157 #[test]
1158 fn multiple_bounds_multiple_specs() {
1159 let constraint = parse_type_cond_spec(
1160 quote! { where T: A+B+Foo<i32>, U: C, [requires(true), ensures(false), pure]},
1161 )
1162 .unwrap();
1163
1164 assert_bounds_eq(
1165 &constraint.trait_bounds,
1166 &[quote! { T : A + B + Foo < i32 > }, quote! { U : C }],
1167 );
1168 match &constraint.specs[0] {
1169 NestedSpec::Requires(ts) => assert_eq!(ts.to_string(), "true"),
1170 _ => panic!(),
1171 }
1172 match &constraint.specs[1] {
1173 NestedSpec::Ensures(ts) => assert_eq!(ts.to_string(), "false"),
1174 _ => panic!(),
1175 }
1176 assert_matches!(&constraint.specs[2], NestedSpec::Pure);
1177 assert_eq!(constraint.specs.len(), 3);
1178 }
1179
1180 #[test]
1181 fn no_specs() {
1182 let constraint = parse_type_cond_spec(quote! { where T: A, []}).unwrap();
1183 assert_bounds_eq(&constraint.trait_bounds, &[quote! { T : A }]);
1184 assert!(constraint.specs.is_empty());
1185 }
1186
1187 #[test]
1188 fn fully_qualified_trait_path() {
1189 let constraint =
1190 parse_type_cond_spec(quote! { where T: path::to::A, [requires(true)]}).unwrap();
1191 assert_bounds_eq(&constraint.trait_bounds, &[quote! { T : path :: to :: A }]);
1192 }
1193
1194 #[test]
1195 fn tuple_generics() {
1196 assert!(parse_type_cond_spec(quote! { where T: Fn<(i32,), Output = i32>, []}).is_ok());
1198 assert!(parse_type_cond_spec(quote! { where T: Fn<(i32,)>, []}).is_ok());
1199 assert!(parse_type_cond_spec(quote! { where T: Fn<(i32, bool)>, []}).is_ok());
1200 assert!(parse_type_cond_spec(quote! { where T: Fn<(i32, bool,)>, []}).is_ok());
1201 }
1202
1203 fn assert_bounds_eq(parsed: &[syn::PredicateType], quotes: &[TokenStream]) {
1204 assert_eq!(parsed.len(), quotes.len());
1205 for (parsed, quote) in parsed.iter().zip(quotes.iter()) {
1206 assert_eq!(
1207 syn::WherePredicate::Type(parsed.clone()),
1208 syn::parse_quote! { #quote }
1209 );
1210 }
1211 }
1212 }
1213}