prusti_specs/specifications/
preparser.rs

1/// The preparser processes Prusti syntax into Rust syntax.
2use proc_macro2::{Delimiter, Span, TokenStream, TokenTree};
3use proc_macro2::{Punct, Spacing::*};
4use quote::{quote, quote_spanned, ToTokens};
5use std::collections::VecDeque;
6use syn::{
7    parse::{Parse, ParseStream},
8    spanned::Spanned,
9};
10
11/// The representation of an argument to a quantifier (for example `a: i32`)
12#[derive(Debug, Clone)]
13pub struct Arg {
14    pub name: syn::Ident,
15    pub typ: syn::Type,
16}
17
18pub fn parse_prusti(tokens: TokenStream) -> syn::Result<TokenStream> {
19    let parsed = PrustiTokenStream::new(tokens).parse()?;
20    // to make sure we catch errors in the Rust syntax early (and with the
21    // correct spans), we try to parse the resulting stream using syn here
22    syn::parse2::<syn::Expr>(parsed.clone())?;
23    Ok(parsed)
24}
25pub fn parse_prusti_pledge(tokens: TokenStream) -> syn::Result<TokenStream> {
26    // TODO: pledges with reference that is not "result" are not supported;
27    // for this reason we assert here that the reference (if there is any) is "result"
28    // then return the RHS only
29    let (reference, rhs) = PrustiTokenStream::new(tokens).parse_pledge()?;
30    if let Some(reference) = reference {
31        if reference.to_string() != "result" {
32            return err(
33                reference.span(),
34                "reference of after_expiry must be \"result\"",
35            );
36        }
37    }
38    syn::parse2::<syn::Expr>(rhs.clone())?;
39    Ok(rhs)
40}
41
42pub fn parse_prusti_assert_pledge(tokens: TokenStream) -> syn::Result<(TokenStream, TokenStream)> {
43    // TODO: pledges with reference that is not "result" are not supported;
44    // for this reason we assert here that the reference (if there is any) is "result"
45    // then return the RHS only
46    let (reference, lhs, rhs) = PrustiTokenStream::new(tokens).parse_assert_pledge()?;
47    if let Some(reference) = reference {
48        if reference.to_string() != "result" {
49            return err(
50                reference.span(),
51                "reference of assert_on_expiry must be \"result\"",
52            );
53        }
54    }
55    syn::parse2::<syn::Expr>(lhs.clone())?;
56    syn::parse2::<syn::Expr>(rhs.clone())?;
57    Ok((lhs, rhs))
58}
59
60pub fn parse_type_cond_spec(tokens: TokenStream) -> syn::Result<TypeCondSpecRefinement> {
61    syn::parse2(tokens)
62}
63
64/*
65Preparsing consists of two stages:
66
671. In [PrustiTokenStream::new], we map a [TokenStream] to a [PrustiTokenStream]
68   by identifying unary and binary operators. We also take care of Rust binary
69   operators that have lower precedence than the Prusti ones. Note that at this
70   token-based stage, a "binary operator" includes e.g. the semicolon for
71   statement sequencing.
72
732. In [PrustiTokenStream::parse], we perform the actual parsing as well as the
74   translation back to Rust syntax. The parser is a Pratt parser with binding
75   powers defined in [PrustiBinaryOp::binding_power]. Performing translation to
76   Rust syntax in this step allows us to not have to define data types for the
77   Prusti AST, as we reuse the Rust AST (i.e. [TokenTree] and [TokenStream]).
78*/
79
80/// The preparser reuses [syn::Result] to integrate with the rest of the specs
81/// library, even though syn is not used here.
82fn error(span: Span, msg: &str) -> syn::Error {
83    syn::Error::new(span, msg)
84}
85
86/// Same as `error`, conveniently packaged as `syn::Result::Err`.
87fn err<T>(span: Span, msg: &str) -> syn::Result<T> {
88    Err(error(span, msg))
89}
90
91#[derive(Debug, Clone)]
92struct PrustiTokenStream {
93    tokens: VecDeque<PrustiToken>,
94    source_span: Span,
95    // TODO: can we somehow update the span after popping stuff?
96}
97
98impl PrustiTokenStream {
99    /// Constructs a stream of Prusti tokens from a stream of Rust tokens.
100    fn new(source: TokenStream) -> Self {
101        let source_span = source.span();
102        let source = source.into_iter().collect::<Vec<_>>();
103
104        let mut pos = 0;
105        let mut tokens = VecDeque::new();
106
107        // TODO: figure out syntax for spec entailments (|= is taken in Rust)
108
109        while pos < source.len() {
110            // no matter what tokens we see, we will consume at least one
111            pos += 1;
112            tokens.push_back(match (&source[pos - 1], source.get(pos), source.get(pos + 1), source.get(pos + 2)) {
113                (
114                    TokenTree::Punct(p1),
115                    Some(TokenTree::Punct(p2)),
116                    Some(TokenTree::Punct(p3)),
117                    Some(TokenTree::Punct(p4)),
118                ) if let Some(op) = PrustiToken::parse_op4(p1, p2, p3, p4) => {
119                    // this was a four-character operator, consume three
120                    // additional tokens
121                    pos += 3;
122                    op
123                }
124                (
125                    TokenTree::Punct(p1),
126                    Some(TokenTree::Punct(p2)),
127                    Some(TokenTree::Punct(p3)),
128                    _
129                ) if let Some(op) = PrustiToken::parse_op3(p1, p2, p3) => {
130                    // this was a three-character operator, consume two
131                    // additional tokens
132                    pos += 2;
133                    op
134                }
135                (
136                    TokenTree::Punct(p1),
137                    Some(TokenTree::Punct(p2)),
138                    _,
139                    _,
140                ) if let Some(op) = PrustiToken::parse_op2(p1, p2) => {
141                    // this was a two-character operator, consume one
142                    // additional token
143                    pos += 1;
144                    op
145                }
146                (TokenTree::Ident(ident), _, _, _) if ident == "outer" =>
147                    PrustiToken::Outer(ident.span()),
148                (TokenTree::Ident(ident), _, _, _) if ident == "forall" =>
149                    PrustiToken::Quantifier(ident.span(), Quantifier::Forall),
150                (TokenTree::Ident(ident), _, _, _) if ident == "exists" =>
151                    PrustiToken::Quantifier(ident.span(), Quantifier::Exists),
152                (TokenTree::Punct(punct), _, _, _)
153                    if punct.as_char() == ',' && punct.spacing() == Alone =>
154                    PrustiToken::BinOp(punct.span(), PrustiBinaryOp::Rust(RustOp::Comma)),
155                (TokenTree::Punct(punct), _, _, _)
156                    if punct.as_char() == ';' && punct.spacing() == Alone =>
157                    PrustiToken::BinOp(punct.span(), PrustiBinaryOp::Rust(RustOp::Semicolon)),
158                (TokenTree::Punct(punct), _, _, _)
159                    if punct.as_char() == '=' && punct.spacing() == Alone =>
160                    PrustiToken::BinOp(punct.span(), PrustiBinaryOp::Rust(RustOp::Assign)),
161                (token @ TokenTree::Punct(punct), _, _, _) if punct.spacing() == Joint => {
162                    // make sure to fully consume any Rust operator
163                    // to avoid later mis-identifying its suffix
164                    tokens.push_back(PrustiToken::Token(token.clone()));
165                    while let Some(token @ TokenTree::Punct(p)) = source.get(pos) {
166                        pos += 1;
167                        tokens.push_back(PrustiToken::Token(token.clone()));
168                        if p.spacing() != Joint {
169                            break;
170                        }
171                    }
172                    continue;
173                }
174                (TokenTree::Group(group), _, _, _) => PrustiToken::Group(
175                    group.span(),
176                    group.delimiter(),
177                    Box::new(Self::new(group.stream())),
178                ),
179                (token, _, _, _) => PrustiToken::Token(token.clone()),
180            });
181        }
182        Self {
183            tokens,
184            source_span,
185        }
186    }
187
188    fn is_empty(&self) -> bool {
189        self.tokens.is_empty()
190    }
191
192    fn parse_rest<T, F>(mut self, f: F) -> syn::Result<T>
193    where
194        F: FnOnce(&mut Self) -> syn::Result<T>,
195    {
196        let result = f(&mut self)?;
197        if !self.is_empty() {
198            let start = self.tokens.front().expect("unreachable").span();
199            let end = self.tokens.back().expect("unreachable").span();
200            let span = join_spans(start, end);
201            return err(span, "unexpected extra tokens");
202        }
203        Ok(result)
204    }
205
206    /// Processes a Prusti token stream back into Rust syntax.
207    /// Prusti-specific syntax is allowed and translated.
208    fn parse(mut self) -> syn::Result<TokenStream> {
209        self.expr_bp(0)
210    }
211
212    /// Processes a Prusti token stream back into Rust syntax.
213    /// Prusti-specific syntax is not allowed and will raise an error.
214    fn parse_rust_only(self) -> syn::Result<TokenStream> {
215        Ok(TokenStream::from_iter(
216            self.tokens
217                .into_iter()
218                .map(|token| match token {
219                    PrustiToken::Group(_, _, box stream) => stream.parse_rust_only(),
220                    PrustiToken::Token(tree) => Ok(tree.to_token_stream()),
221                    PrustiToken::BinOp(span, PrustiBinaryOp::Rust(op)) => Ok(op.to_tokens(span)),
222                    _ => err(token.span(), "unexpected Prusti syntax"),
223                })
224                .collect::<Result<Vec<_>, _>>()?,
225        ))
226    }
227
228    /// Processes a Prusti token stream for a pledge, in the form `a => b` or
229    /// just `b`.
230    fn parse_pledge(self) -> syn::Result<(Option<TokenStream>, TokenStream)> {
231        let mut pledge_ops = self.split(PrustiBinaryOp::Rust(RustOp::Arrow), false);
232        if pledge_ops.len() == 1 {
233            Ok((None, pledge_ops[0].expr_bp(0)?))
234        } else if pledge_ops.len() == 2 {
235            Ok((Some(pledge_ops[0].expr_bp(0)?), pledge_ops[1].expr_bp(0)?))
236        } else {
237            err(Span::call_site(), "too many arrows in after_expiry")
238        }
239    }
240
241    /// Processes a Prusti token stream for an assert pledge, in the form `a =>
242    /// b, c` or `b, c`.
243    fn parse_assert_pledge(self) -> syn::Result<(Option<TokenStream>, TokenStream, TokenStream)> {
244        let mut pledge_ops = self.split(PrustiBinaryOp::Rust(RustOp::Arrow), false);
245        let (reference, body) = match (pledge_ops.pop(), pledge_ops.pop(), pledge_ops.pop()) {
246            (Some(body), None, _) => (None, body),
247            (Some(body), Some(mut reference), None) => (Some(reference.expr_bp(0)?), body),
248            _ => return err(Span::call_site(), "too many arrows in assert_on_expiry"),
249        };
250        let mut body_parts = body.split(PrustiBinaryOp::Rust(RustOp::Comma), false);
251        if body_parts.len() == 2 {
252            Ok((
253                reference,
254                body_parts[0].expr_bp(0)?,
255                body_parts[1].expr_bp(0)?,
256            ))
257        } else {
258            err(Span::call_site(), "missing assertion")
259        }
260    }
261
262    /// The core of the Pratt parser algorithm. [self.tokens] is the source of
263    /// "lexemes". [min_bp] is the minimum binding power we need to see when
264    /// identifying a binary operator.
265    /// See https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html
266    fn expr_bp(&mut self, min_bp: u8) -> syn::Result<TokenStream> {
267        let mut lhs = match self.tokens.pop_front() {
268            Some(PrustiToken::Group(span, delimiter, box stream)) => {
269                let mut group = proc_macro2::Group::new(delimiter, stream.parse()?);
270                group.set_span(span);
271                TokenTree::Group(group).to_token_stream()
272            }
273            Some(PrustiToken::Outer(span)) => {
274                let _stream = self
275                    .pop_group(Delimiter::Parenthesis)
276                    .ok_or_else(|| error(span, "expected parenthesized expression after outer"))?;
277                todo!()
278            }
279            Some(PrustiToken::Quantifier(span, kind)) => {
280                let mut stream = self.pop_group(Delimiter::Parenthesis).ok_or_else(|| {
281                    error(span, "expected parenthesized expression after quantifier")
282                })?;
283                let args = stream
284                    .pop_closure_args()
285                    .ok_or_else(|| error(span, "expected quantifier body"))?;
286
287                {
288                    // for quantifiers, argument types must be explicit
289                    // here we parse the closure with syn and check each
290                    // argument has a type annotation
291                    let cl_args = args.clone().parse_rust_only()?;
292                    let check_cl = quote! { | #cl_args | 0 };
293                    let parsed_cl = syn::parse2::<syn::ExprClosure>(check_cl)?;
294                    for pat in parsed_cl.inputs {
295                        match pat {
296                            syn::Pat::Type(_) => {}
297                            _ => {
298                                return err(
299                                    pat.span(),
300                                    "quantifier arguments must have explicit types",
301                                )
302                            }
303                        }
304                    }
305                };
306
307                let triggers = stream.extract_triggers()?;
308                if args.is_empty() {
309                    return err(span, "a quantifier must have at least one argument");
310                }
311                let args = args.parse()?;
312                let body = stream.parse()?;
313                kind.translate(span, triggers, args, body)
314            }
315
316            Some(PrustiToken::SpecEnt(span, _)) | Some(PrustiToken::CallDesc(span, _)) => {
317                return err(span, "unexpected operator")
318            }
319
320            // some Rust binary operators can appear on their own, e.g. `(..)`
321            Some(PrustiToken::BinOp(span, PrustiBinaryOp::Rust(op))) => op.to_tokens(span),
322
323            Some(PrustiToken::BinOp(span, _)) => return err(span, "unexpected binary operator"),
324            Some(PrustiToken::Token(token)) => token.to_token_stream(),
325            None => return Ok(TokenStream::new()),
326        };
327        loop {
328            let (span, op) = match self.tokens.front() {
329                // If we see a group or token, we simply add them to the
330                // current LHS. This way fragments of Rust code with higher-
331                // precedence operators (e.g. plus) are connected into atoms
332                // as far as our parser is concerned.
333                Some(PrustiToken::Group(span, delimiter, box stream)) => {
334                    let mut group = proc_macro2::Group::new(*delimiter, stream.clone().parse()?);
335                    group.set_span(*span);
336                    lhs.extend(TokenTree::Group(group).to_token_stream());
337                    self.tokens.pop_front();
338                    continue;
339                }
340                Some(PrustiToken::Token(token)) => {
341                    lhs.extend(token.to_token_stream());
342                    self.tokens.pop_front();
343                    continue;
344                }
345
346                Some(PrustiToken::SpecEnt(span, once)) => {
347                    let span = *span;
348                    let once = *once;
349                    self.tokens.pop_front();
350                    let args = self
351                        .pop_closure_args()
352                        .ok_or_else(|| error(span, "expected closure arguments"))?;
353                    let nested_closure_specs = self.pop_group_of_nested_specs(span)?;
354                    lhs = translate_spec_ent(
355                        span,
356                        once,
357                        lhs,
358                        args.split(PrustiBinaryOp::Rust(RustOp::Comma), true)
359                            .into_iter()
360                            .map(|stream| stream.parse())
361                            .collect::<Result<Vec<_>, _>>()?,
362                        nested_closure_specs,
363                    );
364                    continue;
365                }
366
367                Some(PrustiToken::CallDesc(..)) => todo!("call desc"),
368
369                Some(PrustiToken::BinOp(span, op)) => (*span, *op),
370                Some(PrustiToken::Outer(span)) => return err(*span, "unexpected outer"),
371                Some(PrustiToken::Quantifier(span, _)) => {
372                    return err(*span, "unexpected quantifier")
373                }
374
375                None => break,
376            };
377            let (l_bp, r_bp) = op.binding_power();
378            if l_bp < min_bp {
379                break;
380            }
381            self.tokens.pop_front();
382            let rhs = self.expr_bp(r_bp)?;
383
384            // In [new], when identifying consecutive sequences of operators,
385            // we delegate to `parse_op*` which identifies Rust operators. In
386            // some cases, such as `..`, the binary operator does not actually
387            // require a RHS. Thus we only emit this error for operators that
388            // Prusti defines, as the actual Rust operators will raise a parse
389            // error after desugaring anyway.
390            if !matches!(op, PrustiBinaryOp::Rust(_)) && rhs.is_empty() {
391                return err(span, "expected expression");
392            }
393            lhs = op.translate(span, lhs, rhs);
394        }
395        Ok(lhs)
396    }
397
398    fn pop_group(&mut self, delimiter: Delimiter) -> Option<Self> {
399        match self.tokens.pop_front() {
400            Some(PrustiToken::Group(_, del, box stream)) if del == delimiter => Some(stream),
401            _ => None,
402        }
403    }
404
405    fn pop_closure_args(&mut self) -> Option<Self> {
406        let mut tokens = VecDeque::new();
407
408        // special case: empty closure might be parsed as a logical or
409        if matches!(
410            self.tokens.front(),
411            Some(PrustiToken::BinOp(_, PrustiBinaryOp::Or))
412        ) {
413            return Some(Self {
414                tokens,
415                source_span: self.source_span,
416            });
417        }
418
419        if !self.tokens.pop_front()?.is_closure_brace() {
420            return None;
421        }
422        loop {
423            let token = self.tokens.pop_front()?;
424            if token.is_closure_brace() {
425                break;
426            }
427            tokens.push_back(token);
428        }
429
430        Some(Self {
431            tokens,
432            source_span: self.source_span,
433        })
434    }
435
436    fn pop_parenthesized_group(&mut self) -> syn::Result<Self> {
437        match self.tokens.pop_front() {
438            Some(PrustiToken::Group(_span, Delimiter::Parenthesis, box group)) => {
439                Ok(group) // TODO: need to clone()?
440            }
441            _ => Err(error(self.source_span, "expected parenthesized group")),
442        }
443    }
444
445    fn pop_single_nested_spec(&mut self) -> syn::Result<NestedSpec<Self>> {
446        let first = self
447            .tokens
448            .pop_front()
449            .ok_or_else(|| error(self.source_span, "expected nested spec"))?;
450        if let PrustiToken::Token(TokenTree::Ident(spec_type)) = first {
451            match spec_type.to_string().as_ref() {
452                "requires" => Ok(NestedSpec::Requires(self.pop_parenthesized_group()?)),
453                "ensures" => Ok(NestedSpec::Ensures(self.pop_parenthesized_group()?)),
454                "pure" => Ok(NestedSpec::Pure),
455                other => err(
456                    self.source_span,
457                    format!("unexpected nested spec type: {other}").as_ref(),
458                ),
459            }
460        } else {
461            err(self.source_span, "expected identifier")
462        }
463    }
464
465    fn pop_group_of_nested_specs(
466        &mut self,
467        span: Span,
468    ) -> syn::Result<Vec<NestedSpec<TokenStream>>> {
469        let group_of_specs = self
470            .pop_group(Delimiter::Bracket)
471            .ok_or_else(|| error(span, "expected nested specification in brackets"))?;
472        let parsed = group_of_specs
473            .split(PrustiBinaryOp::Rust(RustOp::Comma), true)
474            .into_iter()
475            .map(|stream| stream.parse_rest(|stream| stream.pop_single_nested_spec()))
476            .map(|stream| stream.and_then(|s| s.parse()))
477            .collect::<syn::Result<Vec<NestedSpec<TokenStream>>>>()?;
478        Ok(parsed)
479    }
480
481    fn split(self, split_on: PrustiBinaryOp, allow_trailing: bool) -> Vec<Self> {
482        if self.tokens.is_empty() {
483            return vec![];
484        }
485        let mut res = self
486            .tokens
487            .into_iter()
488            .collect::<Vec<_>>()
489            .split(|token| matches!(token, PrustiToken::BinOp(_, t) if *t == split_on))
490            .map(|group| Self {
491                tokens: group.iter().cloned().collect(),
492                source_span: self.source_span,
493            })
494            .collect::<Vec<_>>();
495        if allow_trailing && res.len() > 1 && res[res.len() - 1].tokens.is_empty() {
496            res.pop();
497        }
498        res
499    }
500
501    fn extract_triggers(&mut self) -> syn::Result<Vec<Vec<TokenStream>>> {
502        let len = self.tokens.len();
503        if len < 4 {
504            return Ok(vec![]);
505        }
506        match [
507            &self.tokens[len - 4],
508            &self.tokens[len - 3],
509            &self.tokens[len - 2],
510            &self.tokens[len - 1],
511        ] {
512            [PrustiToken::BinOp(_, PrustiBinaryOp::Rust(RustOp::Comma)), PrustiToken::Token(TokenTree::Ident(ident)), PrustiToken::BinOp(_, PrustiBinaryOp::Rust(RustOp::Assign)), PrustiToken::Group(triggers_span, Delimiter::Bracket, box triggers)]
513                if ident == "triggers" =>
514            {
515                let triggers = triggers
516                    .clone()
517                    .split(PrustiBinaryOp::Rust(RustOp::Comma), true)
518                    .into_iter()
519                    .map(|mut stream| {
520                        stream
521                            .pop_group(Delimiter::Parenthesis)
522                            .ok_or_else(|| {
523                                error(*triggers_span, "trigger sets must be tuples of expressions")
524                            })?
525                            .split(PrustiBinaryOp::Rust(RustOp::Comma), true)
526                            .into_iter()
527                            .map(|stream| stream.parse())
528                            .collect::<Result<Vec<_>, _>>()
529                    })
530                    .collect::<Result<Vec<_>, _>>();
531                self.tokens.truncate(len - 4);
532                triggers
533            }
534            _ => Ok(vec![]),
535        }
536    }
537}
538
539#[derive(Debug)]
540pub struct TypeCondSpecRefinement {
541    pub trait_bounds: Vec<syn::PredicateType>,
542    pub specs: Vec<NestedSpec<TokenStream>>,
543}
544
545impl Parse for TypeCondSpecRefinement {
546    fn parse(input: ParseStream) -> syn::Result<Self> {
547        input
548            .parse::<syn::Token![where]>()
549            .map_err(with_type_cond_spec_example)?;
550        Ok(TypeCondSpecRefinement {
551            trait_bounds: parse_trait_bounds(input)?,
552            specs: PrustiTokenStream::new(input.parse().unwrap())
553                .parse_rest(|pts| pts.pop_group_of_nested_specs(input.span()))?,
554        })
555    }
556}
557
558fn parse_trait_bounds(input: ParseStream) -> syn::Result<Vec<syn::PredicateType>> {
559    let mut bounds: Vec<syn::PredicateType> = Vec::new();
560    loop {
561        let predicate = input
562            .parse::<syn::WherePredicate>()
563            .map_err(with_type_cond_spec_example)?;
564        bounds.push(validate_predicate(predicate)?);
565        input
566            .parse::<syn::token::Comma>()
567            .map_err(with_type_cond_spec_example)?;
568        if input.peek(syn::token::Bracket) || input.is_empty() {
569            // now expecting specs in []
570            // also breaking when empty, to handle that as missing specs rather than a missing constraint
571            break;
572        }
573    }
574    Ok(bounds)
575}
576
577fn validate_predicate(predicate: syn::WherePredicate) -> syn::Result<syn::PredicateType> {
578    use syn::WherePredicate::*;
579
580    match predicate {
581        Type(type_bound) => {
582            validate_trait_bounds(&type_bound)?;
583            Ok(type_bound)
584        }
585        Lifetime(lifetime_bound) => disallowed_lifetime_error(lifetime_bound.span()),
586        Eq(eq_bound) => err(
587            eq_bound.span(),
588            "equality constraints are not allowed in type-conditional spec refinements",
589        ),
590    }
591}
592
593fn disallowed_lifetime_error<T>(span: Span) -> syn::Result<T> {
594    err(
595        span,
596        "lifetimes are not allowed in type-conditional spec refinement trait bounds",
597    )
598}
599
600fn validate_trait_bounds(trait_bounds: &syn::PredicateType) -> syn::Result<()> {
601    if let Some(lifetimes) = &trait_bounds.lifetimes {
602        return disallowed_lifetime_error(lifetimes.span());
603    }
604    for bound in &trait_bounds.bounds {
605        match bound {
606            syn::TypeParamBound::Lifetime(lt) => {
607                return disallowed_lifetime_error(lt.span());
608            }
609            syn::TypeParamBound::Trait(trait_bound) => {
610                if let Some(lt) = &trait_bound.lifetimes {
611                    return disallowed_lifetime_error(lt.span());
612                }
613            }
614        }
615    }
616
617    Ok(())
618}
619
620fn with_type_cond_spec_example(mut err: syn::Error) -> syn::Error {
621    err.combine(error(err.span(), "expected where constraint and specifications in brackets, e.g.: `refine_spec(where T: A + B, U: C, [requires(...), ...])`"));
622    err
623}
624
625/// A specification enclosed in another specification (e.g. in spec entailments or type-conditional spec refinements)
626#[derive(Debug)]
627pub enum NestedSpec<T> {
628    Requires(T),
629    Ensures(T),
630    Pure,
631}
632
633impl NestedSpec<PrustiTokenStream> {
634    fn parse(self) -> syn::Result<NestedSpec<TokenStream>> {
635        Ok(match self {
636            NestedSpec::Requires(stream) => NestedSpec::Requires(stream.parse()?),
637            NestedSpec::Ensures(stream) => NestedSpec::Ensures(stream.parse()?),
638            NestedSpec::Pure => NestedSpec::Pure,
639        })
640    }
641}
642
643#[derive(Debug, Clone)]
644enum PrustiToken {
645    Group(Span, Delimiter, Box<PrustiTokenStream>),
646    Token(TokenTree),
647    BinOp(Span, PrustiBinaryOp),
648    // TODO: add note about unops not sharing a variant, descriptions ...
649    Outer(Span),
650    Quantifier(Span, Quantifier),
651    SpecEnt(Span, bool),
652    CallDesc(Span, bool),
653}
654
655fn translate_spec_ent(
656    span: Span,
657    once: bool,
658    cl_expr: TokenStream,
659    cl_args: Vec<TokenStream>,
660    contract: Vec<NestedSpec<TokenStream>>,
661) -> TokenStream {
662    let once = if once {
663        quote_spanned! { span => true }
664    } else {
665        quote_spanned! { span => false }
666    };
667
668    // TODO: move extraction function generation into "fn_type_extractor"
669    let arg_count = cl_args.len();
670    let generics_args = (0..arg_count)
671        .map(|i| TokenTree::Ident(proc_macro2::Ident::new(&format!("GA{i}"), span)))
672        .collect::<Vec<_>>();
673    let generic_res = TokenTree::Ident(proc_macro2::Ident::new("GR", span));
674
675    let extract_args = (0..arg_count)
676        .map(|i| TokenTree::Ident(proc_macro2::Ident::new(&format!("__extract_arg{i}"), span)))
677        .collect::<Vec<_>>();
678    let extract_args_decl = extract_args
679        .iter()
680        .zip(generics_args.iter())
681        .map(|(ident, arg_type)| {
682            quote_spanned! { span =>
683                #[prusti::spec_only]
684                fn #ident<
685                    #(#generics_args),* ,
686                    #generic_res,
687                    F: FnOnce( #(#generics_args),* ) -> #generic_res
688                >(_f: &F) -> #arg_type { unreachable!() }
689            }
690        })
691        .collect::<Vec<_>>();
692
693    let preconds = contract
694        .iter()
695        .filter_map(|spec| match spec {
696            NestedSpec::Requires(stream) => Some(stream.clone()),
697            _ => None,
698        })
699        .collect::<Vec<_>>();
700    let postconds = contract
701        .into_iter()
702        .filter_map(|spec| match spec {
703            NestedSpec::Ensures(stream) => Some(stream),
704            _ => None,
705        })
706        .collect::<Vec<_>>();
707
708    // TODO: figure out `outer`
709
710    quote_spanned! { span => {
711        let __cl_ref = & #cl_expr;
712        #(#extract_args_decl)*
713        #[prusti::spec_only]
714        fn __extract_res<
715            #(#generics_args),* ,
716            #generic_res,
717            F: FnOnce( #(#generics_args),* ) -> #generic_res
718        >(_f: &F) -> #generic_res { unreachable!() }
719        #( let #cl_args = #extract_args(__cl_ref); )*
720        let result = __extract_res(__cl_ref);
721        specification_entailment(
722            #once,
723            __cl_ref,
724            ( #( #[prusti::spec_only] || -> bool { #preconds }, )* ),
725            ( #( #[prusti::spec_only] || -> bool { #postconds }, )* ),
726        )
727    } }
728}
729
730#[derive(Debug, Clone)]
731enum Quantifier {
732    Forall,
733    Exists,
734}
735
736impl Quantifier {
737    fn translate(
738        &self,
739        span: Span,
740        triggers: Vec<Vec<TokenStream>>,
741        args: TokenStream,
742        body: TokenStream,
743    ) -> TokenStream {
744        let full_span = join_spans(span, body.span());
745        let trigger_sets = triggers
746            .into_iter()
747            .map(|set| {
748                let triggers = TokenStream::from_iter(set.into_iter().map(|trigger| {
749                    quote_spanned! { trigger.span() =>
750                    #[prusti::spec_only] | #args | ( #trigger ), }
751                }));
752                quote_spanned! { full_span => ( #triggers ) }
753            })
754            .collect::<Vec<_>>();
755        let body = quote_spanned! { body.span() => #body };
756        match self {
757            Self::Forall => quote_spanned! { full_span => ::prusti_contracts::forall(
758                ( #( #trigger_sets, )* ),
759                #[prusti::spec_only] | #args | -> bool { #body }
760            ) },
761            Self::Exists => quote_spanned! { full_span => ::prusti_contracts::exists(
762                ( #( #trigger_sets, )* ),
763                #[prusti::spec_only] | #args | -> bool { #body }
764            ) },
765        }
766    }
767}
768
769// For Prusti-specific operators, in [operator2], [operator3], and [operator4]
770// we mainly care about the spacing of the last [Punct], as this lets us
771// know that the last character is not itself part of an actual Rust
772// operator.
773//
774// "==>" should still have the expected spacing of [Joint, Joint, Alone]
775// even though "==" and ">" are separate Rust operators.
776fn operator2(op: &str, p1: &Punct, p2: &Punct) -> bool {
777    let chars = op.chars().collect::<Vec<_>>();
778    [p1.as_char(), p2.as_char()] == chars[0..2] && p1.spacing() == Joint && p2.spacing() == Alone
779}
780
781fn operator3(op: &str, p1: &Punct, p2: &Punct, p3: &Punct) -> bool {
782    let chars = op.chars().collect::<Vec<_>>();
783    [p1.as_char(), p2.as_char(), p3.as_char()] == chars[0..3]
784        && p1.spacing() == Joint
785        && p2.spacing() == Joint
786        && p3.spacing() == Alone
787}
788
789fn operator4(op: &str, p1: &Punct, p2: &Punct, p3: &Punct, p4: &Punct) -> bool {
790    let chars = op.chars().collect::<Vec<_>>();
791    [p1.as_char(), p2.as_char(), p3.as_char(), p4.as_char()] == chars[0..4]
792        && p1.spacing() == Joint
793        && p2.spacing() == Joint
794        && p3.spacing() == Joint
795        && p4.spacing() == Alone
796}
797
798impl PrustiToken {
799    fn span(&self) -> Span {
800        match self {
801            Self::Group(span, _, _)
802            | Self::BinOp(span, _)
803            | Self::Outer(span)
804            | Self::Quantifier(span, _)
805            | Self::SpecEnt(span, _)
806            | Self::CallDesc(span, _) => *span,
807            Self::Token(tree) => tree.span(),
808        }
809    }
810
811    fn is_closure_brace(&self) -> bool {
812        matches!(self, Self::Token(TokenTree::Punct(p))
813            if p.as_char() == '|' && p.spacing() == proc_macro2::Spacing::Alone)
814    }
815
816    fn parse_op2(p1: &Punct, p2: &Punct) -> Option<Self> {
817        let span = join_spans(p1.span(), p2.span());
818        Some(Self::BinOp(
819            span,
820            if operator2("&&", p1, p2) {
821                PrustiBinaryOp::And
822            } else if operator2("||", p1, p2) {
823                PrustiBinaryOp::Or
824            } else if operator2("->", p1, p2) {
825                PrustiBinaryOp::Implies
826            } else if operator2("..", p1, p2) {
827                PrustiBinaryOp::Rust(RustOp::Range)
828            } else if operator2("+=", p1, p2) {
829                PrustiBinaryOp::Rust(RustOp::AddAssign)
830            } else if operator2("-=", p1, p2) {
831                PrustiBinaryOp::Rust(RustOp::SubtractAssign)
832            } else if operator2("*=", p1, p2) {
833                PrustiBinaryOp::Rust(RustOp::MultiplyAssign)
834            } else if operator2("/=", p1, p2) {
835                PrustiBinaryOp::Rust(RustOp::DivideAssign)
836            } else if operator2("%=", p1, p2) {
837                PrustiBinaryOp::Rust(RustOp::ModuloAssign)
838            } else if operator2("&=", p1, p2) {
839                PrustiBinaryOp::Rust(RustOp::BitAndAssign)
840            //} else if operator2("|=", p1, p2) {
841            //    PrustiBinaryOp::Rust(RustOp::BitOrAssign)
842            } else if operator2("^=", p1, p2) {
843                PrustiBinaryOp::Rust(RustOp::BitXorAssign)
844            } else if operator2("=>", p1, p2) {
845                PrustiBinaryOp::Rust(RustOp::Arrow)
846            } else if operator2("|=", p1, p2) {
847                return Some(Self::SpecEnt(span, false));
848            } else if operator2("~>", p1, p2) {
849                return Some(Self::CallDesc(span, false));
850            } else {
851                return None;
852            },
853        ))
854    }
855
856    fn parse_op3(p1: &Punct, p2: &Punct, p3: &Punct) -> Option<Self> {
857        let span = join_spans(join_spans(p1.span(), p2.span()), p3.span());
858        Some(Self::BinOp(
859            span,
860            if operator3("==>", p1, p2, p3) {
861                PrustiBinaryOp::Implies
862            } else if operator3("<==", p1, p2, p3) {
863                PrustiBinaryOp::ImpliesReverse
864            } else if operator3("===", p1, p2, p3) {
865                PrustiBinaryOp::SnapEq
866            } else if operator3("!==", p1, p2, p3) {
867                PrustiBinaryOp::SnapNe
868            } else if operator3("..=", p1, p2, p3) {
869                PrustiBinaryOp::Rust(RustOp::RangeInclusive)
870            } else if operator3("<<=", p1, p2, p3) {
871                PrustiBinaryOp::Rust(RustOp::LeftShiftAssign)
872            } else if operator3(">>=", p1, p2, p3) {
873                PrustiBinaryOp::Rust(RustOp::RightShiftAssign)
874            } else if operator3("|=!", p1, p2, p3) {
875                return Some(Self::SpecEnt(span, true));
876            } else if operator3("~>!", p1, p2, p3) {
877                return Some(Self::CallDesc(span, true));
878            } else {
879                return None;
880            },
881        ))
882    }
883
884    fn parse_op4(p1: &Punct, p2: &Punct, p3: &Punct, p4: &Punct) -> Option<Self> {
885        let span = join_spans(
886            join_spans(join_spans(p1.span(), p2.span()), p3.span()),
887            p4.span(),
888        );
889        Some(Self::BinOp(
890            span,
891            if operator4("<==>", p1, p2, p3, p4) {
892                PrustiBinaryOp::Iff
893            } else {
894                return None;
895            },
896        ))
897    }
898}
899
900#[derive(Debug, Clone, Copy, PartialEq, Eq)]
901enum PrustiBinaryOp {
902    Rust(RustOp),
903    Iff,
904    Implies,
905    ImpliesReverse,
906    Or,
907    And,
908    SnapEq,
909    SnapNe,
910}
911
912impl PrustiBinaryOp {
913    /// This function defines both the precedence and associativity of each
914    /// binary operator. The result is used in [PrustiTokenStream::expr_bp].
915    /// The value is the "power" with which each side of the binary operator
916    /// binds to its LHS resp. RHS. So, given:
917    ///
918    /// binop1  expr  binop2
919    ///
920    /// Where binop1 has binding power (_, 3) and binop2 (4, _), then binop2
921    /// will bind to expr first, as 4 > 3.
922    ///
923    /// Associativity is likewise defined by making sure that each side of the
924    /// binding power is different. (4, 3) is right associative, (3, 4) is left
925    /// associative.
926    fn binding_power(&self) -> (u8, u8) {
927        // TODO: should <== and ==> have the same binding power? === and !==?
928        match self {
929            Self::Rust(_) => (0, 0),
930            Self::Iff => (4, 3),
931            Self::Implies => (6, 5),
932            Self::ImpliesReverse => (5, 6),
933            Self::Or => (7, 8),
934            Self::And => (9, 10),
935            Self::SnapEq => (11, 12),
936            Self::SnapNe => (11, 12),
937        }
938    }
939
940    fn translate(&self, span: Span, raw_lhs: TokenStream, raw_rhs: TokenStream) -> TokenStream {
941        // TODO: enforce types more strictly with type ascriptions
942        let lhs = quote_spanned! { raw_lhs.span() => (#raw_lhs) };
943        let rhs = quote_spanned! { raw_rhs.span() => (#raw_rhs) };
944        match self {
945            Self::Rust(op) => op.translate(span, raw_lhs, raw_rhs),
946            Self::Iff => {
947                let joined_span = join_spans(lhs.span(), rhs.span());
948                quote_spanned! { joined_span => #lhs == #rhs }
949            }
950            // implication is desugared into this form to avoid evaluation
951            // order issues: `f(a, b)` makes Rust evaluate both `a` and `b`
952            // before the `f` call
953            Self::Implies => {
954                let joined_span = join_spans(lhs.span(), rhs.span());
955                // preserve span of LHS
956                let not_lhs = quote_spanned! { lhs.span() => !#lhs };
957                quote_spanned! { joined_span => #not_lhs || #rhs }
958            }
959            Self::ImpliesReverse => {
960                let joined_span = join_spans(lhs.span(), rhs.span());
961                // preserve span of RHS
962                let not_rhs = quote_spanned! { rhs.span() => !#rhs };
963                quote_spanned! { joined_span => #not_rhs || #lhs }
964            }
965            Self::Or => quote_spanned! { span => #lhs || #rhs },
966            Self::And => quote_spanned! { span => #lhs && #rhs },
967            Self::SnapEq => {
968                let joined_span = join_spans(lhs.span(), rhs.span());
969                quote_spanned! { joined_span => snapshot_equality(&#lhs, &#rhs) }
970            }
971            Self::SnapNe => {
972                let joined_span = join_spans(lhs.span(), rhs.span());
973                quote_spanned! { joined_span => !snapshot_equality(&#lhs, &#rhs) }
974            }
975        }
976    }
977}
978
979#[derive(Debug, Clone, Copy, PartialEq, Eq)]
980enum RustOp {
981    RangeInclusive,
982    LeftShiftAssign,
983    RightShiftAssign,
984    Range,
985    AddAssign,
986    SubtractAssign,
987    MultiplyAssign,
988    DivideAssign,
989    ModuloAssign,
990    BitAndAssign,
991    // FIXME: |= spec entailment
992    // BitOrAssign,
993    BitXorAssign,
994    Arrow,
995    Comma,
996    Semicolon,
997    Assign,
998}
999
1000impl RustOp {
1001    fn translate(&self, span: Span, lhs: TokenStream, rhs: TokenStream) -> TokenStream {
1002        let op = self.to_tokens(span);
1003        quote! { #lhs #op #rhs }
1004    }
1005
1006    fn to_tokens(self, span: Span) -> TokenStream {
1007        match self {
1008            Self::RangeInclusive => quote_spanned! { span => ..= },
1009            Self::LeftShiftAssign => quote_spanned! { span => <<= },
1010            Self::RightShiftAssign => quote_spanned! { span => >>= },
1011            Self::Range => quote_spanned! { span => .. },
1012            Self::AddAssign => quote_spanned! { span => += },
1013            Self::SubtractAssign => quote_spanned! { span => -= },
1014            Self::MultiplyAssign => quote_spanned! { span => *= },
1015            Self::DivideAssign => quote_spanned! { span => /= },
1016            Self::ModuloAssign => quote_spanned! { span => %= },
1017            Self::BitAndAssign => quote_spanned! { span => &= },
1018            // Self::BitOrAssign => quote_spanned! { span => |= },
1019            Self::BitXorAssign => quote_spanned! { span => ^= },
1020            Self::Arrow => quote_spanned! { span => => },
1021            Self::Comma => quote_spanned! { span => , },
1022            Self::Semicolon => quote_spanned! { span => ; },
1023            Self::Assign => quote_spanned! { span => = },
1024        }
1025    }
1026}
1027
1028fn join_spans(s1: Span, s2: Span) -> Span {
1029    // Tests don't run in the proc macro context, so this gets a little funky for them
1030    if cfg!(test) {
1031        // During tests we don't care so much about returning a default
1032        s1.join(s2).unwrap_or(s1)
1033    } else {
1034        // This works even when compiled with stable, unlike `s1.join(s2)`
1035        s1.unwrap()
1036            .join(s2.unwrap())
1037            .expect("Failed to join spans!")
1038            .into()
1039    }
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045
1046    macro_rules! assert_error {
1047        ( $result:expr, $expected:expr ) => {{
1048            let _res = $result;
1049            assert!(_res.is_err());
1050            let _err = _res.unwrap_err();
1051            assert_eq!(_err.to_string(), $expected);
1052        }};
1053    }
1054
1055    #[test]
1056    fn test_preparser() {
1057        assert_eq!(
1058            parse_prusti("a ==> b".parse().unwrap())
1059                .unwrap()
1060                .to_string(),
1061            "! (a) || (b)",
1062        );
1063        assert_eq!(
1064            parse_prusti("a === b + c".parse().unwrap())
1065                .unwrap()
1066                .to_string(),
1067            "snapshot_equality (& (a) , & (b + c))",
1068        );
1069        assert_eq!(
1070            parse_prusti("a !== b + c".parse().unwrap())
1071                .unwrap()
1072                .to_string(),
1073            "! snapshot_equality (& (a) , & (b + c))",
1074        );
1075        assert_eq!(
1076            parse_prusti("a ==> b ==> c".parse().unwrap())
1077                .unwrap()
1078                .to_string(),
1079            "! (a) || (! (b) || (c))",
1080        );
1081        assert_eq!(
1082            parse_prusti("(a ==> b && c) ==> d || e".parse().unwrap())
1083                .unwrap()
1084                .to_string(),
1085            "! ((! (a) || ((b) && (c)))) || ((d) || (e))",
1086        );
1087        assert_eq!(
1088            parse_prusti("forall(|x: i32| a ==> b)".parse().unwrap())
1089                .unwrap()
1090                .to_string(),
1091            ":: prusti_contracts :: forall (() , # [prusti :: spec_only] | x : i32 | -> bool { ! (a) || (b) })",
1092        );
1093        assert_eq!(
1094            parse_prusti("exists(|x: i32| a === b)".parse().unwrap()).unwrap().to_string(),
1095            ":: prusti_contracts :: exists (() , # [prusti :: spec_only] | x : i32 | -> bool { snapshot_equality (& (a) , & (b)) })",
1096        );
1097        assert_eq!(
1098            parse_prusti("forall(|x: i32| a ==> b, triggers = [(c,), (d, e)])".parse().unwrap()).unwrap().to_string(),
1099            ":: prusti_contracts :: forall (((# [prusti :: spec_only] | x : i32 | (c) ,) , (# [prusti :: spec_only] | x : i32 | (d) , # [prusti :: spec_only] | x : i32 | (e) ,) ,) , # [prusti :: spec_only] | x : i32 | -> bool { ! (a) || (b) })",
1100        );
1101        assert_eq!(
1102            parse_prusti("assert!(a === b ==> b)".parse().unwrap())
1103                .unwrap()
1104                .to_string(),
1105            "assert ! (! (snapshot_equality (& (a) , & (b))) || (b))",
1106        );
1107    }
1108
1109    mod type_cond_specs {
1110        use std::assert_matches::assert_matches;
1111
1112        use super::*;
1113
1114        #[test]
1115        fn invalid_args() {
1116            let err_invalid_bounds = "expected one of: `for`, parentheses, `fn`, `unsafe`, `extern`, identifier, `::`, `<`, square brackets, `*`, `&`, `!`, `impl`, `_`, lifetime";
1117            assert_error!(
1118                parse_type_cond_spec(quote! { [requires(false)] }),
1119                "expected `where`"
1120            );
1121            assert_error!(
1122                parse_type_cond_spec(quote! { where [requires(false)] }),
1123                err_invalid_bounds
1124            );
1125            assert_error!(
1126                parse_type_cond_spec(quote! { [requires(false)], T: A }),
1127                "expected `where`"
1128            );
1129            assert_error!(
1130                parse_type_cond_spec(quote! { where [requires(false)], T: A }),
1131                err_invalid_bounds
1132            );
1133            assert_error!(
1134                parse_type_cond_spec(quote! {}),
1135                format!("unexpected end of input, {}", "expected `where`")
1136            );
1137            assert_error!(parse_type_cond_spec(quote! { T: A }), "expected `where`");
1138            assert_error!(parse_type_cond_spec(quote! { where T: A }), "expected `,`");
1139            assert_error!(
1140                parse_type_cond_spec(quote! { where T: A,  }),
1141                "expected nested specification in brackets"
1142            );
1143            assert_error!(
1144                parse_type_cond_spec(quote! { where T: A, {} }),
1145                err_invalid_bounds
1146            );
1147            assert_error!(
1148                parse_type_cond_spec(quote! { where T: A [requires(false)] }),
1149                "expected `,`"
1150            );
1151            assert_error!(
1152                parse_type_cond_spec(quote! { where T: A, [requires(false)], "nope" }),
1153                "unexpected extra tokens"
1154            );
1155        }
1156
1157        #[test]
1158        fn multiple_bounds_multiple_specs() {
1159            let constraint = parse_type_cond_spec(
1160                quote! { where T: A+B+Foo<i32>, U: C, [requires(true), ensures(false), pure]},
1161            )
1162            .unwrap();
1163
1164            assert_bounds_eq(
1165                &constraint.trait_bounds,
1166                &[quote! { T : A + B + Foo < i32 > }, quote! { U : C }],
1167            );
1168            match &constraint.specs[0] {
1169                NestedSpec::Requires(ts) => assert_eq!(ts.to_string(), "true"),
1170                _ => panic!(),
1171            }
1172            match &constraint.specs[1] {
1173                NestedSpec::Ensures(ts) => assert_eq!(ts.to_string(), "false"),
1174                _ => panic!(),
1175            }
1176            assert_matches!(&constraint.specs[2], NestedSpec::Pure);
1177            assert_eq!(constraint.specs.len(), 3);
1178        }
1179
1180        #[test]
1181        fn no_specs() {
1182            let constraint = parse_type_cond_spec(quote! { where T: A, []}).unwrap();
1183            assert_bounds_eq(&constraint.trait_bounds, &[quote! { T : A }]);
1184            assert!(constraint.specs.is_empty());
1185        }
1186
1187        #[test]
1188        fn fully_qualified_trait_path() {
1189            let constraint =
1190                parse_type_cond_spec(quote! { where T: path::to::A, [requires(true)]}).unwrap();
1191            assert_bounds_eq(&constraint.trait_bounds, &[quote! { T : path :: to :: A }]);
1192        }
1193
1194        #[test]
1195        fn tuple_generics() {
1196            // just check that parsing succeeds
1197            assert!(parse_type_cond_spec(quote! { where T: Fn<(i32,), Output = i32>, []}).is_ok());
1198            assert!(parse_type_cond_spec(quote! { where T: Fn<(i32,)>, []}).is_ok());
1199            assert!(parse_type_cond_spec(quote! { where T: Fn<(i32, bool)>, []}).is_ok());
1200            assert!(parse_type_cond_spec(quote! { where T: Fn<(i32, bool,)>, []}).is_ok());
1201        }
1202
1203        fn assert_bounds_eq(parsed: &[syn::PredicateType], quotes: &[TokenStream]) {
1204            assert_eq!(parsed.len(), quotes.len());
1205            for (parsed, quote) in parsed.iter().zip(quotes.iter()) {
1206                assert_eq!(
1207                    syn::WherePredicate::Type(parsed.clone()),
1208                    syn::parse_quote! { #quote }
1209                );
1210            }
1211        }
1212    }
1213}