1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//! Value and operators in calculation graph

use cauchy::Scalar;
use serde_derive::{Deserialize, Serialize};

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Unary {
    Neg,
    Square,
    Exp,
    Ln,
    Sin,
    Cos,
    Tan,
    Sinh,
    Cosh,
    Tanh,
}

impl Unary {
    /// Evaluate the result value of the operator
    pub fn eval_value<A: Scalar>(&self, arg: A) -> A {
        match self {
            Unary::Neg => -arg,
            Unary::Square => arg.conj() * arg,
            Unary::Exp => arg.exp(),
            Unary::Ln => arg.ln(),
            Unary::Sin => arg.sin(),
            Unary::Cos => arg.cos(),
            Unary::Tan => arg.tan(),
            Unary::Sinh => arg.sinh(),
            Unary::Cosh => arg.cosh(),
            Unary::Tanh => arg.tanh(),
        }
    }
    /// Evaluate the derivative of the operator multiplied by the received
    /// derivative from upper of the graph.
    pub fn eval_deriv<A: Scalar>(&self, arg: A, deriv: A) -> A {
        match self {
            Unary::Neg => -deriv,
            Unary::Square => A::from_f64(2.0).unwrap() * arg.conj() * deriv,
            Unary::Exp => arg.exp() * deriv,
            Unary::Ln => deriv / arg,
            Unary::Sin => arg.cos() * deriv,
            Unary::Cos => -arg.sin() * deriv,
            Unary::Tan => -deriv / (arg.cos() * arg.cos()),
            Unary::Sinh => arg.cosh() * deriv,
            Unary::Cosh => arg.sinh() * deriv,
            Unary::Tanh => deriv / (arg.cosh() * arg.cosh()),
        }
    }
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Binary {
    Add,
    Mul,
    Div,
    Pow,
}

impl Binary {
    /// Evaluate the result value of the operator
    pub fn eval_value<A: Scalar>(&self, lhs: A, rhs: A) -> A {
        match self {
            Binary::Add => lhs + rhs,
            Binary::Mul => lhs * rhs,
            Binary::Div => lhs / rhs,
            Binary::Pow => lhs.pow(rhs),
        }
    }
    /// Evaluate the derivative of the operator multiplied by the received
    /// derivative from upper of the graph.
    pub fn eval_deriv<A: Scalar>(&self, lhs: A, rhs: A, deriv: A) -> (A, A) {
        match self {
            Binary::Add => (deriv, deriv),
            Binary::Mul => (rhs * deriv, lhs * deriv),
            Binary::Div => (deriv / rhs, -lhs * deriv / (rhs * rhs)),
            Binary::Pow => (
                rhs * lhs.pow(rhs - A::from_f64(1.0).unwrap()),
                rhs.ln() * lhs.pow(rhs),
            ),
        }
    }
}