Skip to main content

alkahest_cas/
parse.rs

1//! V2-21 — Pratt recursive-descent expression parser (Rust port).
2//!
3//! Mirrors `python/alkahest/_parse.py` exactly: same grammar, same function
4//! names, same precedence levels.  The Python layer can delegate to this once
5//! the PyO3 binding is wired up.
6//!
7//! # Grammar (informal)
8//!
9//! ```text
10//! expr     ::= term (('+' | '-') term)*
11//! term     ::= factor (('*' | '/') factor)*
12//! factor   ::= unary ('^' | '**') factor   -- right-assoc
13//! unary    ::= '-' unary | primary
14//! primary  ::= NUMBER | IDENT | IDENT '(' args ')' | '(' expr ')'
15//! args     ::= expr (',' expr)*
16//! ```
17//!
18//! Binding powers (Pratt):
19//! - `+` / `-` infix: 10
20//! - `*` / `/` infix: 20
21//! - `^` / `**` infix: 30 (right-associative: right-bp = 29)
22//! - unary `-` / `+`: 25
23//!
24//! # Example
25//!
26//! ```
27//! use alkahest_cas::{ExprPool, parse};
28//! use alkahest_cas::kernel::Domain;
29//! use std::collections::HashMap;
30//!
31//! let pool = ExprPool::new();
32//! let x = pool.symbol("x", Domain::Real);
33//! let mut syms = HashMap::from([("x".to_owned(), x)]);
34//! let e = parse("x^2 + 2*x + 1", &pool, &mut syms).unwrap();
35//! ```
36
37use std::collections::HashMap;
38
39use crate::errors::AlkahestError;
40use crate::kernel::{Domain, ExprId, ExprPool};
41
42// ---------------------------------------------------------------------------
43// Error type
44// ---------------------------------------------------------------------------
45
46/// A lexical or syntactic error produced by [`parse`].
47///
48/// Every `ParseError` carries a stable diagnostic code (`E-PARSE-NNN`) and an
49/// optional byte-offset span into the source string.
50#[derive(Debug, Clone)]
51pub struct ParseError {
52    pub message: String,
53    pub span: Option<(usize, usize)>,
54    code_idx: u8, // 1 = E-PARSE-001, 2 = E-PARSE-002, 3 = E-PARSE-003
55}
56
57impl ParseError {
58    fn lex(msg: impl Into<String>, span: (usize, usize)) -> Self {
59        ParseError {
60            message: msg.into(),
61            span: Some(span),
62            code_idx: 1,
63        }
64    }
65
66    fn syntax(msg: impl Into<String>, span: (usize, usize)) -> Self {
67        ParseError {
68            message: msg.into(),
69            span: Some(span),
70            code_idx: 2,
71        }
72    }
73
74    fn unknown_func(msg: impl Into<String>, span: (usize, usize)) -> Self {
75        ParseError {
76            message: msg.into(),
77            span: Some(span),
78            code_idx: 3,
79        }
80    }
81}
82
83impl std::fmt::Display for ParseError {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "[{}] {}", self.code(), self.message)?;
86        if let Some((s, e)) = self.span {
87            write!(f, " (bytes {s}–{e})")?;
88        }
89        Ok(())
90    }
91}
92
93impl std::error::Error for ParseError {}
94
95impl AlkahestError for ParseError {
96    fn code(&self) -> &'static str {
97        match self.code_idx {
98            1 => "E-PARSE-001",
99            2 => "E-PARSE-002",
100            _ => "E-PARSE-003",
101        }
102    }
103
104    fn remediation(&self) -> Option<&'static str> {
105        match self.code_idx {
106            1 => Some("only ASCII arithmetic expressions are supported"),
107            2 => Some("check parentheses and operator placement"),
108            _ => Some("use a known function: sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, atan2, exp, log, sqrt, abs, sign, floor, ceil, round, erf, erfc, gamma"),
109        }
110    }
111
112    fn span(&self) -> Option<(usize, usize)> {
113        self.span
114    }
115}
116
117// ---------------------------------------------------------------------------
118// Token
119// ---------------------------------------------------------------------------
120
121#[derive(Debug, Clone, PartialEq)]
122enum Tok {
123    Num(String),   // integer or float literal
124    Ident(String), // identifier / function name
125    Plus,
126    Minus,
127    Star,
128    Slash,
129    Caret,    // ^
130    StarStar, // **
131    LParen,
132    RParen,
133    Comma,
134    Eof,
135}
136
137#[derive(Debug, Clone)]
138struct Token {
139    tok: Tok,
140    offset: usize, // byte offset in source
141}
142
143// ---------------------------------------------------------------------------
144// Lexer
145// ---------------------------------------------------------------------------
146
147fn tokenize(src: &str) -> Result<Vec<Token>, ParseError> {
148    let bytes = src.as_bytes();
149    let n = bytes.len();
150    let mut pos = 0;
151    let mut tokens = Vec::new();
152
153    while pos < n {
154        let b = bytes[pos];
155
156        // Whitespace
157        if b == b' ' || b == b'\t' || b == b'\r' || b == b'\n' {
158            pos += 1;
159            continue;
160        }
161
162        // Number: digits optionally followed by '.digits' and/or 'e[+-]digits'
163        if b.is_ascii_digit() || (b == b'.' && pos + 1 < n && bytes[pos + 1].is_ascii_digit()) {
164            let start = pos;
165            while pos < n && bytes[pos].is_ascii_digit() {
166                pos += 1;
167            }
168            if pos < n && bytes[pos] == b'.' {
169                pos += 1;
170                while pos < n && bytes[pos].is_ascii_digit() {
171                    pos += 1;
172                }
173            }
174            if pos < n && (bytes[pos] == b'e' || bytes[pos] == b'E') {
175                pos += 1;
176                if pos < n && (bytes[pos] == b'+' || bytes[pos] == b'-') {
177                    pos += 1;
178                }
179                while pos < n && bytes[pos].is_ascii_digit() {
180                    pos += 1;
181                }
182            }
183            tokens.push(Token {
184                tok: Tok::Num(src[start..pos].to_owned()),
185                offset: start,
186            });
187            continue;
188        }
189
190        // Identifier
191        if b.is_ascii_alphabetic() || b == b'_' {
192            let start = pos;
193            while pos < n && (bytes[pos].is_ascii_alphanumeric() || bytes[pos] == b'_') {
194                pos += 1;
195            }
196            tokens.push(Token {
197                tok: Tok::Ident(src[start..pos].to_owned()),
198                offset: start,
199            });
200            continue;
201        }
202
203        // `**` must come before `*`
204        if b == b'*' && pos + 1 < n && bytes[pos + 1] == b'*' {
205            tokens.push(Token {
206                tok: Tok::StarStar,
207                offset: pos,
208            });
209            pos += 2;
210            continue;
211        }
212
213        let tok = match b {
214            b'+' => Tok::Plus,
215            b'-' => Tok::Minus,
216            b'*' => Tok::Star,
217            b'/' => Tok::Slash,
218            b'^' => Tok::Caret,
219            b'(' => Tok::LParen,
220            b')' => Tok::RParen,
221            b',' => Tok::Comma,
222            _ => {
223                return Err(ParseError::lex(
224                    format!("unexpected character {:?}", b as char),
225                    (pos, pos + 1),
226                ))
227            }
228        };
229        tokens.push(Token { tok, offset: pos });
230        pos += 1;
231    }
232
233    tokens.push(Token {
234        tok: Tok::Eof,
235        offset: n,
236    });
237    Ok(tokens)
238}
239
240// ---------------------------------------------------------------------------
241// Binding powers
242// ---------------------------------------------------------------------------
243
244const BP_ADD: u8 = 10;
245const BP_MUL: u8 = 20;
246const BP_POW: u8 = 30;
247const BP_UNARY: u8 = 25;
248
249fn infix_bp(tok: &Tok) -> u8 {
250    match tok {
251        Tok::Plus | Tok::Minus => BP_ADD,
252        Tok::Star | Tok::Slash => BP_MUL,
253        Tok::Caret | Tok::StarStar => BP_POW,
254        _ => 0,
255    }
256}
257
258// ---------------------------------------------------------------------------
259// Known function names
260// ---------------------------------------------------------------------------
261
262const KNOWN_FUNCS: &[&str] = &[
263    "sin", "cos", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan", "atan2", "exp", "log",
264    "sqrt", "abs", "sign", "floor", "ceil", "round", "erf", "erfc", "gamma",
265];
266
267fn is_known_func(name: &str) -> bool {
268    KNOWN_FUNCS.contains(&name)
269}
270
271// ---------------------------------------------------------------------------
272// Parser
273// ---------------------------------------------------------------------------
274
275struct Parser<'a> {
276    tokens: Vec<Token>,
277    pos: usize,
278    pool: &'a ExprPool,
279    symbols: &'a mut HashMap<String, ExprId>,
280}
281
282impl<'a> Parser<'a> {
283    fn new(
284        tokens: Vec<Token>,
285        pool: &'a ExprPool,
286        symbols: &'a mut HashMap<String, ExprId>,
287    ) -> Self {
288        Parser {
289            tokens,
290            pos: 0,
291            pool,
292            symbols,
293        }
294    }
295
296    fn peek(&self) -> &Token {
297        &self.tokens[self.pos]
298    }
299
300    fn advance(&mut self) -> Token {
301        let tok = self.tokens[self.pos].clone();
302        if tok.tok != Tok::Eof {
303            self.pos += 1;
304        }
305        tok
306    }
307
308    fn expect(&mut self, expected: &Tok) -> Result<Token, ParseError> {
309        let tok = self.advance();
310        if &tok.tok == expected {
311            Ok(tok)
312        } else {
313            let label = format!("{expected:?}");
314            if tok.tok == Tok::Eof {
315                Err(ParseError::syntax(
316                    format!("expected {label} but reached end of input"),
317                    (tok.offset, tok.offset),
318                ))
319            } else {
320                Err(ParseError::syntax(
321                    format!("expected {label}"),
322                    (tok.offset, tok.offset + 1),
323                ))
324            }
325        }
326    }
327
328    fn parse_expr(&mut self, rbp: u8) -> Result<ExprId, ParseError> {
329        let tok = self.advance();
330        let mut left = self.nud(tok)?;
331        loop {
332            let lbp = infix_bp(&self.peek().tok);
333            if lbp <= rbp {
334                break;
335            }
336            let op = self.advance();
337            left = self.led(op, left)?;
338        }
339        Ok(left)
340    }
341
342    /// Null denotation — prefix position / atom.
343    fn nud(&mut self, tok: Token) -> Result<ExprId, ParseError> {
344        let pool = self.pool;
345        match &tok.tok {
346            Tok::Num(s) => {
347                let s = s.clone();
348                if s.contains('.') || s.to_ascii_lowercase().contains('e') {
349                    Ok(pool.float(s.parse::<f64>().unwrap(), 53))
350                } else {
351                    let n: i64 = s.parse().map_err(|_| {
352                        ParseError::lex(
353                            format!("integer literal out of range: {s}"),
354                            (tok.offset, tok.offset + s.len()),
355                        )
356                    })?;
357                    Ok(pool.integer(n))
358                }
359            }
360
361            Tok::Ident(name) => {
362                let name = name.clone();
363                if self.peek().tok == Tok::LParen {
364                    self.parse_funcall(&name, tok.offset)
365                } else {
366                    // Look up in caller-supplied map, or intern a new Real symbol.
367                    let id = if let Some(&id) = self.symbols.get(&name) {
368                        id
369                    } else {
370                        let id = pool.symbol(name.clone(), Domain::Real);
371                        self.symbols.insert(name, id);
372                        id
373                    };
374                    Ok(id)
375                }
376            }
377
378            Tok::Minus => {
379                let operand = self.parse_expr(BP_UNARY)?;
380                // -x  →  (-1) * x
381                let neg1 = self.pool.integer(-1i64);
382                Ok(self.pool.mul(vec![neg1, operand]))
383            }
384
385            Tok::Plus => self.parse_expr(BP_UNARY),
386
387            Tok::LParen => {
388                if self.peek().tok == Tok::RParen {
389                    return Err(ParseError::syntax(
390                        "empty parentheses",
391                        (tok.offset, tok.offset + 1),
392                    ));
393                }
394                let inner = self.parse_expr(0)?;
395                self.expect(&Tok::RParen)?;
396                Ok(inner)
397            }
398
399            other => Err(ParseError::syntax(
400                format!("unexpected token {other:?}"),
401                (tok.offset, tok.offset + 1),
402            )),
403        }
404    }
405
406    /// Left denotation — infix position.
407    fn led(&mut self, op: Token, left: ExprId) -> Result<ExprId, ParseError> {
408        let pool = self.pool;
409        match op.tok {
410            Tok::Plus => {
411                let right = self.parse_expr(BP_ADD)?;
412                Ok(pool.add(vec![left, right]))
413            }
414            Tok::Minus => {
415                let right = self.parse_expr(BP_ADD)?;
416                // left - right  →  left + (-1)*right
417                let neg1 = pool.integer(-1i64);
418                let neg_right = pool.mul(vec![neg1, right]);
419                Ok(pool.add(vec![left, neg_right]))
420            }
421            Tok::Star => {
422                let right = self.parse_expr(BP_MUL)?;
423                Ok(pool.mul(vec![left, right]))
424            }
425            Tok::Slash => {
426                let right = self.parse_expr(BP_MUL)?;
427                // left / right  →  left * right^(-1)
428                let neg1 = pool.integer(-1i64);
429                let inv = pool.pow(right, neg1);
430                Ok(pool.mul(vec![left, inv]))
431            }
432            Tok::Caret | Tok::StarStar => {
433                // Right-associative: right-bp = BP_POW - 1
434                let right = self.parse_expr(BP_POW - 1)?;
435                Ok(pool.pow(left, right))
436            }
437            other => Err(ParseError::syntax(
438                format!("unexpected token {other:?} in infix position"),
439                (op.offset, op.offset + 1),
440            )),
441        }
442    }
443
444    fn parse_funcall(&mut self, name: &str, offset: usize) -> Result<ExprId, ParseError> {
445        if !is_known_func(name) {
446            return Err(ParseError::unknown_func(
447                format!("unknown function '{name}'"),
448                (offset, offset + name.len()),
449            ));
450        }
451        self.advance(); // consume "("
452        let mut args = Vec::new();
453        if self.peek().tok != Tok::RParen {
454            args.push(self.parse_expr(0)?);
455            while self.peek().tok == Tok::Comma {
456                self.advance(); // consume ","
457                args.push(self.parse_expr(0)?);
458            }
459        }
460        self.expect(&Tok::RParen)?;
461        Ok(self.pool.func(name, args))
462    }
463}
464
465// ---------------------------------------------------------------------------
466// Public entry point
467// ---------------------------------------------------------------------------
468
469/// Parse a mathematical expression string into an [`ExprId`].
470///
471/// Uses a Pratt (top-down operator precedence) recursive-descent parser.
472/// The grammar supports integer/float literals, identifiers, arithmetic
473/// operators (`+`, `-`, `*`, `/`, `^`, `**`), unary `-`/`+`, parentheses,
474/// and a fixed set of mathematical functions:
475/// `sin`, `cos`, `tan`, `sinh`, `cosh`, `tanh`, `asin`, `acos`, `atan`,
476/// `atan2`, `exp`, `log`, `sqrt`, `abs`, `sign`, `floor`, `ceil`, `round`,
477/// `erf`, `erfc`, `gamma`.
478///
479/// `symbols` maps identifier names to pre-existing [`ExprId`]s.  Identifiers
480/// not in the map are interned as new `Domain::Real` symbols and added to the
481/// map so they are reused within the same call.
482///
483/// # Errors
484///
485/// Returns [`ParseError`] (`E-PARSE-001` lexical, `E-PARSE-002` syntactic,
486/// `E-PARSE-003` unknown function) on failure, with a byte-offset span.
487///
488/// # Example
489///
490/// ```
491/// use alkahest_cas::{ExprPool, parse};
492/// use alkahest_cas::kernel::Domain;
493/// use std::collections::HashMap;
494///
495/// let pool = ExprPool::new();
496/// let x = pool.symbol("x", Domain::Real);
497/// let mut syms = HashMap::from([("x".to_owned(), x)]);
498/// let e = parse("sin(x)^2 + cos(x)^2", &pool, &mut syms).unwrap();
499/// ```
500pub fn parse(
501    src: &str,
502    pool: &ExprPool,
503    symbols: &mut HashMap<String, ExprId>,
504) -> Result<ExprId, ParseError> {
505    let tokens = tokenize(src)?;
506    let first = &tokens[0];
507    if first.tok == Tok::Eof {
508        return Err(ParseError::syntax("empty expression", (0, 0)));
509    }
510    let mut parser = Parser::new(tokens, pool, symbols);
511    let expr = parser.parse_expr(0)?;
512    let tail = parser.peek();
513    if tail.tok != Tok::Eof {
514        let off = tail.offset;
515        return Err(ParseError::syntax(
516            format!("unexpected token {:?}", tail.tok),
517            (off, off + 1),
518        ));
519    }
520    Ok(expr)
521}
522
523// ---------------------------------------------------------------------------
524// Tests
525// ---------------------------------------------------------------------------
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    fn pool_and_x() -> (ExprPool, ExprId, HashMap<String, ExprId>) {
532        let pool = ExprPool::new();
533        let x = pool.symbol("x", Domain::Real);
534        let syms = HashMap::from([("x".to_owned(), x)]);
535        (pool, x, syms)
536    }
537
538    #[test]
539    fn integer_literal() {
540        let pool = ExprPool::new();
541        let mut syms = HashMap::new();
542        let e = parse("42", &pool, &mut syms).unwrap();
543        assert_eq!(e, pool.integer(42i64));
544    }
545
546    #[test]
547    fn float_literal() {
548        let pool = ExprPool::new();
549        let mut syms = HashMap::new();
550        parse("3.14", &pool, &mut syms).unwrap();
551    }
552
553    #[test]
554    fn identifier_symbol() {
555        let (pool, x, mut syms) = pool_and_x();
556        let e = parse("x", &pool, &mut syms).unwrap();
557        assert_eq!(e, x);
558    }
559
560    #[test]
561    fn addition() {
562        let (pool, x, mut syms) = pool_and_x();
563        let e = parse("x + 1", &pool, &mut syms).unwrap();
564        let expected = pool.add(vec![x, pool.integer(1i64)]);
565        assert_eq!(e, expected);
566    }
567
568    #[test]
569    fn unary_minus() {
570        let (pool, x, mut syms) = pool_and_x();
571        let e = parse("-x", &pool, &mut syms).unwrap();
572        let neg1 = pool.integer(-1i64);
573        let expected = pool.mul(vec![neg1, x]);
574        assert_eq!(e, expected);
575    }
576
577    #[test]
578    fn power_right_assoc() {
579        let pool = ExprPool::new();
580        let mut syms = HashMap::new();
581        // 2^3^2 should parse as 2^(3^2), not (2^3)^2
582        let e = parse("2^3^2", &pool, &mut syms).unwrap();
583        let two = pool.integer(2i64);
584        let three = pool.integer(3i64);
585        let inner = pool.pow(three, two); // 3^2 (two is hash-consed: same id as literal 2)
586        let expected = pool.pow(two, inner); // 2^(3^2)
587        assert_eq!(e, expected);
588    }
589
590    #[test]
591    fn function_call() {
592        let (pool, x, mut syms) = pool_and_x();
593        let e = parse("sin(x)", &pool, &mut syms).unwrap();
594        let expected = pool.func("sin", vec![x]);
595        assert_eq!(e, expected);
596    }
597
598    #[test]
599    fn atan2_two_args() {
600        let pool = ExprPool::new();
601        let mut syms = HashMap::new();
602        parse("atan2(1, 2)", &pool, &mut syms).unwrap();
603    }
604
605    #[test]
606    fn unknown_function_error() {
607        let pool = ExprPool::new();
608        let mut syms = HashMap::new();
609        let err = parse("foo(x)", &pool, &mut syms).unwrap_err();
610        assert_eq!(err.code(), "E-PARSE-003");
611    }
612
613    #[test]
614    fn lex_error() {
615        let pool = ExprPool::new();
616        let mut syms = HashMap::new();
617        let err = parse("x # y", &pool, &mut syms).unwrap_err();
618        assert_eq!(err.code(), "E-PARSE-001");
619    }
620
621    #[test]
622    fn empty_expression_error() {
623        let pool = ExprPool::new();
624        let mut syms = HashMap::new();
625        let err = parse("", &pool, &mut syms).unwrap_err();
626        assert_eq!(err.code(), "E-PARSE-002");
627    }
628
629    #[test]
630    fn auto_intern_new_symbol() {
631        let pool = ExprPool::new();
632        let mut syms = HashMap::new();
633        parse("y + 1", &pool, &mut syms).unwrap();
634        assert!(syms.contains_key("y"));
635    }
636}