use num_complex::Complex;
use std::str::FromStr;
use std::sync::Arc;
use crate::core::{
ComplexMath,
Real,
};
use crate::err::InitializeError;
pub trait FunctionCall<T: Real>: Apply<T> + Arity {}
impl<T: Real, U> FunctionCall<T> for U
where
U: Apply<T> + Arity,
{}
pub trait Apply<T: Real>
{
fn apply(&self, arg: Vec<Complex<T>>) -> Complex<T>;
}
pub trait Arity
{
fn arity(&self) -> usize;
}
macro_rules! count_args {
() => { 0usize };
($head:ident $(, $tail:ident)*) => { 1usize + count_args!($($tail),*)}
}
#[doc(hidden)]
macro_rules! functions {
($( $variant:ident => {
name: $name:expr,
apply: |$( $arg:ident ),+| $body:expr
}, )*) => {
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FunctionKind {
$( $variant, )*
}
impl FromStr for FunctionKind {
type Err = (); fn from_str(s: &str) -> Result<Self, Self::Err>
{
match s {
$( $name => Ok(Self::$variant), )*
_ => Err(())
}
}
}
impl FunctionKind {
pub fn symbols() -> &'static [&'static str] {
&[$( $name, )*]
}
}
impl Arity for FunctionKind
{
fn arity(&self) -> usize {
match self {
$( Self::$variant => count_args!($($arg),+), )*
}
}
}
impl<T: Real> Apply<T> for FunctionKind
{
fn apply(&self, args: Vec<Complex<T>>) -> Complex<T> {
match self {
$( Self::$variant => {
let mut it = args.into_iter();
$( let $arg = it.next().unwrap(); )+
$body
}, )*
}
}
}
impl std::fmt::Display for FunctionKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
$( Self::$variant => write!(f, $name), )*
}
}
}
};
}
functions! {
Sin => { name: "sin", apply: |x| x.sin() },
Cos => { name: "cos", apply: |x| x.cos() },
Tan => { name: "tan", apply: |x| x.tan() },
Asin => { name: "asin", apply: |x| x.asin() },
Acos => { name: "acos", apply: |x| x.acos() },
Atan => { name: "atan", apply: |x| x.atan() },
Sinh => { name: "sinh", apply: |x| x.sinh() },
Cosh => { name: "cosh", apply: |x| x.cosh() },
Tanh => { name: "tanh", apply: |x| x.tanh() },
Asinh => { name: "asinh", apply: |x| x.asinh() },
Acosh => { name: "acosh", apply: |x| x.acosh() },
Atanh => { name: "atanh", apply: |x| x.atanh() },
Exp => { name: "exp", apply: |x| x.exp() },
Ln => { name: "ln", apply: |x| x.ln() },
Log10 => { name: "log10", apply: |x| x.log10() },
Sqrt => { name: "sqrt", apply: |x| x.sqrt() },
Abs => { name: "abs", apply: |x| x.abs() },
Conj => { name: "conj", apply: |x| x.conj() },
Pow => { name: "pow", apply: |x, y| x.powc(y) },
Powi => { name: "powi", apply: |x, y| x.powi(y.re.to_i32()) },
}
#[cfg(test)]
mod function_tests {
use super::*;
fn c(re: f64, im: f64) -> Complex<f64> { Complex::new(re, im) }
fn eq(a: Complex<f64>, b: Complex<f64>) -> bool { (a - b).norm() < 1e-10 }
#[test]
fn from_str_valid() {
assert_eq!(FunctionKind::from_str("sin"), Ok(FunctionKind::Sin));
assert_eq!(FunctionKind::from_str("cos"), Ok(FunctionKind::Cos));
assert_eq!(FunctionKind::from_str("pow"), Ok(FunctionKind::Pow));
assert_eq!(FunctionKind::from_str("powi"), Ok(FunctionKind::Powi));
}
#[test]
fn from_str_invalid() {
assert!(FunctionKind::from_str("SIN").is_err());
assert!(FunctionKind::from_str("").is_err());
assert!(FunctionKind::from_str("log").is_err());
}
#[test]
fn arity_unary() {
for f in [FunctionKind::Sin, FunctionKind::Cos, FunctionKind::Exp,
FunctionKind::Ln, FunctionKind::Sqrt, FunctionKind::Abs] {
assert_eq!(f.arity(), 1);
}
}
#[test]
fn arity_binary() {
assert_eq!(FunctionKind::Pow.arity(), 2);
assert_eq!(FunctionKind::Powi.arity(), 2);
}
#[test]
fn apply_sin_cos() {
assert!(eq(FunctionKind::Sin.apply(vec![c(0.0, 0.0)]), c(0.0, 0.0)));
assert!(eq(FunctionKind::Cos.apply(vec![c(0.0, 0.0)]), c(1.0, 0.0)));
}
#[test]
fn apply_exp_ln_roundtrip() {
let x = c(1.0, 1.0);
let exp_x = FunctionKind::Exp.apply(vec![x]);
assert!(eq(FunctionKind::Ln.apply(vec![exp_x]), x));
}
#[test]
fn apply_abs_is_real() {
assert!(eq(
FunctionKind::Abs.apply(vec![c(3.0, 4.0)]),
c(5.0, 0.0),
));
}
#[test]
fn apply_pow_binary() {
assert!(eq(
FunctionKind::Pow.apply(vec![c(2.0, 0.0), c(8.0, 0.0)]),
c(256.0, 0.0),
));
}
#[test]
fn apply_powi_integer_exp() {
assert!(eq(
FunctionKind::Powi.apply(vec![c(3.0, 0.0), c(4.0, 0.0)]),
c(81.0, 0.0),
));
}
#[test]
fn display() {
assert_eq!(FunctionKind::Sin.to_string(), "sin");
assert_eq!(FunctionKind::Log10.to_string(), "log10");
assert_eq!(FunctionKind::Pow.to_string(), "pow");
}
}
type FuncType<T> = dyn Fn(Vec<Complex<T>>) -> Complex<T> + Send + Sync;
#[derive(Clone)]
pub struct UserFn<T: Real>
{
func: Arc<FuncType<T>>,
deriv: Vec<UserFn<T>>,
arity: usize,
name: String,
}
impl<T: Real> UserFn<T> {
pub fn new<F, S, const N: usize>(name: S, func: F) -> Self
where
F: Fn([Complex<T>; N]) -> Complex<T> + Send + Sync + 'static,
S: Into<String>,
{
Self {
func: Arc::new(move |args| {
let arr = args.try_into().unwrap_or_else(|_| unreachable!("arity mismatch"));
func(arr)
}),
deriv: Vec::new(),
arity: N,
name: name.into(),
}
}
pub fn with_derivative(mut self, diffs: impl IntoIterator<Item = Self>) -> Result<Self, InitializeError> {
let diffs: Vec<Self> = diffs.into_iter().collect();
if diffs.len() != self.arity {
return Err(InitializeError::DerivativesNumberMismatched {
expected: self.arity, number: diffs.len()
});
}
self.deriv = diffs;
Ok(self)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn derivative(&self, var: usize) -> Option<&Self> {
self.deriv.get(var)
}
}
impl<T: Real> Arity for UserFn<T>
{
fn arity(&self) -> usize {
self.arity
}
}
impl<T: Real> Apply<T> for UserFn<T> {
fn apply(&self, args: Vec<Complex<T>>) -> Complex<T> {
(self.func)(args)
}
}
impl<T: Real> std::fmt::Debug for UserFn<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UserFn")
.field("name", &self.name)
.field("arity", &self.arity)
.finish_non_exhaustive()
}
}
impl<T: Real> PartialEq for UserFn<T> {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.arity == other.arity
}
}
#[cfg(test)]
mod userfn_tests {
use super::*;
use approx::assert_abs_diff_eq;
fn c(re: f64, im: f64) -> Complex<f64> { Complex::new(re, im) }
#[test]
fn apply_unary() {
let f = UserFn::new(
"inc",
|[x] : [Complex<f64>; 1]| x + Complex::ONE,
);
assert_eq!(f.apply(vec![Complex::ZERO]), Complex::ONE);
}
#[test]
fn apply_binary() {
let f = UserFn::new(
"add",
|[x, y]| x + y,
);
assert_eq!(
f.apply(vec![c(1.0, 0.0), c(2.0, 0.0)]),
c(3.0, 0.0),
);
}
#[test]
fn apply_ternary() {
let f = UserFn::new(
"sum",
|[x, y, z]| x + y + z,
);
assert_eq!(
f.apply(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0)]),
c(6.0, 0.0),
);
}
#[test]
fn partial_eq() {
let f1 = UserFn::new("f", |[x] : [Complex<f64>; 1]| x);
let f2 = UserFn::new("f", |[x] : [Complex<f64>; 1]| x + x);
let f3 = UserFn::new("g", |[x] : [Complex<f64>; 1]| x);
let f4 = UserFn::new("f", |[x, y] : [Complex<f64>; 2]| x + y);
assert_eq!(f1, f2);
assert_ne!(f1, f3);
assert_ne!(f1, f4);
}
#[test]
fn without_derivative() {
let f = UserFn::new("f", |[x] : [Complex<f64>; 1]| x * x);
assert!(f.derivative(0).is_none());
}
#[test]
fn with_analytic_derivative() {
let df = UserFn::new(
"square_deriv",
|[x]| c(2.0, 0.0) * x,
);
let f = UserFn::new(
"square",
|[x]| x * x,
).with_derivative(vec![df])
.unwrap();
let deriv = f.derivative(0).expect("should exist");
let result = deriv.apply(vec![c(4.0, 0.0)]);
assert_abs_diff_eq!(result.re, 8.0, epsilon = 1e-12);
assert_abs_diff_eq!(result.im, 0.0, epsilon = 1e-12);
}
#[test]
fn debug_contains_name_and_arity() {
let f = UserFn::new(
"mul",
|[x] : [Complex<f64>; 1]| x * x,
);
let s = format!("{:?}", f);
assert!(s.contains("mul"));
assert!(s.contains("arity"));
}
}