formulac 0.8.0

A complex-number and extensible function supported math expression parser for Rust
Documentation
//! # operators.rs
//!
//! This module defines unary and binary operators used in mathematical expressions.
//! It provides enums for operator kinds, precedence information, and application logic.

use num_complex::Complex;
use std::str::FromStr;

use crate::core::{
    ComplexMath,
    Real,
};

pub const DIFFERENTIAL_OPERATOR_STR: &str = "diff";

macro_rules! operator_kind {
    ($($symbol:expr => $kind:ident), *$(,)?) => {
        #[derive(Debug, Clone, Copy, PartialEq)]
        pub enum OperatorKind {
            $( $kind ), *
        }

        impl FromStr for OperatorKind {
            type Err = (); // unknown only
            fn from_str(s: &str) -> Result<Self, Self::Err>
            {
                match s {
                    $( $symbol => Ok(Self::$kind), )*
                    _ => Err(()),
                }
            }
        }

        impl OperatorKind {
            pub fn symbols() -> &'static [&'static str]
            {
                &[$($symbol), *]
            }
       }

        impl std::fmt::Display for OperatorKind {
            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                match self {
                    $( Self::$kind => write!(f, $symbol), )*
                }
            }
        }
    };
}

operator_kind! {
    "+" => Plus,
    "-" => Minus,
    "*" => Mul,
    "/" => Div,
    "^" => Pow,
}

#[doc(hidden)]
/// Internal macro to define all unary operators.
macro_rules! unary_operator_kind {
    ($($name:ident => { kind: $kind:ident, apply: $apply:expr }),* $(,)?) => {
        /// Represents a unary operator in a mathematical expression.
        #[derive(Debug, Clone, Copy, PartialEq)]
        pub(crate) enum UnaryOperatorKind {
            $($name),*
        }

        impl TryFrom<OperatorKind> for UnaryOperatorKind {
            type Error = (); // unknown only
            fn try_from(k: OperatorKind) -> Result<Self, Self::Error> {
                match k {
                    $( OperatorKind::$kind => Ok(Self::$name), )*
                    _ => Err(()),
                }
            }
        }

        impl UnaryOperatorKind {
            /// Applies the unary operator to a complex number.
            pub(crate) fn apply<T: Real>(&self, x: Complex<T>) -> Complex<T> {
                match self {
                    $( Self::$name => $apply(x), )*
                }
            }

        }

        impl std::fmt::Display for UnaryOperatorKind {
            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                match self {
                    $( Self::$name => write!(f, "{}", OperatorKind::$kind), )*
                }
            }
        }
    };
}

unary_operator_kind! {
    Positive => { kind: Plus, apply: |x| x },
    Negative => { kind: Minus, apply: |x: Complex<_>| -x },
}

#[doc(hidden)]
/// Internal macro to define all binary operators.
macro_rules! binary_operators {
    ($($name:ident => {
        kind: $kind:ident,
        precedence: $prec:expr,
        left_assoc: $assoc:expr,
        apply: $apply:expr
    }),* $(,)?) => {
        /// Represents a binary operator in a mathematical expression.
        #[derive(Debug, Clone, Copy, PartialEq)]
        pub(crate) enum BinaryOperatorKind {
            $($name),*
        }

        impl TryFrom<OperatorKind> for BinaryOperatorKind {
            type Error = (); // unknown only
            fn try_from(k: OperatorKind) -> Result<Self, Self::Error> {
                match k {
                    $( OperatorKind::$kind => Ok(Self::$name), )*
                }
            }
        }

        impl BinaryOperatorKind {
            #[inline]
            pub(crate) fn precedence(&self) -> u8 {
                match self {
                    $( Self::$name => $prec, )*
                }
            }

            #[inline]
            pub(crate) fn is_left_assoc(&self) -> bool {
                match self {
                    $( Self::$name => $assoc, )*
                }
            }

            /// Applies the operator to two complex numbers.
            #[inline]
            pub(crate) fn apply<T: Real>(&self, l: Complex<T>, r: Complex<T>) -> Complex<T> {
                match self {
                    $(Self::$name => $apply(l, r),)*
                }
            }
        }

        impl std::fmt::Display for BinaryOperatorKind {
            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                match self {
                    $( Self::$name => write!(f, "{}",  OperatorKind::$kind), )*
                }
            }
        }
    };
}

binary_operators! {
    Add => { kind: Plus, precedence: 0, left_assoc: true,  apply: |l, r| l + r },
    Sub => { kind: Minus, precedence: 0, left_assoc: true,  apply: |l, r| l - r },
    Mul => { kind: Mul, precedence: 1, left_assoc: true,  apply: |l, r| l * r },
    Div => { kind: Div, precedence: 1, left_assoc: true,  apply: |l, r| l / r },
    Pow => { kind: Pow, precedence: 2, left_assoc: false, apply: |l: Complex<T>, r: Complex<T>| l.powc(r) },
}

#[cfg(test)]
mod tests {
    use super::*;
    use num_complex::Complex;

    // Helper
    fn c(re: f64, im:f64) -> Complex<f64> {
        Complex::new(re, im)
    }

    fn eq(a: Complex<f64>, b: Complex<f64>) -> bool {
        (a - b).norm() < 1.0e-10
    }

    // -- symbols --
    #[test]
    fn symbols_contains_all() {
        let syms = OperatorKind::symbols();
        for s in ["+", "-", "*", "/", "^"] {
            assert!(syms.contains(&s), "missing symbol: {s}");
        }
    }


