1use std::str::FromStr;
2
3use raekna_common::{
4 expression::{Expression, Literal},
5 function_name::FunctionName,
6};
7
8use crate::{
9 errors::ParserResult,
10 lexer::{Operator, Token, TokenTree},
11 ParserError,
12};
13
14pub fn parse(raw_expr: &'_ str) -> ParserResult<Expression> {
15 if raw_expr.is_empty() {
16 return Err(ParserError::EmptyExpression);
17 }
18 TokenTree::parse_input(raw_expr)
19 .map_err(ParserError::NomError)
20 .and_then(|(_, tt)| convert_token_tree(tt, true))
21}
22
23fn convert_token_tree(token_tree: TokenTree, allow_variable_def: bool) -> ParserResult<Expression> {
24 let mut parser = Parser::new(token_tree.num_operators);
25 parser.convert_token_tree(token_tree, allow_variable_def)?;
26 parser.finish()
27}
28
29struct Parser {
30 variable: Option<String>,
31 operators: Vec<Operator>,
32 expressions: Vec<Option<Expression>>,
33 is_sign: bool,
34 should_negate: bool,
35}
36
37impl Parser {
38 fn new(num_operators: usize) -> Self {
39 Self {
40 variable: None,
41 operators: Vec::with_capacity(num_operators),
42 expressions: vec![],
43 is_sign: true,
44 should_negate: false,
45 }
46 }
47
48 fn finish(mut self) -> ParserResult<Expression> {
49 Self::collapse_expressions(&mut self.expressions, &mut self.operators, &self.variable)
50 }
51
52 fn convert_token_tree(
53 &mut self,
54 token_tree: TokenTree,
55 allow_variable_def: bool,
56 ) -> ParserResult<()> {
57 for (i, token) in token_tree.tokens.into_iter().enumerate() {
58 let expr = match token {
59 Token::Literal(literal) => {
60 let sn = match literal {
61 Literal::Integer(i) => {
62 Literal::Integer(if self.should_negate { -i } else { i })
63 }
64 Literal::Float(f) => {
65 Literal::Float(if self.should_negate { -f } else { f })
66 }
67 };
68 self.is_sign = false;
69 self.should_negate = false;
70 Some(Expression::Literal(sn))
71 }
72 Token::Operator(operator) => {
73 if self.is_sign {
74 match operator {
75 Operator::Add => {}
76 Operator::Subtract => self.should_negate = true,
77 Operator::Multiply => return Err(ParserError::InvalidSign('*')),
78 Operator::Divide => return Err(ParserError::InvalidSign('/')),
79 Operator::Modulo => return Err(ParserError::InvalidSign('%')),
80 Operator::Power => return Err(ParserError::InvalidSign('^')),
81 }
82 self.is_sign = false;
83 } else {
84 self.operators.push(operator);
85 self.is_sign = true;
86 }
87 None
88 }
89 Token::Function(name, args) => {
90 let args = args
91 .into_iter()
92 .map(|a| convert_token_tree(a, false))
93 .collect::<ParserResult<Vec<_>>>()?;
94 let function = {
95 use FunctionName::*;
96 let function = FunctionName::from_str(&name)
97 .map_err(|_| ParserError::UnknownFunctionName(name))?;
98 match function {
99 Ceil if args.len() == 2 => CeilPrec,
100 Floor if args.len() == 2 => FloorPrec,
101 Round if args.len() == 2 => RoundPrec,
102 _ => function,
103 }
104 };
105 let expr = Expression::Function(function, args);
106 let expr = self.maybe_negate(expr);
107 self.is_sign = false;
108 self.should_negate = false;
109 Some(expr)
110 }
111 Token::VariableDefinition(name) => {
112 if i != 0 || !allow_variable_def {
113 return Err(ParserError::InvalidVariableDefinition(name));
114 }
115 self.variable = Some(name);
116 None
117 }
118 Token::VariableReference(name) => {
119 let expr = Expression::VariableRef(name);
120 let expr = self.maybe_negate(expr);
121 self.is_sign = false;
122 self.should_negate = false;
123 Some(expr)
124 }
125 Token::Nested(nested_tree) => {
126 let expr = convert_token_tree(nested_tree, false)?;
127 let expr = self.maybe_negate(expr);
128 self.is_sign = false;
129 self.should_negate = false;
130 Some(expr)
131 }
132 };
133 if let Some(expr) = expr {
134 self.expressions.push(Some(expr));
135 }
136 }
137 match (self.expressions.is_empty(), self.operators.is_empty()) {
138 (true, true) => Err(ParserError::EmptyExpression),
139 (true, false) => {
140 let mut expressions = vec![];
141 let mut operators = vec![];
142 std::mem::swap(&mut expressions, &mut self.expressions);
143 std::mem::swap(&mut operators, &mut self.operators);
144 Err(ParserError::InvalidExpression {
145 expressions,
146 operators,
147 })
148 }
149 _ => Ok(()),
150 }
151 }
152
153 fn collapse_expressions(
154 exprs: &mut [Option<Expression>],
155 operators: &mut [Operator],
156 variable: &Option<String>,
157 ) -> ParserResult<Expression> {
158 let expr = if exprs.len() != operators.len() + 1 {
159 return Err(ParserError::InvalidExpression {
160 expressions: exprs.to_owned(),
161 operators: operators.to_owned(),
162 });
163 } else if exprs.len() == 1 {
164 let mut res = None;
165 std::mem::swap(&mut exprs[0], &mut res);
166 res.unwrap()
167 } else {
168 let mut last_operator = (0, operators[0]);
169 for (i, o) in operators.iter().enumerate().skip(1) {
170 match (last_operator.1, o) {
171 (Operator::Power, _)
172 | (
173 Operator::Multiply | Operator::Divide | Operator::Modulo,
174 Operator::Multiply
175 | Operator::Divide
176 | Operator::Modulo
177 | Operator::Add
178 | Operator::Subtract,
179 )
180 | (Operator::Add | Operator::Subtract, Operator::Add | Operator::Subtract) => {
181 last_operator = (i, *o)
182 }
183 _ => {}
184 }
185 }
186 let left = if last_operator.0 == 0 {
187 let mut left = None;
188 std::mem::swap(&mut exprs[0], &mut left);
189 left.unwrap()
190 } else {
191 Self::collapse_expressions(
192 &mut exprs[..last_operator.0 + 1],
193 &mut operators[..last_operator.0],
194 variable,
195 )?
196 };
197 let right = if last_operator.0 == operators.len().saturating_sub(1) {
198 let mut left = None;
199 std::mem::swap(&mut exprs[operators.len()], &mut left);
200 left.unwrap()
201 } else {
202 Self::collapse_expressions(
203 &mut exprs[last_operator.0 + 1..],
204 &mut operators[last_operator.0 + 1..],
205 variable,
206 )?
207 };
208 let function_name = match last_operator.1 {
209 Operator::Add => FunctionName::Add,
210 Operator::Subtract => FunctionName::Subtract,
211 Operator::Multiply => FunctionName::Multiply,
212 Operator::Divide => FunctionName::Divide,
213 Operator::Modulo => FunctionName::Modulus,
214 Operator::Power => FunctionName::Power,
215 };
216 Expression::Function(function_name, vec![left, right])
217 };
218 if let Some(name) = variable {
219 Ok(Expression::Variable(name.clone(), Box::new(expr)))
220 } else {
221 Ok(expr)
222 }
223 }
224
225 fn maybe_negate(&mut self, expr: Expression) -> Expression {
226 let expr = if self.should_negate {
227 Expression::Function(FunctionName::Negate, vec![expr])
228 } else {
229 expr
230 };
231 self.is_sign = false;
232 self.should_negate = false;
233 expr
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::lexer::{Operator, Token};
241
242 mod operator_ordering {
243 use super::*;
244
245 #[test]
246 fn same_operator_repeated() {
247 let tt = TokenTree {
248 num_operators: 3,
249 tokens: vec![
250 Token::Literal(Literal::Integer(1)),
251 Token::Operator(Operator::Add),
252 Token::Literal(Literal::Integer(2)),
253 Token::Operator(Operator::Add),
254 Token::Literal(Literal::Integer(3)),
255 Token::Operator(Operator::Add),
256 Token::Literal(Literal::Integer(4)),
257 ],
258 };
259
260 let expected = Expression::Function(
261 FunctionName::Add,
262 vec![
263 Expression::Function(
264 FunctionName::Add,
265 vec![
266 Expression::Function(
267 FunctionName::Add,
268 vec![
269 Expression::Literal(Literal::Integer(1)),
270 Expression::Literal(Literal::Integer(2)),
271 ],
272 ),
273 Expression::Literal(Literal::Integer(3)),
274 ],
275 ),
276 Expression::Literal(Literal::Integer(4)),
277 ],
278 );
279 let actual = convert_token_tree(tt, true).unwrap();
280
281 assert_eq!(actual, expected);
282 }
283
284 #[test]
285 fn different_operators_mixed() {
286 let tt = TokenTree {
287 num_operators: 6,
288 tokens: vec![
289 Token::Literal(Literal::Integer(1)),
290 Token::Operator(Operator::Add),
291 Token::Literal(Literal::Integer(2)),
292 Token::Operator(Operator::Multiply),
293 Token::Literal(Literal::Integer(3)),
294 Token::Operator(Operator::Power),
295 Token::Literal(Literal::Integer(4)),
296 Token::Operator(Operator::Divide),
297 Token::Literal(Literal::Integer(5)),
298 Token::Operator(Operator::Subtract),
299 Token::Literal(Literal::Integer(6)),
300 Token::Operator(Operator::Add),
301 Token::Literal(Literal::Integer(7)),
302 ],
303 };
304
305 let expected = Expression::Function(
306 FunctionName::Add,
307 vec![
308 Expression::Function(
309 FunctionName::Subtract,
310 vec![
311 Expression::Function(
312 FunctionName::Add,
313 vec![
314 Expression::Literal(Literal::Integer(1)),
315 Expression::Function(
316 FunctionName::Divide,
317 vec![
318 Expression::Function(
319 FunctionName::Multiply,
320 vec![
321 Expression::Literal(Literal::Integer(2)),
322 Expression::Function(
323 FunctionName::Power,
324 vec![
325 Expression::Literal(Literal::Integer(
326 3,
327 )),
328 Expression::Literal(Literal::Integer(
329 4,
330 )),
331 ],
332 ),
333 ],
334 ),
335 Expression::Literal(Literal::Integer(5)),
336 ],
337 ),
338 ],
339 ),
340 Expression::Literal(Literal::Integer(6)),
341 ],
342 ),
343 Expression::Literal(Literal::Integer(7)),
344 ],
345 );
346 let actual = convert_token_tree(tt, true).unwrap();
347
348 assert_eq!(actual, expected);
349 }
350 }
351
352 #[test]
353 fn mix_of_expressions() {
354 let tt = TokenTree {
355 num_operators: 1,
356 tokens: vec![
357 Token::VariableDefinition("var_def".to_owned()),
358 Token::Function(
359 "sqrt".to_owned(),
360 vec![TokenTree {
361 num_operators: 0,
362 tokens: vec![Token::Literal(Literal::Integer(1))],
363 }],
364 ),
365 Token::Operator(Operator::Multiply),
366 Token::Nested(TokenTree {
367 num_operators: 1,
368 tokens: vec![
369 Token::VariableReference("my_var".to_owned()),
370 Token::Operator(Operator::Add),
371 Token::VariableReference("my_second_var".to_owned()),
372 ],
373 }),
374 ],
375 };
376
377 let expected = Expression::Variable(
378 "var_def".to_owned(),
379 Box::new(Expression::Function(
380 FunctionName::Multiply,
381 vec![
382 Expression::Function(
383 FunctionName::SquareRoot,
384 vec![Expression::Literal(Literal::Integer(1))],
385 ),
386 Expression::Function(
387 FunctionName::Add,
388 vec![
389 Expression::VariableRef("my_var".to_owned()),
390 Expression::VariableRef("my_second_var".to_owned()),
391 ],
392 ),
393 ],
394 )),
395 );
396 let actual = convert_token_tree(tt, true).unwrap();
397
398 assert_eq!(actual, expected);
399 }
400
401 #[test]
402 fn negative_number() {
403 let tt = TokenTree {
404 num_operators: 1,
405 tokens: vec![
406 Token::Operator(Operator::Subtract),
407 Token::Literal(Literal::Integer(2)),
408 ],
409 };
410
411 let expected = Expression::Literal(Literal::Integer(-2));
412 let actual = convert_token_tree(tt, true).unwrap();
413
414 assert_eq!(actual, expected);
415 }
416
417 #[test]
418 fn function_with_negative_arguments() {
419 let tt = TokenTree {
420 num_operators: 0,
421 tokens: vec![Token::Function(
422 "add".to_owned(),
423 vec![
424 TokenTree {
425 num_operators: 1,
426 tokens: vec![
427 Token::Operator(Operator::Subtract),
428 Token::Literal(Literal::Integer(1)),
429 ],
430 },
431 TokenTree {
432 num_operators: 1,
433 tokens: vec![
434 Token::Operator(Operator::Subtract),
435 Token::Literal(Literal::Integer(2)),
436 ],
437 },
438 ],
439 )],
440 };
441
442 let expected = Expression::Function(
443 FunctionName::Add,
444 vec![
445 Expression::Literal(Literal::Integer(-1)),
446 Expression::Literal(Literal::Integer(-2)),
447 ],
448 );
449 let actual = convert_token_tree(tt, true).unwrap();
450
451 assert_eq!(actual, expected);
452 }
453
454 #[test]
455 #[should_panic]
456 fn invalid_function_name() {
457 let tt = TokenTree {
458 num_operators: 0,
459 tokens: vec![Token::Function("invalid".to_owned(), vec![])],
460 };
461 convert_token_tree(tt, true).unwrap();
462 }
463
464 #[test]
465 #[should_panic]
466 fn only_variable_def() {
467 let tt = TokenTree {
468 num_operators: 0,
469 tokens: vec![Token::VariableDefinition("my_var".to_owned())],
470 };
471 convert_token_tree(tt, true).unwrap();
472 }
473
474 #[test]
475 #[should_panic]
476 fn variable_def_is_not_first_token() {
477 let tt = TokenTree {
478 num_operators: 0,
479 tokens: vec![
480 Token::Literal(Literal::Integer(1)),
481 Token::VariableDefinition("invalid".to_owned()),
482 ],
483 };
484 convert_token_tree(tt, true).unwrap();
485 }
486
487 #[test]
488 #[should_panic]
489 fn variable_def_in_nested_token_tree() {
490 let tt = TokenTree {
491 num_operators: 0,
492 tokens: vec![
493 Token::Nested(TokenTree {
494 num_operators: 0,
495 tokens: vec![
496 Token::VariableDefinition("invalid".to_owned()),
497 Token::Literal(Literal::Integer(1)),
498 ],
499 }),
500 Token::Literal(Literal::Integer(1)),
501 ],
502 };
503 convert_token_tree(tt, true).unwrap();
504 }
505}