mathlex 0.4.1

Mathematical expression parser for LaTeX and plain text notation, producing a language-agnostic AST
Documentation
// Allow large error variants - boxing would be a breaking API change
#![allow(clippy::result_large_err)]

use super::*;

impl LatexParser {
    /// Parses a number (integer or float).
    pub(super) fn parse_number(&self, num_str: &str, span: Span) -> ParseResult<Expression> {
        if num_str.contains('.') {
            // Float
            num_str
                .parse::<f64>()
                .map(|f| ExprKind::Float(MathFloat::from(f)).into())
                .map_err(|_| ParseError::invalid_number(num_str, "invalid float", Some(span)))
        } else {
            // Integer
            num_str
                .parse::<i64>()
                .map(|n| ExprKind::Integer(n).into())
                .map_err(|_| ParseError::invalid_number(num_str, "invalid integer", Some(span)))
        }
    }

    /// Parses \frac{num}{denom}, promoting to a derivative when the pattern matches.
    pub(super) fn parse_frac_command(&mut self) -> ParseResult<Expression> {
        self.in_fraction_context = true;
        let numerator = self.braced(|p| p.parse_expression())?;
        let denominator = self.braced(|p| p.parse_expression())?;
        self.in_fraction_context = false;

        if let Some(deriv) = self.try_parse_derivative(numerator.clone(), denominator.clone())? {
            return Ok(deriv);
        }

        Ok(ExprKind::Binary {
            op: BinaryOp::Div,
            left: Box::new(numerator),
            right: Box::new(denominator),
        }
        .into())
    }

    /// Parses \log (with optional base subscript) or \ln.
    pub(super) fn parse_log_command(&mut self, is_log: bool) -> ParseResult<Expression> {
        if is_log && self.check(&LatexToken::Underscore) {
            self.next(); // consume _
            let base = self.parse_braced_or_atom()?;
            let arg = self.parse_function_arg()?;
            return Ok(ExprKind::Function {
                name: "log".to_string(),
                args: vec![arg, base],
            }
            .into());
        }
        let arg = self.parse_function_arg()?;
        Ok(ExprKind::Function {
            name: if is_log { "log" } else { "ln" }.to_string(),
            args: vec![arg],
        }
        .into())
    }

    /// Parses \lfloor expr \rfloor or \lceil expr \rceil.
    pub(super) fn parse_floor_ceil_command(&mut self, is_floor: bool) -> ParseResult<Expression> {
        let expr = self.parse_expression()?;
        let (close_cmd, fn_name, err_msg) = if is_floor {
            ("rfloor", "floor", "expected \\rfloor after \\lfloor")
        } else {
            ("rceil", "ceil", "expected \\rceil after \\lceil")
        };
        if let Some((LatexToken::Command(c), _)) = self.peek() {
            if c == close_cmd {
                self.next();
                return Ok(ExprKind::Function {
                    name: fn_name.to_string(),
                    args: vec![expr],
                }
                .into());
            }
        }
        Err(ParseError::custom(
            err_msg.to_string(),
            Some(self.current_span()),
        ))
    }

    /// Parses \delta (Kronecker) or \varepsilon / \epsilon (Levi-Civita).
    /// Falls back to a plain variable when no tensor indices follow.
    pub(super) fn parse_tensor_symbol_command(&mut self, cmd: &str) -> ParseResult<Expression> {
        if self.looks_like_tensor_index() {
            let indices = self.parse_tensor_indices()?;
            if !indices.is_empty() {
                return Ok(if cmd == "delta" {
                    ExprKind::KroneckerDelta { indices }.into()
                } else {
                    ExprKind::LeviCivita { indices }.into()
                });
            }
        }
        Ok(Expression::variable(cmd.to_string()).into())
    }

