arithmetic_parser/grammars/
mod.rs

1//! Grammar functionality and a collection of standard grammars.
2//!
3//! # Defining grammars
4//!
5//! To define a [`Grammar`], you'll need a [`ParseLiteral`] implementation, which defines
6//! how literals are parsed (numbers, strings, chars, hex- / base64-encoded byte sequences, etc.).
7//! There are standard impls for floating-point number parsing and the complex numbers
8//! (if the relevant feature is on).
9//!
10//! You may define how to parse type annotations by implementing `Grammar` explicitly.
11//! Alternatively, if you don't need type annotations, a `Grammar` can be obtained from
12//! a [`ParseLiteral`] impl by wrapping it into [`Untyped`].
13//!
14//! Once you have a `Grammar`, you can supply it as a `Base` for [`Parse`]. `Parse` methods
15//! allow to parse complete or streaming [`Block`](crate::Block)s of statements.
16//! Note that `Untyped` and [`Typed`] wrappers allow to avoid an explicit `Parse` impl.
17//!
18//! See [`ParseLiteral`], [`Grammar`] and [`Parse`] docs for the examples of various grammar
19//! definitions.
20
21use nom::{
22    bytes::complete::take_while_m_n,
23    character::complete::{char as tag_char, digit1},
24    combinator::{map_res, not, opt, peek, recognize},
25    number::complete::{double, float},
26    sequence::{terminated, tuple},
27    Slice,
28};
29
30use core::{fmt, marker::PhantomData};
31
32mod traits;
33
34pub use self::traits::{
35    Features, Grammar, IntoInputSpan, MockTypes, Parse, ParseLiteral, Typed, Untyped,
36    WithMockedTypes,
37};
38
39use crate::{spans::NomResult, ErrorKind, InputSpan};
40
41/// Single-type numeric grammar parameterized by the literal type.
42#[derive(Debug)]
43pub struct NumGrammar<T>(PhantomData<T>);
44
45/// Type alias for a grammar on `f32` literals.
46pub type F32Grammar = NumGrammar<f32>;
47/// Type alias for a grammar on `f64` literals.
48pub type F64Grammar = NumGrammar<f64>;
49
50impl<T: NumLiteral> ParseLiteral for NumGrammar<T> {
51    type Lit = T;
52
53    fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
54        T::parse(input)
55    }
56}
57
58/// Numeric literal used in `NumGrammar`s.
59pub trait NumLiteral: 'static + Clone + fmt::Debug {
60    /// Tries to parse a literal.
61    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self>;
62}
63
64/// Ensures that the child parser does not consume a part of a larger expression by rejecting
65/// if the part following the input is an alphanumeric char or `_`.
66///
67/// For example, `float` parses `-Inf`, which can lead to parser failure if it's a part of
68/// a larger expression (e.g., `-Infer(2, 3)`).
69pub fn ensure_no_overlap<'a, T>(
70    mut parser: impl FnMut(InputSpan<'a>) -> NomResult<'a, T>,
71) -> impl FnMut(InputSpan<'a>) -> NomResult<'a, T> {
72    let truncating_parser = move |input| {
73        parser(input).map(|(rest, number)| (maybe_truncate_consumed_input(input, rest), number))
74    };
75
76    terminated(
77        truncating_parser,
78        peek(not(take_while_m_n(1, 1, |c: char| {
79            c.is_ascii_alphabetic() || c == '_'
80        }))),
81    )
82}
83
84fn can_start_a_var_name(byte: u8) -> bool {
85    byte == b'_' || byte.is_ascii_alphabetic()
86}
87
88fn maybe_truncate_consumed_input<'a>(input: InputSpan<'a>, rest: InputSpan<'a>) -> InputSpan<'a> {
89    let relative_offset = rest.location_offset() - input.location_offset();
90    debug_assert!(relative_offset > 0, "num parser succeeded for empty string");
91    let last_consumed_byte_index = relative_offset - 1;
92
93    let input_fragment = *input.fragment();
94    let input_as_bytes = input_fragment.as_bytes();
95    if relative_offset < input_fragment.len()
96        && input_fragment.is_char_boundary(last_consumed_byte_index)
97        && input_as_bytes[last_consumed_byte_index] == b'.'
98        && can_start_a_var_name(input_as_bytes[relative_offset])
99    {
100        // The last char consumed by the parser is '.' and the next part looks like
101        // a method call. Shift the `rest` boundary to include '.'.
102        input.slice(last_consumed_byte_index..)
103    } else {
104        rest
105    }
106}
107
108macro_rules! impl_num_literal_for_uint {
109    ($($num:ident),+) => {
110        $(
111        impl NumLiteral for $num {
112            fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
113                let parser = |s: InputSpan<'_>| s.fragment().parse().map_err(ErrorKind::literal);
114                map_res(digit1, parser)(input)
115            }
116        }
117        )+
118    };
119}
120
121impl_num_literal_for_uint!(u8, u16, u32, u64, u128);
122
123macro_rules! impl_num_literal_for_int {
124    ($($num:ident),+) => {
125        $(
126        impl NumLiteral for $num {
127            fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
128                let parser = |s: InputSpan<'_>| s.fragment().parse().map_err(ErrorKind::literal);
129                map_res(recognize(tuple((opt(tag_char('-')), digit1))), parser)(input)
130            }
131        }
132        )+
133    };
134}
135
136impl_num_literal_for_int!(i8, i16, i32, i64, i128);
137
138impl NumLiteral for f32 {
139    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
140        ensure_no_overlap(float)(input)
141    }
142}
143
144impl NumLiteral for f64 {
145    fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
146        ensure_no_overlap(double)(input)
147    }
148}
149
150#[cfg(feature = "num-complex")]
151mod complex {
152    use nom::{
153        branch::alt,
154        character::complete::one_of,
155        combinator::{map, opt},
156        number::complete::{double, float},
157        sequence::tuple,
158    };
159    use num_complex::Complex;
160    use num_traits::Num;
161
162    use super::{ensure_no_overlap, NumLiteral};
163    use crate::{InputSpan, NomResult};
164
165    fn complex_parser<'a, T: Num>(
166        num_parser: impl FnMut(InputSpan<'a>) -> NomResult<'a, T>,
167    ) -> impl FnMut(InputSpan<'a>) -> NomResult<'a, Complex<T>> {
168        let i_parser = map(one_of("ij"), |_| Complex::new(T::zero(), T::one()));
169
170        let parser = tuple((num_parser, opt(one_of("ij"))));
171        let parser = map(parser, |(value, maybe_imag)| {
172            if maybe_imag.is_some() {
173                Complex::new(T::zero(), value)
174            } else {
175                Complex::new(value, T::zero())
176            }
177        });
178
179        alt((i_parser, parser))
180    }
181
182    impl NumLiteral for num_complex::Complex32 {
183        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
184            ensure_no_overlap(complex_parser(float))(input)
185        }
186    }
187
188    impl NumLiteral for num_complex::Complex64 {
189        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
190            ensure_no_overlap(complex_parser(double))(input)
191        }
192    }
193}
194
195#[cfg(feature = "num-bigint")]
196mod bigint {
197    use nom::{
198        character::complete::{char as tag_char, digit1},
199        combinator::{map_res, opt, recognize},
200        sequence::tuple,
201    };
202    use num_bigint::{BigInt, BigUint};
203    use num_traits::Num;
204
205    use super::NumLiteral;
206    use crate::{ErrorKind, InputSpan, NomResult};
207
208    impl NumLiteral for BigInt {
209        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
210            let parser = |s: InputSpan<'_>| {
211                BigInt::from_str_radix(s.fragment(), 10).map_err(ErrorKind::literal)
212            };
213            map_res(recognize(tuple((opt(tag_char('-')), digit1))), parser)(input)
214        }
215    }
216
217    impl NumLiteral for BigUint {
218        fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
219            let parser = |s: InputSpan<'_>| {
220                BigUint::from_str_radix(s.fragment(), 10).map_err(ErrorKind::literal)
221            };
222            map_res(digit1, parser)(input)
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::Expr;
231
232    use assert_matches::assert_matches;
233    use nom::Err as NomErr;
234
235    #[test]
236    fn parsing_numbers_with_dot() {
237        #[derive(Debug, Clone, Copy)]
238        struct Sample {
239            input: &'static str,
240            consumed: usize,
241            value: f32,
242        }
243
244        #[rustfmt::skip]
245        const SAMPLES: &[Sample] = &[
246            Sample { input: "1.25+3", consumed: 4, value: 1.25 },
247
248            // Cases in which '.' should be consumed.
249            Sample { input: "1.", consumed: 2, value: 1.0 },
250            Sample { input: "-1.", consumed: 3, value: -1.0 },
251            Sample { input: "1. + 2.", consumed: 2, value: 1.0 },
252            Sample { input: "1.+2.", consumed: 2, value: 1.0 },
253            Sample { input: "1. .sin()", consumed: 2, value: 1.0 },
254
255            // Cases in which '.' should not be consumed.
256            Sample { input: "1.sin()", consumed: 1, value: 1.0 },
257            Sample { input: "-3.sin()", consumed: 2, value: -3.0 },
258            Sample { input: "-3.5.sin()", consumed: 4, value: -3.5 },
259        ];
260
261        for &sample in SAMPLES {
262            let (rest, number) = <f32 as NumLiteral>::parse(InputSpan::new(sample.input)).unwrap();
263            assert!(
264                (number - sample.value).abs() < f32::EPSILON,
265                "Failed sample: {:?}",
266                sample
267            );
268            assert_eq!(
269                rest.location_offset(),
270                sample.consumed,
271                "Failed sample: {:?}",
272                sample
273            );
274        }
275    }
276
277    #[cfg(feature = "std")]
278    // ^-- This behavior is specific to `lexical-core` dependency, which is switched on with `std`.
279    #[test]
280    fn parsing_infinity() {
281        let parsed = Untyped::<F32Grammar>::parse_statements("Inf").unwrap();
282        let ret = parsed.return_value.unwrap().extra;
283        assert_matches!(ret, Expr::Literal(lit) if lit == f32::INFINITY);
284
285        let parsed = Untyped::<F32Grammar>::parse_statements("-Inf").unwrap();
286        let ret = parsed.return_value.unwrap().extra;
287        assert_matches!(ret, Expr::Literal(lit) if lit == -f32::INFINITY);
288
289        let parsed = Untyped::<F32Grammar>::parse_statements("Infty").unwrap();
290        let ret = parsed.return_value.unwrap().extra;
291        assert_matches!(ret, Expr::Variable);
292
293        let parsed = Untyped::<F32Grammar>::parse_statements("Infer(1)").unwrap();
294        let ret = parsed.return_value.unwrap().extra;
295        assert_matches!(ret, Expr::Function { .. });
296
297        let parsed = Untyped::<F32Grammar>::parse_statements("-Infty").unwrap();
298        let ret = parsed.return_value.unwrap().extra;
299        assert_matches!(ret, Expr::Unary { .. });
300
301        let parsed = Untyped::<F32Grammar>::parse_statements("-Infer(2, 3)").unwrap();
302        let ret = parsed.return_value.unwrap().extra;
303        assert_matches!(ret, Expr::Unary { .. });
304    }
305
306    #[cfg(feature = "num-complex")]
307    #[test]
308    fn parsing_i() {
309        use crate::UnaryOp;
310        use num_complex::Complex32;
311
312        type C32Grammar = Untyped<NumGrammar<Complex32>>;
313
314        let parsed = C32Grammar::parse_statements("i").unwrap();
315        let ret = parsed.return_value.unwrap().extra;
316        assert_matches!(ret, Expr::Literal(lit) if lit == Complex32::i());
317
318        let parsed = C32Grammar::parse_statements("i + 5").unwrap();
319        let ret = parsed.return_value.unwrap().extra;
320        let i_as_lhs = &ret.binary_lhs().unwrap().extra;
321        assert_matches!(*i_as_lhs, Expr::Literal(lit) if lit == Complex32::i());
322
323        let parsed = C32Grammar::parse_statements("5 - i").unwrap();
324        let ret = parsed.return_value.unwrap().extra;
325        let i_as_rhs = &ret.binary_rhs().unwrap().extra;
326        assert_matches!(*i_as_rhs, Expr::Literal(lit) if lit == Complex32::i());
327
328        // `i` should not be parsed as a literal if it's a part of larger expression.
329        let parsed = C32Grammar::parse_statements("ix + 5").unwrap();
330        let ret = parsed.return_value.unwrap().extra;
331        let variable = &ret.binary_lhs().unwrap().extra;
332        assert_matches!(*variable, Expr::Variable);
333
334        let parsed = C32Grammar::parse_statements("-i + 5").unwrap();
335        let ret = parsed.return_value.unwrap().extra;
336        let negation_expr = &ret.binary_lhs().unwrap().extra;
337        let inner_lhs = match negation_expr {
338            Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
339            _ => panic!("Unexpected LHS: {:?}", negation_expr),
340        };
341        assert_matches!(inner_lhs, Expr::Literal(lit) if *lit == Complex32::i());
342
343        let parsed = C32Grammar::parse_statements("-ix + 5").unwrap();
344        let ret = parsed.return_value.unwrap().extra;
345        let var_negation = &ret.binary_lhs().unwrap().extra;
346        let negated_var = match var_negation {
347            Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
348            _ => panic!("Unexpected LHS: {:?}", var_negation),
349        };
350        assert_matches!(negated_var, Expr::Variable);
351    }
352
353    #[test]
354    fn uint_parsers() {
355        let (_, u8_val) = <u8 as NumLiteral>::parse(InputSpan::new("3")).unwrap();
356        assert_eq!(u8_val, 3);
357        let (_, u16_val) = <u16 as NumLiteral>::parse(InputSpan::new("33333")).unwrap();
358        assert_eq!(u16_val, 33_333);
359        let (_, u32_val) = <u32 as NumLiteral>::parse(InputSpan::new("1111111111")).unwrap();
360        assert_eq!(u32_val, 1_111_111_111);
361        let (_, u64_val) =
362            <u64 as NumLiteral>::parse(InputSpan::new(&u64::MAX.to_string())).unwrap();
363        assert_eq!(u64_val, u64::MAX);
364        let (_, u128_val) =
365            <u128 as NumLiteral>::parse(InputSpan::new(&u128::MAX.to_string())).unwrap();
366        assert_eq!(u128_val, u128::MAX);
367    }
368
369    #[test]
370    fn int_parsers() {
371        let (_, min_val) = <i8 as NumLiteral>::parse(InputSpan::new("-128")).unwrap();
372        assert_eq!(min_val, -128);
373        let (_, max_val) = <i8 as NumLiteral>::parse(InputSpan::new("127")).unwrap();
374        assert_eq!(max_val, 127);
375
376        let err = <i8 as NumLiteral>::parse(InputSpan::new("128")).unwrap_err();
377        let err_kind = match &err {
378            NomErr::Error(err) => err.kind(),
379            _ => panic!("Unexpected error type: {:?}", err),
380        };
381        assert_matches!(err_kind, ErrorKind::Literal(_));
382    }
383
384    #[cfg(feature = "num-bigint")]
385    #[test]
386    fn bigint_parsers() {
387        use num_bigint::{BigInt, BigUint};
388
389        for len in 1..500 {
390            let input = "1".repeat(len);
391            let (_, value) = <BigUint as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
392            assert_eq!(value, BigUint::parse_bytes(input.as_bytes(), 10).unwrap());
393
394            let (_, value) = <BigInt as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
395            let expected_value = BigInt::parse_bytes(input.as_bytes(), 10).unwrap();
396            assert_eq!(value, expected_value);
397            let (_, value) =
398                <BigInt as NumLiteral>::parse(InputSpan::new(&format!("-{}", input))).unwrap();
399            assert_eq!(value, -expected_value);
400        }
401    }
402}