mathexpr 0.1.1

A fast, safe mathematical expression parser and evaluator with bytecode compilation
Documentation
//! Abstract Syntax Tree types for mathexpr.
//!
//! This module contains the AST representation of parsed mathematical expressions.

#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};

use crate::error::{CompileError, EvalError};

/// A parsed mathematical expression.
///
/// This is the AST representation produced by the parser. It can be compiled
/// into bytecode for efficient repeated evaluation.
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
    /// A numeric literal (e.g., `42`, `3.14`).
    Number(f64),
    /// A variable reference (e.g., `x`, `score`).
    Variable(String),
    /// The current/input value, represented by `_` in expressions.
    CurrentValue,
    /// A binary operation (e.g., `a + b`, `x * y`).
    BinaryOp {
        /// The operator.
        op: BinOp,
        /// Left operand.
        left: Box<Expr>,
        /// Right operand.
        right: Box<Expr>,
    },
    /// Unary negation (e.g., `-x`).
    UnaryMinus(Box<Expr>),
    /// A function call (e.g., `sqrt(x)`, `max(a, b)`).
    FunctionCall {
        /// Function name.
        name: String,
        /// Function arguments.
        args: Vec<Expr>,
    },
}

/// Binary operators supported by mathexpr.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinOp {
    /// Addition (`+`).
    Add,
    /// Subtraction (`-`).
    Sub,
    /// Multiplication (`*`).
    Mul,
    /// Division (`/`).
    Div,
    /// Modulo (`%`).
    Mod,
    /// Exponentiation (`^`).
    Pow,
}

impl BinOp {
    /// Evaluate this binary operator with the given operands.
    #[inline]
    pub fn eval(self, left: f64, right: f64) -> Result<f64, EvalError> {
        match self {
            BinOp::Add => Ok(left + right),
            BinOp::Sub => Ok(left - right),
            BinOp::Mul => Ok(left * right),
            BinOp::Div => {
                if right == 0.0 {
                    Err(EvalError::DivisionByZero)
                } else {
                    Ok(left / right)
                }
            }
            BinOp::Mod => {
                if right == 0.0 {
                    Err(EvalError::DivisionByZero)
                } else {
                    Ok(left % right)
                }
            }
            BinOp::Pow => Ok(libm::pow(left, right)),
        }
    }
}

/// Built-in functions supported by mathexpr.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BuiltinFn {
    // Tier 1: Core math functions
    /// Absolute value: `abs(x)`
    Abs,
    /// Square root: `sqrt(x)`
    Sqrt,
    /// Natural logarithm: `log(x)` or `ln(x)`
    Log,
    /// Base-10 logarithm: `log10(x)`
    Log10,
    /// Exponential: `exp(x)`
    Exp,
    /// Minimum: `min(a, b)`
    Min,
    /// Maximum: `max(a, b)`
    Max,
    /// Power: `pow(base, exp)`
    Pow,
    /// Modulo: `mod(a, b)`
    Mod,

    // Tier 2: Trigonometric functions
    /// Sine: `sin(x)`
    Sin,
    /// Cosine: `cos(x)`
    Cos,
    /// Tangent: `tan(x)`
    Tan,
    /// Arc sine: `asin(x)`
    Asin,
    /// Arc cosine: `acos(x)`
    Acos,
    /// Arc tangent: `atan(x)`
    Atan,
    /// Hyperbolic sine: `sinh(x)`
    Sinh,
    /// Hyperbolic cosine: `cosh(x)`
    Cosh,
    /// Hyperbolic tangent: `tanh(x)`
    Tanh,

    // Tier 3: Rounding & utility functions
    /// Floor: `floor(x)`
    Floor,
    /// Ceiling: `ceil(x)`
    Ceil,
    /// Round to nearest: `round(x)`
    Round,
    /// Truncate toward zero: `trunc(x)`
    Trunc,
    /// Sign: `signum(x)`
    Signum,
    /// Cube root: `cbrt(x)`
    Cbrt,
    /// Base-2 logarithm: `log2(x)`
    Log2,
    /// Clamp to range: `clamp(x, min, max)`
    Clamp,

    // Constants (zero-arity)
    /// Pi constant: `pi` or `pi()`
    Pi,
    /// Euler's number: `e` or `e()`
    E,
}