    /// Dispatches a LaTeX command to the appropriate sub-parser.
    pub(super) fn parse_command(&mut self, cmd: &str, span: Span) -> ParseResult<Expression> {
        match cmd {
            "frac" => self.parse_frac_command(),

            "sqrt" => {
                if self.check(&LatexToken::LBracket) {
                    let n = self.bracketed(|p| p.parse_expression())?;
                    let x = self.braced(|p| p.parse_expression())?;
                    Ok(ExprKind::Function {
                        name: "root".to_string(),
                        args: vec![x, n],
                    }
                    .into())
                } else {
                    let x = self.braced(|p| p.parse_expression())?;
                    Ok(ExprKind::Function {
                        name: "sqrt".to_string(),
                        args: vec![x],
                    }
                    .into())
                }
            }

            "delta" | "varepsilon" | "epsilon" => self.parse_tensor_symbol_command(cmd),

            // Greek letters → variables (\pi is a constant)
            "alpha" | "beta" | "gamma" | "zeta" | "eta" | "theta" | "iota" | "kappa" | "lambda"
            | "mu" | "nu" | "xi" | "omicron" | "pi" | "rho" | "sigma" | "tau" | "upsilon"
            | "phi" | "chi" | "psi" | "omega" | "Gamma" | "Delta" | "Theta" | "Lambda" | "Xi"
            | "Pi" | "Sigma" | "Upsilon" | "Phi" | "Psi" | "Omega" => {
                if cmd == "pi" {
                    Ok(Expression::constant(MathConstant::Pi).into())
                } else {
                    Ok(Expression::variable(cmd.to_string()).into())
                }
            }

            "partial" => Ok(Expression::variable("partial".to_string()).into()),

            // Single-argument functions
            "sin" | "cos" | "tan" | "sec" | "csc" | "cot" | "arcsin" | "arccos" | "arctan"
            | "sinh" | "cosh" | "tanh" | "exp" | "det" | "min" | "max" | "gcd" | "lcm" | "abs"
            | "floor" | "ceil" | "sgn" | "trunc" | "rad" | "deg" => {
                let arg = self.parse_function_arg()?;
                Ok(ExprKind::Function {
                    name: cmd.to_string(),
                    args: vec![arg],
                }
                .into())
            }

            // Three-argument functions: clamp(x, lo, hi) and lerp(a, b, t)
            "clamp" | "lerp" => self.parse_three_arg_function(cmd),

            "ln" => self.parse_log_command(false),
            "log" => self.parse_log_command(true),
            "lfloor" => self.parse_floor_ceil_command(true),
            "lceil" => self.parse_floor_ceil_command(false),
            "operatorname" => self.parse_operatorname_command(span),

            "int" => self.parse_integral(),
            "lim" => self.parse_limit(),
            "sum" => self.parse_sum(),
            "prod" => self.parse_product(),

            _ => Err(ParseError::invalid_latex_command(cmd, Some(span))),
        }
    }

    /// Parses `\operatorname{name}` as a named function call.
    ///
    /// Reads the function name from the braced argument, then parses the
    /// function argument (braced, parenthesized, or unbraced primary).
    pub(super) fn parse_operatorname_command(&mut self, span: Span) -> ParseResult<Expression> {
        // Read the operator name from {name}
        self.consume(LatexToken::LBrace)?;
        let mut name = String::new();
        while let Some((tok, _)) = self.peek() {
            match tok {
                LatexToken::RBrace => break,
                LatexToken::Eof => {
                    return Err(ParseError::unexpected_eof(
                        vec!["operator name"],
                        Some(span),
                    ));
                }
                LatexToken::Letter(ch) => {
                    let ch = *ch;
                    self.next();
                    name.push(ch);
                }
                _ => {
                    return Err(ParseError::custom(
                        r"\operatorname{} must contain a plain name".to_string(),
                        Some(span),
                    ));
                }
            }
        }
        self.consume(LatexToken::RBrace)?;
        if name.is_empty() {
            return Err(ParseError::custom(
                r"\operatorname{} name must not be empty".to_string(),
                Some(span),
            ));
        }
        let arg = self.parse_function_arg()?;
        Ok(ExprKind::Function {
            name,
            args: vec![arg],
        }
        .into())
    }

    /// Parses a parenthesized three-argument function: `\name(a, b, c)`.
    ///
    /// Used for `\clamp` and `\lerp`, which require exactly three comma-separated
    /// arguments enclosed in parentheses.
    pub(super) fn parse_three_arg_function(&mut self, name: &str) -> ParseResult<Expression> {
        self.consume(LatexToken::LParen)?;
        let first = self.parse_expression()?;
        self.consume(LatexToken::Comma)?;
        let second = self.parse_expression()?;
        self.consume(LatexToken::Comma)?;
        let third = self.parse_expression()?;
        self.consume(LatexToken::RParen)?;
        Ok(ExprKind::Function {
            name: name.to_string(),
            args: vec![first, second, third],
        }
        .into())
    }

    /// Parses a function argument (either braced or a primary expression).
    pub(super) fn parse_function_arg(&mut self) -> ParseResult<Expression> {
        if self.check(&LatexToken::LBrace) {
            self.braced(|p| p.parse_expression())
        } else if self.check(&LatexToken::LParen) {
            self.next(); // consume (
            let expr = self.parse_expression()?;
            self.consume(LatexToken::RParen)?;
            Ok(expr)
        } else {
            // For LaTeX, functions can take unbraced simple arguments: \sin x
            self.parse_power()
        }
    }

    /// Parses an expression in braces {...} or a single atom.
    pub(super) fn parse_braced_or_atom(&mut self) -> ParseResult<Expression> {
        if self.check(&LatexToken::LBrace) {
            self.braced(|p| p.parse_expression())
        } else {
            // Parse a single atom (number, letter, etc.)
            self.parse_primary()
        }
    }

