1use astro_float::{BigFloat, Radix};
4use nom::branch::alt;
5use nom::bytes::complete::{is_not, take_while};
6use nom::character::complete::{char, digit1, satisfy};
7use nom::combinator::{cut, map, opt, recognize, value};
8use nom::multi::{fold_many0, many0, many1, separated_list0};
9use nom::sequence::{delimited, pair, preceded, terminated, tuple};
10use nom::IResult;
11
12use crate::Number;
13use crate::RM;
14use crate::{ast, PREC};
15
16pub fn parse_stmt_list(input: &str) -> IResult<&str, Vec<ast::Stmt>> {
18 separated_list0(newline1, parse_stmt)(input)
19}
20
21fn newline1(input: &str) -> IResult<&str, &str> {
22 recognize(many1(char('\n')))(input)
23}
24
25fn whitespace0(input: &str) -> IResult<&str, &str> {
26 recognize(many0(alt((char(' '), char('\t')))))(input)
27}
28
29pub fn parse_stmt(input: &str) -> IResult<&str, ast::Stmt> {
31 terminated(
32 alt((
33 map(
34 tuple((
35 preceded(whitespace0, parse_symbol),
36 delimited(
37 char('('),
38 separated_list0(
39 delimited(whitespace0, char(','), whitespace0),
40 parse_symbol,
41 ),
42 char(')'),
43 ),
44 delimited(whitespace0, char('='), whitespace0),
45 terminated(parse_expr, whitespace0),
46 )),
47 |(name, params, _, body)| ast::Stmt::FuncDef { name, params, body },
48 ),
49 map(
50 tuple((
51 delimited(whitespace0, parse_symbol, whitespace0),
52 char('='),
53 delimited(whitespace0, parse_expr, whitespace0),
54 )),
55 |(name, _, value)| ast::Stmt::Assignment { name, value },
56 ),
57 map(delimited(whitespace0, parse_expr, whitespace0), |expr| {
58 ast::Stmt::ExprStmt(expr)
59 }),
60 )),
61 opt(parse_comment),
62 )(input)
63}
64
65fn parse_comment(input: &str) -> IResult<&str, ()> {
66 value((), pair(char(';'), is_not("\n")))(input)
67}
68
69pub fn parse_expr(input: &str) -> IResult<&str, ast::Expr> {
71 let (input, first_term) = parse_term(input)?;
72 let x = fold_many0(
74 tuple((parse_addop, parse_term)),
75 || first_term.clone(),
76 |acc, (op, term)| ast::Expr::BinaryExpr {
77 lhs: Box::new(acc),
78 rhs: Box::new(term),
79 op,
80 },
81 )(input);
82 x
83}
84
85fn parse_term(input: &str) -> IResult<&str, ast::Expr> {
86 let (input, first_factor) = parse_unary_expr(input)?;
87 let x = fold_many0(
88 tuple((parse_mulop, parse_unary_expr)),
89 || first_factor.clone(),
90 |acc, (op, factor)| ast::Expr::BinaryExpr {
91 lhs: Box::new(acc),
92 rhs: Box::new(factor),
93 op,
94 },
95 )(input);
96 x
97}
98
99fn parse_unary_expr(input: &str) -> IResult<&str, ast::Expr> {
100 alt((
101 map(
102 tuple((parse_unop, preceded(whitespace0, parse_expr))),
103 |(op, expr)| ast::Expr::UnaryExpr {
104 op,
105 data: Box::new(expr),
106 },
107 ),
108 parse_exponent,
109 ))(input)
110}
111
112fn parse_unop(input: &str) -> IResult<&str, ast::UnaryOp> {
113 map(char('-'), |_| ast::UnaryOp::Negate)(input)
114}
115
116fn parse_exponent(input: &str) -> IResult<&str, ast::Expr> {
117 let (input, first_base) = parse_parens(input)?;
118 let x = fold_many0(
119 tuple((char('^'), parse_exponent)),
120 || first_base.clone(),
121 |base, (_, expt)| ast::Expr::BinaryExpr {
122 lhs: Box::new(base),
123 rhs: Box::new(expt),
124 op: ast::BinaryOp::Power,
125 },
126 )(input);
127 x
128}
129
130fn parse_parens(input: &str) -> IResult<&str, ast::Expr> {
131 alt((
132 delimited(
133 preceded(whitespace0, char('(')),
134 parse_expr,
135 terminated(char(')'), whitespace0),
136 ),
137 parse_function_call,
138 map(parse_atom, |atom| ast::Expr::AtomExpr(atom)),
139 ))(input)
140}
141
142fn parse_function_call(input: &str) -> IResult<&str, ast::Expr> {
143 map(
144 tuple((
145 parse_symbol,
146 delimited(char('('), separated_list0(char(','), parse_expr), char(')')),
147 )),
148 |(function, args)| ast::Expr::FunctionCall { function, args },
149 )(input)
150}
151
152fn parse_addop(input: &str) -> IResult<&str, ast::BinaryOp> {
153 delimited(
154 whitespace0,
155 map(alt((char('+'), char('-'))), |c: char| match c {
156 '+' => ast::BinaryOp::Plus,
157 '-' => ast::BinaryOp::Minus,
158 _ => unreachable!(),
159 }),
160 whitespace0,
161 )(input)
162}
163
164fn parse_mulop(input: &str) -> IResult<&str, ast::BinaryOp> {
165 delimited(
166 whitespace0,
167 map(alt((char('*'), char('/'))), |c: char| match c {
168 '*' => ast::BinaryOp::Times,
169 '/' => ast::BinaryOp::Divide,
170 _ => unreachable!(),
171 }),
172 whitespace0,
173 )(input)
174}
175
176fn parse_atom(input: &str) -> IResult<&str, ast::Atom> {
177 delimited(
178 whitespace0,
179 alt((
180 map(parse_number, |num| ast::Atom::Num(num)),
181 map(parse_symbol, |sym| ast::Atom::Symbol(sym)),
182 )),
183 whitespace0,
184 )(input)
185}
186
187fn recognize_number(input: &str) -> IResult<&str, &str> {
188 recognize(tuple((
189 opt(alt((char('+'), char('-')))),
190 alt((
191 map(tuple((digit1, opt(pair(char('.'), opt(digit1))))), |_| ()),
192 map(tuple((char('.'), digit1)), |_| ()),
193 )),
194 opt(tuple((
195 alt((char('e'), char('E'))),
196 opt(alt((char('+'), char('-')))),
197 cut(digit1),
198 ))),
199 )))(input)
200}
201
202fn parse_number(input: &str) -> IResult<&str, Number> {
203 map(recognize_number, |s: &str| {
204 BigFloat::parse(s, Radix::Dec, PREC, RM)
205 })(input)
206}
207
208fn parse_symbol(input: &str) -> IResult<&str, String> {
209 map(
210 recognize(tuple((
211 satisfy(|c| is_symbol_character(c) && !c.is_ascii_digit()),
212 take_while(is_symbol_character),
213 ))),
214 |s: &str| s.to_string(),
215 )(input)
216}
217
218fn is_symbol_character(c: char) -> bool {
219 c.is_alphanumeric()
220}
221
222#[cfg(test)]
223mod tests {
224 use crate::{context::Context, eval, PREC};
225
226 use super::*;
227
228 #[test]
229 fn test_parse_number() {
230 recognize_number("123").unwrap();
231 recognize_number("123").unwrap();
232 recognize_number("123.456").unwrap();
233 recognize_number("123E10").unwrap();
234 recognize_number("-12.45E-10").unwrap();
235
236 let (_rest, num) = parse_number("123").unwrap();
237 assert_eq!(num, BigFloat::from_f64(123_f64, PREC));
238 let (_rest, num) = parse_number("10e10").unwrap();
239 assert_eq!(num, BigFloat::from_f64(10e10_f64, PREC));
240 let (_rest, num) = parse_number("-12.45E-10").unwrap();
241 assert_eq!(num, BigFloat::parse("-12.45e-10", Radix::Dec, PREC, RM));
242 }
243
244 #[test]
245 fn test_parse_expr() {
246 let (_rest, expr) = parse_expr("123 + 456 + 7").unwrap();
247
248 let ctx = Context::new();
249
250 assert_eq!(
251 eval::eval_expr(&expr, &ctx).unwrap(),
252 BigFloat::from_f64(123_f64 + 456_f64 + 7_f64, PREC)
253 );
254
255 let (_rest, expr) = parse_expr("sqrt(1) + 3").unwrap();
256 println!("{:?}", expr);
257 }
258
259 #[test]
260 fn test_parse_fn_call() {
261 let (_rest, _expr) = parse_function_call("g( x , y)").unwrap();
262 }
263
264 #[test]
265 fn test_parse_stmt() {
266 let (_rest, _stmt) = parse_stmt("sqrt(1) + 2 * 3;").unwrap();
267 }
268
269 #[test]
270 fn test_parse_fn_def() {
271 let (_rest, _def) = parse_stmt("f(x) = x + 1;").unwrap();
272 let (_rest, _def) = parse_stmt("f(x) = sqrt(x) + 1;").unwrap();
273 }
274
275 #[test]
276 fn test_parse_stmt_list() {
277 let (rest, stmts) = parse_stmt_list("x=5 ; here is an EOL comment\n1+2").unwrap();
278 assert_eq!(stmts.len(), 2);
279 assert!(rest.is_empty());
280 }
281}