impl BuiltinFn {
    /// Try to resolve a function name to a builtin.
    ///
    /// Returns `Some((function, arity))` if the name matches a builtin function,
    /// or `None` if the function is unknown.
    pub fn from_name(name: &str) -> Option<(Self, usize)> {
        match name {
            // Tier 1: Core math functions
            "abs" => Some((BuiltinFn::Abs, 1)),
            "sqrt" => Some((BuiltinFn::Sqrt, 1)),
            "log" | "ln" => Some((BuiltinFn::Log, 1)),
            "log10" => Some((BuiltinFn::Log10, 1)),
            "exp" => Some((BuiltinFn::Exp, 1)),
            "min" => Some((BuiltinFn::Min, 2)),
            "max" => Some((BuiltinFn::Max, 2)),
            "pow" => Some((BuiltinFn::Pow, 2)),
            "mod" => Some((BuiltinFn::Mod, 2)),

            // Tier 2: Trigonometric functions
            "sin" => Some((BuiltinFn::Sin, 1)),
            "cos" => Some((BuiltinFn::Cos, 1)),
            "tan" => Some((BuiltinFn::Tan, 1)),
            "asin" => Some((BuiltinFn::Asin, 1)),
            "acos" => Some((BuiltinFn::Acos, 1)),
            "atan" => Some((BuiltinFn::Atan, 1)),
            "sinh" => Some((BuiltinFn::Sinh, 1)),
            "cosh" => Some((BuiltinFn::Cosh, 1)),
            "tanh" => Some((BuiltinFn::Tanh, 1)),

            // Tier 3: Rounding & utility functions
            "floor" => Some((BuiltinFn::Floor, 1)),
            "ceil" => Some((BuiltinFn::Ceil, 1)),
            "round" => Some((BuiltinFn::Round, 1)),
            "trunc" => Some((BuiltinFn::Trunc, 1)),
            "signum" => Some((BuiltinFn::Signum, 1)),
            "cbrt" => Some((BuiltinFn::Cbrt, 1)),
            "log2" => Some((BuiltinFn::Log2, 1)),
            "clamp" => Some((BuiltinFn::Clamp, 3)),

            // Constants (zero-arity functions)
            "pi" => Some((BuiltinFn::Pi, 0)),
            "e" => Some((BuiltinFn::E, 0)),

            _ => None,
        }
    }