    // ── UnaryOperatorKind ─────────────────────────────────────
    mod unary {
        use super::*;

        #[test]
        fn from_valid_symbols() {
            assert_eq!(UnaryOperatorKind::try_from(OperatorKind::Plus), Ok(UnaryOperatorKind::Positive));
            assert_eq!(UnaryOperatorKind::try_from(OperatorKind::Minus), Ok(UnaryOperatorKind::Negative));
        }

        #[test]
        fn from_invalid_symbol() {
            assert!(UnaryOperatorKind::try_from(OperatorKind::Mul).is_err());
            assert!(UnaryOperatorKind::try_from(OperatorKind::Div).is_err());
            assert!(UnaryOperatorKind::try_from(OperatorKind::Pow).is_err());
        }

        #[test]
        fn apply_positive_is_identity() {
            let cases = [c(0.0, 0.0), c(3.0, 0.0), c(-2.0, 5.0)];
            for x in cases {
                assert_eq!(UnaryOperatorKind::Positive.apply(x), x);
            }
        }

        #[test]
        fn apply_negative_negates() {
            assert_eq!(UnaryOperatorKind::Negative.apply(c(3.0, 4.0)), c(-3.0, -4.0));
            assert_eq!(UnaryOperatorKind::Negative.apply(c(0.0, 0.0)), c(0.0, 0.0));
        }

        #[test]
        fn display() {
            assert_eq!(UnaryOperatorKind::Positive.to_string(), "+");
            assert_eq!(UnaryOperatorKind::Negative.to_string(), "-");
        }
    }

    // ── BinaryOperatorKind ────────────────────────────────────
    mod binary {
        use super::*;

        // -- TryFrom --
        #[test]
        fn from_valid_symbols() {
            assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Plus), Ok(BinaryOperatorKind::Add));
            assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Minus), Ok(BinaryOperatorKind::Sub));
            assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Mul), Ok(BinaryOperatorKind::Mul));
            assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Div), Ok(BinaryOperatorKind::Div));
            assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Pow), Ok(BinaryOperatorKind::Pow));
        }

        // -- precedence --
        #[test]
        fn precedence_ordering() {
            assert_eq!(BinaryOperatorKind::Add.precedence(), BinaryOperatorKind::Sub.precedence());
            assert!(BinaryOperatorKind::Mul.precedence() > BinaryOperatorKind::Add.precedence());
            assert!(BinaryOperatorKind::Div.precedence() > BinaryOperatorKind::Sub.precedence());
            assert!(BinaryOperatorKind::Pow.precedence() > BinaryOperatorKind::Mul.precedence());
        }

        // -- is_left_assoc --
        #[test]
        fn associativity() {
            assert!(BinaryOperatorKind::Add.is_left_assoc());
            assert!(BinaryOperatorKind::Sub.is_left_assoc());
            assert!(BinaryOperatorKind::Mul.is_left_assoc());
            assert!(BinaryOperatorKind::Div.is_left_assoc());
            assert!( ! BinaryOperatorKind::Pow.is_left_assoc() ); // right association
        }

        // -- apply: Real --
        #[test]
        fn apply_real() {
            let (a, b) = (c(6.0, 0.0), c(2.0, 0.0));
            assert_eq!(BinaryOperatorKind::Add.apply(a, b), c(8.0,  0.0));
            assert_eq!(BinaryOperatorKind::Sub.apply(a, b), c(4.0,  0.0));
            assert_eq!(BinaryOperatorKind::Mul.apply(a, b), c(12.0, 0.0));
            assert_eq!(BinaryOperatorKind::Div.apply(a, b), c(3.0,  0.0));
        }

        #[test]
        fn apply_pow_real() {
            assert!(eq(
                BinaryOperatorKind::Pow.apply(c(2.0, 0.0), c(10.0, 0.0)),
                c(1024.0, 0.0),
            ));
        }

        // -- apply: Complex --
        #[test]
        fn apply_add_complex() {
            // (1+2i) + (3+4i) = 4+6i
            assert_eq!(
                BinaryOperatorKind::Add.apply(c(1.0, 2.0), c(3.0, 4.0)),
                c(4.0, 6.0),
            );
        }

        #[test]
        fn apply_mul_complex() {
            // (1+i)(1-i) = 2
            assert!(eq(
                BinaryOperatorKind::Mul.apply(c(1.0, 1.0), c(1.0, -1.0)),
                c(2.0, 0.0),
            ));
        }

        #[test]
        fn apply_pow_complex() {
            // i^2 = -1
            assert!(eq(
                BinaryOperatorKind::Pow.apply(c(0.0, 1.0), c(2.0, 0.0)),
                c(-1.0, 0.0),
            ));
        }

        // -- apply --
        #[test]
        fn div_by_zero_does_not_panic() {
            let result = BinaryOperatorKind::Div.apply(c(1.0, 0.0), c(0.0, 0.0));
            assert!(result.re.is_infinite() || result.re.is_nan());
        }

        // -- Display --
        #[test]
        fn display() {
            assert_eq!(BinaryOperatorKind::Add.to_string(), "+");
            assert_eq!(BinaryOperatorKind::Sub.to_string(), "-");
            assert_eq!(BinaryOperatorKind::Mul.to_string(), "*");
            assert_eq!(BinaryOperatorKind::Div.to_string(), "/");
            assert_eq!(BinaryOperatorKind::Pow.to_string(), "^");
        }
    }
}