    /// Helper: parses content within braces {...}.
    pub(super) fn braced<F, T>(&mut self, parser_fn: F) -> ParseResult<T>
    where
        F: FnOnce(&mut Self) -> ParseResult<T>,
    {
        self.consume(LatexToken::LBrace)?;
        let result = parser_fn(self)?;
        self.consume(LatexToken::RBrace)?;
        Ok(result)
    }

    /// Helper: parses content within brackets [...].
    pub(super) fn bracketed<F, T>(&mut self, parser_fn: F) -> ParseResult<T>
    where
        F: FnOnce(&mut Self) -> ParseResult<T>,
    {
        self.consume(LatexToken::LBracket)?;
        let result = parser_fn(self)?;
        self.consume(LatexToken::RBracket)?;
        Ok(result)
    }

    /// Checks if the next subscript/superscript looks like a tensor index (letters)
    /// rather than a power or regular subscript (numbers/expressions).
    /// This helps distinguish between \delta^i_j (tensor) and \delta^2 (power).
    pub(super) fn looks_like_tensor_index(&self) -> bool {
        // Must have ^ or _ next
        if !self.check(&LatexToken::Caret) && !self.check(&LatexToken::Underscore) {
            return false;
        }

        // Look at what follows ^ or _
        // We need to peek 2 tokens ahead
        let next_pos = self.pos + 1;
        if let Some((token, _)) = self.tokens.get(next_pos) {
            match token {
                // Single letter index: ^i or _j
                LatexToken::Letter(_) => true,
                // Greek letter index: ^\mu or _\nu
                LatexToken::Command(_) => true,
                // Braced group: ^{ij} or _{kl} - check first char inside braces
                LatexToken::LBrace => {
                    // Look inside the braces
                    if let Some((inner, _)) = self.tokens.get(next_pos + 1) {
                        matches!(inner, LatexToken::Letter(_) | LatexToken::Command(_))
                    } else {
                        false
                    }
                }
                // Number means this is a power, not tensor index
                LatexToken::Number(_) => false,
                _ => false,
            }
        } else {
            false
        }
    }

    /// Parses tensor indices from the current position.
    /// Handles patterns like ^{ij}_{kl}, ^i_j, _{ij}, etc.
    /// Returns a vector of TensorIndex with the appropriate index types.
    pub(super) fn parse_tensor_indices(&mut self) -> ParseResult<Vec<TensorIndex>> {
        let mut indices = Vec::new();

        // Parse upper indices (^{...} or ^x)
        if self.check(&LatexToken::Caret) {
            self.next(); // consume ^
            let upper_indices = self.parse_index_group(IndexType::Upper)?;
            indices.extend(upper_indices);
        }

        // Parse lower indices (_{...} or _x)
        if self.check(&LatexToken::Underscore) {
            self.next(); // consume _
            let lower_indices = self.parse_index_group(IndexType::Lower)?;
            indices.extend(lower_indices);
        }

        // Handle mixed notation: T^i_j^k (rare but valid)
        // Check for additional upper indices after lower
        if self.check(&LatexToken::Caret) {
            self.next();
            let more_upper = self.parse_index_group(IndexType::Upper)?;
            indices.extend(more_upper);
        }

        Ok(indices)
    }

    /// Parses a single tensor index name (letter or command).
    fn parse_single_tensor_index(&mut self, index_type: IndexType) -> ParseResult<TensorIndex> {
        match self.peek() {
            Some((LatexToken::Letter(ch), _)) => {
                let ch = *ch;
                self.next();
                Ok(TensorIndex {
                    name: ch.to_string(),
                    index_type,
                })
            }
            Some((LatexToken::Command(cmd), _)) => {
                let cmd = cmd.clone();
                self.next();
                Ok(TensorIndex {
                    name: cmd,
                    index_type,
                })
            }
            Some((_, span)) => Err(ParseError::custom(
                "expected letter in tensor index".to_string(),
                Some(*span),
            )),
            None => Err(ParseError::unexpected_eof(
                vec!["tensor index"],
                Some(self.current_span()),
            )),
        }
    }

    /// Parses a group of indices (either braced or single character).
    pub(super) fn parse_index_group(
        &mut self,
        index_type: IndexType,
    ) -> ParseResult<Vec<TensorIndex>> {
        let mut indices = Vec::new();
        if self.check(&LatexToken::LBrace) {
            self.next();
            while !self.check(&LatexToken::RBrace) && !self.check(&LatexToken::Eof) {
                indices.push(self.parse_single_tensor_index(index_type)?);
            }
            self.consume(LatexToken::RBrace)?;
        } else {
            indices.push(self.parse_single_tensor_index(index_type)?);
        }
        Ok(indices)
    }
}