quantrs2_symengine_pure/parser/
mod.rs1use crate::error::{SymEngineError, SymEngineResult};
25use crate::expr::Expression;
26use crate::ops::trig;
27
28#[derive(Debug, Clone, PartialEq)]
30enum Token {
31 Number(f64),
32 Identifier(String),
33 Plus,
34 Minus,
35 Star,
36 Slash,
37 Caret,
38 LParen,
39 RParen,
40 Comma,
41 Eof,
42}
43
44struct Lexer {
46 input: Vec<char>,
47 pos: usize,
48}
49
50impl Lexer {
51 fn new(input: &str) -> Self {
52 Self {
53 input: input.chars().collect(),
54 pos: 0,
55 }
56 }
57
58 fn peek(&self) -> Option<char> {
59 self.input.get(self.pos).copied()
60 }
61
62 fn advance(&mut self) -> Option<char> {
63 let c = self.peek();
64 self.pos += 1;
65 c
66 }
67
68 fn skip_whitespace(&mut self) {
69 while let Some(c) = self.peek() {
70 if c.is_whitespace() {
71 self.advance();
72 } else {
73 break;
74 }
75 }
76 }
77
78 fn read_number(&mut self) -> Token {
79 let mut s = String::new();
80 let mut has_dot = false;
81 let mut has_exp = false;
82
83 while let Some(c) = self.peek() {
84 if c.is_ascii_digit() {
85 s.push(c);
86 self.advance();
87 } else if c == '.' && !has_dot && !has_exp {
88 has_dot = true;
89 s.push(c);
90 self.advance();
91 } else if (c == 'e' || c == 'E') && !has_exp {
92 has_exp = true;
93 s.push(c);
94 self.advance();
95 if let Some(next) = self.peek() {
97 if next == '+' || next == '-' {
98 s.push(next);
99 self.advance();
100 }
101 }
102 } else {
103 break;
104 }
105 }
106
107 let value = s.parse::<f64>().unwrap_or(0.0);
108 Token::Number(value)
109 }
110
111 fn read_identifier(&mut self) -> Token {
112 let mut s = String::new();
113
114 while let Some(c) = self.peek() {
115 if c.is_alphanumeric() || c == '_' {
116 s.push(c);
117 self.advance();
118 } else {
119 break;
120 }
121 }
122
123 Token::Identifier(s)
124 }
125
126 fn next_token(&mut self) -> SymEngineResult<Token> {
127 self.skip_whitespace();
128
129 match self.peek() {
130 None => Ok(Token::Eof),
131 Some(c) => {
132 if c.is_ascii_digit()
133 || (c == '.'
134 && self
135 .input
136 .get(self.pos + 1)
137 .is_some_and(|n| n.is_ascii_digit()))
138 {
139 Ok(self.read_number())
140 } else if c.is_alphabetic() || c == '_' {
141 Ok(self.read_identifier())
142 } else {
143 self.advance();
144 match c {
145 '+' => Ok(Token::Plus),
146 '-' => Ok(Token::Minus),
147 '*' => Ok(Token::Star),
148 '/' => Ok(Token::Slash),
149 '^' => Ok(Token::Caret),
150 '(' => Ok(Token::LParen),
151 ')' => Ok(Token::RParen),
152 ',' => Ok(Token::Comma),
153 _ => Err(SymEngineError::parse(format!("unexpected character: {c}"))),
154 }
155 }
156 }
157 }
158 }
159}
160
161struct Parser {
163 lexer: Lexer,
164 current: Token,
165}
166
167impl Parser {
168 fn new(input: &str) -> SymEngineResult<Self> {
169 let mut lexer = Lexer::new(input);
170 let current = lexer.next_token()?;
171 Ok(Self { lexer, current })
172 }
173
174 fn advance(&mut self) -> SymEngineResult<()> {
175 self.current = self.lexer.next_token()?;
176 Ok(())
177 }
178
179 fn expect(&mut self, expected: Token) -> SymEngineResult<()> {
180 if std::mem::discriminant(&self.current) == std::mem::discriminant(&expected) {
181 self.advance()
182 } else {
183 Err(SymEngineError::parse(format!(
184 "expected {:?}, got {:?}",
185 expected, self.current
186 )))
187 }
188 }
189
190 fn parse_expression(&mut self) -> SymEngineResult<Expression> {
192 self.parse_additive()
193 }
194
195 fn parse_additive(&mut self) -> SymEngineResult<Expression> {
197 let mut left = self.parse_multiplicative()?;
198
199 loop {
200 match &self.current {
201 Token::Plus => {
202 self.advance()?;
203 let right = self.parse_multiplicative()?;
204 left = left + right;
205 }
206 Token::Minus => {
207 self.advance()?;
208 let right = self.parse_multiplicative()?;
209 left = left - right;
210 }
211 _ => break,
212 }
213 }
214
215 Ok(left)
216 }
217
218 fn parse_multiplicative(&mut self) -> SymEngineResult<Expression> {
220 let mut left = self.parse_power()?;
221
222 loop {
223 match &self.current {
224 Token::Star => {
225 self.advance()?;
226 let right = self.parse_power()?;
227 left = left * right;
228 }
229 Token::Slash => {
230 self.advance()?;
231 let right = self.parse_power()?;
232 left = left / right;
233 }
234 _ => break,
235 }
236 }
237
238 Ok(left)
239 }
240
241 fn parse_power(&mut self) -> SymEngineResult<Expression> {
243 let base = self.parse_unary()?;
244
245 if matches!(self.current, Token::Caret) {
246 self.advance()?;
247 let exp = self.parse_power()?; Ok(base.pow(&exp))
249 } else {
250 Ok(base)
251 }
252 }
253
254 fn parse_unary(&mut self) -> SymEngineResult<Expression> {
256 match &self.current {
257 Token::Minus => {
258 self.advance()?;
259 let expr = self.parse_unary()?;
260 Ok(expr.neg())
261 }
262 Token::Plus => {
263 self.advance()?;
264 self.parse_unary()
265 }
266 _ => self.parse_primary(),
267 }
268 }
269
270 fn parse_primary(&mut self) -> SymEngineResult<Expression> {
272 match self.current.clone() {
273 Token::Number(n) => {
274 self.advance()?;
275 Expression::float(n)
276 }
277 Token::Identifier(name) => {
278 self.advance()?;
279
280 if matches!(self.current, Token::LParen) {
282 self.parse_function_call(&name)
283 } else {
284 Ok(Self::get_constant_or_symbol(&name))
286 }
287 }
288 Token::LParen => {
289 self.advance()?;
290 let expr = self.parse_expression()?;
291 self.expect(Token::RParen)?;
292 Ok(expr)
293 }
294 _ => Err(SymEngineError::parse(format!(
295 "unexpected token: {:?}",
296 self.current
297 ))),
298 }
299 }
300
301 fn parse_function_call(&mut self, name: &str) -> SymEngineResult<Expression> {
303 self.expect(Token::LParen)?;
304
305 let mut args = Vec::new();
306 if !matches!(self.current, Token::RParen) {
307 args.push(self.parse_expression()?);
308 while matches!(self.current, Token::Comma) {
309 self.advance()?;
310 args.push(self.parse_expression()?);
311 }
312 }
313
314 self.expect(Token::RParen)?;
315
316 match name {
318 "sin" => {
319 if args.len() != 1 {
320 return Err(SymEngineError::parse("sin requires 1 argument"));
321 }
322 Ok(trig::sin(&args[0]))
323 }
324 "cos" => {
325 if args.len() != 1 {
326 return Err(SymEngineError::parse("cos requires 1 argument"));
327 }
328 Ok(trig::cos(&args[0]))
329 }
330 "tan" => {
331 if args.len() != 1 {
332 return Err(SymEngineError::parse("tan requires 1 argument"));
333 }
334 Ok(trig::tan(&args[0]))
335 }
336 "exp" => {
337 if args.len() != 1 {
338 return Err(SymEngineError::parse("exp requires 1 argument"));
339 }
340 Ok(trig::exp(&args[0]))
341 }
342 "log" | "ln" => {
343 if args.len() != 1 {
344 return Err(SymEngineError::parse("log requires 1 argument"));
345 }
346 Ok(trig::log(&args[0]))
347 }
348 "sqrt" => {
349 if args.len() != 1 {
350 return Err(SymEngineError::parse("sqrt requires 1 argument"));
351 }
352 Ok(trig::sqrt(&args[0]))
353 }
354 "abs" => {
355 if args.len() != 1 {
356 return Err(SymEngineError::parse("abs requires 1 argument"));
357 }
358 Ok(trig::abs(&args[0]))
359 }
360 "sinh" => {
361 if args.len() != 1 {
362 return Err(SymEngineError::parse("sinh requires 1 argument"));
363 }
364 Ok(trig::sinh(&args[0]))
365 }
366 "cosh" => {
367 if args.len() != 1 {
368 return Err(SymEngineError::parse("cosh requires 1 argument"));
369 }
370 Ok(trig::cosh(&args[0]))
371 }
372 "tanh" => {
373 if args.len() != 1 {
374 return Err(SymEngineError::parse("tanh requires 1 argument"));
375 }
376 Ok(trig::tanh(&args[0]))
377 }
378 "asin" | "arcsin" => {
379 if args.len() != 1 {
380 return Err(SymEngineError::parse("asin requires 1 argument"));
381 }
382 Ok(trig::asin(&args[0]))
383 }
384 "acos" | "arccos" => {
385 if args.len() != 1 {
386 return Err(SymEngineError::parse("acos requires 1 argument"));
387 }
388 Ok(trig::acos(&args[0]))
389 }
390 "atan" | "arctan" => {
391 if args.len() != 1 {
392 return Err(SymEngineError::parse("atan requires 1 argument"));
393 }
394 Ok(trig::atan(&args[0]))
395 }
396 "pow" => {
397 if args.len() != 2 {
398 return Err(SymEngineError::parse("pow requires 2 arguments"));
399 }
400 Ok(args[0].pow(&args[1]))
401 }
402 _ => Err(SymEngineError::parse(format!("unknown function: {name}"))),
403 }
404 }
405
406 fn get_constant_or_symbol(name: &str) -> Expression {
408 match name {
409 "pi" | "PI" => Expression::pi(),
410 "e" | "E" => Expression::e(),
411 "i" | "I" => Expression::i(),
412 _ => Expression::symbol(name),
413 }
414 }
415}
416
417pub fn parse(input: &str) -> SymEngineResult<Expression> {
437 if input.trim().is_empty() {
438 return Err(SymEngineError::parse("empty expression"));
439 }
440
441 let mut parser = Parser::new(input)?;
442 let expr = parser.parse_expression()?;
443
444 if !matches!(parser.current, Token::Eof) {
446 return Err(SymEngineError::parse(format!(
447 "unexpected token at end: {:?}",
448 parser.current
449 )));
450 }
451
452 Ok(expr)
453}
454
455pub fn parse_many(input: &str) -> SymEngineResult<Vec<Expression>> {
463 input
464 .split(';')
465 .filter(|s| !s.trim().is_empty())
466 .map(parse)
467 .collect()
468}
469
470#[cfg(test)]
471#[allow(clippy::approx_constant)]
472mod tests {
473 use super::*;
474 use std::collections::HashMap;
475
476 #[test]
477 fn test_parse_number() {
478 let expr = parse("42").expect("should parse");
479 assert!(expr.is_number());
480 assert!((expr.to_f64().expect("is number") - 42.0).abs() < 1e-10);
481 }
482
483 #[test]
484 fn test_parse_float() {
485 let expr = parse("3.14").expect("should parse");
486 assert!(expr.is_number());
487 assert!((expr.to_f64().expect("is number") - 3.14).abs() < 1e-10);
488 }
489
490 #[test]
491 fn test_parse_scientific() {
492 let expr = parse("1e-10").expect("should parse");
493 assert!(expr.is_number());
494 assert!((expr.to_f64().expect("is number") - 1e-10).abs() < 1e-20);
495 }
496
497 #[test]
498 fn test_parse_variable() {
499 let expr = parse("x").expect("should parse");
500 assert_eq!(expr.as_symbol(), Some("x"));
501 }
502
503 #[test]
504 fn test_parse_constant_pi() {
505 let expr = parse("pi").expect("should parse");
506 assert_eq!(expr.as_symbol(), Some("pi"));
507 }
508
509 #[test]
510 fn test_parse_addition() {
511 let expr = parse("x + y").expect("should parse");
512
513 let mut values = HashMap::new();
514 values.insert("x".to_string(), 3.0);
515 values.insert("y".to_string(), 4.0);
516
517 let result = expr.eval(&values).expect("should eval");
518 assert!((result - 7.0).abs() < 1e-10);
519 }
520
521 #[test]
522 fn test_parse_subtraction() {
523 let expr = parse("x - y").expect("should parse");
524
525 let mut values = HashMap::new();
526 values.insert("x".to_string(), 10.0);
527 values.insert("y".to_string(), 3.0);
528
529 let result = expr.eval(&values).expect("should eval");
530 assert!((result - 7.0).abs() < 1e-10);
531 }
532
533 #[test]
534 fn test_parse_multiplication() {
535 let expr = parse("x * y").expect("should parse");
536
537 let mut values = HashMap::new();
538 values.insert("x".to_string(), 3.0);
539 values.insert("y".to_string(), 4.0);
540
541 let result = expr.eval(&values).expect("should eval");
542 assert!((result - 12.0).abs() < 1e-10);
543 }
544
545 #[test]
546 fn test_parse_division() {
547 let expr = parse("x / y").expect("should parse");
548
549 let mut values = HashMap::new();
550 values.insert("x".to_string(), 12.0);
551 values.insert("y".to_string(), 4.0);
552
553 let result = expr.eval(&values).expect("should eval");
554 assert!((result - 3.0).abs() < 1e-10);
555 }
556
557 #[test]
558 fn test_parse_power() {
559 let expr = parse("x ^ 2").expect("should parse");
560
561 let mut values = HashMap::new();
562 values.insert("x".to_string(), 3.0);
563
564 let result = expr.eval(&values).expect("should eval");
565 assert!((result - 9.0).abs() < 1e-10);
566 }
567
568 #[test]
569 fn test_parse_power_right_associative() {
570 let expr = parse("2^3^2").expect("should parse");
572 let result = expr.eval(&HashMap::new()).expect("should eval");
573 assert!((result - 512.0).abs() < 1e-10);
574 }
575
576 #[test]
577 fn test_parse_unary_minus() {
578 let expr = parse("-x").expect("should parse");
579
580 let mut values = HashMap::new();
581 values.insert("x".to_string(), 5.0);
582
583 let result = expr.eval(&values).expect("should eval");
584 assert!((result - (-5.0)).abs() < 1e-10);
585 }
586
587 #[test]
588 fn test_parse_parentheses() {
589 let expr = parse("(x + y) * z").expect("should parse");
590
591 let mut values = HashMap::new();
592 values.insert("x".to_string(), 2.0);
593 values.insert("y".to_string(), 3.0);
594 values.insert("z".to_string(), 4.0);
595
596 let result = expr.eval(&values).expect("should eval");
597 assert!((result - 20.0).abs() < 1e-10); }
599
600 #[test]
601 fn test_parse_complex_expression() {
602 let expr = parse("x^2 + 2*x + 1").expect("should parse");
603
604 let mut values = HashMap::new();
605 values.insert("x".to_string(), 3.0);
606
607 let result = expr.eval(&values).expect("should eval");
608 assert!((result - 16.0).abs() < 1e-10); }
610
611 #[test]
612 fn test_parse_sin() {
613 let expr = parse("sin(x)").expect("should parse");
614
615 let mut values = HashMap::new();
616 values.insert("x".to_string(), 0.0);
617
618 let result = expr.eval(&values).expect("should eval");
619 assert!(result.abs() < 1e-10); }
621
622 #[test]
623 fn test_parse_cos() {
624 let expr = parse("cos(x)").expect("should parse");
625
626 let mut values = HashMap::new();
627 values.insert("x".to_string(), 0.0);
628
629 let result = expr.eval(&values).expect("should eval");
630 assert!((result - 1.0).abs() < 1e-10); }
632
633 #[test]
634 fn test_parse_exp() {
635 let expr = parse("exp(x)").expect("should parse");
636
637 let mut values = HashMap::new();
638 values.insert("x".to_string(), 0.0);
639
640 let result = expr.eval(&values).expect("should eval");
641 assert!((result - 1.0).abs() < 1e-10); }
643
644 #[test]
645 fn test_parse_sqrt() {
646 let expr = parse("sqrt(x)").expect("should parse");
647
648 let mut values = HashMap::new();
649 values.insert("x".to_string(), 4.0);
650
651 let result = expr.eval(&values).expect("should eval");
652 assert!((result - 2.0).abs() < 1e-10);
653 }
654
655 #[test]
656 fn test_parse_nested_functions() {
657 let expr = parse("sin(cos(x))").expect("should parse");
658
659 let mut values = HashMap::new();
660 values.insert("x".to_string(), 0.0);
661
662 let result = expr.eval(&values).expect("should eval");
663 assert!((result - 0.841_470_984_8).abs() < 1e-6);
665 }
666
667 #[test]
668 fn test_parse_combined() {
669 let expr = parse("sin(x)^2 + cos(x)^2").expect("should parse");
670
671 let mut values = HashMap::new();
672 values.insert("x".to_string(), 1.5); let result = expr.eval(&values).expect("should eval");
675 assert!((result - 1.0).abs() < 1e-10); }
677
678 #[test]
679 fn test_parse_many() {
680 let exprs = parse_many("x + 1; y * 2; z ^ 3").expect("should parse");
681 assert_eq!(exprs.len(), 3);
682 }
683
684 #[test]
685 fn test_parse_empty_error() {
686 let result = parse("");
687 assert!(result.is_err());
688 }
689
690 #[test]
691 fn test_parse_invalid_syntax() {
692 let result = parse("x + + y");
693 let _ = result;
696 }
697}