    /// Evaluate the function with given arguments.
    #[inline]
    pub fn eval(self, args: &[f64]) -> Result<f64, EvalError> {
        match self {
            // Tier 1: Core math functions
            BuiltinFn::Abs => Ok(libm::fabs(args[0])),
            BuiltinFn::Sqrt => {
                if args[0] < 0.0 {
                    Ok(f64::NAN)
                } else {
                    Ok(libm::sqrt(args[0]))
                }
            }
            BuiltinFn::Log => {
                if args[0] <= 0.0 {
                    Ok(f64::NAN)
                } else {
                    Ok(libm::log(args[0]))
                }
            }
            BuiltinFn::Log10 => {
                if args[0] <= 0.0 {
                    Ok(f64::NAN)
                } else {
                    Ok(libm::log10(args[0]))
                }
            }
            BuiltinFn::Exp => Ok(libm::exp(args[0])),
            BuiltinFn::Min => Ok(libm::fmin(args[0], args[1])),
            BuiltinFn::Max => Ok(libm::fmax(args[0], args[1])),
            BuiltinFn::Pow => Ok(libm::pow(args[0], args[1])),
            BuiltinFn::Mod => {
                if args[1] == 0.0 {
                    Err(EvalError::DivisionByZero)
                } else {
                    Ok(libm::fmod(args[0], args[1]))
                }
            }

            // Tier 2: Trigonometric functions
            BuiltinFn::Sin => Ok(libm::sin(args[0])),
            BuiltinFn::Cos => Ok(libm::cos(args[0])),
            BuiltinFn::Tan => Ok(libm::tan(args[0])),
            BuiltinFn::Asin => Ok(libm::asin(args[0])), // Returns NaN if |x| > 1
            BuiltinFn::Acos => Ok(libm::acos(args[0])), // Returns NaN if |x| > 1
            BuiltinFn::Atan => Ok(libm::atan(args[0])),
            BuiltinFn::Sinh => Ok(libm::sinh(args[0])),
            BuiltinFn::Cosh => Ok(libm::cosh(args[0])),
            BuiltinFn::Tanh => Ok(libm::tanh(args[0])),

            // Tier 3: Rounding & utility functions
            BuiltinFn::Floor => Ok(libm::floor(args[0])),
            BuiltinFn::Ceil => Ok(libm::ceil(args[0])),
            BuiltinFn::Round => Ok(libm::round(args[0])),
            BuiltinFn::Trunc => Ok(libm::trunc(args[0])),
            BuiltinFn::Signum => {
                // libm doesn't have signum, implement manually
                if args[0].is_nan() {
                    Ok(f64::NAN)
                } else if args[0] == 0.0 {
                    Ok(args[0]) // Preserve sign of zero
                } else if args[0] > 0.0 {
                    Ok(1.0)
                } else {
                    Ok(-1.0)
                }
            }
            BuiltinFn::Cbrt => Ok(libm::cbrt(args[0])),
            BuiltinFn::Log2 => {
                if args[0] <= 0.0 {
                    Ok(f64::NAN)
                } else {
                    Ok(libm::log2(args[0]))
                }
            }
            BuiltinFn::Clamp => {
                let (x, min, max) = (args[0], args[1], args[2]);
                Ok(libm::fmin(libm::fmax(x, min), max))
            }

            // Constants
            BuiltinFn::Pi => Ok(core::f64::consts::PI),
            BuiltinFn::E => Ok(core::f64::consts::E),
        }
    }

    /// Get the arity (number of arguments) for this function.
    pub fn arity(self) -> usize {
        match self {
            // Zero-arity (constants)
            BuiltinFn::Pi | BuiltinFn::E => 0,

            // Single-argument functions
            BuiltinFn::Abs
            | BuiltinFn::Sqrt
            | BuiltinFn::Log
            | BuiltinFn::Log10
            | BuiltinFn::Log2
            | BuiltinFn::Exp
            | BuiltinFn::Sin
            | BuiltinFn::Cos
            | BuiltinFn::Tan
            | BuiltinFn::Asin
            | BuiltinFn::Acos
            | BuiltinFn::Atan
            | BuiltinFn::Sinh
            | BuiltinFn::Cosh
            | BuiltinFn::Tanh
            | BuiltinFn::Floor
            | BuiltinFn::Ceil
            | BuiltinFn::Round
            | BuiltinFn::Trunc
            | BuiltinFn::Signum
            | BuiltinFn::Cbrt => 1,

            // Two-argument functions
            BuiltinFn::Min | BuiltinFn::Max | BuiltinFn::Pow | BuiltinFn::Mod => 2,

            // Three-argument functions
            BuiltinFn::Clamp => 3,
        }
    }

    /// Check if this function can be called with the given number of arguments.
    pub fn check_arity(self, got: usize, name: &str) -> Result<(), CompileError> {
        let expected = self.arity();
        if got != expected {
            Err(CompileError::WrongArity {
                name: String::from(name),
                expected,
                got,
            })
        } else {
            Ok(())
        }
    }
}