1use nom::{
22 bytes::complete::take_while_m_n,
23 character::complete::{char as tag_char, digit1},
24 combinator::{map_res, not, opt, peek, recognize},
25 number::complete::{double, float},
26 sequence::{terminated, tuple},
27 Slice,
28};
29
30use core::{fmt, marker::PhantomData};
31
32mod traits;
33
34pub use self::traits::{
35 Features, Grammar, IntoInputSpan, MockTypes, Parse, ParseLiteral, Typed, Untyped,
36 WithMockedTypes,
37};
38
39use crate::{spans::NomResult, ErrorKind, InputSpan};
40
41#[derive(Debug)]
43pub struct NumGrammar<T>(PhantomData<T>);
44
45pub type F32Grammar = NumGrammar<f32>;
47pub type F64Grammar = NumGrammar<f64>;
49
50impl<T: NumLiteral> ParseLiteral for NumGrammar<T> {
51 type Lit = T;
52
53 fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
54 T::parse(input)
55 }
56}
57
58pub trait NumLiteral: 'static + Clone + fmt::Debug {
60 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self>;
62}
63
64pub fn ensure_no_overlap<'a, T>(
70 mut parser: impl FnMut(InputSpan<'a>) -> NomResult<'a, T>,
71) -> impl FnMut(InputSpan<'a>) -> NomResult<'a, T> {
72 let truncating_parser = move |input| {
73 parser(input).map(|(rest, number)| (maybe_truncate_consumed_input(input, rest), number))
74 };
75
76 terminated(
77 truncating_parser,
78 peek(not(take_while_m_n(1, 1, |c: char| {
79 c.is_ascii_alphabetic() || c == '_'
80 }))),
81 )
82}
83
84fn can_start_a_var_name(byte: u8) -> bool {
85 byte == b'_' || byte.is_ascii_alphabetic()
86}
87
88fn maybe_truncate_consumed_input<'a>(input: InputSpan<'a>, rest: InputSpan<'a>) -> InputSpan<'a> {
89 let relative_offset = rest.location_offset() - input.location_offset();
90 debug_assert!(relative_offset > 0, "num parser succeeded for empty string");
91 let last_consumed_byte_index = relative_offset - 1;
92
93 let input_fragment = *input.fragment();
94 let input_as_bytes = input_fragment.as_bytes();
95 if relative_offset < input_fragment.len()
96 && input_fragment.is_char_boundary(last_consumed_byte_index)
97 && input_as_bytes[last_consumed_byte_index] == b'.'
98 && can_start_a_var_name(input_as_bytes[relative_offset])
99 {
100 input.slice(last_consumed_byte_index..)
103 } else {
104 rest
105 }
106}
107
108macro_rules! impl_num_literal_for_uint {
109 ($($num:ident),+) => {
110 $(
111 impl NumLiteral for $num {
112 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
113 let parser = |s: InputSpan<'_>| s.fragment().parse().map_err(ErrorKind::literal);
114 map_res(digit1, parser)(input)
115 }
116 }
117 )+
118 };
119}
120
121impl_num_literal_for_uint!(u8, u16, u32, u64, u128);
122
123macro_rules! impl_num_literal_for_int {
124 ($($num:ident),+) => {
125 $(
126 impl NumLiteral for $num {
127 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
128 let parser = |s: InputSpan<'_>| s.fragment().parse().map_err(ErrorKind::literal);
129 map_res(recognize(tuple((opt(tag_char('-')), digit1))), parser)(input)
130 }
131 }
132 )+
133 };
134}
135
136impl_num_literal_for_int!(i8, i16, i32, i64, i128);
137
138impl NumLiteral for f32 {
139 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
140 ensure_no_overlap(float)(input)
141 }
142}
143
144impl NumLiteral for f64 {
145 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
146 ensure_no_overlap(double)(input)
147 }
148}
149
150#[cfg(feature = "num-complex")]
151mod complex {
152 use nom::{
153 branch::alt,
154 character::complete::one_of,
155 combinator::{map, opt},
156 number::complete::{double, float},
157 sequence::tuple,
158 };
159 use num_complex::Complex;
160 use num_traits::Num;
161
162 use super::{ensure_no_overlap, NumLiteral};
163 use crate::{InputSpan, NomResult};
164
165 fn complex_parser<'a, T: Num>(
166 num_parser: impl FnMut(InputSpan<'a>) -> NomResult<'a, T>,
167 ) -> impl FnMut(InputSpan<'a>) -> NomResult<'a, Complex<T>> {
168 let i_parser = map(one_of("ij"), |_| Complex::new(T::zero(), T::one()));
169
170 let parser = tuple((num_parser, opt(one_of("ij"))));
171 let parser = map(parser, |(value, maybe_imag)| {
172 if maybe_imag.is_some() {
173 Complex::new(T::zero(), value)
174 } else {
175 Complex::new(value, T::zero())
176 }
177 });
178
179 alt((i_parser, parser))
180 }
181
182 impl NumLiteral for num_complex::Complex32 {
183 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
184 ensure_no_overlap(complex_parser(float))(input)
185 }
186 }
187
188 impl NumLiteral for num_complex::Complex64 {
189 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
190 ensure_no_overlap(complex_parser(double))(input)
191 }
192 }
193}
194
195#[cfg(feature = "num-bigint")]
196mod bigint {
197 use nom::{
198 character::complete::{char as tag_char, digit1},
199 combinator::{map_res, opt, recognize},
200 sequence::tuple,
201 };
202 use num_bigint::{BigInt, BigUint};
203 use num_traits::Num;
204
205 use super::NumLiteral;
206 use crate::{ErrorKind, InputSpan, NomResult};
207
208 impl NumLiteral for BigInt {
209 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
210 let parser = |s: InputSpan<'_>| {
211 BigInt::from_str_radix(s.fragment(), 10).map_err(ErrorKind::literal)
212 };
213 map_res(recognize(tuple((opt(tag_char('-')), digit1))), parser)(input)
214 }
215 }
216
217 impl NumLiteral for BigUint {
218 fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
219 let parser = |s: InputSpan<'_>| {
220 BigUint::from_str_radix(s.fragment(), 10).map_err(ErrorKind::literal)
221 };
222 map_res(digit1, parser)(input)
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::Expr;
231
232 use assert_matches::assert_matches;
233 use nom::Err as NomErr;
234
235 #[test]
236 fn parsing_numbers_with_dot() {
237 #[derive(Debug, Clone, Copy)]
238 struct Sample {
239 input: &'static str,
240 consumed: usize,
241 value: f32,
242 }
243
244 #[rustfmt::skip]
245 const SAMPLES: &[Sample] = &[
246 Sample { input: "1.25+3", consumed: 4, value: 1.25 },
247
248 Sample { input: "1.", consumed: 2, value: 1.0 },
250 Sample { input: "-1.", consumed: 3, value: -1.0 },
251 Sample { input: "1. + 2.", consumed: 2, value: 1.0 },
252 Sample { input: "1.+2.", consumed: 2, value: 1.0 },
253 Sample { input: "1. .sin()", consumed: 2, value: 1.0 },
254
255 Sample { input: "1.sin()", consumed: 1, value: 1.0 },
257 Sample { input: "-3.sin()", consumed: 2, value: -3.0 },
258 Sample { input: "-3.5.sin()", consumed: 4, value: -3.5 },
259 ];
260
261 for &sample in SAMPLES {
262 let (rest, number) = <f32 as NumLiteral>::parse(InputSpan::new(sample.input)).unwrap();
263 assert!(
264 (number - sample.value).abs() < f32::EPSILON,
265 "Failed sample: {:?}",
266 sample
267 );
268 assert_eq!(
269 rest.location_offset(),
270 sample.consumed,
271 "Failed sample: {:?}",
272 sample
273 );
274 }
275 }
276
277 #[cfg(feature = "std")]
278 #[test]
280 fn parsing_infinity() {
281 let parsed = Untyped::<F32Grammar>::parse_statements("Inf").unwrap();
282 let ret = parsed.return_value.unwrap().extra;
283 assert_matches!(ret, Expr::Literal(lit) if lit == f32::INFINITY);
284
285 let parsed = Untyped::<F32Grammar>::parse_statements("-Inf").unwrap();
286 let ret = parsed.return_value.unwrap().extra;
287 assert_matches!(ret, Expr::Literal(lit) if lit == -f32::INFINITY);
288
289 let parsed = Untyped::<F32Grammar>::parse_statements("Infty").unwrap();
290 let ret = parsed.return_value.unwrap().extra;
291 assert_matches!(ret, Expr::Variable);
292
293 let parsed = Untyped::<F32Grammar>::parse_statements("Infer(1)").unwrap();
294 let ret = parsed.return_value.unwrap().extra;
295 assert_matches!(ret, Expr::Function { .. });
296
297 let parsed = Untyped::<F32Grammar>::parse_statements("-Infty").unwrap();
298 let ret = parsed.return_value.unwrap().extra;
299 assert_matches!(ret, Expr::Unary { .. });
300
301 let parsed = Untyped::<F32Grammar>::parse_statements("-Infer(2, 3)").unwrap();
302 let ret = parsed.return_value.unwrap().extra;
303 assert_matches!(ret, Expr::Unary { .. });
304 }
305
306 #[cfg(feature = "num-complex")]
307 #[test]
308 fn parsing_i() {
309 use crate::UnaryOp;
310 use num_complex::Complex32;
311
312 type C32Grammar = Untyped<NumGrammar<Complex32>>;
313
314 let parsed = C32Grammar::parse_statements("i").unwrap();
315 let ret = parsed.return_value.unwrap().extra;
316 assert_matches!(ret, Expr::Literal(lit) if lit == Complex32::i());
317
318 let parsed = C32Grammar::parse_statements("i + 5").unwrap();
319 let ret = parsed.return_value.unwrap().extra;
320 let i_as_lhs = &ret.binary_lhs().unwrap().extra;
321 assert_matches!(*i_as_lhs, Expr::Literal(lit) if lit == Complex32::i());
322
323 let parsed = C32Grammar::parse_statements("5 - i").unwrap();
324 let ret = parsed.return_value.unwrap().extra;
325 let i_as_rhs = &ret.binary_rhs().unwrap().extra;
326 assert_matches!(*i_as_rhs, Expr::Literal(lit) if lit == Complex32::i());
327
328 let parsed = C32Grammar::parse_statements("ix + 5").unwrap();
330 let ret = parsed.return_value.unwrap().extra;
331 let variable = &ret.binary_lhs().unwrap().extra;
332 assert_matches!(*variable, Expr::Variable);
333
334 let parsed = C32Grammar::parse_statements("-i + 5").unwrap();
335 let ret = parsed.return_value.unwrap().extra;
336 let negation_expr = &ret.binary_lhs().unwrap().extra;
337 let inner_lhs = match negation_expr {
338 Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
339 _ => panic!("Unexpected LHS: {:?}", negation_expr),
340 };
341 assert_matches!(inner_lhs, Expr::Literal(lit) if *lit == Complex32::i());
342
343 let parsed = C32Grammar::parse_statements("-ix + 5").unwrap();
344 let ret = parsed.return_value.unwrap().extra;
345 let var_negation = &ret.binary_lhs().unwrap().extra;
346 let negated_var = match var_negation {
347 Expr::Unary { inner, op } if op.extra == UnaryOp::Neg => &inner.extra,
348 _ => panic!("Unexpected LHS: {:?}", var_negation),
349 };
350 assert_matches!(negated_var, Expr::Variable);
351 }
352
353 #[test]
354 fn uint_parsers() {
355 let (_, u8_val) = <u8 as NumLiteral>::parse(InputSpan::new("3")).unwrap();
356 assert_eq!(u8_val, 3);
357 let (_, u16_val) = <u16 as NumLiteral>::parse(InputSpan::new("33333")).unwrap();
358 assert_eq!(u16_val, 33_333);
359 let (_, u32_val) = <u32 as NumLiteral>::parse(InputSpan::new("1111111111")).unwrap();
360 assert_eq!(u32_val, 1_111_111_111);
361 let (_, u64_val) =
362 <u64 as NumLiteral>::parse(InputSpan::new(&u64::MAX.to_string())).unwrap();
363 assert_eq!(u64_val, u64::MAX);
364 let (_, u128_val) =
365 <u128 as NumLiteral>::parse(InputSpan::new(&u128::MAX.to_string())).unwrap();
366 assert_eq!(u128_val, u128::MAX);
367 }
368
369 #[test]
370 fn int_parsers() {
371 let (_, min_val) = <i8 as NumLiteral>::parse(InputSpan::new("-128")).unwrap();
372 assert_eq!(min_val, -128);
373 let (_, max_val) = <i8 as NumLiteral>::parse(InputSpan::new("127")).unwrap();
374 assert_eq!(max_val, 127);
375
376 let err = <i8 as NumLiteral>::parse(InputSpan::new("128")).unwrap_err();
377 let err_kind = match &err {
378 NomErr::Error(err) => err.kind(),
379 _ => panic!("Unexpected error type: {:?}", err),
380 };
381 assert_matches!(err_kind, ErrorKind::Literal(_));
382 }
383
384 #[cfg(feature = "num-bigint")]
385 #[test]
386 fn bigint_parsers() {
387 use num_bigint::{BigInt, BigUint};
388
389 for len in 1..500 {
390 let input = "1".repeat(len);
391 let (_, value) = <BigUint as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
392 assert_eq!(value, BigUint::parse_bytes(input.as_bytes(), 10).unwrap());
393
394 let (_, value) = <BigInt as NumLiteral>::parse(InputSpan::new(&input)).unwrap();
395 let expected_value = BigInt::parse_bytes(input.as_bytes(), 10).unwrap();
396 assert_eq!(value, expected_value);
397 let (_, value) =
398 <BigInt as NumLiteral>::parse(InputSpan::new(&format!("-{}", input))).unwrap();
399 assert_eq!(value, -expected_value);
400 }
401 }
402}