precise_calc/
parser.rs

1//! Contains functions used to parse strings into an AST.
2
3use astro_float::{BigFloat, Radix};
4use nom::branch::alt;
5use nom::bytes::complete::{is_not, take_while};
6use nom::character::complete::{char, digit1, satisfy};
7use nom::combinator::{cut, map, opt, recognize, value};
8use nom::multi::{fold_many0, many0, many1, separated_list0};
9use nom::sequence::{delimited, pair, preceded, terminated, tuple};
10use nom::IResult;
11
12use crate::Number;
13use crate::RM;
14use crate::{ast, PREC};
15
16/// Parse a list of statements separated by newlines.
17pub fn parse_stmt_list(input: &str) -> IResult<&str, Vec<ast::Stmt>> {
18    separated_list0(newline1, parse_stmt)(input)
19}
20
21fn newline1(input: &str) -> IResult<&str, &str> {
22    recognize(many1(char('\n')))(input)
23}
24
25fn whitespace0(input: &str) -> IResult<&str, &str> {
26    recognize(many0(alt((char(' '), char('\t')))))(input)
27}
28
29/// Parse a single statement.
30pub fn parse_stmt(input: &str) -> IResult<&str, ast::Stmt> {
31    terminated(
32        alt((
33            map(
34                tuple((
35                    preceded(whitespace0, parse_symbol),
36                    delimited(
37                        char('('),
38                        separated_list0(
39                            delimited(whitespace0, char(','), whitespace0),
40                            parse_symbol,
41                        ),
42                        char(')'),
43                    ),
44                    delimited(whitespace0, char('='), whitespace0),
45                    terminated(parse_expr, whitespace0),
46                )),
47                |(name, params, _, body)| ast::Stmt::FuncDef { name, params, body },
48            ),
49            map(
50                tuple((
51                    delimited(whitespace0, parse_symbol, whitespace0),
52                    char('='),
53                    delimited(whitespace0, parse_expr, whitespace0),
54                )),
55                |(name, _, value)| ast::Stmt::Assignment { name, value },
56            ),
57            map(delimited(whitespace0, parse_expr, whitespace0), |expr| {
58                ast::Stmt::ExprStmt(expr)
59            }),
60        )),
61        opt(parse_comment),
62    )(input)
63}
64
65fn parse_comment(input: &str) -> IResult<&str, ()> {
66    value((), pair(char(';'), is_not("\n")))(input)
67}
68
69/// Parse a single expression.
70pub fn parse_expr(input: &str) -> IResult<&str, ast::Expr> {
71    let (input, first_term) = parse_term(input)?;
72    // I cannot for the life of me figure out why I need to bind this or clone first_term but here we are.
73    let x = fold_many0(
74        tuple((parse_addop, parse_term)),
75        || first_term.clone(),
76        |acc, (op, term)| ast::Expr::BinaryExpr {
77            lhs: Box::new(acc),
78            rhs: Box::new(term),
79            op,
80        },
81    )(input);
82    x
83}
84
85fn parse_term(input: &str) -> IResult<&str, ast::Expr> {
86    let (input, first_factor) = parse_unary_expr(input)?;
87    let x = fold_many0(
88        tuple((parse_mulop, parse_unary_expr)),
89        || first_factor.clone(),
90        |acc, (op, factor)| ast::Expr::BinaryExpr {
91            lhs: Box::new(acc),
92            rhs: Box::new(factor),
93            op,
94        },
95    )(input);
96    x
97}
98
99fn parse_unary_expr(input: &str) -> IResult<&str, ast::Expr> {
100    alt((
101        map(
102            tuple((parse_unop, preceded(whitespace0, parse_expr))),
103            |(op, expr)| ast::Expr::UnaryExpr {
104                op,
105                data: Box::new(expr),
106            },
107        ),
108        parse_exponent,
109    ))(input)
110}
111
112fn parse_unop(input: &str) -> IResult<&str, ast::UnaryOp> {
113    map(char('-'), |_| ast::UnaryOp::Negate)(input)
114}
115
116fn parse_exponent(input: &str) -> IResult<&str, ast::Expr> {
117    let (input, first_base) = parse_parens(input)?;
118    let x = fold_many0(
119        tuple((char('^'), parse_exponent)),
120        || first_base.clone(),
121        |base, (_, expt)| ast::Expr::BinaryExpr {
122            lhs: Box::new(base),
123            rhs: Box::new(expt),
124            op: ast::BinaryOp::Power,
125        },
126    )(input);
127    x
128}
129
130fn parse_parens(input: &str) -> IResult<&str, ast::Expr> {
131    alt((
132        delimited(
133            preceded(whitespace0, char('(')),
134            parse_expr,
135            terminated(char(')'), whitespace0),
136        ),
137        parse_function_call,
138        map(parse_atom, |atom| ast::Expr::AtomExpr(atom)),
139    ))(input)
140}
141
142fn parse_function_call(input: &str) -> IResult<&str, ast::Expr> {
143    map(
144        tuple((
145            parse_symbol,
146            delimited(char('('), separated_list0(char(','), parse_expr), char(')')),
147        )),
148        |(function, args)| ast::Expr::FunctionCall { function, args },
149    )(input)
150}
151
152fn parse_addop(input: &str) -> IResult<&str, ast::BinaryOp> {
153    delimited(
154        whitespace0,
155        map(alt((char('+'), char('-'))), |c: char| match c {
156            '+' => ast::BinaryOp::Plus,
157            '-' => ast::BinaryOp::Minus,
158            _ => unreachable!(),
159        }),
160        whitespace0,
161    )(input)
162}
163
164fn parse_mulop(input: &str) -> IResult<&str, ast::BinaryOp> {
165    delimited(
166        whitespace0,
167        map(alt((char('*'), char('/'))), |c: char| match c {
168            '*' => ast::BinaryOp::Times,
169            '/' => ast::BinaryOp::Divide,
170            _ => unreachable!(),
171        }),
172        whitespace0,
173    )(input)
174}
175
176fn parse_atom(input: &str) -> IResult<&str, ast::Atom> {
177    delimited(
178        whitespace0,
179        alt((
180            map(parse_number, |num| ast::Atom::Num(num)),
181            map(parse_symbol, |sym| ast::Atom::Symbol(sym)),
182        )),
183        whitespace0,
184    )(input)
185}
186
187fn recognize_number(input: &str) -> IResult<&str, &str> {
188    recognize(tuple((
189        opt(alt((char('+'), char('-')))),
190        alt((
191            map(tuple((digit1, opt(pair(char('.'), opt(digit1))))), |_| ()),
192            map(tuple((char('.'), digit1)), |_| ()),
193        )),
194        opt(tuple((
195            alt((char('e'), char('E'))),
196            opt(alt((char('+'), char('-')))),
197            cut(digit1),
198        ))),
199    )))(input)
200}
201
202fn parse_number(input: &str) -> IResult<&str, Number> {
203    map(recognize_number, |s: &str| {
204        BigFloat::parse(s, Radix::Dec, PREC, RM)
205    })(input)
206}
207
208fn parse_symbol(input: &str) -> IResult<&str, String> {
209    map(
210        recognize(tuple((
211            satisfy(|c| is_symbol_character(c) && !c.is_ascii_digit()),
212            take_while(is_symbol_character),
213        ))),
214        |s: &str| s.to_string(),
215    )(input)
216}
217
218fn is_symbol_character(c: char) -> bool {
219    c.is_alphanumeric()
220}
221
222#[cfg(test)]
223mod tests {
224    use crate::{context::Context, eval, PREC};
225
226    use super::*;
227
228    #[test]
229    fn test_parse_number() {
230        recognize_number("123").unwrap();
231        recognize_number("123").unwrap();
232        recognize_number("123.456").unwrap();
233        recognize_number("123E10").unwrap();
234        recognize_number("-12.45E-10").unwrap();
235
236        let (_rest, num) = parse_number("123").unwrap();
237        assert_eq!(num, BigFloat::from_f64(123_f64, PREC));
238        let (_rest, num) = parse_number("10e10").unwrap();
239        assert_eq!(num, BigFloat::from_f64(10e10_f64, PREC));
240        let (_rest, num) = parse_number("-12.45E-10").unwrap();
241        assert_eq!(num, BigFloat::parse("-12.45e-10", Radix::Dec, PREC, RM));
242    }
243
244    #[test]
245    fn test_parse_expr() {
246        let (_rest, expr) = parse_expr("123 + 456 + 7").unwrap();
247
248        let ctx = Context::new();
249
250        assert_eq!(
251            eval::eval_expr(&expr, &ctx).unwrap(),
252            BigFloat::from_f64(123_f64 + 456_f64 + 7_f64, PREC)
253        );
254
255        let (_rest, expr) = parse_expr("sqrt(1) + 3").unwrap();
256        println!("{:?}", expr);
257    }
258
259    #[test]
260    fn test_parse_fn_call() {
261        let (_rest, _expr) = parse_function_call("g(  x , y)").unwrap();
262    }
263
264    #[test]
265    fn test_parse_stmt() {
266        let (_rest, _stmt) = parse_stmt("sqrt(1) + 2 * 3;").unwrap();
267    }
268
269    #[test]
270    fn test_parse_fn_def() {
271        let (_rest, _def) = parse_stmt("f(x) = x + 1;").unwrap();
272        let (_rest, _def) = parse_stmt("f(x) = sqrt(x) + 1;").unwrap();
273    }
274
275    #[test]
276    fn test_parse_stmt_list() {
277        let (rest, stmts) = parse_stmt_list("x=5 ; here is an EOL comment\n1+2").unwrap();
278        assert_eq!(stmts.len(), 2);
279        assert!(rest.is_empty());
280    }
281}