arithmetic_typing/ast/
mod.rs

1//! ASTs for type annotations and their parsing logic.
2//!
3//! # Overview
4//!
5//! This module contains types representing AST for parsed type annotations; for example,
6//! [`TypeAst`] and [`FunctionAst`]. These two types expose `parse` method which
7//! allows to integrate them into `nom` parsing.
8
9use nom::{
10    branch::alt,
11    bytes::complete::{tag, take, take_until, take_while, take_while1, take_while_m_n},
12    character::complete::char as tag_char,
13    combinator::{cut, map, map_res, not, opt, peek, recognize},
14    multi::{many0, separated_list0, separated_list1},
15    sequence::{delimited, preceded, separated_pair, terminated, tuple},
16};
17
18use arithmetic_parser::{with_span, ErrorKind as ParseErrorKind, InputSpan, NomResult, Spanned};
19
20mod conversion;
21#[cfg(test)]
22mod tests;
23
24pub use self::conversion::AstConversionError;
25pub(crate) use self::conversion::AstConversionState;
26
27/// Type annotation after parsing.
28///
29/// Compared to [`Type`], this enum corresponds to AST, not to the logical presentation
30/// of a type.
31///
32/// [`Type`]: crate::Type
33///
34/// # Examples
35///
36/// ```
37/// use arithmetic_parser::InputSpan;
38/// # use arithmetic_typing::ast::TypeAst;
39/// # use assert_matches::assert_matches;
40///
41/// # fn main() -> anyhow::Result<()> {
42/// let input = InputSpan::new("(Num, ('T) -> ('T, 'T))");
43/// let (_, ty) = TypeAst::parse(input)?;
44/// let elements = match ty.extra {
45///     TypeAst::Tuple(elements) => elements,
46///     _ => unreachable!(),
47/// };
48/// assert_eq!(elements.start[0].extra, TypeAst::Ident);
49/// assert_matches!(
50///     &elements.start[1].extra,
51///     TypeAst::Function { .. }
52/// );
53/// # Ok(())
54/// # }
55/// ```
56#[derive(Debug, Clone, PartialEq)]
57#[non_exhaustive]
58pub enum TypeAst<'a> {
59    /// Type placeholder (`_`). Corresponds to a certain type that is not specified, like `_`
60    /// in type annotations in Rust.
61    Some,
62    /// Any type (`any`).
63    Any,
64    /// Dynamically applied constraints (`dyn _`).
65    Dyn(TypeConstraintsAst<'a>),
66    /// Non-ticked identifier, e.g., `Bool`.
67    Ident,
68    /// Ticked identifier, e.g., `'T`.
69    Param,
70    /// Functional type.
71    Function(Box<FunctionAst<'a>>),
72    /// Functional type with constraints.
73    FunctionWithConstraints {
74        /// Constraints on function params.
75        constraints: Spanned<'a, ConstraintsAst<'a>>,
76        /// Function body.
77        function: Box<Spanned<'a, FunctionAst<'a>>>,
78    },
79    /// Tuple type; for example, `(Num, Bool)`.
80    Tuple(TupleAst<'a>),
81    /// Slice type; for example, `[Num]` or `[(Num, T); N]`.
82    Slice(SliceAst<'a>),
83    /// Object type; for example, `{ len: Num }`. Not to be confused with object constraints.
84    Object(ObjectAst<'a>),
85}
86
87impl<'a> TypeAst<'a> {
88    /// Parses `input` as a type. This parser can be composed using `nom` infrastructure.
89    pub fn parse(input: InputSpan<'a>) -> NomResult<'a, Spanned<'a, Self>> {
90        with_span(type_definition)(input)
91    }
92}
93
94/// Spanned [`TypeAst`].
95pub type SpannedTypeAst<'a> = Spanned<'a, TypeAst<'a>>;
96
97/// Parsed tuple type, such as `(Num, Bool)` or `(fn() -> Num, ...[Num; _])`.
98#[derive(Debug, Clone, PartialEq)]
99pub struct TupleAst<'a> {
100    /// Elements at the beginning of the tuple, e.g., `Num` and `Bool`
101    /// in `(Num, Bool, ...[T; _])`.
102    pub start: Vec<SpannedTypeAst<'a>>,
103    /// Middle of the tuple, e.g., `[T; _]` in `(Num, Bool, ...[T; _])`.
104    pub middle: Option<Spanned<'a, SliceAst<'a>>>,
105    /// Elements at the end of the tuple, e.g., `Bool` in `(...[Num; _], Bool)`.
106    /// Guaranteed to be empty if `middle` is not present.
107    pub end: Vec<SpannedTypeAst<'a>>,
108}
109
110/// Parsed slice type, such as `[Num; N]`.
111#[derive(Debug, Clone, PartialEq)]
112pub struct SliceAst<'a> {
113    /// Element of this slice; for example, `Num` in `[Num; N]`.
114    pub element: Box<SpannedTypeAst<'a>>,
115    /// Length of this slice; for example, `N` in `[Num; N]`.
116    pub length: Spanned<'a, TupleLenAst>,
117}
118
119/// Parsed functional type.
120///
121/// In contrast to [`Function`], this struct corresponds to AST, not to the logical representation
122/// of functional types.
123///
124/// [`Function`]: crate::Function
125///
126/// # Examples
127///
128/// ```
129/// use arithmetic_parser::InputSpan;
130/// # use assert_matches::assert_matches;
131/// # use arithmetic_typing::ast::{FunctionAst, TypeAst};
132///
133/// # fn main() -> anyhow::Result<()> {
134/// let input = InputSpan::new("([Num; N]) -> Num");
135/// let (rest, ty) = FunctionAst::parse(input)?;
136/// assert!(rest.fragment().is_empty());
137/// assert_matches!(ty.args.extra.start[0].extra, TypeAst::Slice(_));
138/// assert_eq!(ty.return_type.extra, TypeAst::Ident);
139/// # Ok(())
140/// # }
141/// ```
142#[derive(Debug, Clone, PartialEq)]
143#[non_exhaustive]
144pub struct FunctionAst<'a> {
145    /// Function arguments.
146    pub args: Spanned<'a, TupleAst<'a>>,
147    /// Return type of the function.
148    pub return_type: SpannedTypeAst<'a>,
149}
150
151impl<'a> FunctionAst<'a> {
152    /// Parses `input` as a functional type. This parser can be composed using `nom` infrastructure.
153    pub fn parse(input: InputSpan<'a>) -> NomResult<'a, Self> {
154        fn_definition(input)
155    }
156}
157
158/// Parsed tuple length.
159#[derive(Debug, Clone, PartialEq)]
160#[non_exhaustive]
161pub enum TupleLenAst {
162    /// Length placeholder (`_`). Corresponds to any single length.
163    Some,
164    /// Dynamic tuple length. This length is *implicit*, as in `[Num]`. As such, it has
165    /// an empty span.
166    Dynamic,
167    /// Reference to a length; for example, `N` in `[Num; N]`.
168    Ident,
169}
170
171/// Parameter constraints, e.g. `for<len! N; T: Lin>`.
172#[derive(Debug, Clone, PartialEq)]
173#[non_exhaustive]
174pub struct ConstraintsAst<'a> {
175    /// Static lengths, e.g., `N` in `for<len! N>`.
176    pub static_lengths: Vec<Spanned<'a>>,
177    /// Type constraints.
178    pub type_params: Vec<(Spanned<'a>, TypeConstraintsAst<'a>)>,
179}
180
181/// Bounds that can be placed on a type variable.
182#[derive(Debug, Default, Clone, PartialEq)]
183#[non_exhaustive]
184pub struct TypeConstraintsAst<'a> {
185    /// Object constraint, such as `{ x: 'T }`.
186    pub object: Option<ObjectAst<'a>>,
187    /// Spans corresponding to constraints, e.g. `Foo` and `Bar` in `Foo + Bar`.
188    pub terms: Vec<Spanned<'a>>,
189}
190
191/// Object type or constraint, such as `{ x: Num, y: [(Num, Bool)] }`.
192#[derive(Debug, Clone, PartialEq)]
193#[non_exhaustive]
194pub struct ObjectAst<'a> {
195    /// Fields of the object.
196    pub fields: Vec<(Spanned<'a>, SpannedTypeAst<'a>)>,
197}
198
199/// Whitespace and comments.
200fn ws(input: InputSpan<'_>) -> NomResult<'_, InputSpan<'_>> {
201    fn narrow_ws(input: InputSpan<'_>) -> NomResult<'_, InputSpan<'_>> {
202        take_while1(|c: char| c.is_ascii_whitespace())(input)
203    }
204
205    fn long_comment_body(input: InputSpan<'_>) -> NomResult<'_, InputSpan<'_>> {
206        cut(take_until("*/"))(input)
207    }
208
209    let comment = preceded(tag("//"), take_while(|c: char| c != '\n'));
210    let long_comment = delimited(tag("/*"), long_comment_body, tag("*/"));
211    let ws_line = alt((narrow_ws, comment, long_comment));
212    recognize(many0(ws_line))(input)
213}
214
215/// Comma separator.
216fn comma_sep(input: InputSpan<'_>) -> NomResult<'_, char> {
217    delimited(ws, tag_char(','), ws)(input)
218}
219
220fn ident(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_>> {
221    preceded(
222        peek(take_while_m_n(1, 1, |c: char| {
223            c.is_ascii_alphabetic() || c == '_'
224        })),
225        map(
226            take_while1(|c: char| c.is_ascii_alphanumeric() || c == '_'),
227            Spanned::from,
228        ),
229    )(input)
230}
231
232fn not_keyword(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_>> {
233    map_res(ident, |ident| {
234        if *ident.fragment() == "as" {
235            Err(ParseErrorKind::Type(anyhow::anyhow!(
236                "`as` is a reserved keyword"
237            )))
238        } else {
239            Ok(ident)
240        }
241    })(input)
242}
243
244fn type_param_ident(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_>> {
245    preceded(tag_char('\''), ident)(input)
246}
247
248fn comma_separated_types(input: InputSpan<'_>) -> NomResult<'_, Vec<SpannedTypeAst<'_>>> {
249    separated_list0(delimited(ws, tag_char(','), ws), with_span(type_definition))(input)
250}
251
252fn tuple_middle(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_, SliceAst<'_>>> {
253    preceded(terminated(tag("..."), ws), with_span(slice_definition))(input)
254}
255
256type TupleTailAst<'a> = (Spanned<'a, SliceAst<'a>>, Vec<SpannedTypeAst<'a>>);
257
258fn tuple_tail(input: InputSpan<'_>) -> NomResult<'_, TupleTailAst<'_>> {
259    tuple((
260        tuple_middle,
261        map(
262            opt(preceded(comma_sep, comma_separated_types)),
263            Option::unwrap_or_default,
264        ),
265    ))(input)
266}
267
268fn tuple_definition(input: InputSpan<'_>) -> NomResult<'_, TupleAst<'_>> {
269    let maybe_comma = opt(comma_sep);
270
271    let main_parser = alt((
272        map(tuple_tail, |(middle, end)| TupleAst {
273            start: Vec::new(),
274            middle: Some(middle),
275            end,
276        }),
277        map(
278            tuple((comma_separated_types, opt(preceded(comma_sep, tuple_tail)))),
279            |(start, maybe_tail)| {
280                if let Some((middle, end)) = maybe_tail {
281                    TupleAst {
282                        start,
283                        middle: Some(middle),
284                        end,
285                    }
286                } else {
287                    TupleAst {
288                        start,
289                        middle: None,
290                        end: Vec::new(),
291                    }
292                }
293            },
294        ),
295    ));
296
297    preceded(
298        terminated(tag_char('('), ws),
299        // Once we've encountered the opening `(`, the input *must* correspond to the parser.
300        cut(terminated(
301            main_parser,
302            tuple((maybe_comma, ws, tag_char(')'))),
303        )),
304    )(input)
305}
306
307fn tuple_len(input: InputSpan<'_>) -> NomResult<'_, Spanned<'_, TupleLenAst>> {
308    let semicolon = tuple((ws, tag_char(';'), ws));
309    let empty = map(take(0_usize), Spanned::from);
310    map(alt((preceded(semicolon, not_keyword), empty)), |id| {
311        id.map_extra(|()| match *id.fragment() {
312            "_" => TupleLenAst::Some,
313            "" => TupleLenAst::Dynamic,
314            _ => TupleLenAst::Ident,
315        })
316    })(input)
317}
318
319fn slice_definition(input: InputSpan<'_>) -> NomResult<'_, SliceAst<'_>> {
320    preceded(
321        terminated(tag_char('['), ws),
322        // Once we've encountered the opening `[`, the input *must* correspond to the parser.
323        cut(terminated(
324            map(
325                tuple((with_span(type_definition), tuple_len)),
326                |(element, length)| SliceAst {
327                    element: Box::new(element),
328                    length,
329                },
330            ),
331            tuple((ws, tag_char(']'))),
332        )),
333    )(input)
334}
335
336fn object(input: InputSpan<'_>) -> NomResult<'_, ObjectAst<'_>> {
337    let colon = tuple((ws, tag_char(':'), ws));
338    let object_field = separated_pair(ident, colon, with_span(type_definition));
339    let object_body = terminated(separated_list1(comma_sep, object_field), opt(comma_sep));
340    let object = preceded(
341        terminated(tag_char('{'), ws),
342        cut(terminated(object_body, tuple((ws, tag_char('}'))))),
343    );
344    map(object, |fields| ObjectAst { fields })(input)
345}
346
347fn constraint_sep(input: InputSpan<'_>) -> NomResult<'_, ()> {
348    map(tuple((ws, tag_char('+'), ws)), drop)(input)
349}
350
351fn simple_type_bounds(input: InputSpan<'_>) -> NomResult<'_, TypeConstraintsAst<'_>> {
352    map(separated_list1(constraint_sep, not_keyword), |terms| {
353        TypeConstraintsAst {
354            object: None,
355            terms,
356        }
357    })(input)
358}
359
360fn type_bounds(input: InputSpan<'_>) -> NomResult<'_, TypeConstraintsAst<'_>> {
361    alt((
362        map(
363            tuple((
364                object,
365                opt(preceded(
366                    constraint_sep,
367                    separated_list1(constraint_sep, not_keyword),
368                )),
369            )),
370            |(object, terms)| TypeConstraintsAst {
371                object: Some(object),
372                terms: terms.unwrap_or_default(),
373            },
374        ),
375        simple_type_bounds,
376    ))(input)
377}
378
379fn type_params(input: InputSpan<'_>) -> NomResult<'_, Vec<(Spanned<'_>, TypeConstraintsAst<'_>)>> {
380    let type_bounds = preceded(tuple((ws, tag_char(':'), ws)), type_bounds);
381    let type_param = tuple((type_param_ident, type_bounds));
382    separated_list1(comma_sep, type_param)(input)
383}
384
385/// Function params, including the `for` keyword and `<>` brackets.
386fn constraints(input: InputSpan<'_>) -> NomResult<'_, ConstraintsAst<'_>> {
387    let semicolon = tuple((ws, tag_char(';'), ws));
388
389    let len_params = preceded(
390        terminated(tag("len!"), ws),
391        separated_list1(comma_sep, not_keyword),
392    );
393
394    let params_parser = alt((
395        map(
396            tuple((len_params, opt(preceded(semicolon, type_params)))),
397            |(static_lengths, type_params)| (static_lengths, type_params.unwrap_or_default()),
398        ),
399        map(type_params, |type_params| (vec![], type_params)),
400    ));
401
402    let constraints_parser = tuple((
403        terminated(tag("for"), ws),
404        terminated(tag_char('<'), ws),
405        cut(terminated(params_parser, tuple((ws, tag_char('>'))))),
406    ));
407
408    map(
409        constraints_parser,
410        |(_, _, (static_lengths, type_params))| ConstraintsAst {
411            static_lengths,
412            type_params,
413        },
414    )(input)
415}
416
417fn return_type(input: InputSpan<'_>) -> NomResult<'_, SpannedTypeAst<'_>> {
418    preceded(tuple((ws, tag("->"), ws)), cut(with_span(type_definition)))(input)
419}
420
421fn fn_or_tuple(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
422    map(
423        tuple((with_span(tuple_definition), opt(return_type))),
424        |(args, return_type)| {
425            if let Some(return_type) = return_type {
426                TypeAst::Function(Box::new(FunctionAst { args, return_type }))
427            } else {
428                TypeAst::Tuple(args.extra)
429            }
430        },
431    )(input)
432}
433
434fn fn_definition(input: InputSpan<'_>) -> NomResult<'_, FunctionAst<'_>> {
435    map(
436        tuple((with_span(tuple_definition), return_type)),
437        |(args, return_type)| FunctionAst { args, return_type },
438    )(input)
439}
440
441fn fn_definition_with_constraints(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
442    map(
443        tuple((with_span(constraints), ws, cut(with_span(fn_definition)))),
444        |(constraints, _, function)| TypeAst::FunctionWithConstraints {
445            constraints,
446            function: Box::new(function),
447        },
448    )(input)
449}
450
451fn not_ident_char(input: InputSpan<'_>) -> NomResult<'_, ()> {
452    peek(not(take_while_m_n(1, 1, |c: char| {
453        c.is_ascii_alphanumeric() || c == '_'
454    })))(input)
455}
456
457fn any_type(input: InputSpan<'_>) -> NomResult<'_, ()> {
458    terminated(map(tag("any"), drop), not_ident_char)(input)
459}
460
461fn dyn_type(input: InputSpan<'_>) -> NomResult<'_, TypeConstraintsAst<'_>> {
462    map(
463        preceded(
464            terminated(tag("dyn"), not_ident_char),
465            opt(preceded(ws, type_bounds)),
466        ),
467        Option::unwrap_or_default,
468    )(input)
469}
470
471fn free_ident(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
472    map(not_keyword, |id| match *id.fragment() {
473        "_" => TypeAst::Some,
474        _ => TypeAst::Ident,
475    })(input)
476}
477
478fn type_definition(input: InputSpan<'_>) -> NomResult<'_, TypeAst<'_>> {
479    alt((
480        fn_or_tuple,
481        fn_definition_with_constraints,
482        map(type_param_ident, |_| TypeAst::Param),
483        map(slice_definition, TypeAst::Slice),
484        map(object, TypeAst::Object),
485        map(dyn_type, TypeAst::Dyn),
486        map(any_type, |()| TypeAst::Any),
487        free_ident,
488    ))(input)
489}