use num_complex::Complex;
use std::str::FromStr;
use crate::core::{
ComplexMath,
Real,
};
pub const DIFFERENTIAL_OPERATOR_STR: &str = "diff";
macro_rules! operator_kind {
($($symbol:expr => $kind:ident), *$(,)?) => {
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OperatorKind {
$( $kind ), *
}
impl FromStr for OperatorKind {
type Err = (); fn from_str(s: &str) -> Result<Self, Self::Err>
{
match s {
$( $symbol => Ok(Self::$kind), )*
_ => Err(()),
}
}
}
impl OperatorKind {
pub fn symbols() -> &'static [&'static str]
{
&[$($symbol), *]
}
}
impl std::fmt::Display for OperatorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
$( Self::$kind => write!(f, $symbol), )*
}
}
}
};
}
operator_kind! {
"+" => Plus,
"-" => Minus,
"*" => Mul,
"/" => Div,
"^" => Pow,
}
#[doc(hidden)]
macro_rules! unary_operator_kind {
($($name:ident => { kind: $kind:ident, apply: $apply:expr }),* $(,)?) => {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum UnaryOperatorKind {
$($name),*
}
impl TryFrom<OperatorKind> for UnaryOperatorKind {
type Error = (); fn try_from(k: OperatorKind) -> Result<Self, Self::Error> {
match k {
$( OperatorKind::$kind => Ok(Self::$name), )*
_ => Err(()),
}
}
}
impl UnaryOperatorKind {
pub(crate) fn apply<T: Real>(&self, x: Complex<T>) -> Complex<T> {
match self {
$( Self::$name => $apply(x), )*
}
}
}
impl std::fmt::Display for UnaryOperatorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
$( Self::$name => write!(f, "{}", OperatorKind::$kind), )*
}
}
}
};
}
unary_operator_kind! {
Positive => { kind: Plus, apply: |x| x },
Negative => { kind: Minus, apply: |x: Complex<_>| -x },
}
#[doc(hidden)]
macro_rules! binary_operators {
($($name:ident => {
kind: $kind:ident,
precedence: $prec:expr,
left_assoc: $assoc:expr,
apply: $apply:expr
}),* $(,)?) => {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum BinaryOperatorKind {
$($name),*
}
impl TryFrom<OperatorKind> for BinaryOperatorKind {
type Error = (); fn try_from(k: OperatorKind) -> Result<Self, Self::Error> {
match k {
$( OperatorKind::$kind => Ok(Self::$name), )*
}
}
}
impl BinaryOperatorKind {
#[inline]
pub(crate) fn precedence(&self) -> u8 {
match self {
$( Self::$name => $prec, )*
}
}
#[inline]
pub(crate) fn is_left_assoc(&self) -> bool {
match self {
$( Self::$name => $assoc, )*
}
}
#[inline]
pub(crate) fn apply<T: Real>(&self, l: Complex<T>, r: Complex<T>) -> Complex<T> {
match self {
$(Self::$name => $apply(l, r),)*
}
}
}
impl std::fmt::Display for BinaryOperatorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
$( Self::$name => write!(f, "{}", OperatorKind::$kind), )*
}
}
}
};
}
binary_operators! {
Add => { kind: Plus, precedence: 0, left_assoc: true, apply: |l, r| l + r },
Sub => { kind: Minus, precedence: 0, left_assoc: true, apply: |l, r| l - r },
Mul => { kind: Mul, precedence: 1, left_assoc: true, apply: |l, r| l * r },
Div => { kind: Div, precedence: 1, left_assoc: true, apply: |l, r| l / r },
Pow => { kind: Pow, precedence: 2, left_assoc: false, apply: |l: Complex<T>, r: Complex<T>| l.powc(r) },
}
#[cfg(test)]
mod tests {
use super::*;
use num_complex::Complex;
fn c(re: f64, im:f64) -> Complex<f64> {
Complex::new(re, im)
}
fn eq(a: Complex<f64>, b: Complex<f64>) -> bool {
(a - b).norm() < 1.0e-10
}
#[test]
fn symbols_contains_all() {
let syms = OperatorKind::symbols();
for s in ["+", "-", "*", "/", "^"] {
assert!(syms.contains(&s), "missing symbol: {s}");
}
}
mod unary {
use super::*;
#[test]
fn from_valid_symbols() {
assert_eq!(UnaryOperatorKind::try_from(OperatorKind::Plus), Ok(UnaryOperatorKind::Positive));
assert_eq!(UnaryOperatorKind::try_from(OperatorKind::Minus), Ok(UnaryOperatorKind::Negative));
}
#[test]
fn from_invalid_symbol() {
assert!(UnaryOperatorKind::try_from(OperatorKind::Mul).is_err());
assert!(UnaryOperatorKind::try_from(OperatorKind::Div).is_err());
assert!(UnaryOperatorKind::try_from(OperatorKind::Pow).is_err());
}
#[test]
fn apply_positive_is_identity() {
let cases = [c(0.0, 0.0), c(3.0, 0.0), c(-2.0, 5.0)];
for x in cases {
assert_eq!(UnaryOperatorKind::Positive.apply(x), x);
}
}
#[test]
fn apply_negative_negates() {
assert_eq!(UnaryOperatorKind::Negative.apply(c(3.0, 4.0)), c(-3.0, -4.0));
assert_eq!(UnaryOperatorKind::Negative.apply(c(0.0, 0.0)), c(0.0, 0.0));
}
#[test]
fn display() {
assert_eq!(UnaryOperatorKind::Positive.to_string(), "+");
assert_eq!(UnaryOperatorKind::Negative.to_string(), "-");
}
}
mod binary {
use super::*;
#[test]
fn from_valid_symbols() {
assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Plus), Ok(BinaryOperatorKind::Add));
assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Minus), Ok(BinaryOperatorKind::Sub));
assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Mul), Ok(BinaryOperatorKind::Mul));
assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Div), Ok(BinaryOperatorKind::Div));
assert_eq!(BinaryOperatorKind::try_from(OperatorKind::Pow), Ok(BinaryOperatorKind::Pow));
}
#[test]
fn precedence_ordering() {
assert_eq!(BinaryOperatorKind::Add.precedence(), BinaryOperatorKind::Sub.precedence());
assert!(BinaryOperatorKind::Mul.precedence() > BinaryOperatorKind::Add.precedence());
assert!(BinaryOperatorKind::Div.precedence() > BinaryOperatorKind::Sub.precedence());
assert!(BinaryOperatorKind::Pow.precedence() > BinaryOperatorKind::Mul.precedence());
}
#[test]
fn associativity() {
assert!(BinaryOperatorKind::Add.is_left_assoc());
assert!(BinaryOperatorKind::Sub.is_left_assoc());
assert!(BinaryOperatorKind::Mul.is_left_assoc());
assert!(BinaryOperatorKind::Div.is_left_assoc());
assert!( ! BinaryOperatorKind::Pow.is_left_assoc() ); }
#[test]
fn apply_real() {
let (a, b) = (c(6.0, 0.0), c(2.0, 0.0));
assert_eq!(BinaryOperatorKind::Add.apply(a, b), c(8.0, 0.0));
assert_eq!(BinaryOperatorKind::Sub.apply(a, b), c(4.0, 0.0));
assert_eq!(BinaryOperatorKind::Mul.apply(a, b), c(12.0, 0.0));
assert_eq!(BinaryOperatorKind::Div.apply(a, b), c(3.0, 0.0));
}
#[test]
fn apply_pow_real() {
assert!(eq(
BinaryOperatorKind::Pow.apply(c(2.0, 0.0), c(10.0, 0.0)),
c(1024.0, 0.0),
));
}
#[test]
fn apply_add_complex() {
assert_eq!(
BinaryOperatorKind::Add.apply(c(1.0, 2.0), c(3.0, 4.0)),
c(4.0, 6.0),
);
}
#[test]
fn apply_mul_complex() {
assert!(eq(
BinaryOperatorKind::Mul.apply(c(1.0, 1.0), c(1.0, -1.0)),
c(2.0, 0.0),
));
}
#[test]
fn apply_pow_complex() {
assert!(eq(
BinaryOperatorKind::Pow.apply(c(0.0, 1.0), c(2.0, 0.0)),
c(-1.0, 0.0),
));
}
#[test]
fn div_by_zero_does_not_panic() {
let result = BinaryOperatorKind::Div.apply(c(1.0, 0.0), c(0.0, 0.0));
assert!(result.re.is_infinite() || result.re.is_nan());
}
#[test]
fn display() {
assert_eq!(BinaryOperatorKind::Add.to_string(), "+");
assert_eq!(BinaryOperatorKind::Sub.to_string(), "-");
assert_eq!(BinaryOperatorKind::Mul.to_string(), "*");
assert_eq!(BinaryOperatorKind::Div.to_string(), "/");
assert_eq!(BinaryOperatorKind::Pow.to_string(), "^");
}
}
}