1use std::collections::HashMap;
38
39use crate::errors::AlkahestError;
40use crate::kernel::{Domain, ExprId, ExprPool};
41
42#[derive(Debug, Clone)]
51pub struct ParseError {
52 pub message: String,
53 pub span: Option<(usize, usize)>,
54 code_idx: u8, }
56
57impl ParseError {
58 fn lex(msg: impl Into<String>, span: (usize, usize)) -> Self {
59 ParseError {
60 message: msg.into(),
61 span: Some(span),
62 code_idx: 1,
63 }
64 }
65
66 fn syntax(msg: impl Into<String>, span: (usize, usize)) -> Self {
67 ParseError {
68 message: msg.into(),
69 span: Some(span),
70 code_idx: 2,
71 }
72 }
73
74 fn unknown_func(msg: impl Into<String>, span: (usize, usize)) -> Self {
75 ParseError {
76 message: msg.into(),
77 span: Some(span),
78 code_idx: 3,
79 }
80 }
81}
82
83impl std::fmt::Display for ParseError {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "[{}] {}", self.code(), self.message)?;
86 if let Some((s, e)) = self.span {
87 write!(f, " (bytes {s}–{e})")?;
88 }
89 Ok(())
90 }
91}
92
93impl std::error::Error for ParseError {}
94
95impl AlkahestError for ParseError {
96 fn code(&self) -> &'static str {
97 match self.code_idx {
98 1 => "E-PARSE-001",
99 2 => "E-PARSE-002",
100 _ => "E-PARSE-003",
101 }
102 }
103
104 fn remediation(&self) -> Option<&'static str> {
105 match self.code_idx {
106 1 => Some("only ASCII arithmetic expressions are supported"),
107 2 => Some("check parentheses and operator placement"),
108 _ => Some("use a known function: sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, atan2, exp, log, sqrt, abs, sign, floor, ceil, round, erf, erfc, gamma"),
109 }
110 }
111
112 fn span(&self) -> Option<(usize, usize)> {
113 self.span
114 }
115}
116
117#[derive(Debug, Clone, PartialEq)]
122enum Tok {
123 Num(String), Ident(String), Plus,
126 Minus,
127 Star,
128 Slash,
129 Caret, StarStar, LParen,
132 RParen,
133 Comma,
134 Eof,
135}
136
137#[derive(Debug, Clone)]
138struct Token {
139 tok: Tok,
140 offset: usize, }
142
143fn tokenize(src: &str) -> Result<Vec<Token>, ParseError> {
148 let bytes = src.as_bytes();
149 let n = bytes.len();
150 let mut pos = 0;
151 let mut tokens = Vec::new();
152
153 while pos < n {
154 let b = bytes[pos];
155
156 if b == b' ' || b == b'\t' || b == b'\r' || b == b'\n' {
158 pos += 1;
159 continue;
160 }
161
162 if b.is_ascii_digit() || (b == b'.' && pos + 1 < n && bytes[pos + 1].is_ascii_digit()) {
164 let start = pos;
165 while pos < n && bytes[pos].is_ascii_digit() {
166 pos += 1;
167 }
168 if pos < n && bytes[pos] == b'.' {
169 pos += 1;
170 while pos < n && bytes[pos].is_ascii_digit() {
171 pos += 1;
172 }
173 }
174 if pos < n && (bytes[pos] == b'e' || bytes[pos] == b'E') {
175 pos += 1;
176 if pos < n && (bytes[pos] == b'+' || bytes[pos] == b'-') {
177 pos += 1;
178 }
179 while pos < n && bytes[pos].is_ascii_digit() {
180 pos += 1;
181 }
182 }
183 tokens.push(Token {
184 tok: Tok::Num(src[start..pos].to_owned()),
185 offset: start,
186 });
187 continue;
188 }
189
190 if b.is_ascii_alphabetic() || b == b'_' {
192 let start = pos;
193 while pos < n && (bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_') {
194 pos += 1;
195 }
196 tokens.push(Token {
197 tok: Tok::Ident(src[start..pos].to_owned()),
198 offset: start,
199 });
200 continue;
201 }
202
203 if b == b'*' && pos + 1 < n && bytes[pos + 1] == b'*' {
205 tokens.push(Token {
206 tok: Tok::StarStar,
207 offset: pos,
208 });
209 pos += 2;
210 continue;
211 }
212
213 let tok = match b {
214 b'+' => Tok::Plus,
215 b'-' => Tok::Minus,
216 b'*' => Tok::Star,
217 b'/' => Tok::Slash,
218 b'^' => Tok::Caret,
219 b'(' => Tok::LParen,
220 b')' => Tok::RParen,
221 b',' => Tok::Comma,
222 _ => {
223 return Err(ParseError::lex(
224 format!("unexpected character {:?}", b as char),
225 (pos, pos + 1),
226 ))
227 }
228 };
229 tokens.push(Token { tok, offset: pos });
230 pos += 1;
231 }
232
233 tokens.push(Token {
234 tok: Tok::Eof,
235 offset: n,
236 });
237 Ok(tokens)
238}
239
240const BP_ADD: u8 = 10;
245const BP_MUL: u8 = 20;
246const BP_POW: u8 = 30;
247const BP_UNARY: u8 = 25;
248
249fn infix_bp(tok: &Tok) -> u8 {
250 match tok {
251 Tok::Plus | Tok::Minus => BP_ADD,
252 Tok::Star | Tok::Slash => BP_MUL,
253 Tok::Caret | Tok::StarStar => BP_POW,
254 _ => 0,
255 }
256}
257
258const KNOWN_FUNCS: &[&str] = &[
263 "sin", "cos", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan", "atan2", "exp", "log",
264 "sqrt", "abs", "sign", "floor", "ceil", "round", "erf", "erfc", "gamma",
265];
266
267fn is_known_func(name: &str) -> bool {
268 KNOWN_FUNCS.contains(&name)
269}
270
271struct Parser<'a> {
276 tokens: Vec<Token>,
277 pos: usize,
278 pool: &'a ExprPool,
279 symbols: &'a mut HashMap<String, ExprId>,
280}
281
282impl<'a> Parser<'a> {
283 fn new(
284 tokens: Vec<Token>,
285 pool: &'a ExprPool,
286 symbols: &'a mut HashMap<String, ExprId>,
287 ) -> Self {
288 Parser {
289 tokens,
290 pos: 0,
291 pool,
292 symbols,
293 }
294 }
295
296 fn peek(&self) -> &Token {
297 &self.tokens[self.pos]
298 }
299
300 fn advance(&mut self) -> Token {
301 let tok = self.tokens[self.pos].clone();
302 if tok.tok != Tok::Eof {
303 self.pos += 1;
304 }
305 tok
306 }
307
308 fn expect(&mut self, expected: &Tok) -> Result<Token, ParseError> {
309 let tok = self.advance();
310 if &tok.tok == expected {
311 Ok(tok)
312 } else {
313 let label = format!("{expected:?}");
314 if tok.tok == Tok::Eof {
315 Err(ParseError::syntax(
316 format!("expected {label} but reached end of input"),
317 (tok.offset, tok.offset),
318 ))
319 } else {
320 Err(ParseError::syntax(
321 format!("expected {label}"),
322 (tok.offset, tok.offset + 1),
323 ))
324 }
325 }
326 }
327
328 fn parse_expr(&mut self, rbp: u8) -> Result<ExprId, ParseError> {
329 let tok = self.advance();
330 let mut left = self.nud(tok)?;
331 loop {
332 let lbp = infix_bp(&self.peek().tok);
333 if lbp <= rbp {
334 break;
335 }
336 let op = self.advance();
337 left = self.led(op, left)?;
338 }
339 Ok(left)
340 }
341
342 fn nud(&mut self, tok: Token) -> Result<ExprId, ParseError> {
344 let pool = self.pool;
345 match &tok.tok {
346 Tok::Num(s) => {
347 let s = s.clone();
348 if s.contains('.') || s.to_ascii_lowercase().contains('e') {
349 Ok(pool.float(s.parse::<f64>().unwrap(), 53))
350 } else {
351 let n: i64 = s.parse().map_err(|_| {
352 ParseError::lex(
353 format!("integer literal out of range: {s}"),
354 (tok.offset, tok.offset + s.len()),
355 )
356 })?;
357 Ok(pool.integer(n))
358 }
359 }
360
361 Tok::Ident(name) => {
362 let name = name.clone();
363 if self.peek().tok == Tok::LParen {
364 self.parse_funcall(&name, tok.offset)
365 } else {
366 let id = if let Some(&id) = self.symbols.get(&name) {
368 id
369 } else {
370 let id = pool.symbol(name.clone(), Domain::Real);
371 self.symbols.insert(name, id);
372 id
373 };
374 Ok(id)
375 }
376 }
377
378 Tok::Minus => {
379 let operand = self.parse_expr(BP_UNARY)?;
380 let neg1 = self.pool.integer(-1i64);
382 Ok(self.pool.mul(vec![neg1, operand]))
383 }
384
385 Tok::Plus => self.parse_expr(BP_UNARY),
386
387 Tok::LParen => {
388 if self.peek().tok == Tok::RParen {
389 return Err(ParseError::syntax(
390 "empty parentheses",
391 (tok.offset, tok.offset + 1),
392 ));
393 }
394 let inner = self.parse_expr(0)?;
395 self.expect(&Tok::RParen)?;
396 Ok(inner)
397 }
398
399 other => Err(ParseError::syntax(
400 format!("unexpected token {other:?}"),
401 (tok.offset, tok.offset + 1),
402 )),
403 }
404 }
405
406 fn led(&mut self, op: Token, left: ExprId) -> Result<ExprId, ParseError> {
408 let pool = self.pool;
409 match op.tok {
410 Tok::Plus => {
411 let right = self.parse_expr(BP_ADD)?;
412 Ok(pool.add(vec![left, right]))
413 }
414 Tok::Minus => {
415 let right = self.parse_expr(BP_ADD)?;
416 let neg1 = pool.integer(-1i64);
418 let neg_right = pool.mul(vec![neg1, right]);
419 Ok(pool.add(vec![left, neg_right]))
420 }
421 Tok::Star => {
422 let right = self.parse_expr(BP_MUL)?;
423 Ok(pool.mul(vec![left, right]))
424 }
425 Tok::Slash => {
426 let right = self.parse_expr(BP_MUL)?;
427 let neg1 = pool.integer(-1i64);
429 let inv = pool.pow(right, neg1);
430 Ok(pool.mul(vec![left, inv]))
431 }
432 Tok::Caret | Tok::StarStar => {
433 let right = self.parse_expr(BP_POW - 1)?;
435 Ok(pool.pow(left, right))
436 }
437 other => Err(ParseError::syntax(
438 format!("unexpected token {other:?} in infix position"),
439 (op.offset, op.offset + 1),
440 )),
441 }
442 }
443
444 fn parse_funcall(&mut self, name: &str, offset: usize) -> Result<ExprId, ParseError> {
445 if !is_known_func(name) {
446 return Err(ParseError::unknown_func(
447 format!("unknown function '{name}'"),
448 (offset, offset + name.len()),
449 ));
450 }
451 self.advance(); let mut args = Vec::new();
453 if self.peek().tok != Tok::RParen {
454 args.push(self.parse_expr(0)?);
455 while self.peek().tok == Tok::Comma {
456 self.advance(); args.push(self.parse_expr(0)?);
458 }
459 }
460 self.expect(&Tok::RParen)?;
461 Ok(self.pool.func(name, args))
462 }
463}
464
465pub fn parse(
501 src: &str,
502 pool: &ExprPool,
503 symbols: &mut HashMap<String, ExprId>,
504) -> Result<ExprId, ParseError> {
505 let tokens = tokenize(src)?;
506 let first = &tokens[0];
507 if first.tok == Tok::Eof {
508 return Err(ParseError::syntax("empty expression", (0, 0)));
509 }
510 let mut parser = Parser::new(tokens, pool, symbols);
511 let expr = parser.parse_expr(0)?;
512 let tail = parser.peek();
513 if tail.tok != Tok::Eof {
514 let off = tail.offset;
515 return Err(ParseError::syntax(
516 format!("unexpected token {:?}", tail.tok),
517 (off, off + 1),
518 ));
519 }
520 Ok(expr)
521}
522
523#[cfg(test)]
528mod tests {
529 use super::*;
530
531 fn pool_and_x() -> (ExprPool, ExprId, HashMap<String, ExprId>) {
532 let pool = ExprPool::new();
533 let x = pool.symbol("x", Domain::Real);
534 let syms = HashMap::from([("x".to_owned(), x)]);
535 (pool, x, syms)
536 }
537
538 #[test]
539 fn integer_literal() {
540 let pool = ExprPool::new();
541 let mut syms = HashMap::new();
542 let e = parse("42", &pool, &mut syms).unwrap();
543 assert_eq!(e, pool.integer(42i64));
544 }
545
546 #[test]
547 fn float_literal() {
548 let pool = ExprPool::new();
549 let mut syms = HashMap::new();
550 parse("3.14", &pool, &mut syms).unwrap();
551 }
552
553 #[test]
554 fn identifier_symbol() {
555 let (pool, x, mut syms) = pool_and_x();
556 let e = parse("x", &pool, &mut syms).unwrap();
557 assert_eq!(e, x);
558 }
559
560 #[test]
561 fn addition() {
562 let (pool, x, mut syms) = pool_and_x();
563 let e = parse("x + 1", &pool, &mut syms).unwrap();
564 let expected = pool.add(vec![x, pool.integer(1i64)]);
565 assert_eq!(e, expected);
566 }
567
568 #[test]
569 fn unary_minus() {
570 let (pool, x, mut syms) = pool_and_x();
571 let e = parse("-x", &pool, &mut syms).unwrap();
572 let neg1 = pool.integer(-1i64);
573 let expected = pool.mul(vec![neg1, x]);
574 assert_eq!(e, expected);
575 }
576
577 #[test]
578 fn power_right_assoc() {
579 let pool = ExprPool::new();
580 let mut syms = HashMap::new();
581 let e = parse("2^3^2", &pool, &mut syms).unwrap();
583 let two = pool.integer(2i64);
584 let three = pool.integer(3i64);
585 let inner = pool.pow(three, two); let expected = pool.pow(two, inner); assert_eq!(e, expected);
588 }
589
590 #[test]
591 fn function_call() {
592 let (pool, x, mut syms) = pool_and_x();
593 let e = parse("sin(x)", &pool, &mut syms).unwrap();
594 let expected = pool.func("sin", vec![x]);
595 assert_eq!(e, expected);
596 }
597
598 #[test]
599 fn atan2_two_args() {
600 let pool = ExprPool::new();
601 let mut syms = HashMap::new();
602 parse("atan2(1, 2)", &pool, &mut syms).unwrap();
603 }
604
605 #[test]
606 fn unknown_function_error() {
607 let pool = ExprPool::new();
608 let mut syms = HashMap::new();
609 let err = parse("foo(x)", &pool, &mut syms).unwrap_err();
610 assert_eq!(err.code(), "E-PARSE-003");
611 }
612
613 #[test]
614 fn lex_error() {
615 let pool = ExprPool::new();
616 let mut syms = HashMap::new();
617 let err = parse("x # y", &pool, &mut syms).unwrap_err();
618 assert_eq!(err.code(), "E-PARSE-001");
619 }
620
621 #[test]
622 fn empty_expression_error() {
623 let pool = ExprPool::new();
624 let mut syms = HashMap::new();
625 let err = parse("", &pool, &mut syms).unwrap_err();
626 assert_eq!(err.code(), "E-PARSE-002");
627 }
628
629 #[test]
630 fn auto_intern_new_symbol() {
631 let pool = ExprPool::new();
632 let mut syms = HashMap::new();
633 parse("y + 1", &pool, &mut syms).unwrap();
634 assert!(syms.contains_key("y"));
635 }
636}