1use super::*;
2use std::fmt;
3
4#[derive(Debug, Clone)]
8pub struct ParseError {
9 pub pos: usize,
11 pub msg: String,
13}
14
15impl fmt::Display for ParseError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 write!(f, "parse error at position {}: {}", self.pos, self.msg)
18 }
19}
20
21impl std::error::Error for ParseError {}
22
23#[derive(Debug, Clone, PartialEq)]
26enum Token {
27 Number(f64),
28 Ident(String),
29 Plus,
30 Minus,
31 Star,
32 Slash,
33 Caret,
34 LParen,
35 RParen,
36 Comma,
37 Eof,
38}
39
40struct Lexer {
41 chars: Vec<char>,
42 pos: usize,
43}
44
45impl Lexer {
46 fn new(input: &str) -> Self {
47 Lexer { chars: input.chars().collect(), pos: 0 }
48 }
49
50 fn skip_whitespace(&mut self) {
51 while self.pos < self.chars.len() && self.chars[self.pos].is_ascii_whitespace() {
52 self.pos += 1;
53 }
54 }
55
56 fn next_token(&mut self) -> Result<(Token, usize), ParseError> {
57 self.skip_whitespace();
58 let start = self.pos;
59
60 if self.pos >= self.chars.len() {
61 return Ok((Token::Eof, start));
62 }
63
64 let ch = self.chars[self.pos];
65 self.pos += 1;
66
67 match ch {
68 '+' => Ok((Token::Plus, start)),
69 '-' => Ok((Token::Minus, start)),
70 '*' => Ok((Token::Star, start)),
71 '/' => Ok((Token::Slash, start)),
72 '^' => Ok((Token::Caret, start)),
73 '(' => Ok((Token::LParen, start)),
74 ')' => Ok((Token::RParen, start)),
75 ',' => Ok((Token::Comma, start)),
76 c if c.is_ascii_digit() || c == '.' => {
77 let mut s = String::new();
78 s.push(c);
79 while self.pos < self.chars.len()
80 && (self.chars[self.pos].is_ascii_digit() || self.chars[self.pos] == '.')
81 {
82 s.push(self.chars[self.pos]);
83 self.pos += 1;
84 }
85 if self.pos < self.chars.len()
91 && (self.chars[self.pos] == 'e' || self.chars[self.pos] == 'E')
92 {
93 let mut look = self.pos + 1;
94 if look < self.chars.len()
95 && (self.chars[look] == '+' || self.chars[look] == '-')
96 {
97 look += 1;
98 }
99 if look < self.chars.len() && self.chars[look].is_ascii_digit() {
100 s.push(self.chars[self.pos]);
101 self.pos += 1;
102 if self.chars[self.pos] == '+' || self.chars[self.pos] == '-' {
103 s.push(self.chars[self.pos]);
104 self.pos += 1;
105 }
106 while self.pos < self.chars.len()
107 && self.chars[self.pos].is_ascii_digit()
108 {
109 s.push(self.chars[self.pos]);
110 self.pos += 1;
111 }
112 }
113 }
114 let val: f64 = s.parse().map_err(|_| ParseError {
115 pos: start,
116 msg: format!("invalid number: {s}"),
117 })?;
118 Ok((Token::Number(val), start))
119 }
120 c if c.is_ascii_alphabetic() || c == '_' => {
121 let mut s = String::new();
122 s.push(c);
123 while self.pos < self.chars.len()
124 && (self.chars[self.pos].is_ascii_alphanumeric() || self.chars[self.pos] == '_' || self.chars[self.pos] == '.')
125 {
126 s.push(self.chars[self.pos]);
127 self.pos += 1;
128 }
129 Ok((Token::Ident(s), start))
130 }
131 _ => Err(ParseError {
132 pos: start,
133 msg: format!("unexpected character: '{ch}'"),
134 }),
135 }
136 }
137}
138
139struct Parser<'a> {
142 tokens: Vec<(Token, usize)>,
143 pos: usize,
144 bag: Option<&'a FunctionBag>,
145}
146
147impl<'a> Parser<'a> {
148 fn from_str(input: &str) -> Result<Self, ParseError> {
149 let mut lexer = Lexer::new(input);
150 let mut tokens = Vec::new();
151 loop {
152 let tok = lexer.next_token()?;
153 let is_eof = tok.0 == Token::Eof;
154 tokens.push(tok);
155 if is_eof { break; }
156 }
157 Ok(Parser { tokens, pos: 0, bag: None })
158 }
159
160 fn peek(&self) -> &Token {
161 &self.tokens[self.pos].0
162 }
163
164 fn peek_pos(&self) -> usize {
165 self.tokens[self.pos].1
166 }
167
168 fn advance(&mut self) -> &Token {
169 let tok = &self.tokens[self.pos].0;
170 if self.pos + 1 < self.tokens.len() {
171 self.pos += 1;
172 }
173 tok
174 }
175
176 fn expect(&mut self, expected: &Token) -> Result<(), ParseError> {
177 if self.peek() == expected {
178 self.advance();
179 Ok(())
180 } else {
181 Err(ParseError {
182 pos: self.peek_pos(),
183 msg: format!("expected {expected:?}, got {:?}", self.peek()),
184 })
185 }
186 }
187
188 fn parse_expr(&mut self) -> Result<E, ParseError> {
190 let mut left = self.parse_term()?;
191 loop {
192 match self.peek() {
193 Token::Plus => { self.advance(); let right = self.parse_term()?; left = left + right; }
194 Token::Minus => { self.advance(); let right = self.parse_term()?; left = left - right; }
195 _ => break,
196 }
197 }
198 Ok(left)
199 }
200
201 fn parse_term(&mut self) -> Result<E, ParseError> {
203 let mut left = self.parse_unary()?;
204 loop {
205 match self.peek() {
206 Token::Star => { self.advance(); let right = self.parse_unary()?; left = left * right; }
207 Token::Slash => { self.advance(); let right = self.parse_unary()?; left = left / right; }
208 _ => break,
209 }
210 }
211 Ok(left)
212 }
213
214 fn parse_unary(&mut self) -> Result<E, ParseError> {
216 if *self.peek() == Token::Minus {
217 self.advance();
218 let expr = self.parse_unary()?;
219 Ok(-expr)
220 } else {
221 self.parse_power()
222 }
223 }
224
225 fn parse_power(&mut self) -> Result<E, ParseError> {
227 let base = self.parse_atom()?;
228 if *self.peek() == Token::Caret {
229 self.advance();
230 let exp = self.parse_unary()?;
231 Ok(pow(base, exp))
232 } else {
233 Ok(base)
234 }
235 }
236
237 fn parse_atom(&mut self) -> Result<E, ParseError> {
239 match self.peek().clone() {
240 Token::Number(v) => {
241 self.advance();
242 Ok(constant(v))
243 }
244 Token::Ident(name) => {
245 self.advance();
246 if *self.peek() == Token::LParen {
247 self.advance(); let mut args = Vec::new();
250 if *self.peek() != Token::RParen {
251 args.push(self.parse_expr()?);
252 while *self.peek() == Token::Comma {
253 self.advance();
254 args.push(self.parse_expr()?);
255 }
256 }
257 self.expect(&Token::RParen)?;
258 build_function_call(&name, args, self.bag)
259 } else {
260 match name.as_str() {
262 "pi" => Ok(constant(std::f64::consts::PI)),
263 "e" => Ok(constant(std::f64::consts::E)),
264 _ => Ok(symbol(&name)),
265 }
266 }
267 }
268 Token::LParen => {
269 self.advance();
270 let expr = self.parse_expr()?;
271 self.expect(&Token::RParen)?;
272 Ok(expr)
273 }
274 Token::Eof => Err(ParseError {
275 pos: self.peek_pos(),
276 msg: "unexpected end of input".to_string(),
277 }),
278 _ => Err(ParseError {
279 pos: self.peek_pos(),
280 msg: format!("unexpected token: {:?}", self.peek()),
281 }),
282 }
283 }
284}
285
286fn build_function_call(name: &str, args: Vec<E>, bag: Option<&FunctionBag>) -> Result<E, ParseError> {
287 let lookup_name = if name == "H" { "heaviside" } else { name };
289 if let Some(bag) = bag
291 && let Some(result) = bag.call(lookup_name, &args)
292 {
293 return result.map_err(|msg| ParseError { pos: 0, msg });
294 }
295 let fnref = crate::function_by_name(lookup_name).ok_or_else(|| ParseError {
296 pos: 0,
297 msg: format!("unknown function: {name}"),
298 })?;
299 match fnref {
300 crate::FunctionRef::Unary(f) => {
301 if args.len() != 1 {
302 return Err(ParseError {
303 pos: 0,
304 msg: format!("{name} expects 1 argument, got {}", args.len()),
305 });
306 }
307 Ok(f(args.into_iter().next().unwrap()))
308 }
309 crate::FunctionRef::Binary(f) => {
310 if args.len() != 2 {
311 return Err(ParseError {
312 pos: 0,
313 msg: format!("{name} expects 2 arguments, got {}", args.len()),
314 });
315 }
316 let mut it = args.into_iter();
317 Ok(f(it.next().unwrap(), it.next().unwrap()))
318 }
319 crate::FunctionRef::Ternary(f) => {
320 if args.len() != 3 {
321 return Err(ParseError {
322 pos: 0,
323 msg: format!("{name} expects 3 arguments, got {}", args.len()),
324 });
325 }
326 let mut it = args.into_iter();
327 Ok(f(it.next().unwrap(), it.next().unwrap(), it.next().unwrap()))
328 }
329 }
330}
331
332pub fn parse(input: &str) -> Result<E, ParseError> {
353 parse_with_functions(input, &FunctionBag::new())
354}
355
356pub fn parse_with_functions(input: &str, bag: &FunctionBag) -> Result<E, ParseError> {
393 let mut parser = Parser::from_str(input)?;
394 parser.bag = Some(bag);
395 let expr = parser.parse_expr()?;
396 if *parser.peek() != Token::Eof {
397 return Err(ParseError {
398 pos: parser.peek_pos(),
399 msg: format!("unexpected token after expression: {:?}", parser.peek()),
400 });
401 }
402 Ok(expr)
403}
404
405impl std::str::FromStr for E {
406 type Err = ParseError;
407 fn from_str(s: &str) -> Result<E, ParseError> {
408 parse(s)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::{constant, simple_func1, symbol};
416 use std::collections::HashMap;
417
418 fn noenv() -> HashMap<&'static str, f64> {
419 HashMap::new()
420 }
421
422 fn approx(a: f64, b: f64, tol: f64) {
423 assert!((a - b).abs() < tol, "{a} !~= {b} (tol {tol})");
424 }
425
426 #[test]
429 fn parse_arithmetic() {
430 let e = parse("1 + 2 * 3").unwrap();
431 approx(e.eval(&noenv()).unwrap(), 7.0, 1e-12);
432 }
433
434 #[test]
435 fn parse_builtin_unary() {
436 let e = parse("sin(0) + cos(0)").unwrap();
437 approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
438 }
439
440 #[test]
441 fn parse_builtin_binary_atan2() {
442 let e = parse("atan2(1, 1)").unwrap();
443 approx(e.eval(&noenv()).unwrap(), std::f64::consts::FRAC_PI_4, 1e-12);
444 }
445
446 #[test]
447 fn parse_builtin_sqrt_square_roundtrip() {
448 let e = parse("sqrt(2) * sqrt(2)").unwrap();
449 approx(e.eval(&noenv()).unwrap(), 2.0, 1e-10);
450 }
451
452 #[test]
453 fn parse_builtin_ternary_clamp() {
454 let e = parse("clamp(5, 0, 1)").unwrap();
455 approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
456 }
457
458 #[test]
459 fn parse_heaviside_h_alias() {
460 let e = parse("heaviside(0.5) + H(0.5)").unwrap();
461 approx(e.eval(&noenv()).unwrap(), 2.0, 1e-12);
462 }
463
464 #[test]
465 fn parse_rejects_unknown_function() {
466 let err = parse("nope(x)").unwrap_err();
467 assert!(err.msg.contains("unknown function"), "{err}");
468 }
469
470 #[test]
471 fn parse_rejects_wrong_arity() {
472 let err = parse("sin(1, 2)").unwrap_err();
473 assert!(err.msg.contains("1 argument"), "{err}");
474 }
475
476 #[test]
477 fn parse_scientific_notation() {
478 let e = parse("1e3").unwrap();
480 approx(e.eval(&noenv()).unwrap(), 1000.0, 1e-12);
481 let e = parse("1e-12").unwrap();
483 approx(e.eval(&noenv()).unwrap(), 1e-12, 1e-20);
484 let e = parse("2.5E+2").unwrap();
486 approx(e.eval(&noenv()).unwrap(), 250.0, 1e-12);
487 let e = parse("1.0 - x * x + 1e-12").unwrap();
489 let mut env: HashMap<&'static str, f64> = HashMap::new();
490 env.insert("x", 0.0);
491 approx(e.eval(&env).unwrap(), 1.0 + 1e-12, 1e-20);
492 let e = parse("2 * exp(0)").unwrap();
494 approx(e.eval(&noenv()).unwrap(), 2.0, 1e-12);
495 }
496
497 #[test]
498 fn parse_rejects_bare_e_after_number() {
499 let err = parse("2e").unwrap_err();
502 assert!(err.msg.contains("unknown") || err.msg.contains("unexpected"),
503 "{err}");
504 }
505
506 #[test]
509 fn parse_with_functions_empty_bag_falls_through_to_builtins() {
510 let bag = FunctionBag::new();
511 let e = parse_with_functions("sin(0) + 1", &bag).unwrap();
512 approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
513 }
514
515 #[test]
516 fn parse_with_functions_user_symbolic_call() {
517 let mut bag = FunctionBag::new();
518 bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
519 let e = parse_with_functions("sq(2.0)", &bag).unwrap();
520 approx(e.eval(&noenv()).unwrap(), 4.0, 1e-12);
521 }
522
523 #[test]
524 fn parse_with_functions_unknown_in_empty_bag_fails() {
525 let bag = FunctionBag::new();
526 let err = parse_with_functions("sq(1)", &bag).unwrap_err();
527 assert!(err.msg.contains("unknown function"), "{err}");
528 }
529
530 #[test]
531 fn parse_with_functions_shadows_builtin() {
532 let mut bag = FunctionBag::new();
533 bag.add_symbolic("sin", vec!["x".into()], constant(7.0));
535 let e = parse_with_functions("sin(0.5)", &bag).unwrap();
536 approx(e.eval(&noenv()).unwrap(), 7.0, 1e-12);
537 }
538
539 #[test]
540 fn parse_with_functions_h_alias_still_works() {
541 let bag = FunctionBag::new();
542 let e = parse_with_functions("H(0.5)", &bag).unwrap();
543 approx(e.eval(&noenv()).unwrap(), 1.0, 1e-12);
544 }
545
546 #[test]
547 fn bag_add_e_func_round_trip() {
548 let sq_e = simple_func1("sq", |t| t.clone() * t)(symbol("t"));
550 let mut bag = FunctionBag::new();
551 bag.add(sq_e).unwrap();
552 let e = parse_with_functions("sq(3)", &bag).unwrap();
553 approx(e.eval(&noenv()).unwrap(), 9.0, 1e-12);
554 }
555
556 #[test]
557 fn bag_add1_unary_closure() {
558 let mut bag = FunctionBag::new();
561 bag.add1(simple_func1("sq", |t| t.clone() * t)).unwrap();
562 let e = parse_with_functions("sq(4)", &bag).unwrap();
563 approx(e.eval(&noenv()).unwrap(), 16.0, 1e-12);
564 }
565
566 #[test]
567 fn bag_add2_binary_closure() {
568 let mut bag = FunctionBag::new();
569 bag.add2(simple_func2("hypot",
570 |a, b| crate::sqrt(a.clone()*a + b.clone()*b))).unwrap();
571 let e = parse_with_functions("hypot(3, 4)", &bag).unwrap();
572 approx(e.eval(&noenv()).unwrap(), 5.0, 1e-10);
573 }
574
575 #[test]
576 #[allow(non_snake_case)]
577 fn bag_addN_quaternary_closure() {
578 let mut bag = FunctionBag::new();
580 bag.addN(4, crate::simple_func("blend", 4, |args: Vec<E>|
581 args[0].clone() + args[1].clone() + args[2].clone() + args[3].clone()
582 )).unwrap();
583 let e = parse_with_functions("blend(1, 2, 3, 4)", &bag).unwrap();
584 approx(e.eval(&noenv()).unwrap(), 10.0, 1e-12);
585 }
586
587 #[test]
588 fn bag_add_rejects_non_func() {
589 let mut bag = FunctionBag::new();
590 let err = bag.add(constant(1.0)).unwrap_err();
591 assert!(err.contains("expected Expr::Func"), "{err}");
592 }
593
594 #[test]
595 fn parse_with_functions_rejects_wrong_arity() {
596 let mut bag = FunctionBag::new();
597 bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
598 let err = parse_with_functions("sq(1, 2)", &bag).unwrap_err();
599 assert!(err.msg.contains("1 argument"), "{err}");
600 }
601
602 #[test]
603 fn parameter_shadowing() {
604 let mut bag = FunctionBag::new();
608 bag.add_symbolic("sq", vec!["x".into()], parse("x*x").unwrap());
609 let e = parse_with_functions("sq(3)", &bag).unwrap();
610 let vars: HashMap<&str, f64> = [("x", 5.0)].into_iter().collect();
611 approx(e.eval(&vars).unwrap(), 9.0, 1e-12);
612 }
613
614 #[test]
615 fn chained_user_functions_compose() {
616 let mut bag = FunctionBag::new();
617 bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
618 let mag_body = parse_with_functions("sqrt(sq(a) + sq(b))", &bag).unwrap();
621 bag.add_symbolic("mag", vec!["a".into(), "b".into()], mag_body);
622 let e = parse_with_functions("mag(3, 4)", &bag).unwrap();
623 approx(e.eval(&noenv()).unwrap(), 5.0, 1e-10);
624 }
625
626 #[test]
627 fn bag_remove_and_contains() {
628 let mut bag = FunctionBag::new();
629 bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
630 assert!(bag.contains("sq"));
631 assert!(bag.remove("sq"));
632 assert!(!bag.contains("sq"));
633 assert!(!bag.remove("sq"));
634 }
635
636 #[test]
637 fn bag_names_and_entries() {
638 let mut bag = FunctionBag::new();
639 bag.add_symbolic("sq", vec!["t".into()], parse("t*t").unwrap());
640 bag.add_symbolic("mag", vec!["a".into(), "b".into()], parse("a+b").unwrap());
641 let mut names = bag.names();
642 names.sort();
643 assert_eq!(names, vec!["mag".to_string(), "sq".to_string()]);
644 let mut entries: Vec<(String, usize)> =
645 bag.entries().map(|(n, a)| (n.to_string(), a)).collect();
646 entries.sort();
647 assert_eq!(entries, vec![("mag".to_string(), 2), ("sq".to_string(), 1)]);
648 }
649}