formality_core/
parse.rs

1use std::{str::FromStr, sync::Arc};
2
3use crate::{
4    binder::CoreBinder,
5    cast::To,
6    collections::Set,
7    language::{CoreKind, CoreParameter, Language},
8    set,
9    term::CoreTerm,
10    variable::CoreBoundVar,
11};
12use std::fmt::Debug;
13
14/// Trait for parsing a [`Term<L>`](`crate::term::Term`) as input.
15pub trait CoreParse<L: Language>: Sized + Debug {
16    /// Parse a single instance of this type, returning an error if no such
17    /// instance is present.
18    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self>;
19
20    /// Parse many instances of self, expecting `close_char` to appear after the last instance
21    /// (`close_char` is not consumed).
22    fn parse_many<'t>(
23        scope: &Scope<L>,
24        mut text: &'t str,
25        close_char: char,
26    ) -> ParseResult<'t, Vec<Self>> {
27        let mut result = vec![];
28        while !skip_whitespace(text).starts_with(close_char) {
29            let (e, t) = Self::parse(scope, text)?;
30            result.push(e);
31            text = t;
32        }
33        Ok((result, text))
34    }
35
36    /// Comma separated list with optional trailing comma.
37    fn parse_comma<'t>(
38        scope: &Scope<L>,
39        mut text: &'t str,
40        close_char: char,
41    ) -> ParseResult<'t, Vec<Self>> {
42        let mut result = vec![];
43        while !skip_whitespace(text).starts_with(close_char) {
44            let (e, t) = Self::parse(scope, text)?;
45            result.push(e);
46            text = t;
47
48            if let Ok(((), t)) = expect_char(',', text) {
49                text = t;
50            } else {
51                break;
52            }
53        }
54
55        Ok((result, text))
56    }
57}
58
59/// Tracks an error that occurred while parsing.
60/// The parse error records the input text it saw, which will be
61/// some suffix of the original input, along with a message.
62///
63/// The actual [`ParseResult`] type tracks a *set* of parse errors.
64/// When parse errors are generated, there is just one (e.g., "expected identifier"),
65/// but when there are choice points in the grammar (e.g., when parsing an enum),
66/// those errors can be combined by [`require_unambiguous`].
67#[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq)]
68pub struct ParseError<'t> {
69    /// Input that triggered the parse error. Some suffix
70    /// of the original input.
71    pub text: &'t str,
72
73    /// Message describing what was expected.
74    pub message: String,
75}
76
77impl<'t> ParseError<'t> {
78    /// Creates a single parse error at the given point. Returns
79    /// a set so that it can be wrapped as a [`ParseResult`].
80    pub fn at(text: &'t str, message: String) -> Set<Self> {
81        set![ParseError { text, message }]
82    }
83
84    /// Offset of this error relative to the starting point `text`
85    pub fn offset(&self, text: &str) -> usize {
86        assert!(text.ends_with(self.text));
87        text.len() - self.text.len()
88    }
89
90    /// Returns the text that was consumed before this error occurred,
91    /// with `text` is the starting point.
92    pub fn consumed_before<'s>(&self, text: &'s str) -> &'s str {
93        let o = self.offset(text);
94        &text[..o]
95    }
96}
97
98pub type ParseResult<'t, T> = Result<(T, &'t str), Set<ParseError<'t>>>;
99
100/// Tracks the variables in scope at this point in parsing.
101#[derive(Clone, Debug)]
102pub struct Scope<L: Language> {
103    bindings: Vec<(String, CoreParameter<L>)>,
104}
105
106impl<L: Language> Scope<L> {
107    /// Creates a new scope with the given set of bindings.
108    pub fn new(bindings: impl IntoIterator<Item = (String, CoreParameter<L>)>) -> Self {
109        Self {
110            bindings: bindings.into_iter().collect(),
111        }
112    }
113
114    /// Look for a variable with the given name.
115    pub fn lookup(&self, name: &str) -> Option<CoreParameter<L>> {
116        self.bindings
117            .iter()
118            .rev()
119            .flat_map(|(n, p)| if name == n { Some(p.clone()) } else { None })
120            .next()
121    }
122
123    /// Create a new scope that extends `self` with `bindings`.
124    pub fn with_bindings(
125        &self,
126        bindings: impl IntoIterator<Item = (String, CoreParameter<L>)>,
127    ) -> Self {
128        let mut s = self.clone();
129        s.bindings.extend(bindings);
130        s
131    }
132}
133
134/// Records a single binding, used when parsing [`Binder`].
135#[derive(Clone, Debug)]
136pub struct Binding<L: Language> {
137    /// Name the user during during parsing
138    pub name: String,
139
140    /// The bound var representation.
141    pub bound_var: CoreBoundVar<L>,
142}
143
144impl<L, T> CoreParse<L> for Vec<T>
145where
146    L: Language,
147    T: CoreParse<L>,
148{
149    #[tracing::instrument(level = "trace", ret)]
150    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
151        let ((), text) = expect_char('[', text)?;
152        let (v, text) = T::parse_comma(scope, text, ']')?;
153        let ((), text) = expect_char(']', text)?;
154        Ok((v, text))
155    }
156}
157
158impl<L, T> CoreParse<L> for Set<T>
159where
160    L: Language,
161    T: CoreParse<L> + Ord,
162{
163    #[tracing::instrument(level = "trace", ret)]
164    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
165        let ((), text) = expect_char('{', text)?;
166        let (v, text) = T::parse_comma(scope, text, '}')?;
167        let ((), text) = expect_char('}', text)?;
168        let s = v.into_iter().collect();
169        Ok((s, text))
170    }
171}
172
173impl<L, T> CoreParse<L> for Option<T>
174where
175    L: Language,
176    T: CoreParse<L>,
177{
178    #[tracing::instrument(level = "trace", ret)]
179    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
180        match T::parse(scope, text) {
181            Ok((value, text)) => Ok((Some(value), text)),
182            Err(_) => Ok((None, text)),
183        }
184    }
185}
186
187/// Binding grammar is `$kind $name`, e.g., `ty Foo`.
188impl<L: Language> CoreParse<L> for Binding<L> {
189    #[tracing::instrument(level = "trace", ret)]
190    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
191        let (kind, text) = <CoreKind<L>>::parse(scope, text)?;
192        let (name, text) = identifier(text)?;
193        let bound_var = CoreBoundVar::fresh(kind);
194        Ok((Binding { name, bound_var }, text))
195    }
196}
197
198/// Parse a binder: find the names in scope, parse the contents, and then
199/// replace names with debruijn indices.
200impl<L, T> CoreParse<L> for CoreBinder<L, T>
201where
202    L: Language,
203    T: CoreTerm<L>,
204{
205    #[tracing::instrument(level = "trace", ret)]
206    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
207        let ((), text) = expect_char(L::BINDING_OPEN, text)?;
208        let (bindings, text) = Binding::parse_comma(scope, text, '>')?;
209        let ((), text) = expect_char(L::BINDING_CLOSE, text)?;
210
211        // parse the contents with those names in scope
212        let scope1 =
213            scope.with_bindings(bindings.iter().map(|b| (b.name.clone(), b.bound_var.to())));
214        let (data, text) = T::parse(&scope1, text)?;
215
216        let kvis: Vec<CoreBoundVar<L>> = bindings.iter().map(|b| b.bound_var).collect();
217        Ok((CoreBinder::new(kvis, data), text))
218    }
219}
220
221impl<L, T> CoreParse<L> for Arc<T>
222where
223    L: Language,
224    T: CoreParse<L>,
225{
226    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
227        let (data, text) = T::parse(scope, text)?;
228        Ok((Arc::new(data), text))
229    }
230}
231
232impl<L> CoreParse<L> for usize
233where
234    L: Language,
235{
236    #[tracing::instrument(level = "trace", ret)]
237    fn parse<'t>(_scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
238        number(text)
239    }
240}
241
242impl<L> CoreParse<L> for u32
243where
244    L: Language,
245{
246    #[tracing::instrument(level = "trace", ret)]
247    fn parse<'t>(_scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
248        number(text)
249    }
250}
251
252impl<L> CoreParse<L> for u64
253where
254    L: Language,
255{
256    #[tracing::instrument(level = "trace", ret)]
257    fn parse<'t>(_scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
258        number(text)
259    }
260}
261
262/// Extract the next character from input, returning an error if we've reached the input.
263///
264/// Warning: does not skip whitespace.
265fn char(text: &str) -> ParseResult<'_, char> {
266    let ch = match text.chars().next() {
267        Some(c) => c,
268        None => return Err(ParseError::at(text, "unexpected end of input".to_string())),
269    };
270    Ok((ch, &text[char::len_utf8(ch)..]))
271}
272
273/// Extract a number from the input, erroring if the input does not start with a number.
274#[tracing::instrument(level = "trace", ret)]
275pub fn number<T>(text0: &str) -> ParseResult<'_, T>
276where
277    T: FromStr + Debug,
278{
279    let (id, text1) = accumulate(text0, char::is_numeric, char::is_numeric, "number")?;
280    match T::from_str(&id) {
281        Ok(t) => Ok((t, text1)),
282        Err(_) => Err(ParseError::at(text0, format!("invalid number"))),
283    }
284}
285
286/// Consume next character and require that it be `ch`.
287#[tracing::instrument(level = "trace", ret)]
288pub fn expect_char(ch: char, text0: &str) -> ParseResult<'_, ()> {
289    let text1 = skip_whitespace(text0);
290    let (ch1, text1) = char(text1)?;
291    if ch == ch1 {
292        Ok(((), text1))
293    } else {
294        Err(ParseError::at(text0, format!("expected `{}`", ch)))
295    }
296}
297
298/// Consume a comma if one is present.
299#[tracing::instrument(level = "trace", ret)]
300pub fn skip_trailing_comma(text: &str) -> &str {
301    text.strip_prefix(',').unwrap_or(text)
302}
303
304/// Extracts a maximal identifier from the start of text,
305/// following the usual rules.
306#[tracing::instrument(level = "trace", ret)]
307pub fn identifier(text: &str) -> ParseResult<'_, String> {
308    accumulate(
309        text,
310        |ch| matches!(ch, 'a'..='z' | 'A'..='Z' | '_'),
311        |ch| matches!(ch, 'a'..='z' | 'A'..='Z' | '_' | '0'..='9'),
312        "identifier",
313    )
314}
315
316/// Consume next identifier, requiring that it be equal to `expected`.
317#[tracing::instrument(level = "trace", ret)]
318pub fn expect_keyword<'t>(expected: &str, text0: &'t str) -> ParseResult<'t, ()> {
319    match identifier(text0) {
320        Ok((ident, text1)) if &*ident == expected => Ok(((), text1)),
321        _ => Err(ParseError::at(text0, format!("expected `{}`", expected))),
322    }
323}
324
325/// Reject next identifier if it is the given keyword. Consumes nothing.
326#[tracing::instrument(level = "trace", ret)]
327pub fn reject_keyword<'t>(expected: &str, text0: &'t str) -> ParseResult<'t, ()> {
328    match expect_keyword(expected, text0) {
329        Ok(_) => Err(ParseError::at(
330            text0,
331            format!("found keyword `{}`", expected),
332        )),
333        Err(_) => Ok(((), text0)),
334    }
335}
336
337/// Convenience function for use when generating code: calls the closure it is given
338/// as argument. Used to introduce new scope for name bindings.
339pub fn try_parse<'a, R>(f: impl Fn() -> ParseResult<'a, R>) -> ParseResult<'a, R> {
340    f()
341}
342
343/// Used at choice points in the grammar. Iterates over all possible parses, looking
344/// for a single successful parse. If there are multiple successful parses, that
345/// indicates an ambiguous grammar, so we panic. If there are no successful parses,
346/// tries to come up with the best error it can: it prefers errors that arise from "partially successful"
347/// parses (e.g., parses that consume some input before failing), but if there are none of those,
348/// it will give an error at `text` saying that we expected to find a `expected`.
349pub fn require_unambiguous<'t, R>(
350    text: &'t str,
351    f: impl IntoIterator<Item = ParseResult<'t, R>>,
352    expected: &'static str,
353) -> ParseResult<'t, R>
354where
355    R: std::fmt::Debug,
356{
357    let mut errors = set![];
358    let mut results = vec![];
359    for result in f {
360        match result {
361            Ok(v) => results.push(v),
362            Err(es) => {
363                for e in es {
364                    // only include an error if the error resulted after at least
365                    // one non-whitespace character was consumed
366                    if !skip_whitespace(e.consumed_before(text)).is_empty() {
367                        errors.insert(e);
368                    }
369                }
370            }
371        }
372    }
373
374    if results.len() > 1 {
375        // More than one *positive* result indicates an ambiguous grammar, which is a programmer bug,
376        // not a fault of the input, so we panic (rather than returning Err)
377        panic!("parsing ambiguity: {results:?}");
378    } else if results.len() == 1 {
379        Ok(results.pop().unwrap())
380    } else if errors.is_empty() {
381        Err(ParseError::at(text, format!("{} expected", expected)))
382    } else {
383        Err(errors)
384    }
385}
386
387/// Extracts a maximal identifier from the start of text,
388/// following the usual rules.
389fn accumulate<'t>(
390    text0: &'t str,
391    start_test: impl Fn(char) -> bool,
392    continue_test: impl Fn(char) -> bool,
393    description: &'static str,
394) -> ParseResult<'t, String> {
395    let text1 = skip_whitespace(text0);
396    let mut buffer = String::new();
397
398    let (ch, text1) = char(text1)?;
399    if !start_test(ch) {
400        return Err(ParseError::at(text0, format!("{} expected", description)));
401    }
402    buffer.push(ch);
403
404    let mut text1 = text1;
405    while let Ok((ch, t)) = char(text1) {
406        if !continue_test(ch) {
407            break;
408        }
409
410        buffer.push(ch);
411        text1 = t;
412    }
413
414    Ok((buffer, text1))
415}
416
417impl<L: Language, A: CoreParse<L>, B: CoreParse<L>> CoreParse<L> for (A, B) {
418    #[tracing::instrument(level = "trace", ret)]
419    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
420        let ((), text) = expect_char('(', text)?;
421        let (a, text) = A::parse(scope, text)?;
422        let ((), text) = expect_char(',', text)?;
423        let (b, text) = B::parse(scope, text)?;
424        let text = skip_trailing_comma(text);
425        let ((), text) = expect_char(')', text)?;
426        Ok(((a, b), text))
427    }
428}
429
430impl<L: Language> CoreParse<L> for () {
431    #[tracing::instrument(level = "trace", ret)]
432    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
433        let ((), text) = expect_char('(', text)?;
434        let ((), text) = expect_char(')', text)?;
435        Ok(((), text))
436    }
437}
438
439impl<L: Language, A: CoreParse<L>, B: CoreParse<L>, C: CoreParse<L>> CoreParse<L> for (A, B, C) {
440    #[tracing::instrument(level = "trace", ret)]
441    fn parse<'t>(scope: &Scope<L>, text: &'t str) -> ParseResult<'t, Self> {
442        let ((), text) = expect_char('(', text)?;
443        let (a, text) = A::parse(scope, text)?;
444        let ((), text) = expect_char(',', text)?;
445        let (b, text) = B::parse(scope, text)?;
446        let ((), text) = expect_char(',', text)?;
447        let (c, text) = C::parse(scope, text)?;
448        let text = skip_trailing_comma(text);
449        let ((), text) = expect_char(')', text)?;
450        Ok(((a, b, c), text))
451    }
452}
453
454/// Skips leading whitespace and comments.
455pub fn skip_whitespace(mut text: &str) -> &str {
456    loop {
457        let len = text.len();
458
459        text = text.trim_start();
460
461        if text.starts_with("//") {
462            match text.find('\n') {
463                Some(index) => {
464                    text = &text[index + 1..];
465                }
466                None => {
467                    text = "";
468                }
469            }
470        }
471
472        if text.len() == len {
473            return text;
474        }
475    }
476}