Skip to main content

openpql_pql_parser/ast/
num.rs

1use super::{
2    Display, Error, LalrError, Loc, LocInfo, NumValueFloat, NumValueInt,
3    Spanned, str,
4};
5
6impl Spanned for Num {
7    fn loc(&self) -> LocInfo {
8        self.loc
9    }
10}
11
12#[derive(Clone, PartialEq, derive_more::From, derive_more::Debug)]
13#[debug("{}", self.inner)]
14pub struct Num {
15    pub inner: NumValue,
16    pub loc: (Loc, Loc),
17}
18
19impl From<(NumValueFloat, (Loc, Loc))> for Num {
20    fn from((val, loc): (NumValueFloat, (Loc, Loc))) -> Self {
21        Self {
22            inner: val.into(),
23            loc,
24        }
25    }
26}
27
28impl From<(NumValueInt, (Loc, Loc))> for Num {
29    fn from((val, loc): (NumValueInt, (Loc, Loc))) -> Self {
30        Self {
31            inner: val.into(),
32            loc,
33        }
34    }
35}
36
37/// # Panics
38/// float parse won't fail /-?(\d+)?\.\d+/
39/// <https://doc.rust-lang.org/std/primitive.f64.html#method.from_str>
40impl<'input> TryFrom<(&'input str, (Loc, Loc), bool)> for Num {
41    type Error = LalrError<'input>;
42
43    fn try_from(
44        (src, loc, is_float): (&'input str, (Loc, Loc), bool),
45    ) -> Result<Self, Self::Error> {
46        if is_float {
47            Ok((src.parse::<NumValueFloat>().unwrap(), loc).into())
48        } else {
49            src.parse::<NumValueInt>().map_or_else(
50                |_| Err(Error::InvalidNumericValue(loc).into()),
51                |v| Ok((v, loc).into()),
52            )
53        }
54    }
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, derive_more::From, Display)]
58pub enum NumValue {
59    #[display("{_0}")]
60    Int(NumValueInt),
61    #[display("{_0}")]
62    Float(NumValueFloat),
63}
64
65#[cfg(test)]
66mod tests {
67
68    use super::*;
69    use crate::*;
70
71    fn assert_num<T>(src: &str, expected: T)
72    where
73        NumValue: From<T>,
74    {
75        let loc_start = 0;
76        let loc_end = src.len();
77        assert_eq!(
78            parse_num(src),
79            Ok((NumValue::from(expected), (loc_start, loc_end)).into())
80        );
81    }
82
83    #[test]
84    fn test_num() {
85        assert_num("0", 0);
86        assert_num("-1", -1);
87        assert_num("-1.5", -1.5);
88        assert_num("-.5", -0.5);
89        assert_num(".5", 0.5);
90    }
91
92    #[test]
93    fn test_err() {
94        let toobig = format!("{}0", NumValueInt::MAX);
95        assert_eq!(
96            parse_num(&toobig),
97            Err(Error::InvalidNumericValue((0, toobig.len())))
98        );
99    }
100
101    #[test]
102    fn test_dbg() {
103        assert_eq!(format!("{:?}", Num::from((-123, (0, 1)))), "-123");
104    }
105}