quantrs2_symengine_pure/parser/
mod.rs

1//! Expression parser for mathematical expressions.
2//!
3//! This module provides a parser for converting string representations
4//! of mathematical expressions into `Expression` objects.
5//!
6//! ## Supported Syntax
7//!
8//! - Numbers: `42`, `3.14`, `-2.5`, `1e-10`
9//! - Variables: `x`, `theta`, `alpha_1`
10//! - Constants: `pi`, `e`, `I` (imaginary unit)
11//! - Operators: `+`, `-`, `*`, `/`, `^` (power)
12//! - Functions: `sin`, `cos`, `tan`, `exp`, `log`, `sqrt`, `abs`
13//! - Parentheses: `(`, `)`
14//!
15//! ## Examples
16//!
17//! ```ignore
18//! use quantrs2_symengine_pure::parser::parse;
19//!
20//! let expr = parse("sin(x) + cos(y)").unwrap();
21//! let expr2 = parse("x^2 + 2*x + 1").unwrap();
22//! ```
23
24use crate::error::{SymEngineError, SymEngineResult};
25use crate::expr::Expression;
26use crate::ops::trig;
27
28/// Token types for the lexer
29#[derive(Debug, Clone, PartialEq)]
30enum Token {
31    Number(f64),
32    Identifier(String),
33    Plus,
34    Minus,
35    Star,
36    Slash,
37    Caret,
38    LParen,
39    RParen,
40    Comma,
41    Eof,
42}
43
44/// Lexer for tokenizing mathematical expressions
45struct Lexer {
46    input: Vec<char>,
47    pos: usize,
48}
49
50impl Lexer {
51    fn new(input: &str) -> Self {
52        Self {
53            input: input.chars().collect(),
54            pos: 0,
55        }
56    }
57
58    fn peek(&self) -> Option<char> {
59        self.input.get(self.pos).copied()
60    }
61
62    fn advance(&mut self) -> Option<char> {
63        let c = self.peek();
64        self.pos += 1;
65        c
66    }
67
68    fn skip_whitespace(&mut self) {
69        while let Some(c) = self.peek() {
70            if c.is_whitespace() {
71                self.advance();
72            } else {
73                break;
74            }
75        }
76    }
77
78    fn read_number(&mut self) -> Token {
79        let mut s = String::new();
80        let mut has_dot = false;
81        let mut has_exp = false;
82
83        while let Some(c) = self.peek() {
84            if c.is_ascii_digit() {
85                s.push(c);
86                self.advance();
87            } else if c == '.' && !has_dot && !has_exp {
88                has_dot = true;
89                s.push(c);
90                self.advance();
91            } else if (c == 'e' || c == 'E') && !has_exp {
92                has_exp = true;
93                s.push(c);
94                self.advance();
95                // Handle optional sign after exponent
96                if let Some(next) = self.peek() {
97                    if next == '+' || next == '-' {
98                        s.push(next);
99                        self.advance();
100                    }
101                }
102            } else {
103                break;
104            }
105        }
106
107        let value = s.parse::<f64>().unwrap_or(0.0);
108        Token::Number(value)
109    }
110
111    fn read_identifier(&mut self) -> Token {
112        let mut s = String::new();
113
114        while let Some(c) = self.peek() {
115            if c.is_alphanumeric() || c == '_' {
116                s.push(c);
117                self.advance();
118            } else {
119                break;
120            }
121        }
122
123        Token::Identifier(s)
124    }
125
126    fn next_token(&mut self) -> SymEngineResult<Token> {
127        self.skip_whitespace();
128
129        match self.peek() {
130            None => Ok(Token::Eof),
131            Some(c) => {
132                if c.is_ascii_digit()
133                    || (c == '.'
134                        && self
135                            .input
136                            .get(self.pos + 1)
137                            .is_some_and(|n| n.is_ascii_digit()))
138                {
139                    Ok(self.read_number())
140                } else if c.is_alphabetic() || c == '_' {
141                    Ok(self.read_identifier())
142                } else {
143                    self.advance();
144                    match c {
145                        '+' => Ok(Token::Plus),
146                        '-' => Ok(Token::Minus),
147                        '*' => Ok(Token::Star),
148                        '/' => Ok(Token::Slash),
149                        '^' => Ok(Token::Caret),
150                        '(' => Ok(Token::LParen),
151                        ')' => Ok(Token::RParen),
152                        ',' => Ok(Token::Comma),
153                        _ => Err(SymEngineError::parse(format!("unexpected character: {c}"))),
154                    }
155                }
156            }
157        }
158    }
159}
160
161/// Parser for mathematical expressions
162struct Parser {
163    lexer: Lexer,
164    current: Token,
165}
166
167impl Parser {
168    fn new(input: &str) -> SymEngineResult<Self> {
169        let mut lexer = Lexer::new(input);
170        let current = lexer.next_token()?;
171        Ok(Self { lexer, current })
172    }
173
174    fn advance(&mut self) -> SymEngineResult<()> {
175        self.current = self.lexer.next_token()?;
176        Ok(())
177    }
178
179    fn expect(&mut self, expected: Token) -> SymEngineResult<()> {
180        if std::mem::discriminant(&self.current) == std::mem::discriminant(&expected) {
181            self.advance()
182        } else {
183            Err(SymEngineError::parse(format!(
184                "expected {:?}, got {:?}",
185                expected, self.current
186            )))
187        }
188    }
189
190    /// Parse a complete expression
191    fn parse_expression(&mut self) -> SymEngineResult<Expression> {
192        self.parse_additive()
193    }
194
195    /// Parse additive expressions: a + b, a - b
196    fn parse_additive(&mut self) -> SymEngineResult<Expression> {
197        let mut left = self.parse_multiplicative()?;
198
199        loop {
200            match &self.current {
201                Token::Plus => {
202                    self.advance()?;
203                    let right = self.parse_multiplicative()?;
204                    left = left + right;
205                }
206                Token::Minus => {
207                    self.advance()?;
208                    let right = self.parse_multiplicative()?;
209                    left = left - right;
210                }
211                _ => break,
212            }
213        }
214
215        Ok(left)
216    }
217
218    /// Parse multiplicative expressions: a * b, a / b
219    fn parse_multiplicative(&mut self) -> SymEngineResult<Expression> {
220        let mut left = self.parse_power()?;
221
222        loop {
223            match &self.current {
224                Token::Star => {
225                    self.advance()?;
226                    let right = self.parse_power()?;
227                    left = left * right;
228                }
229                Token::Slash => {
230                    self.advance()?;
231                    let right = self.parse_power()?;
232                    left = left / right;
233                }
234                _ => break,
235            }
236        }
237
238        Ok(left)
239    }
240
241    /// Parse power expressions: a ^ b (right associative)
242    fn parse_power(&mut self) -> SymEngineResult<Expression> {
243        let base = self.parse_unary()?;
244
245        if matches!(self.current, Token::Caret) {
246            self.advance()?;
247            let exp = self.parse_power()?; // Right associative
248            Ok(base.pow(&exp))
249        } else {
250            Ok(base)
251        }
252    }
253
254    /// Parse unary expressions: -a, +a
255    fn parse_unary(&mut self) -> SymEngineResult<Expression> {
256        match &self.current {
257            Token::Minus => {
258                self.advance()?;
259                let expr = self.parse_unary()?;
260                Ok(expr.neg())
261            }
262            Token::Plus => {
263                self.advance()?;
264                self.parse_unary()
265            }
266            _ => self.parse_primary(),
267        }
268    }
269
270    /// Parse primary expressions: numbers, variables, function calls, parentheses
271    fn parse_primary(&mut self) -> SymEngineResult<Expression> {
272        match self.current.clone() {
273            Token::Number(n) => {
274                self.advance()?;
275                Expression::float(n)
276            }
277            Token::Identifier(name) => {
278                self.advance()?;
279
280                // Check if this is a function call
281                if matches!(self.current, Token::LParen) {
282                    self.parse_function_call(&name)
283                } else {
284                    // It's a variable or constant
285                    Ok(Self::get_constant_or_symbol(&name))
286                }
287            }
288            Token::LParen => {
289                self.advance()?;
290                let expr = self.parse_expression()?;
291                self.expect(Token::RParen)?;
292                Ok(expr)
293            }
294            _ => Err(SymEngineError::parse(format!(
295                "unexpected token: {:?}",
296                self.current
297            ))),
298        }
299    }
300
301    /// Parse a function call: func(args...)
302    fn parse_function_call(&mut self, name: &str) -> SymEngineResult<Expression> {
303        self.expect(Token::LParen)?;
304
305        let mut args = Vec::new();
306        if !matches!(self.current, Token::RParen) {
307            args.push(self.parse_expression()?);
308            while matches!(self.current, Token::Comma) {
309                self.advance()?;
310                args.push(self.parse_expression()?);
311            }
312        }
313
314        self.expect(Token::RParen)?;
315
316        // Match known functions
317        match name {
318            "sin" => {
319                if args.len() != 1 {
320                    return Err(SymEngineError::parse("sin requires 1 argument"));
321                }
322                Ok(trig::sin(&args[0]))
323            }
324            "cos" => {
325                if args.len() != 1 {
326                    return Err(SymEngineError::parse("cos requires 1 argument"));
327                }
328                Ok(trig::cos(&args[0]))
329            }
330            "tan" => {
331                if args.len() != 1 {
332                    return Err(SymEngineError::parse("tan requires 1 argument"));
333                }
334                Ok(trig::tan(&args[0]))
335            }
336            "exp" => {
337                if args.len() != 1 {
338                    return Err(SymEngineError::parse("exp requires 1 argument"));
339                }
340                Ok(trig::exp(&args[0]))
341            }
342            "log" | "ln" => {
343                if args.len() != 1 {
344                    return Err(SymEngineError::parse("log requires 1 argument"));
345                }
346                Ok(trig::log(&args[0]))
347            }
348            "sqrt" => {
349                if args.len() != 1 {
350                    return Err(SymEngineError::parse("sqrt requires 1 argument"));
351                }
352                Ok(trig::sqrt(&args[0]))
353            }
354            "abs" => {
355                if args.len() != 1 {
356                    return Err(SymEngineError::parse("abs requires 1 argument"));
357                }
358                Ok(trig::abs(&args[0]))
359            }
360            "sinh" => {
361                if args.len() != 1 {
362                    return Err(SymEngineError::parse("sinh requires 1 argument"));
363                }
364                Ok(trig::sinh(&args[0]))
365            }
366            "cosh" => {
367                if args.len() != 1 {
368                    return Err(SymEngineError::parse("cosh requires 1 argument"));
369                }
370                Ok(trig::cosh(&args[0]))
371            }
372            "tanh" => {
373                if args.len() != 1 {
374                    return Err(SymEngineError::parse("tanh requires 1 argument"));
375                }
376                Ok(trig::tanh(&args[0]))
377            }
378            "asin" | "arcsin" => {
379                if args.len() != 1 {
380                    return Err(SymEngineError::parse("asin requires 1 argument"));
381                }
382                Ok(trig::asin(&args[0]))
383            }
384            "acos" | "arccos" => {
385                if args.len() != 1 {
386                    return Err(SymEngineError::parse("acos requires 1 argument"));
387                }
388                Ok(trig::acos(&args[0]))
389            }
390            "atan" | "arctan" => {
391                if args.len() != 1 {
392                    return Err(SymEngineError::parse("atan requires 1 argument"));
393                }
394                Ok(trig::atan(&args[0]))
395            }
396            "pow" => {
397                if args.len() != 2 {
398                    return Err(SymEngineError::parse("pow requires 2 arguments"));
399                }
400                Ok(args[0].pow(&args[1]))
401            }
402            _ => Err(SymEngineError::parse(format!("unknown function: {name}"))),
403        }
404    }
405
406    /// Get a constant or create a symbol
407    fn get_constant_or_symbol(name: &str) -> Expression {
408        match name {
409            "pi" | "PI" => Expression::pi(),
410            "e" | "E" => Expression::e(),
411            "i" | "I" => Expression::i(),
412            _ => Expression::symbol(name),
413        }
414    }
415}
416
417/// Parse a mathematical expression from a string.
418///
419/// # Arguments
420/// * `input` - The expression string to parse
421///
422/// # Returns
423/// The parsed `Expression` or an error if parsing fails.
424///
425/// # Examples
426///
427/// ```ignore
428/// use quantrs2_symengine_pure::parser::parse;
429///
430/// let expr = parse("x^2 + 2*x + 1").unwrap();
431/// let expr2 = parse("sin(pi/2)").unwrap();
432/// ```
433///
434/// # Errors
435/// Returns `SymEngineError::ParseError` if the input is not a valid expression.
436pub fn parse(input: &str) -> SymEngineResult<Expression> {
437    if input.trim().is_empty() {
438        return Err(SymEngineError::parse("empty expression"));
439    }
440
441    let mut parser = Parser::new(input)?;
442    let expr = parser.parse_expression()?;
443
444    // Ensure we consumed all input
445    if !matches!(parser.current, Token::Eof) {
446        return Err(SymEngineError::parse(format!(
447            "unexpected token at end: {:?}",
448            parser.current
449        )));
450    }
451
452    Ok(expr)
453}
454
455/// Parse multiple expressions separated by semicolons.
456///
457/// # Arguments
458/// * `input` - String containing multiple expressions separated by `;`
459///
460/// # Returns
461/// Vector of parsed expressions.
462pub fn parse_many(input: &str) -> SymEngineResult<Vec<Expression>> {
463    input
464        .split(';')
465        .filter(|s| !s.trim().is_empty())
466        .map(parse)
467        .collect()
468}
469
470#[cfg(test)]
471#[allow(clippy::approx_constant)]
472mod tests {
473    use super::*;
474    use std::collections::HashMap;
475
476    #[test]
477    fn test_parse_number() {
478        let expr = parse("42").expect("should parse");
479        assert!(expr.is_number());
480        assert!((expr.to_f64().expect("is number") - 42.0).abs() < 1e-10);
481    }
482
483    #[test]
484    fn test_parse_float() {
485        let expr = parse("3.14").expect("should parse");
486        assert!(expr.is_number());
487        assert!((expr.to_f64().expect("is number") - 3.14).abs() < 1e-10);
488    }
489
490    #[test]
491    fn test_parse_scientific() {
492        let expr = parse("1e-10").expect("should parse");
493        assert!(expr.is_number());
494        assert!((expr.to_f64().expect("is number") - 1e-10).abs() < 1e-20);
495    }
496
497    #[test]
498    fn test_parse_variable() {
499        let expr = parse("x").expect("should parse");
500        assert_eq!(expr.as_symbol(), Some("x"));
501    }
502
503    #[test]
504    fn test_parse_constant_pi() {
505        let expr = parse("pi").expect("should parse");
506        assert_eq!(expr.as_symbol(), Some("pi"));
507    }
508
509    #[test]
510    fn test_parse_addition() {
511        let expr = parse("x + y").expect("should parse");
512
513        let mut values = HashMap::new();
514        values.insert("x".to_string(), 3.0);
515        values.insert("y".to_string(), 4.0);
516
517        let result = expr.eval(&values).expect("should eval");
518        assert!((result - 7.0).abs() < 1e-10);
519    }
520
521    #[test]
522    fn test_parse_subtraction() {
523        let expr = parse("x - y").expect("should parse");
524
525        let mut values = HashMap::new();
526        values.insert("x".to_string(), 10.0);
527        values.insert("y".to_string(), 3.0);
528
529        let result = expr.eval(&values).expect("should eval");
530        assert!((result - 7.0).abs() < 1e-10);
531    }
532
533    #[test]
534    fn test_parse_multiplication() {
535        let expr = parse("x * y").expect("should parse");
536
537        let mut values = HashMap::new();
538        values.insert("x".to_string(), 3.0);
539        values.insert("y".to_string(), 4.0);
540
541        let result = expr.eval(&values).expect("should eval");
542        assert!((result - 12.0).abs() < 1e-10);
543    }
544
545    #[test]
546    fn test_parse_division() {
547        let expr = parse("x / y").expect("should parse");
548
549        let mut values = HashMap::new();
550        values.insert("x".to_string(), 12.0);
551        values.insert("y".to_string(), 4.0);
552
553        let result = expr.eval(&values).expect("should eval");
554        assert!((result - 3.0).abs() < 1e-10);
555    }
556
557    #[test]
558    fn test_parse_power() {
559        let expr = parse("x ^ 2").expect("should parse");
560
561        let mut values = HashMap::new();
562        values.insert("x".to_string(), 3.0);
563
564        let result = expr.eval(&values).expect("should eval");
565        assert!((result - 9.0).abs() < 1e-10);
566    }
567
568    #[test]
569    fn test_parse_power_right_associative() {
570        // 2^3^2 should be 2^(3^2) = 2^9 = 512, not (2^3)^2 = 64
571        let expr = parse("2^3^2").expect("should parse");
572        let result = expr.eval(&HashMap::new()).expect("should eval");
573        assert!((result - 512.0).abs() < 1e-10);
574    }
575
576    #[test]
577    fn test_parse_unary_minus() {
578        let expr = parse("-x").expect("should parse");
579
580        let mut values = HashMap::new();
581        values.insert("x".to_string(), 5.0);
582
583        let result = expr.eval(&values).expect("should eval");
584        assert!((result - (-5.0)).abs() < 1e-10);
585    }
586
587    #[test]
588    fn test_parse_parentheses() {
589        let expr = parse("(x + y) * z").expect("should parse");
590
591        let mut values = HashMap::new();
592        values.insert("x".to_string(), 2.0);
593        values.insert("y".to_string(), 3.0);
594        values.insert("z".to_string(), 4.0);
595
596        let result = expr.eval(&values).expect("should eval");
597        assert!((result - 20.0).abs() < 1e-10); // (2+3)*4 = 20
598    }
599
600    #[test]
601    fn test_parse_complex_expression() {
602        let expr = parse("x^2 + 2*x + 1").expect("should parse");
603
604        let mut values = HashMap::new();
605        values.insert("x".to_string(), 3.0);
606
607        let result = expr.eval(&values).expect("should eval");
608        assert!((result - 16.0).abs() < 1e-10); // 9 + 6 + 1 = 16
609    }
610
611    #[test]
612    fn test_parse_sin() {
613        let expr = parse("sin(x)").expect("should parse");
614
615        let mut values = HashMap::new();
616        values.insert("x".to_string(), 0.0);
617
618        let result = expr.eval(&values).expect("should eval");
619        assert!(result.abs() < 1e-10); // sin(0) = 0
620    }
621
622    #[test]
623    fn test_parse_cos() {
624        let expr = parse("cos(x)").expect("should parse");
625
626        let mut values = HashMap::new();
627        values.insert("x".to_string(), 0.0);
628
629        let result = expr.eval(&values).expect("should eval");
630        assert!((result - 1.0).abs() < 1e-10); // cos(0) = 1
631    }
632
633    #[test]
634    fn test_parse_exp() {
635        let expr = parse("exp(x)").expect("should parse");
636
637        let mut values = HashMap::new();
638        values.insert("x".to_string(), 0.0);
639
640        let result = expr.eval(&values).expect("should eval");
641        assert!((result - 1.0).abs() < 1e-10); // exp(0) = 1
642    }
643
644    #[test]
645    fn test_parse_sqrt() {
646        let expr = parse("sqrt(x)").expect("should parse");
647
648        let mut values = HashMap::new();
649        values.insert("x".to_string(), 4.0);
650
651        let result = expr.eval(&values).expect("should eval");
652        assert!((result - 2.0).abs() < 1e-10);
653    }
654
655    #[test]
656    fn test_parse_nested_functions() {
657        let expr = parse("sin(cos(x))").expect("should parse");
658
659        let mut values = HashMap::new();
660        values.insert("x".to_string(), 0.0);
661
662        let result = expr.eval(&values).expect("should eval");
663        // sin(cos(0)) = sin(1) ≈ 0.8414
664        assert!((result - 0.841_470_984_8).abs() < 1e-6);
665    }
666
667    #[test]
668    fn test_parse_combined() {
669        let expr = parse("sin(x)^2 + cos(x)^2").expect("should parse");
670
671        let mut values = HashMap::new();
672        values.insert("x".to_string(), 1.5); // any value should give 1
673
674        let result = expr.eval(&values).expect("should eval");
675        assert!((result - 1.0).abs() < 1e-10); // sin²+cos² = 1
676    }
677
678    #[test]
679    fn test_parse_many() {
680        let exprs = parse_many("x + 1; y * 2; z ^ 3").expect("should parse");
681        assert_eq!(exprs.len(), 3);
682    }
683
684    #[test]
685    fn test_parse_empty_error() {
686        let result = parse("");
687        assert!(result.is_err());
688    }
689
690    #[test]
691    fn test_parse_invalid_syntax() {
692        let result = parse("x + + y");
693        // This might parse depending on the implementation
694        // but at least it shouldn't panic
695        let _ = result;
696    }
697}