Skip to main content

anodized_core/annotate/
syntax.rs

1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{ToTokens, TokenStreamExt};
3use syn::{
4    Attribute, Expr, Ident, Pat, Token,
5    parse::{Parse, ParseStream, Result},
6    punctuated::Punctuated,
7    token,
8};
9
10/// Raw spec arguments, i.e. as they appear in the `#[spec(...)]` proc macro invocation.
11///
12/// Can represent a well-formed but invalid spec so that e.g. `anodized-fmt` may work with it.
13#[derive(Debug, Clone)]
14pub struct SpecArgs {
15    pub args: Punctuated<SpecArg, Token![,]>,
16}
17
18impl Parse for SpecArgs {
19    fn parse(input: ParseStream) -> Result<Self> {
20        Ok(Self {
21            args: Punctuated::<SpecArg, Token![,]>::parse_terminated(input)?,
22        })
23    }
24}
25
26impl SpecArgs {
27    /// Check whether the spec arguments are sorted correctly, ignoring unknown keywords.
28    pub fn is_sorted(&self) -> bool {
29        self.args
30            .iter()
31            .filter(|arg| !matches!(arg.keyword, Keyword::Unknown(_)))
32            .is_sorted_by_key(|arg| &arg.keyword)
33    }
34}
35
36/// A single spec argument.
37#[derive(Debug, Clone)]
38pub struct SpecArg {
39    pub attrs: Vec<Attribute>,
40    pub keyword: Keyword,
41    pub keyword_span: Span,
42    pub colon: Option<Token![:]>,
43    pub value: SpecArgValue,
44}
45
46impl Parse for SpecArg {
47    fn parse(input: ParseStream) -> Result<Self> {
48        let attrs = input.call(Attribute::parse_outer)?;
49        let (keyword, keyword_span) = Keyword::parse(input)?;
50
51        let (colon, value) = if input.peek(Token![:]) {
52            (
53                input.parse()?,
54                match keyword {
55                    Keyword::Binds => SpecArgValue::parse_pat_or_expr(input)?,
56                    Keyword::Captures => SpecArgValue::Captures(input.parse()?),
57                    _ => SpecArgValue::parse_expr_or_pat(input)?,
58                },
59            )
60        } else {
61            (None, SpecArgValue::None)
62        };
63
64        Ok(Self {
65            attrs,
66            keyword,
67            keyword_span,
68            colon,
69            value,
70        })
71    }
72}
73/// Each [`SpecArg`]'s value needs to be parsed in a way that allows invalid specs, e.g.
74/// forms which do not correspond directly to an [`syn::Expr`] in standard Rust.
75///
76/// NOTE:
77/// a [`SpecArgValue`] may hold unrelated syntactic elements such as ['syn::Expr`], [`syn::Pat`],
78/// and even fragments that would never appear as part of a valid Rust program.
79#[derive(Debug, Clone, Eq, PartialEq)]
80pub enum SpecArgValue {
81    None,
82    Expr(Expr),
83    Pat(Pat),
84    Captures(Captures),
85}
86
87impl SpecArgValue {
88    /// Return the `Expr` or fail.
89    pub fn try_into_expr(self) -> Result<Expr> {
90        if let Self::Expr(expr) = self {
91            return Ok(expr);
92        };
93        Err(syn::Error::new_spanned(self, "expected an expression"))
94    }
95
96    /// Return the `Pat` or fail.
97    pub fn try_into_pat(self) -> Result<Pat> {
98        if let Self::Pat(pat) = self {
99            return Ok(pat);
100        };
101        Err(syn::Error::new_spanned(self, "expected a pattern"))
102    }
103
104    /// Return the `Captures` or fail.
105    pub fn try_into_captures(self) -> Result<Captures> {
106        if let Self::Captures(captures) = self {
107            return Ok(captures);
108        };
109        Err(syn::Error::new_spanned(
110            self,
111            "expected captures: expression `as` pattern",
112        ))
113    }
114
115    /// Try to parse as `Expr` then as `Pat`.
116    fn parse_expr_or_pat(input: ParseStream) -> Result<Self> {
117        if let Ok(expr) = Self::parse_expr_or_nothing(input) {
118            Ok(Self::Expr(expr))
119        } else if let Ok(pat) = Self::parse_pat_or_nothing(input) {
120            Ok(Self::Pat(pat))
121        } else {
122            Err(input.error("expected an expression or a pattern"))
123        }
124    }
125
126    /// Try to parse as `Pat` then as `Expr`.
127    fn parse_pat_or_expr(input: ParseStream) -> Result<Self> {
128        if let Ok(pat) = Self::parse_pat_or_nothing(input) {
129            Ok(Self::Pat(pat))
130        } else if let Ok(expr) = Self::parse_expr_or_nothing(input) {
131            Ok(Self::Expr(expr))
132        } else {
133            Err(input.error("expected a pattern or an expression"))
134        }
135    }
136
137    /// Try to parse as `Expr` but consume no input on failure.
138    fn parse_expr_or_nothing(input: ParseStream<'_>) -> Result<Expr> {
139        use syn::parse::discouraged::Speculative;
140        let fork = input.fork();
141        match Expr::parse(&fork) {
142            Ok(expr) => {
143                input.advance_to(&fork);
144                Ok(expr)
145            }
146            Err(err) => Err(err),
147        }
148    }
149
150    /// Try to parse as `Pat` but consume no input on failure.
151    fn parse_pat_or_nothing(input: ParseStream<'_>) -> Result<Pat> {
152        use syn::parse::discouraged::Speculative;
153        let fork = input.fork();
154        match Pat::parse_single(&fork) {
155            Ok(pat) => {
156                input.advance_to(&fork);
157                Ok(pat)
158            }
159            Err(err) => Err(err),
160        }
161    }
162}
163
164impl ToTokens for SpecArgValue {
165    fn to_tokens(&self, tokens: &mut TokenStream) {
166        match self {
167            SpecArgValue::None => {}
168            SpecArgValue::Expr(expr) => expr.to_tokens(tokens),
169            SpecArgValue::Pat(pat) => pat.to_tokens(tokens),
170            SpecArgValue::Captures(captures) => captures.to_tokens(tokens),
171        }
172    }
173}
174
175/// A group of capture expressions, either a single one or a list.
176/// These are not composed of top level [`syn::Expr`] expressions.
177#[derive(Debug, Clone, Eq, PartialEq)]
178pub enum Captures {
179    One(Box<CaptureExpr>),
180    Many {
181        bracket: token::Bracket,
182        elems: Punctuated<CaptureExpr, Token![,]>,
183    },
184}
185
186impl Parse for Captures {
187    fn parse(input: ParseStream) -> Result<Self> {
188        // For bracketed input, we need to distinguish between:
189        // 1. `[a, b, c]` - an array of capture expressions
190        // 2. `[a, b, c] as slice` - a single capture with an array expr
191        //
192        // Multiple captures are in brackets, not followed by `as`
193        if input.peek(token::Bracket) && !input.peek2(Token![as]) {
194            // Parse as an array of captures
195            let content;
196            let bracket = syn::bracketed!(content in input);
197            let elems = Punctuated::parse_terminated(&content)?;
198            Ok(Captures::Many { bracket, elems })
199        } else {
200            // Otherwise parse as one capture
201            Ok(Captures::One(input.parse()?))
202        }
203    }
204}
205
206impl ToTokens for Captures {
207    fn to_tokens(&self, tokens: &mut TokenStream) {
208        match self {
209            Self::One(capture_expr) => capture_expr.to_tokens(tokens),
210            Self::Many { bracket, elems } => bracket.surround(tokens, |tokens| {
211                elems.to_tokens(tokens);
212            }),
213        }
214    }
215}
216
217/// The form in a `captures` clause: <expression> `as` <pattern>.
218#[derive(Debug, Clone, Eq, PartialEq)]
219pub struct CaptureExpr {
220    pub expr: Option<Expr>,
221    pub as_: Option<Token![as]>,
222    pub pat: Option<Pat>,
223}
224
225impl ToTokens for CaptureExpr {
226    fn to_tokens(&self, tokens: &mut TokenStream) {
227        if let Some(expr) = &self.expr {
228            expr.to_tokens(tokens);
229        }
230        if let Some(as_) = &self.as_ {
231            as_.to_tokens(tokens);
232        }
233        if let Some(pat) = &self.pat {
234            pat.to_tokens(tokens);
235        }
236    }
237}
238
239impl Parse for CaptureExpr {
240    fn parse(input: ParseStream) -> Result<Self> {
241        let tokens_before_as = take_until_comma_or_last_as(input)?;
242        let expr = syn::parse::Parser::parse2(Expr::parse, tokens_before_as).ok();
243        // TODO: need to check that the entirety of `tokens_before_as` was consumed
244        let as_ = if input.peek(Token![as]) {
245            Some(input.parse()?)
246        } else {
247            None
248        };
249        let pat = SpecArgValue::parse_pat_or_nothing(input).ok();
250        Ok(Self { expr, as_, pat })
251    }
252}
253
254/// Custom keywords for parsing. This allows us to use `requires`, `ensures`, etc.,
255/// as if they were built-in Rust keywords during parsing.
256pub mod kw {
257    syn::custom_keyword!(functional);
258    syn::custom_keyword!(pure);
259    syn::custom_keyword!(total);
260    syn::custom_keyword!(deterministic);
261    syn::custom_keyword!(effectfree);
262    syn::custom_keyword!(infallible);
263    syn::custom_keyword!(terminating);
264    syn::custom_keyword!(requires);
265    syn::custom_keyword!(maintains);
266    syn::custom_keyword!(captures);
267    syn::custom_keyword!(binds);
268    syn::custom_keyword!(ensures);
269    syn::custom_keyword!(decreases);
270}
271
272#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
273pub enum Keyword {
274    Unknown(Ident),
275    Functional,
276    Pure,
277    Total,
278    Deterministic,
279    Effectfree,
280    Infallible,
281    Terminating,
282    Requires,
283    Maintains,
284    Captures,
285    Binds,
286    Ensures,
287    Decreases,
288}
289
290impl Keyword {
291    fn parse(input: ParseStream) -> Result<(Self, Span)> {
292        use Keyword::*;
293        Ok(if input.peek(kw::functional) {
294            let keyword: kw::functional = input.parse()?;
295            (Functional, keyword.span)
296        } else if input.peek(kw::pure) {
297            let keyword: kw::pure = input.parse()?;
298            (Pure, keyword.span)
299        } else if input.peek(kw::total) {
300            let keyword: kw::total = input.parse()?;
301            (Total, keyword.span)
302        } else if input.peek(kw::deterministic) {
303            let keyword: kw::deterministic = input.parse()?;
304            (Deterministic, keyword.span)
305        } else if input.peek(kw::effectfree) {
306            let keyword: kw::effectfree = input.parse()?;
307            (Effectfree, keyword.span)
308        } else if input.peek(kw::infallible) {
309            let keyword: kw::infallible = input.parse()?;
310            (Infallible, keyword.span)
311        } else if input.peek(kw::terminating) {
312            let keyword: kw::terminating = input.parse()?;
313            (Terminating, keyword.span)
314        } else if input.peek(kw::requires) {
315            let keyword: kw::requires = input.parse()?;
316            (Requires, keyword.span)
317        } else if input.peek(kw::maintains) {
318            let token: kw::maintains = input.parse()?;
319            (Maintains, token.span)
320        } else if input.peek(kw::captures) {
321            let token: kw::captures = input.parse()?;
322            (Captures, token.span)
323        } else if input.peek(kw::binds) {
324            let token: kw::binds = input.parse()?;
325            (Binds, token.span)
326        } else if input.peek(kw::ensures) {
327            let token: kw::ensures = input.parse()?;
328            (Ensures, token.span)
329        } else if input.peek(kw::decreases) {
330            let token: kw::decreases = input.parse()?;
331            (Decreases, token.span)
332        } else {
333            let ident: Ident = input.parse()?;
334            let span = ident.span();
335            (Unknown(ident), span)
336        })
337    }
338}
339
340impl std::fmt::Display for Keyword {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        match self {
343            Keyword::Unknown(ident) => write!(f, "{}", ident),
344            Keyword::Functional => write!(f, "functional"),
345            Keyword::Pure => write!(f, "pure"),
346            Keyword::Total => write!(f, "total"),
347            Keyword::Deterministic => write!(f, "deterministic"),
348            Keyword::Effectfree => write!(f, "effectfree"),
349            Keyword::Infallible => write!(f, "infallible"),
350            Keyword::Terminating => write!(f, "terminating"),
351            Keyword::Requires => write!(f, "requires"),
352            Keyword::Maintains => write!(f, "maintains"),
353            Keyword::Captures => write!(f, "captures"),
354            Keyword::Binds => write!(f, "binds"),
355            Keyword::Ensures => write!(f, "ensures"),
356            Keyword::Decreases => write!(f, "decreases"),
357        }
358    }
359}
360
361/// Find the last `as` token that is before a comma (or end of input).
362/// Consume and return tokens before the last `as`.
363/// If no `as` is encountered, consume and return all tokens before a comma.
364/// Groups (delimited by `()`, `[]`, `{}`) are considered atomically,
365/// so any `as` or comma inside them is ignored.
366fn take_until_comma_or_last_as(input: ParseStream) -> Result<TokenStream> {
367    use syn::parse::discouraged::Speculative;
368    let fork = input.fork();
369    let mut peeked_tokens = TokenStream::new();
370    let mut consumed_tokens = TokenStream::new();
371    let mut has_seen_as = false;
372    while !fork.is_empty() && !fork.peek(Token![,]) {
373        if fork.peek(Token![as]) {
374            has_seen_as = true;
375            // Consumed peeked tokens
376            consumed_tokens.extend(peeked_tokens);
377            peeked_tokens = TokenStream::new();
378            input.advance_to(&fork);
379        }
380        let token: TokenTree = fork.parse()?;
381        peeked_tokens.append(token);
382    }
383    if has_seen_as {
384        Ok(consumed_tokens)
385    } else {
386        input.advance_to(&fork);
387        Ok(peeked_tokens)
388    }
389}