Skip to main content

arael_sym/
parse.rs

1use super::*;
2use std::fmt;
3
4/// Error type for expression parsing.
5///
6/// Contains the byte position of the error and a human-readable message.
7#[derive(Debug, Clone)]
8pub struct ParseError {
9    /// Byte offset in the input where the error occurred.
10    pub pos: usize,
11    /// Human-readable error description.
12    pub msg: String,
13}
14
15impl fmt::Display for ParseError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        write!(f, "parse error at position {}: {}", self.pos, self.msg)
18    }
19}
20
21impl std::error::Error for ParseError {}
22
23// --- Tokens ---
24
25#[derive(Debug, Clone, PartialEq)]
26enum Token {
27    Number(f64),
28    Ident(String),
29    Plus,
30    Minus,
31    Star,
32    Slash,
33    Caret,
34    LParen,
35    RParen,
36    Comma,
37    Eof,
38}
39
40struct Lexer {
41    chars: Vec<char>,
42    pos: usize,
43}
44
45impl Lexer {
46    fn new(input: &str) -> Self {
47        Lexer { chars: input.chars().collect(), pos: 0 }
48    }
49
50    fn skip_whitespace(&mut self) {
51        while self.pos < self.chars.len() && self.chars[self.pos].is_ascii_whitespace() {
52            self.pos += 1;
53        }
54    }
55
56    fn next_token(&mut self) -> Result<(Token, usize), ParseError> {
57        self.skip_whitespace();
58        let start = self.pos;
59
60        if self.pos >= self.chars.len() {
61            return Ok((Token::Eof, start));
62        }
63
64        let ch = self.chars[self.pos];
65        self.pos += 1;
66
67        match ch {
68            '+' => Ok((Token::Plus, start)),
69            '-' => Ok((Token::Minus, start)),
70            '*' => Ok((Token::Star, start)),
71            '/' => Ok((Token::Slash, start)),
72            '^' => Ok((Token::Caret, start)),
73            '(' => Ok((Token::LParen, start)),
74            ')' => Ok((Token::RParen, start)),
75            ',' => Ok((Token::Comma, start)),
76            c if c.is_ascii_digit() || c == '.' => {
77                let mut s = String::new();
78                s.push(c);
79                while self.pos < self.chars.len()
80                    && (self.chars[self.pos].is_ascii_digit() || self.chars[self.pos] == '.')
81                {
82                    s.push(self.chars[self.pos]);
83                    self.pos += 1;
84                }
85                // Optional scientific exponent: e / E, optional sign, >=1 digit.
86                // Only consume `e`/`E` if immediately followed by a digit or
87                // signed digit -- otherwise leave it for the identifier path
88                // (e.g. `2e` must stay an error, `2 * exp(x)` must parse the
89                // `exp` ident cleanly when the number ends).
90                if self.pos < self.chars.len()
91                    && (self.chars[self.pos] == 'e' || self.chars[self.pos] == 'E')
92                {
93                    let mut look = self.pos + 1;
94                    if look < self.chars.len()
95                        && (self.chars[look] == '+' || self.chars[look] == '-')
96                    {
97                        look += 1;
98                    }
99                    if look < self.chars.len() && self.chars[look].is_ascii_digit() {
100                        s.push(self.chars[self.pos]);
101                        self.pos += 1;
102                        if self.chars[self.pos] == '+' || self.chars[self.pos] == '-' {
103                            s.push(self.chars[self.pos]);
104                            self.pos += 1;
105                        }
106                        while self.pos < self.chars.len()
107                            && self.chars[self.pos].is_ascii_digit()
108                        {
109                            s.push(self.chars[self.pos]);
110                            self.pos += 1;
111                        }
112                    }
113                }
114                let val: f64 = s.parse().map_err(|_| ParseError {
115                    pos: start,
116                    msg: format!("invalid number: {s}"),
117                })?;
118                Ok((Token::Number(val), start))
119            }
120            c if c.is_ascii_alphabetic() || c == '_' => {
121                let mut s = String::new();
122                s.push(c);
123                while self.pos < self.chars.len()
124                    && (self.chars[self.pos].is_ascii_alphanumeric() || self.chars[self.pos] == '_' || self.chars[self.pos] == '.')
125                {
126                    s.push(self.chars[self.pos]);
127                    self.pos += 1;
128                }
129                Ok((Token::Ident(s), start))
130            }
131            _ => Err(ParseError {
132                pos: start,
133                msg: format!("unexpected character: '{ch}'"),
134            }),
135        }
136    }
137}
138
139// --- Parser ---
140
141struct Parser<'a> {
142    tokens: Vec<(Token, usize)>,
143    pos: usize,
144    bag: Option<&'a FunctionBag>,
145}
146
147impl<'a> Parser<'a> {
148    fn from_str(input: &str) -> Result<Self, ParseError> {
149        let mut lexer = Lexer::new(input);
150        let mut tokens = Vec::new();
151        loop {
152            let tok = lexer.next_token()?;
153            let is_eof = tok.0 == Token::Eof;
154            tokens.push(tok);
155            if is_eof { break; }
156        }
157        Ok(Parser { tokens, pos: 0, bag: None })
158    }
159
160    fn peek(&self) -> &Token {
161        &self.tokens[self.pos].0
162    }
163
164    fn peek_pos(&self) -> usize {
165        self.tokens[self.pos].1
166    }
167
168    fn advance(&mut self) -> &Token {
169        let tok = &self.tokens[self.pos].0;
170        if self.pos + 1 < self.tokens.len() {
171            self.pos += 1;
172        }
173        tok
174    }
175
176    fn expect(&mut self, expected: &Token) -> Result<(), ParseError> {
177        if self.peek() == expected {
178            self.advance();
179            Ok(())
180        } else {
181            Err(ParseError {
182                pos: self.peek_pos(),
183                msg: format!("expected {expected:?}, got {:?}", self.peek()),
184            })
185        }
186    }
187
188    // expr = term (('+' | '-') term)*
189    fn parse_expr(&mut self) -> Result<E, ParseError> {
190        let mut left = self.parse_term()?;
191        loop {
192            match self.peek() {
193                Token::Plus => { self.advance(); let right = self.parse_term()?; left = left + right; }
194                Token::Minus => { self.advance(); let right = self.parse_term()?; left = left - right; }
195                _ => break,
196            }
197        }
198        Ok(left)
199    }
200
201    // term = unary (('*' | '/') unary)*
202    fn parse_term(&mut self) -> Result<E, ParseError> {
203        let mut left = self.parse_unary()?;
204        loop {
205            match self.peek() {
206                Token::Star => { self.advance(); let right = self.parse_unary()?; left = left * right; }
207                Token::Slash => { self.advance(); let right = self.parse_unary()?; left = left / right; }
208                _ => break,
209            }
210        }
211        Ok(left)
212    }
213
214    // unary = '-' unary | power
215    fn parse_unary(&mut self) -> Result<E, ParseError> {
216        if *self.peek() == Token::Minus {
217            self.advance();
218            let expr = self.parse_unary()?;
219            Ok(-expr)
220        } else {
221            self.parse_power()
222        }
223    }
224
225    // power = atom ('^' power)?   (right-associative)
226    fn parse_power(&mut self) -> Result<E, ParseError> {
227        let base = self.parse_atom()?;
228        if *self.peek() == Token::Caret {
229            self.advance();
230            let exp = self.parse_unary()?;
231            Ok(pow(base, exp))
232        } else {
233            Ok(base)
234        }
235    }
236
237    // atom = NUMBER | IDENT | IDENT '(' args ')' | '(' expr ')'
238    fn parse_atom(&mut self) -> Result<E, ParseError> {
239        match self.peek().clone() {
240            Token::Number(v) => {
241                self.advance();
242                Ok(constant(v))
243            }
244            Token::Ident(name) => {
245                self.advance();
246                if *self.peek() == Token::LParen {
247                    // Function call
248                    self.advance(); // consume '('
249                    let mut args = Vec::new();
250                    if *self.peek() != Token::RParen {
251                        args.push(self.parse_expr()?);
252                        while *self.peek() == Token::Comma {
253                            self.advance();
254                            args.push(self.parse_expr()?);
255                        }
256                    }
257                    self.expect(&Token::RParen)?;
258                    build_function_call(&name, args, self.bag)
259                } else {
260                    // Named constant or symbol
261                    match name.as_str() {
262                        "pi" => Ok(constant(std::f64::consts::PI)),
263                        "e" => Ok(constant(std::f64::consts::E)),
264                        _ => Ok(symbol(&name)),
265                    }
266                }
267            }
268            Token::LParen => {
269                self.advance();
270                let expr = self.parse_expr()?;
271                self.expect(&Token::RParen)?;
272                Ok(expr)
273            }
274            Token::Eof => Err(ParseError {
275                pos: self.peek_pos(),
276                msg: "unexpected end of input".to_string(),
277            }),
278            _ => Err(ParseError {
279                pos: self.peek_pos(),
280                msg: format!("unexpected token: {:?}", self.peek()),
281            }),
282        }
283    }
284}
285
286fn build_function_call(name: &str, args: Vec<E>, bag: Option<&FunctionBag>) -> Result<E, ParseError> {
287    // "H" is a parser-only alias for heaviside; normalize before lookup.
288    let lookup_name = if name == "H" { "heaviside" } else { name };
289    // User-defined functions in the bag take priority over built-ins.
290    if let Some(bag) = bag
291        && let Some(result) = bag.call(lookup_name, &args)
292    {
293        return result.map_err(|msg| ParseError { pos: 0, msg });
294    }
295    let fnref = crate::function_by_name(lookup_name).ok_or_else(|| ParseError {
296        pos: 0,
297        msg: format!("unknown function: {name}"),
298    })?;
299    match fnref {
300        crate::FunctionRef::Unary(f) => {
301            if args.len() != 1 {
302                return Err(ParseError {
303                    pos: 0,
304                    msg: format!("{name} expects 1 argument, got {}", args.len()),
305                });
306            }
307            Ok(f(args.into_iter().next().unwrap()))
308        }
309        crate::FunctionRef::Binary(f) => {
310            if args.len() != 2 {
311                return Err(ParseError {
312                    pos: 0,
313                    msg: format!("{name} expects 2 arguments, got {}", args.len()),
314                });
315            }
316            let mut it = args.into_iter();
317            Ok(f(it.next().unwrap(), it.next().unwrap()))
318        }
319        crate::FunctionRef::Ternary(f) => {
320            if args.len() != 3 {
321                return Err(ParseError {
322                    pos: 0,
323                    msg: format!("{name} expects 3 arguments, got {}", args.len()),
324                });
325            }
326            let mut it = args.into_iter();
327            Ok(f(it.next().unwrap(), it.next().unwrap(), it.next().unwrap()))
328        }
329    }
330}
331
332/// Parse a string into a symbolic expression, using only built-in
333/// functions.
334///
335/// Supports standard infix notation with `+`, `-`, `*`, `/`, `^` (power),
336/// parentheses, and function calls (`sin`, `cos`, `tan`, `asin`, `acos`,
337/// `atan`, `atan2`, `sinh`, `cosh`, `tanh`, `exp`, `ln`, `log2`, `log10`,
338/// `sqrt`, `abs`, `heaviside` (alias `H`), `clamp`, `pow`, `rad_diff`,
339/// `rad_sum`, `safe_atan2`, `safe_sqrt`, `safe_asin`, `safe_acos`,
340/// `identity`). See the full list in [`crate::FUNCTIONS`].
341///
342/// The identifiers `pi` and `e` are recognized as named constants.
343/// All other identifiers become symbolic variables.
344///
345/// Use [`parse_with_functions`] to also recognise user-defined
346/// functions (e.g. from [`FunctionBag::add_symbolic`]) during parsing.
347///
348/// # Errors
349///
350/// Returns a [`ParseError`] on invalid syntax, unknown functions, or
351/// wrong argument counts.
352pub fn parse(input: &str) -> Result<E, ParseError> {
353    parse_with_functions(input, &FunctionBag::new())
354}
355
356/// Parse a string into a symbolic expression, consulting a
357/// [`FunctionBag`] before the built-in function table.
358///
359/// Names in `bag` take priority over built-ins with the same name
360/// (shadowing). If a function name is not in `bag`, the parser falls
361/// back to [`crate::function_by_name`], so built-ins remain available
362/// regardless of what's in the bag. Passing `&FunctionBag::new()`
363/// (empty) is equivalent to calling [`parse`].
364///
365/// # Example
366///
367/// ```
368/// use arael_sym::{parse_with_functions, FunctionBag, parse, constant, symbol};
369/// use std::collections::HashMap;
370///
371/// let mut bag = FunctionBag::new();
372/// // Register a user-defined function whose body uses `t` as a symbol
373/// // to stand for the formal parameter.
374/// bag.add_symbolic("sq", vec!["t".to_string()], parse("t*t").unwrap());
375///
376/// let e = parse_with_functions("sq(3)", &bag).unwrap();
377/// let vars: HashMap<&str, f64> = HashMap::new();
378/// assert_eq!(e.eval(&vars).unwrap(), 9.0);
379/// ```
380///
381/// # Errors
382///
383/// As [`parse`], plus:
384/// - Arity mismatch between a call-site and the bag's stored parameter
385///   list.
386///
387/// # See also
388///
389/// [`examples/calc_demo.rs`](https://github.com/harakas/arael/blob/master/examples/calc_demo.rs)
390/// is an end-to-end REPL example that uses this function for both
391/// expression evaluation and runtime function definitions.
392pub fn parse_with_functions(input: &str, bag: &FunctionBag) -> Result<E, ParseError> {
393    let mut parser = Parser::from_str(input)?;
394    parser.bag = Some(bag);
395    let expr = parser.parse_expr()?;
396    if *parser.peek() != Token::Eof {
397        return Err(ParseError {
398            pos: parser.peek_pos(),
399            msg: format!("unexpected token after expression: {:?}", parser.peek()),
400        });
401    }
402    Ok(expr)
403}
404
405impl std::str::FromStr for E {
406    type Err = ParseError;
407    fn from_str(s: &str) -> Result<E, ParseError> {
408        parse(s)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::{constant, simple_func1, symbol};
416    use std::collections::HashMap;
417
418    fn noenv() -> HashMap<&'static str, f64> {
419        HashMap::new()
420    }
421
422    fn approx(a: f64, b: f64, tol: f64) {
423        assert!((a - b).abs() < tol, "{a} !~= {b} (tol {tol})");
424    }
425
426    // --- Backward compat: plain `parse` keeps its existing surface. ---
427
428    #[test]
429    fn parse_arithmetic() {
430        let e = parse("1 + 2 * 3").unwrap();
431        approx(e.eval(&noenv()).unwrap(), 7.0, 1e-12);
432    }
433
434    #[test]
435    fn parse_builtin_unary() {
436        let e = parse("sin(0) + cos(0)").unwrap();
437        approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
438    }
439
440    #[test]
441    fn parse_builtin_binary_atan2() {
442        let e = parse("atan2(1, 1)").unwrap();
443        approx(e.eval(&noenv()).unwrap(), std::f64::consts::FRAC_PI_4, 1e-12);
444    }
445
446    #[test]
447    fn parse_builtin_sqrt_square_roundtrip() {
448        let e = parse("sqrt(2) * sqrt(2)").unwrap();
449        approx(e.eval(&noenv()).unwrap(), 2.0, 1e-10);
450    }
451
452    #[test]
453    fn parse_builtin_ternary_clamp() {
454        let e = parse("clamp(5, 0, 1)").unwrap();
455        approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
456    }
457
458    #[test]
459    fn parse_heaviside_h_alias() {
460        let e = parse("heaviside(0.5) + H(0.5)").unwrap();
461        approx(e.eval(&noenv()).unwrap(), 2.0, 1e-12);
462    }
463
464    #[test]
465    fn parse_rejects_unknown_function() {
466        let err = parse("nope(x)").unwrap_err();
467        assert!(err.msg.contains("unknown function"), "{err}");
468    }
469
470    #[test]
471    fn parse_rejects_wrong_arity() {
472        let err = parse("sin(1, 2)").unwrap_err();
473        assert!(err.msg.contains("1 argument"), "{err}");
474    }
475
476    #[test]
477    fn parse_scientific_notation() {
478        // Positive exponent.
479        let e = parse("1e3").unwrap();
480        approx(e.eval(&noenv()).unwrap(), 1000.0, 1e-12);
481        // Negative exponent.
482        let e = parse("1e-12").unwrap();
483        approx(e.eval(&noenv()).unwrap(), 1e-12, 1e-20);
484        // Explicit positive sign and decimal mantissa.
485        let e = parse("2.5E+2").unwrap();
486        approx(e.eval(&noenv()).unwrap(), 250.0, 1e-12);
487        // In expression context.
488        let e = parse("1.0 - x * x + 1e-12").unwrap();
489        let mut env: HashMap<&'static str, f64> = HashMap::new();
490        env.insert("x", 0.0);
491        approx(e.eval(&env).unwrap(), 1.0 + 1e-12, 1e-20);
492        // `e` starting an ident must still work when not attached to a number.
493        let e = parse("2 * exp(0)").unwrap();
494        approx(e.eval(&noenv()).unwrap(), 2.0, 1e-12);
495    }
496
497    #[test]
498    fn parse_rejects_bare_e_after_number() {
499        // `2e` with no following digit is invalid: the lexer leaves `e`
500        // as a separate token, and `exp`-less `e` is an unknown symbol.
501        let err = parse("2e").unwrap_err();
502        assert!(err.msg.contains("unknown") || err.msg.contains("unexpected"),
503            "{err}");
504    }
505
506    // --- parse_with_functions behaviour. ---
507
508    #[test]
509    fn parse_with_functions_empty_bag_falls_through_to_builtins() {
510        let bag = FunctionBag::new();
511        let e = parse_with_functions("sin(0) + 1", &bag).unwrap();
512        approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
513    }
514
515    #[test]
516    fn parse_with_functions_user_symbolic_call() {
517        let mut bag = FunctionBag::new();
518        bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
519        let e = parse_with_functions("sq(2.0)", &bag).unwrap();
520        approx(e.eval(&noenv()).unwrap(), 4.0, 1e-12);
521    }
522
523    #[test]
524    fn parse_with_functions_unknown_in_empty_bag_fails() {
525        let bag = FunctionBag::new();
526        let err = parse_with_functions("sq(1)", &bag).unwrap_err();
527        assert!(err.msg.contains("unknown function"), "{err}");
528    }
529
530    #[test]
531    fn parse_with_functions_shadows_builtin() {
532        let mut bag = FunctionBag::new();
533        // `sin` in the bag is a constant-7 "function" regardless of input.
534        bag.add_symbolic("sin", vec!["x".into()], constant(7.0));
535        let e = parse_with_functions("sin(0.5)", &bag).unwrap();
536        approx(e.eval(&noenv()).unwrap(), 7.0, 1e-12);
537    }
538
539    #[test]
540    fn parse_with_functions_h_alias_still_works() {
541        let bag = FunctionBag::new();
542        let e = parse_with_functions("H(0.5)", &bag).unwrap();
543        approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
544    }
545
546    #[test]
547    fn bag_add_e_func_round_trip() {
548        // Pass an already-formed Expr::Func E directly.
549        let sq_e = simple_func1("sq", |t| t.clone() * t)(symbol("t"));
550        let mut bag = FunctionBag::new();
551        bag.add(sq_e).unwrap();
552        let e = parse_with_functions("sq(3)", &bag).unwrap();
553        approx(e.eval(&noenv()).unwrap(), 9.0, 1e-12);
554    }
555
556    #[test]
557    fn bag_add1_unary_closure() {
558        // Pass the simple_func1 closure directly; bag invokes it with
559        // a placeholder symbol to derive the function template.
560        let mut bag = FunctionBag::new();
561        bag.add1(simple_func1("sq", |t| t.clone() * t)).unwrap();
562        let e = parse_with_functions("sq(4)", &bag).unwrap();
563        approx(e.eval(&noenv()).unwrap(), 16.0, 1e-12);
564    }
565
566    #[test]
567    fn bag_add2_binary_closure() {
568        let mut bag = FunctionBag::new();
569        bag.add2(simple_func2("hypot",
570            |a, b| crate::sqrt(a.clone()*a + b.clone()*b))).unwrap();
571        let e = parse_with_functions("hypot(3, 4)", &bag).unwrap();
572        approx(e.eval(&noenv()).unwrap(), 5.0, 1e-10);
573    }
574
575    #[test]
576    #[allow(non_snake_case)]
577    fn bag_addN_quaternary_closure() {
578        // Arity 4: something the old API could not express.
579        let mut bag = FunctionBag::new();
580        bag.addN(4, crate::simple_func("blend", 4, |args: Vec<E>|
581            args[0].clone() + args[1].clone() + args[2].clone() + args[3].clone()
582        )).unwrap();
583        let e = parse_with_functions("blend(1, 2, 3, 4)", &bag).unwrap();
584        approx(e.eval(&noenv()).unwrap(), 10.0, 1e-12);
585    }
586
587    #[test]
588    fn bag_add_rejects_non_func() {
589        let mut bag = FunctionBag::new();
590        let err = bag.add(constant(1.0)).unwrap_err();
591        assert!(err.contains("expected Expr::Func"), "{err}");
592    }
593
594    #[test]
595    fn parse_with_functions_rejects_wrong_arity() {
596        let mut bag = FunctionBag::new();
597        bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
598        let err = parse_with_functions("sq(1, 2)", &bag).unwrap_err();
599        assert!(err.msg.contains("1 argument"), "{err}");
600    }
601
602    #[test]
603    fn parameter_shadowing() {
604        // x = 5 in the outer vars map; sq(x) = x*x registered. Calling
605        // sq(3) must yield 9, not 25: the formal parameter `x` shadows
606        // the outer `x = 5` for the duration of the function body.
607        let mut bag = FunctionBag::new();
608        bag.add_symbolic("sq", vec!["x".into()], parse("x*x").unwrap());
609        let e = parse_with_functions("sq(3)", &bag).unwrap();
610        let vars: HashMap<&str, f64> = [("x", 5.0)].into_iter().collect();
611        approx(e.eval(&vars).unwrap(), 9.0, 1e-12);
612    }
613
614    #[test]
615    fn chained_user_functions_compose() {
616        let mut bag = FunctionBag::new();
617        bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
618        // mag's body is parsed with the CURRENT bag -- which already
619        // contains `sq` -- so `sqrt(sq(a) + sq(b))` resolves fully.
620        let mag_body = parse_with_functions("sqrt(sq(a) + sq(b))", &bag).unwrap();
621        bag.add_symbolic("mag", vec!["a".into(), "b".into()], mag_body);
622        let e = parse_with_functions("mag(3, 4)", &bag).unwrap();
623        approx(e.eval(&noenv()).unwrap(), 5.0, 1e-10);
624    }
625
626    #[test]
627    fn bag_remove_and_contains() {
628        let mut bag = FunctionBag::new();
629        bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
630        assert!(bag.contains("sq"));
631        assert!(bag.remove("sq"));
632        assert!(!bag.contains("sq"));
633        assert!(!bag.remove("sq"));
634    }
635
636    #[test]
637    fn bag_names_and_entries() {
638        let mut bag = FunctionBag::new();
639        bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
640        bag.add_symbolic("mag", vec!["a".into(), "b".into()], parse("a+b").unwrap());
641        let mut names = bag.names();
642        names.sort();
643        assert_eq!(names, vec!["mag".to_string(), "sq".to_string()]);
644        let mut entries: Vec<(String, usize)> =
645            bag.entries().map(|(n, a)| (n.to_string(), a)).collect();
646        entries.sort();
647        assert_eq!(entries, vec![("mag".to_string(), 2), ("sq".to_string(), 1)]);
648    }
649}