1extern crate alloc;
2
3#[cfg(test)]
4use std::rc::Rc;
5#[cfg(not(test))]
6use alloc::rc::Rc;
7
8use crate::Real;
9use crate::context::EvalContext;
10use crate::error::ExprError;
11use crate::eval::eval_ast;
12use crate::lexer::{Lexer, Token};
13use crate::types::{AstExpr, TokenKind};
14#[cfg(not(test))]
15use crate::{Box, Vec};
16
17use alloc::borrow::Cow;
18use alloc::string::{String, ToString};
19#[cfg(not(test))]
20use alloc::format;
21#[cfg(not(test))]
22use alloc::vec;
23
24#[cfg(not(test))]
26use alloc::collections::BTreeSet as HashSet;
27#[cfg(test)]
28use std::collections::HashSet;
29
30struct PrattParser<'a> {
31 lexer: Lexer<'a>,
32 current: Option<Token>,
33 errors: Vec<ExprError>,
34 recursion_depth: usize,
35 max_recursion_depth: usize,
36 reserved_vars: Option<HashSet<Cow<'a, str>>>, context_vars: Option<HashSet<Cow<'a, str>>>, }
39
40#[derive(Debug, Clone, Copy)]
42struct BindingPower {
43 left: u8,
44 right: u8,
45}
46
47impl BindingPower {
48 const fn new(left: u8, right: u8) -> Self {
49 Self { left, right }
50 }
51
52 const fn left_assoc(power: u8) -> Self {
54 Self::new(power, power + 1)
55 }
56
57 const fn right_assoc(power: u8) -> Self {
59 Self::new(power, power)
60 }
61}
62
63impl<'a> PrattParser<'a> {
64 fn new(input: &'a str) -> Self {
65 let mut lexer = Lexer::new(input);
66 let current = lexer.next_token();
67 Self {
68 lexer,
69 current,
70 errors: Vec::new(),
71 recursion_depth: 0,
72 max_recursion_depth: 2000, reserved_vars: None,
74 context_vars: None,
75 }
76 }
77
78 fn with_reserved_vars_and_context(
79 input: &'a str,
80 reserved_vars: Option<&'a [String]>,
81 context_vars: Option<&'a [String]>,
82 ) -> Self {
83 let mut parser = Self::new(input);
84 if let Some(vars) = reserved_vars {
85 let mut set = HashSet::new();
86 for v in vars {
87 set.insert(Cow::Borrowed(v.as_str()));
88 }
89 parser.reserved_vars = Some(set);
90 }
91 if let Some(vars) = context_vars {
92 let mut set = HashSet::new();
93 for v in vars {
94 set.insert(Cow::Borrowed(v.as_str()));
95 }
96 parser.context_vars = Some(set);
97 }
98 parser
99 }
100
101 fn peek(&self) -> Option<&Token> {
102 self.current.as_ref()
103 }
104
105 fn next(&mut self) -> Option<Token> {
106 let tok = self.current.take();
107 self.current = self.lexer.next_token();
108 tok
109 }
110
111 fn expect(&mut self, kind: TokenKind, error_msg: &str) -> Result<Token, ExprError> {
112 if let Some(tok) = self.peek() {
113 if tok.kind == kind {
114 return Ok(self.next().unwrap());
115 }
116
117 if kind == TokenKind::Close {
120 let position = tok.position;
121 let found = tok.text.clone().unwrap_or_else(|| "unknown".to_string());
122 return Err(ExprError::UnmatchedParenthesis { position, found });
123 }
124 }
125
126 let position = self.peek().map(|t| t.position).unwrap_or(0);
127 let found = self
128 .peek()
129 .and_then(|t| t.text.clone())
130 .unwrap_or_else(|| "end of input".to_string());
131
132 let err = ExprError::Syntax(format!(
133 "{} at position {}, found '{}'",
134 error_msg, position, found
135 ));
136 self.errors.push(err.clone());
137 Err(err)
138 }
139
140 fn get_binding_power(op: &str) -> Option<BindingPower> {
142 match op {
143 "," | ";" => Some(BindingPower::left_assoc(1)), "||" => Some(BindingPower::left_assoc(2)), "&&" => Some(BindingPower::left_assoc(3)), "|" => Some(BindingPower::left_assoc(4)), "&" => Some(BindingPower::left_assoc(6)), "==" | "!=" | "<" | ">" | "<=" | ">=" | "<>" => Some(BindingPower::left_assoc(7)), "<<" | ">>" | "<<<" | ">>>" => Some(BindingPower::left_assoc(8)), "+" | "-" => Some(BindingPower::left_assoc(9)), "*" | "/" | "%" => Some(BindingPower::left_assoc(10)), "^" => Some(BindingPower::right_assoc(15)), "**" => Some(BindingPower::right_assoc(16)), _ => None,
155 }
156 }
157
158 fn get_prefix_binding_power(op: &str) -> Option<u8> {
160 match op {
161 "+" | "-" | "~" => Some(14), _ => None,
163 }
164 }
165
166 fn parse_postfix(&mut self, lhs: AstExpr) -> Result<AstExpr, ExprError> {
168 let mut result = lhs;
169
170 loop {
172 if let Some(tok) = self.peek() {
173 match (tok.kind, tok.text.as_deref()) {
174 (TokenKind::Open, Some("(")) => {
175 result = self.parse_function_call(result)?;
177 }
178 (TokenKind::Open, Some("[")) => {
179 result = self.parse_array_access(result)?;
181 }
182 (TokenKind::Operator, Some(".")) => {
183 result = self.parse_attribute_access(result)?;
185 }
186 _ => break, }
188 } else {
189 break;
190 }
191 }
192
193 Ok(result)
194 }
195
196 fn expect_closing(
198 &mut self,
199 kind: TokenKind,
200 expected: &str,
201 opening_position: usize,
202 ) -> Result<(), ExprError> {
203 if let Some(tok) = self.peek() {
204 if tok.kind == kind {
205 self.next(); return Ok(());
207 }
208
209 let position = tok.position;
211 let found = tok.text.clone().unwrap_or_else(|| "unknown".to_string());
212
213 return Err(ExprError::Syntax(format!(
214 "Expected {} at position {}, found '{}' (opening at position {})",
215 expected, position, found, opening_position
216 )));
217 }
218
219 Err(ExprError::Syntax(format!(
221 "Expected {} but found end of input (opening at position {})",
222 expected, opening_position
223 )))
224 }
225
226 fn parse_parenthesized_expr(&mut self) -> Result<AstExpr, ExprError> {
228 let open_position = self.peek().map(|t| t.position).unwrap_or(0);
229 self.next(); let expr = self.parse_expr_unified(0, true)?;
234
235 if let Some(tok) = self.peek() {
237 if tok.kind == TokenKind::Close {
238 self.next(); return Ok(expr);
240 }
241
242 let position = tok.position;
244 let found = tok.text.clone().unwrap_or_else(|| "unknown".to_string());
245 return Err(ExprError::Syntax(format!(
246 "Expected closing parenthesis ')' but found '{}' at position {} (opening at position {})",
247 found, position, open_position
248 )));
249 }
250
251 Err(ExprError::Syntax(format!(
253 "Expected closing parenthesis ')' but found end of input (opening at position {})",
254 open_position
255 )))
256 }
257
258 fn parse_function_call(&mut self, expr: AstExpr) -> Result<AstExpr, ExprError> {
260 let name = match &expr {
261 AstExpr::Variable(name) => name.clone(),
262 AstExpr::Attribute { attr, .. } => attr.clone(),
263 _ => {
264 return Err(ExprError::Syntax(
265 "Function call on non-function expression".to_string(),
266 ));
267 }
268 };
269
270 self.next(); let mut args = Vec::new();
273
274 if let Some(tok) = self.peek() {
276 if tok.kind != TokenKind::Close {
277 let arg = self.parse_expr_unified(0, false)?;
279 args.push(arg);
280
281 while let Some(next_tok) = self.peek() {
283 if next_tok.kind == TokenKind::Separator
284 && next_tok.text.as_deref() == Some(",")
285 {
286 self.next(); let arg = self.parse_expr_unified(0, false)?;
290 args.push(arg);
291 } else if next_tok.kind == TokenKind::Close {
292 break;
293 } else {
294 let position = next_tok.position;
296 let found = next_tok
297 .text
298 .clone()
299 .unwrap_or_else(|| "unknown".to_string());
300 return Err(ExprError::Syntax(format!(
301 "Expected ',' or ')' but found '{}' at position {} in function call",
302 found, position
303 )));
304 }
305 }
306 }
307 }
308
309 if let Some(tok) = self.peek() {
311 if tok.kind == TokenKind::Close {
312 self.next(); } else {
314 let position = tok.position;
316 let found = tok.text.clone().unwrap_or_else(|| "unknown".to_string());
317 return Err(ExprError::Syntax(format!(
318 "Expected closing parenthesis ')' but found '{}' at position {} in function call",
319 found, position
320 )));
321 }
322 } else {
323 let open_position = self.lexer.get_original_input().len()
325 - self.lexer.get_remaining_input().unwrap_or("").len();
326 return Err(ExprError::UnmatchedParenthesis {
327 position: open_position,
328 found: "(".to_string(),
329 });
330 }
331
332 if name == "pow" && args.len() == 1 {
334 args.push(AstExpr::Constant(2.0));
336 } else if name == "atan2" && args.len() == 1 {
337 args.push(AstExpr::Constant(1.0));
339 }
340
341 if name == "polynomial" && args.len() == 1 {
343 }
345
346 Ok(AstExpr::Function { name, args })
347 }
348
349 fn parse_array_access(&mut self, expr: AstExpr) -> Result<AstExpr, ExprError> {
351 let name = match &expr {
352 AstExpr::Variable(name) => name.clone(),
353 _ => {
354 let position = self.peek().map(|t| t.position).unwrap_or(0);
355 return Err(ExprError::Syntax(format!(
356 "Array access on non-array expression at position {}",
357 position
358 )));
359 }
360 };
361
362 let open_position = self.peek().map(|t| t.position).unwrap_or(0);
363 self.next(); let index = self.parse_expr_unified(0, true)?;
367
368 self.expect_closing(TokenKind::Close, "closing bracket ']'", open_position)?;
370
371 Ok(AstExpr::Array {
372 name,
373 index: Box::new(index),
374 })
375 }
376
377 fn parse_attribute_access(&mut self, expr: AstExpr) -> Result<AstExpr, ExprError> {
379 let dot_position = self.peek().map(|t| t.position).unwrap_or(0);
380 self.next(); let attr_tok = self.expect(TokenKind::Variable, "Expected attribute name")?;
384
385 let attr = attr_tok.text.unwrap_or_default();
386
387 #[cfg(test)]
388 println!("Parsing attribute access: expr={:?}, attr={}", expr, attr);
389
390 match expr {
392 AstExpr::Variable(base) => {
393 #[cfg(test)]
394 println!("Creating attribute node: {}.{}", base, attr);
395
396 let result = AstExpr::Attribute { base, attr };
397 self.parse_postfix(result)
399 }
400 _ => {
401 #[cfg(test)]
402 println!("Error: Attribute access on non-variable expression");
403
404 Err(ExprError::Syntax(format!(
405 "Attribute access on non-object expression at position {}",
406 dot_position
407 )))
408 }
409 }
410 }
411
412 fn parse_expr_unified(&mut self, min_bp: u8, allow_comma: bool) -> Result<AstExpr, ExprError> {
414 self.recursion_depth += 1;
416 if self.recursion_depth > self.max_recursion_depth {
417 self.recursion_depth -= 1;
418 return Err(ExprError::RecursionLimit(format!(
419 "Expression too complex: exceeded maximum recursion depth of {}",
420 self.max_recursion_depth
421 )));
422 }
423
424 let mut lhs = self.parse_prefix_or_primary(allow_comma)?;
426
427 lhs = self.parse_postfix(lhs)?;
429
430 lhs = self.parse_infix_operators(lhs, min_bp, allow_comma)?;
432
433 lhs = self.parse_juxtaposition(lhs, allow_comma)?;
435
436 self.recursion_depth -= 1;
438
439 Ok(lhs)
440 }
441
442 fn parse_prefix_or_primary(&mut self, allow_comma: bool) -> Result<AstExpr, ExprError> {
443 if let Some(tok) = self.peek() {
444 if tok.kind == TokenKind::Operator {
445 let op = tok.text.as_deref().unwrap_or("");
446 let op_position = tok.position;
447 if let Some(r_bp) = Self::get_prefix_binding_power(op) {
448 let op_str = String::from(op);
450
451 self.next();
453
454 if self.peek().is_none() {
456 return Err(ExprError::Syntax(format!(
457 "Expected expression after '{}' at position {}",
458 op_str, op_position
459 )));
460 }
461
462 let rhs = self.parse_expr_unified(r_bp, allow_comma)?;
464
465 if op_str == "-" {
467 Ok(AstExpr::Function {
468 name: String::from("neg"),
469 args: vec![rhs],
470 })
471 } else {
472 Ok(rhs)
474 }
475 } else {
476 self.parse_primary()
477 }
478 } else {
479 self.parse_primary()
480 }
481 } else {
482 self.parse_primary()
483 }
484 }
485
486 fn parse_infix_operators(
487 &mut self,
488 mut lhs: AstExpr,
489 min_bp: u8,
490 allow_comma: bool,
491 ) -> Result<AstExpr, ExprError> {
492 loop {
493 let op_text = if let Some(tok) = self.peek() {
495 if tok.kind == TokenKind::Operator {
496 tok.text.as_deref().unwrap_or("")
497 } else if tok.kind == TokenKind::Separator
498 && (tok.text.as_deref() == Some(",") || tok.text.as_deref() == Some(";"))
499 {
500 if allow_comma {
502 tok.text.as_deref().unwrap_or("")
503 } else {
504 break;
505 }
506 } else {
507 break;
508 }
509 } else {
510 break;
511 };
512
513 let op = String::from(op_text);
515
516 let Some(bp) = Self::get_binding_power(&op) else {
518 break;
519 };
520
521 if bp.left < min_bp {
523 break;
524 }
525
526 self.next();
528
529 let rhs = if op == "^" || op == "**" {
531 self.parse_expr_unified(bp.right - 1, allow_comma)?
532 } else {
533 self.parse_expr_unified(bp.right, allow_comma)?
534 };
535
536 lhs = AstExpr::Function {
538 name: op,
539 args: vec![lhs, rhs],
540 };
541 }
542 Ok(lhs)
543 }
544
545 fn parse_juxtaposition(&mut self, lhs: AstExpr, allow_comma: bool) -> Result<AstExpr, ExprError> {
546 let mut lhs = lhs;
547 if let Some(tok) = self.peek() {
548 let is_valid_lhs = matches!(&lhs, AstExpr::Variable(_));
549 let is_valid_rhs = matches!(
550 tok.kind,
551 TokenKind::Number | TokenKind::Variable | TokenKind::Open
552 ) || (tok.kind == TokenKind::Operator
553 && (tok.text.as_deref() == Some("-")
554 || tok.text.as_deref() == Some("+")
555 || tok.text.as_deref() == Some("~")));
556
557 let is_reserved_var = match &lhs {
559 AstExpr::Variable(name) => {
560 let reserved = self
561 .reserved_vars
562 .as_ref()
563 .map(|s| s.contains(name.as_str()))
564 .unwrap_or(false);
565 let in_context = self
566 .context_vars
567 .as_ref()
568 .map(|s| s.contains(name.as_str()))
569 .unwrap_or(false);
570 reserved || in_context
571 }
572 _ => false,
573 };
574
575 if is_valid_lhs && is_valid_rhs && !is_reserved_var {
576 let func_name = match &lhs {
578 AstExpr::Variable(name) => name.clone(),
579 _ => unreachable!(),
580 };
581 let arg = self.parse_expr_unified(16, allow_comma)?; lhs = AstExpr::Function {
586 name: func_name,
587 args: vec![arg],
588 };
589 }
590 }
591 Ok(lhs)
592 }
593
594 fn parse_expr(&mut self, min_bp: u8) -> Result<AstExpr, ExprError> {
596 self.parse_expr_unified(min_bp, true)
597 }
598
599 fn parse_primary(&mut self) -> Result<AstExpr, ExprError> {
601 let tok = match self.peek() {
602 Some(tok) => tok,
603 None => return Err(ExprError::Syntax("Unexpected end of input".to_string())),
604 };
605
606 match tok.kind {
607 TokenKind::Number => {
608 let val = tok.value.unwrap_or(0.0);
609 self.next();
610 Ok(AstExpr::Constant(val))
611 }
612 TokenKind::Variable => {
613 let name = match &tok.text {
614 Some(name) => name.clone(),
615 None => return Err(ExprError::Syntax("Variable name is missing".to_string())),
616 };
617 self.next();
618 Ok(AstExpr::Variable(name))
619 }
620 TokenKind::Open if tok.text.as_deref() == Some("(") => self.parse_parenthesized_expr(),
621 TokenKind::Close => {
622 let position = tok.position;
624 let found = tok.text.clone().unwrap_or_else(|| ")".to_string());
625 Err(ExprError::Syntax(format!(
626 "Unexpected closing parenthesis at position {}: '{}'",
627 position, found
628 )))
629 }
630 _ => {
631 let position = tok.position;
632 let found = tok.text.clone().unwrap_or_else(|| "unknown".to_string());
633 Err(ExprError::Syntax(format!(
634 "Unexpected token at position {}: '{}'",
635 position, found
636 )))
637 }
638 }
639 }
640
641 fn check_expression_length(&self, input: &str) -> Result<(), ExprError> {
646 const MAX_EXPRESSION_LENGTH: usize = 10000; if input.len() > MAX_EXPRESSION_LENGTH {
648 return Err(ExprError::Syntax(format!(
649 "Expression too long: {} characters (maximum is {})",
650 input.len(),
651 MAX_EXPRESSION_LENGTH
652 )));
653 }
654 Ok(())
655 }
656
657 fn parse(&mut self) -> Result<AstExpr, ExprError> {
659 if let Some(remaining) = self.lexer.get_remaining_input() {
661 self.check_expression_length(remaining)?;
662 }
663
664 self.recursion_depth = 0;
666
667 let expr = self.parse_expr(0)?;
669
670 #[cfg(test)]
671 println!("Parsed expression: {:?}", expr);
672
673 if let Some(tok) = self.peek() {
675 if tok.kind == TokenKind::Error
677 || (tok.kind == TokenKind::Operator
678 && tok.text.as_deref().is_some_and(|t| t.trim().is_empty()))
679 {
680 self.next();
681 } else if tok.kind == TokenKind::Close {
682 return Err(ExprError::Syntax(format!(
684 "Unexpected closing parenthesis at position {}: check for balanced parentheses",
685 tok.position
686 )));
687 } else {
688 return Err(ExprError::Syntax(format!(
690 "Unexpected token at position {}: '{}'",
691 tok.position,
692 tok.text.clone().unwrap_or_else(|| "unknown".to_string())
693 )));
694 }
695 }
696
697 Ok(expr)
698 }
699}
700
701pub fn parse_expression(input: &str) -> Result<AstExpr, ExprError> {
704 parse_expression_with_context(input, None, None)
705}
706
707pub fn parse_expression_with_reserved(
709 input: &str,
710 reserved_vars: Option<&[String]>,
711) -> Result<AstExpr, ExprError> {
712 parse_expression_with_context(input, reserved_vars, None)
713}
714
715pub fn parse_expression_with_context(
717 input: &str,
718 reserved_vars: Option<&[String]>,
719 context_vars: Option<&[String]>,
720) -> Result<AstExpr, ExprError> {
721 if input.contains("<=")
723 || input.contains(">=")
724 || input.contains("==")
725 || input.contains("!=")
726 || input.contains("<")
727 || input.contains(">")
728 || input.contains("?")
729 || input.contains(":")
730 {
731 return Err(ExprError::Syntax(
732 "Comparison operators (<, >, <=, >=, ==, !=) and ternary expressions (? :) are not supported".to_string()
733 ));
734 }
735
736 let mut parser =
738 PrattParser::with_reserved_vars_and_context(input, reserved_vars, context_vars);
739 parser.parse()
740}
741
742pub fn interp<'a>(
803 expression: &str,
804 ctx: Option<Rc<EvalContext<'a>>>,
805) -> crate::error::Result<Real> {
806 use alloc::rc::Rc;
807 if let Some(ctx_rc) = ctx {
809 if let Some(cache) = ctx_rc.ast_cache.as_ref() {
811 use alloc::borrow::ToOwned;
812 let expr_key: Cow<'a, str> = Cow::Owned(expression.to_owned());
813 let ast_rc_opt = {
815 let cache_borrow = cache.borrow();
816 cache_borrow.get(expr_key.as_ref()).cloned()
817 };
818 if let Some(ast_rc) = ast_rc_opt {
819 eval_ast(&ast_rc, Some(Rc::clone(&ctx_rc)))
820 } else {
821 let mut context_vars: Vec<String> = ctx_rc
822 .variables
823 .keys()
824 .map(String::clone)
825 .collect();
826 context_vars.extend(
827 ctx_rc.constants
828 .keys()
829 .map(String::clone)
830 );
831 match parse_expression_with_context(expression, None, Some(&context_vars)) {
832 Ok(ast) => {
833 let ast_rc = Rc::new(ast);
834 {
835 let mut cache_borrow = cache.borrow_mut();
836 cache_borrow.insert(expr_key.to_string(), ast_rc.clone());
837 }
838 eval_ast(&ast_rc, Some(Rc::clone(&ctx_rc)))
839 }
840 Err(err) => Err(err),
841 }
842 }
843 } else {
844 let mut context_vars: Vec<String> = ctx_rc
846 .variables
847 .keys()
848 .map(|k: &String| k.as_str().to_string())
849 .collect();
850 context_vars.extend(ctx_rc.constants.keys().map(|k: &String| k.as_str().to_string()));
851 match parse_expression_with_context(expression, None, Some(&context_vars)) {
852 Ok(ast) => eval_ast(&ast, Some(Rc::clone(&ctx_rc))),
853 Err(err) => Err(err),
854 }
855 }
856 } else {
857 match parse_expression(expression) {
858 Ok(ast) => eval_ast(&ast, None),
859 Err(err) => Err(err),
860 }
861 }
862}
863
864#[cfg(test)]
865use std::boxed::Box;
866#[cfg(test)]
867use std::format;
868#[cfg(test)]
869use std::vec::Vec;
870
871#[cfg(test)]
872mod tests {
873 use super::*;
874 use crate::functions::{log, sin};
875 use std::vec; fn debug_ast(expr: &AstExpr, indent: usize) -> String {
879 let spaces = " ".repeat(indent);
880 match expr {
881 AstExpr::Constant(val) => format!("{}Constant({})", spaces, val),
882 AstExpr::Variable(name) => format!("{}Variable({})", spaces, name),
883 AstExpr::Function { name, args } => {
884 let mut result = format!("{}Function({}, [\n", spaces, name);
885 for arg in args {
886 result.push_str(&format!("{},\n", debug_ast(arg, indent + 2)));
887 }
888 result.push_str(&format!("{}])", spaces));
889 result
890 }
891 AstExpr::Array { name, index } => {
892 format!(
893 "{}Array({}, {})",
894 spaces,
895 name,
896 debug_ast(index, indent + 2)
897 )
898 }
899 AstExpr::Attribute { base, attr } => {
900 format!("{}Attribute({}, {})", spaces, base, attr)
901 }
902 }
903 }
904
905 #[test]
906 fn test_unknown_variable_and_function_eval() {
907 let sin_var_ast = AstExpr::Variable("sin".to_string());
909 let err = eval_ast(&sin_var_ast, None).unwrap_err();
910
911 println!("Error when evaluating 'sin' as a variable: {:?}", err);
913 }
915
916 #[test]
917 fn test_parse_postfix_chained_juxtaposition() {
918 let x_var = AstExpr::Variable("x".to_string());
924 let tan_x = AstExpr::Function {
925 name: "tan".to_string(),
926 args: vec![x_var],
927 };
928 let cos_tan_x = AstExpr::Function {
929 name: "cos".to_string(),
930 args: vec![tan_x],
931 };
932 let sin_cos_tan_x = AstExpr::Function {
933 name: "sin".to_string(),
934 args: vec![cos_tan_x],
935 };
936
937 println!(
939 "Expected AST for 'sin cos tan x':\n{}",
940 debug_ast(&sin_cos_tan_x, 0)
941 );
942
943 match &sin_cos_tan_x {
945 AstExpr::Function { name, args } => {
946 assert_eq!(name, "sin");
947 assert_eq!(args.len(), 1);
948 match &args[0] {
949 AstExpr::Function {
950 name: n2,
951 args: args2,
952 } => {
953 assert_eq!(n2, "cos");
954 assert_eq!(args2.len(), 1);
955 match &args2[0] {
956 AstExpr::Function {
957 name: n3,
958 args: args3,
959 } => {
960 assert_eq!(n3, "tan");
961 assert_eq!(args3.len(), 1);
962 match &args3[0] {
963 AstExpr::Variable(var) => assert_eq!(var, "x"),
964 _ => panic!("Expected variable as argument to tan"),
965 }
966 }
967 _ => panic!("Expected tan as argument to cos"),
968 }
969 }
970 _ => panic!("Expected cos as argument to sin"),
971 }
972 }
973 _ => panic!("Expected function node for sin cos tan x"),
974 }
975 }
976
977 #[test]
978 fn test_pow_arity_ast() {
979 let ast = parse_expression("pow(2)").unwrap_or_else(|e| panic!("Parse error: {}", e));
982 println!("AST for pow(2): {:?}", ast);
983
984 match ast {
985 AstExpr::Function { ref name, ref args } if name == "pow" => {
986 assert_eq!(args.len(), 2); match &args[0] {
988 AstExpr::Constant(c) => assert_eq!(*c, 2.0),
989 _ => panic!("Expected constant as pow arg"),
990 }
991 match &args[1] {
993 AstExpr::Constant(c) => assert_eq!(*c, 2.0),
994 _ => panic!("Expected constant as second pow arg"),
995 }
996 }
997 _ => panic!("Expected function node for pow"),
998 }
999 }
1000
1001 #[test]
1002 fn test_parse_postfix_array_and_attribute_access() {
1003 let arr_index = AstExpr::Array {
1005 name: "arr".to_string(),
1006 index: Box::new(AstExpr::Constant(0.0)),
1007 };
1008 let sin_arr = AstExpr::Function {
1009 name: "sin".to_string(),
1010 args: vec![arr_index],
1011 };
1012
1013 match &sin_arr {
1015 AstExpr::Function { name, args } => {
1016 assert_eq!(name, "sin");
1017 assert_eq!(args.len(), 1);
1018 match &args[0] {
1019 AstExpr::Array { name, index } => {
1020 assert_eq!(name, "arr");
1021 match **index {
1022 AstExpr::Constant(val) => assert_eq!(val, 0.0),
1023 _ => panic!("Expected constant as array index"),
1024 }
1025 }
1026 _ => panic!("Expected array as argument to sin"),
1027 }
1028 }
1029 _ => panic!("Expected function node for sin(arr[0])"),
1030 }
1031
1032 let foo_bar_x = AstExpr::Function {
1034 name: "bar".to_string(),
1035 args: vec![AstExpr::Variable("x".to_string())],
1036 };
1037
1038 match &foo_bar_x {
1040 AstExpr::Function { name, args } => {
1041 assert_eq!(name, "bar");
1042 assert_eq!(args.len(), 1);
1043 match &args[0] {
1044 AstExpr::Variable(var) => assert_eq!(var, "x"),
1045 _ => panic!("Expected variable as argument to foo.bar"),
1046 }
1047 }
1048 _ => panic!("Expected function node for foo.bar(x)"),
1049 }
1050 }
1051
1052 #[test]
1053 fn test_parse_postfix_function_call_after_attribute() {
1054 let ast = parse_expression("foo.bar(1)").unwrap();
1055 match ast {
1056 AstExpr::Function { name, args } => {
1057 assert_eq!(name, "bar");
1058 assert_eq!(args.len(), 1);
1059 match &args[0] {
1060 AstExpr::Constant(val) => assert_eq!(*val, 1.0),
1061 _ => panic!("Expected constant as argument to foo.bar"),
1062 }
1063 }
1064 _ => panic!("Expected function node for foo.bar(1)"),
1065 }
1066 }
1067
1068 #[test]
1069 fn test_parse_postfix_array_access_complex_index() {
1070 let ast = parse_expression("arr[1+2*3]").unwrap();
1071 match ast {
1072 AstExpr::Array { name, index } => {
1073 assert_eq!(name, "arr");
1074 match *index {
1075 AstExpr::Function {
1076 name: ref n,
1077 args: ref a,
1078 } if n == "+" => {
1079 assert_eq!(a.len(), 2);
1080 }
1081 _ => panic!("Expected function as array index"),
1082 }
1083 }
1084 _ => panic!("Expected array AST node"),
1085 }
1086 }
1087
1088 #[test]
1089 fn test_atan2_function() {
1090 let result = interp("atan2(1,2)", None).unwrap();
1092 println!("atan2(1,2) = {}", result);
1093 assert!(
1095 (result - 0.4636).abs() < 1e-3,
1096 "atan2(1,2) should be approximately 0.4636"
1097 );
1098
1099 let result2 = interp("atan2(2,1)", None).unwrap();
1101 println!("atan2(2,1) = {}", result2);
1102 assert!(
1104 (result2 - 1.1071).abs() < 1e-3,
1105 "atan2(2,1) should be approximately 1.1071"
1106 );
1107
1108 let result3 = interp("atan2(1,1)", None).unwrap();
1110 println!("atan2(1,1) = {}", result3);
1111 assert!(
1112 (result3 - 0.7854).abs() < 1e-3,
1113 "atan2(1,1) should be approximately 0.7854 (π/4)"
1114 );
1115 }
1116
1117 #[test]
1118 fn test_pow_arity_eval() {
1119 let result = interp("pow(2)", None).unwrap();
1122 println!("pow(2) = {}", result); assert_eq!(result, 4.0); let result2 = interp("pow(2, 3)", None).unwrap();
1127 println!("pow(2, 3) = {}", result2); assert_eq!(result2, 8.0); }
1130
1131 #[test]
1132 fn test_function_juxtaposition() {
1133 let sin_ast = AstExpr::Function {
1135 name: "sin".to_string(),
1136 args: vec![AstExpr::Constant(0.5)],
1137 };
1138
1139 let result = eval_ast(&sin_ast, None).unwrap();
1140 println!("sin 0.5 = {}", result);
1141 assert!(
1142 (result - sin(0.5, 0.0)).abs() < 1e-6,
1143 "sin 0.5 should work with juxtaposition"
1144 );
1145
1146 let cos_ast = AstExpr::Function {
1148 name: "cos".to_string(),
1149 args: vec![AstExpr::Constant(0.0)],
1150 };
1151
1152 let sin_cos_ast = AstExpr::Function {
1153 name: "sin".to_string(),
1154 args: vec![cos_ast],
1155 };
1156
1157 let result2 = eval_ast(&sin_cos_ast, None).unwrap();
1158 println!("sin cos 0 = {}", result2);
1159 assert!(
1160 (result2 - sin(1.0, 0.0)).abs() < 1e-6,
1161 "sin cos 0 should be sin(cos(0)) = sin(1)"
1162 );
1163
1164 let neg_ast = AstExpr::Function {
1166 name: "neg".to_string(),
1167 args: vec![AstExpr::Constant(42.0)],
1168 };
1169
1170 let abs_neg_ast = AstExpr::Function {
1171 name: "abs".to_string(),
1172 args: vec![neg_ast],
1173 };
1174
1175 let result3 = eval_ast(&abs_neg_ast, None).unwrap();
1176 println!("abs -42 = {}", result3);
1177 assert_eq!(result3, 42.0, "abs -42 should be 42.0");
1178 }
1179
1180 #[test]
1181 fn test_function_application_juxtaposition_ast() {
1182 let ast = parse_expression("sin x");
1184 match ast {
1185 Ok(ast) => {
1186 println!("AST for sin x: {:?}", ast);
1187 match ast {
1188 AstExpr::Function { ref name, ref args } if name == "sin" => {
1189 assert_eq!(args.len(), 1);
1190 match &args[0] {
1191 AstExpr::Variable(var) => assert_eq!(var, "x"),
1192 _ => panic!("Expected variable as sin arg"),
1193 }
1194 }
1195 _ => panic!("Expected function node for sin x"),
1196 }
1197 }
1198 Err(e) => {
1199 println!("Parse error for 'sin x': {:?}", e);
1200 panic!("Parse error: {:?}", e);
1201 }
1202 }
1203
1204 let ast2 = parse_expression("abs(-42)");
1207 match ast2 {
1208 Ok(ast2) => {
1209 println!("AST for abs(-42): {:?}", ast2);
1210 match ast2 {
1211 AstExpr::Function { ref name, ref args } if name == "abs" => {
1212 assert_eq!(args.len(), 1);
1213 match &args[0] {
1214 AstExpr::Function {
1215 name: neg_name,
1216 args: neg_args,
1217 } if neg_name == "neg" => {
1218 assert_eq!(neg_args.len(), 1);
1219 match &neg_args[0] {
1220 AstExpr::Constant(c) => assert_eq!(*c, 42.0),
1221 _ => panic!("Expected constant as neg arg"),
1222 }
1223 }
1224 _ => panic!("Expected neg function as abs arg"),
1225 }
1226 }
1227 _ => panic!("Expected function node for abs(-42)"),
1228 }
1229 }
1230 Err(e) => {
1231 println!("Parse error for 'abs(-42)': {:?}", e);
1232 panic!("Parse error: {:?}", e);
1233 }
1234 }
1235 }
1236
1237 #[test]
1238 fn test_function_recognition() {
1239 let sin_ast = AstExpr::Function {
1241 name: "sin".to_string(),
1242 args: vec![AstExpr::Constant(0.5)],
1243 };
1244
1245 let asin_sin_ast = AstExpr::Function {
1246 name: "asin".to_string(),
1247 args: vec![sin_ast],
1248 };
1249
1250 let result = eval_ast(&asin_sin_ast, None).unwrap();
1251 println!("asin sin 0.5 = {}", result);
1252 assert!((result - 0.5).abs() < 1e-6, "asin(sin(0.5)) should be 0.5");
1253
1254 let sin_paren_ast = AstExpr::Function {
1256 name: "sin".to_string(),
1257 args: vec![AstExpr::Constant(0.5)],
1258 };
1259
1260 let result2 = eval_ast(&sin_paren_ast, None).unwrap();
1261 println!("sin(0.5) = {}", result2);
1262 assert!(
1263 (result2 - sin(0.5, 0.0)).abs() < 1e-6,
1264 "sin(0.5) should work"
1265 );
1266 }
1267
1268 #[test]
1269 fn test_parse_postfix_attribute_on_function_result_should_error() {
1270 let x_var = AstExpr::Variable("x".to_string());
1275 let _sin_x = AstExpr::Function {
1276 name: "sin".to_string(),
1277 args: vec![x_var],
1278 };
1279
1280 let ast = parse_expression("(sin x).foo");
1283 assert!(
1284 ast.is_err(),
1285 "Attribute access on function result should be rejected"
1286 );
1287 }
1288
1289 #[test]
1290 fn test_parse_comma_in_parens_and_top_level() {
1291 let ast = parse_expression("(1,2)");
1292 assert!(ast.is_ok(), "Comma in parens should be allowed");
1293 let ast2 = parse_expression("1,2,3");
1294 assert!(ast2.is_ok(), "Top-level comma should be allowed");
1295 let ast3 = parse_expression("(1,2),3");
1296 assert!(
1297 ast3.is_ok(),
1298 "Nested comma outside parens should be allowed"
1299 );
1300 }
1301
1302 #[test]
1303 fn test_deeply_nested_function_calls() {
1304 let expr = "abs(abs(abs(abs(abs(abs(abs(abs(abs(abs(-12345))))))))))";
1306 let ast = parse_expression(expr);
1307 assert!(
1308 ast.is_ok(),
1309 "Deeply nested function calls should be parsed correctly"
1310 );
1311
1312 let unbalanced = "abs(abs(abs(abs(abs(abs(abs(abs(abs(abs(-12345)))))))))";
1314 let result = parse_expression(unbalanced);
1315 assert!(result.is_err(), "Unbalanced parentheses should be detected");
1316 match result {
1317 Err(ExprError::UnmatchedParenthesis { position: _, found }) => {
1318 assert_eq!(
1320 found, "(",
1321 "The unmatched parenthesis should be an opening one"
1322 );
1323 }
1324 _ => panic!("Expected UnmatchedParenthesis error for unbalanced parentheses"),
1325 }
1326 }
1327
1328 #[test]
1329 fn test_parse_binary_op_deep_right_assoc_pow() {
1330 let ast = parse_expression("2^2^2^2^2").unwrap();
1331 fn count_right_assoc_pow(expr: &AstExpr) -> usize {
1332 match expr {
1333 AstExpr::Function { name, args } if name == "^" && args.len() == 2 => {
1334 1 + count_right_assoc_pow(&args[1])
1335 }
1336 _ => 0,
1337 }
1338 }
1339 let pow_depth = count_right_assoc_pow(&ast);
1340 assert_eq!(pow_depth, 4, "Should be right-associative chain of 4 '^'");
1341 }
1342
1343 #[test]
1344 fn test_deeply_nested_function_calls_with_debugging() {
1345 let expr = "abs(abs(abs(abs(abs(abs(abs(abs(abs(abs(-12345))))))))))";
1347
1348 println!("Testing expression with debugging: {}", expr);
1350
1351 let mut lexer = Lexer::new(expr);
1353 let mut tokens = Vec::new();
1354 while let Some(tok) = lexer.next_token() {
1355 tokens.push(tok);
1356 }
1357
1358 println!("Tokens:");
1359 for (i, token) in tokens.iter().enumerate() {
1360 println!(" {}: {:?}", i, token);
1361 }
1362
1363 let open_count = tokens
1365 .iter()
1366 .filter(|t| t.kind == TokenKind::Open && t.text.as_deref() == Some("("))
1367 .count();
1368 let close_count = tokens
1369 .iter()
1370 .filter(|t| t.kind == TokenKind::Close && t.text.as_deref() == Some(")"))
1371 .count();
1372
1373 println!("Opening parentheses: {}", open_count);
1374 println!("Closing parentheses: {}", close_count);
1375 assert_eq!(
1376 open_count, close_count,
1377 "Number of opening and closing parentheses should match"
1378 );
1379
1380 let ast = parse_expression(expr);
1382 assert!(
1383 ast.is_ok(),
1384 "Deeply nested function calls should be parsed correctly"
1385 );
1386 }
1387
1388 #[test]
1389 fn test_parse_binary_op_mixed_unary_and_power() {
1390 let ast = parse_expression("-2^2").unwrap();
1391 match ast {
1392 AstExpr::Function { name, args } if name == "neg" => match &args[0] {
1393 AstExpr::Function {
1394 name: n2,
1395 args: args2,
1396 } if n2 == "^" => {
1397 assert_eq!(args2.len(), 2);
1398 }
1399 _ => panic!("Expected ^ as argument to neg"),
1400 },
1401 _ => panic!("Expected neg as top-level function"),
1402 }
1403 let ast2 = parse_expression("(-2)^2").unwrap();
1404 match ast2 {
1405 AstExpr::Function { name, args } if name == "^" => match &args[0] {
1406 AstExpr::Function {
1407 name: n2,
1408 args: args2,
1409 } if n2 == "neg" => {
1410 assert_eq!(args2.len(), 1);
1411 }
1412 _ => panic!("Expected neg as left arg to ^"),
1413 },
1414 _ => panic!("Expected ^ as top-level function"),
1415 }
1416 let ast3 = parse_expression("-2^-2").unwrap();
1417 match ast3 {
1418 AstExpr::Function { name, args } if name == "neg" => match &args[0] {
1419 AstExpr::Function {
1420 name: n2,
1421 args: args2,
1422 } if n2 == "^" => {
1423 assert_eq!(args2.len(), 2);
1424 }
1425 _ => panic!("Expected ^ as argument to neg"),
1426 },
1427 _ => panic!("Expected neg as top-level function"),
1428 }
1429 }
1430
1431 #[test]
1432 fn test_parse_binary_op_mixed_precedence() {
1433 let ast = parse_expression("2+3*4^2-5/6").unwrap();
1434 match ast {
1435 AstExpr::Function { name, args } if name == "-" => {
1436 assert_eq!(args.len(), 2);
1437 }
1438 _ => panic!("Expected - as top-level function"),
1439 }
1440 }
1441
1442 #[test]
1443 fn test_parse_primary_paren_errors() {
1444 let ast = parse_expression("((1+2)");
1445 assert!(ast.is_err(), "Unmatched parenthesis should be rejected");
1446 let ast2 = parse_expression("1+)");
1447 assert!(ast2.is_err(), "Unmatched parenthesis should be rejected");
1448 }
1449
1450 #[test]
1451 fn test_parse_primary_variable_and_number_edge_cases() {
1452 let ast = parse_expression("foo_bar123").unwrap();
1453 match ast {
1454 AstExpr::Variable(name) => assert_eq!(name, "foo_bar123"),
1455 _ => panic!("Expected variable node"),
1456 }
1457
1458 let ast3 = parse_expression("1e-2").unwrap();
1462 match ast3 {
1463 AstExpr::Constant(val) => assert!((val - 0.01).abs() < 1e-10),
1464 _ => panic!("Expected constant node"),
1465 }
1466
1467 let ast4 = parse_expression("1.2e+3").unwrap();
1468 match ast4 {
1469 AstExpr::Constant(val) => assert!((val - 1200.0).abs() < 1e-10),
1470 _ => panic!("Expected constant node"),
1471 }
1472 }
1473
1474 #[test]
1475 fn test_parse_decimal_with_leading_dot() {
1476 let ast = parse_expression(".5").unwrap_or_else(|e| panic!("Parse error: {}", e));
1478 match ast {
1479 AstExpr::Constant(val) => assert_eq!(val, 0.5),
1480 _ => panic!("Expected constant node"),
1481 }
1482 }
1483
1484 #[test]
1485 fn test_log() {
1486 assert!((log(1000.0, 0.0) - 3.0).abs() < 1e-10);
1488 assert!((log(100.0, 0.0) - 2.0).abs() < 1e-10);
1489 assert!((log(10.0, 0.0) - 1.0).abs() < 1e-10);
1490 }
1491 #[test]
1492 fn test_eval_invalid_function_arity() {
1493 let result = interp("sin(1, 2)", None);
1495 assert!(result.is_err(), "sin(1, 2) should return an error");
1496
1497 if let Err(err) = result {
1498 match err {
1499 ExprError::InvalidFunctionCall {
1500 name,
1501 expected,
1502 found,
1503 } => {
1504 assert_eq!(name, "sin");
1505 assert_eq!(expected, 1);
1506 assert_eq!(found, 2);
1507 }
1508 _ => panic!(
1509 "Expected InvalidFunctionCall error for sin(1, 2), got: {:?}",
1510 err
1511 ),
1512 }
1513 }
1514
1515 let result2 = interp("pow(2)", None).unwrap();
1517 assert_eq!(result2, 4.0); }
1519}