1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use cauchy::Scalar;
use serde_derive::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Unary {
Neg,
Square,
Exp,
Ln,
Sin,
Cos,
Tan,
Sinh,
Cosh,
Tanh,
}
impl Unary {
pub fn eval_value<A: Scalar>(&self, arg: A) -> A {
match self {
Unary::Neg => -arg,
Unary::Square => arg.conj() * arg,
Unary::Exp => arg.exp(),
Unary::Ln => arg.ln(),
Unary::Sin => arg.sin(),
Unary::Cos => arg.cos(),
Unary::Tan => arg.tan(),
Unary::Sinh => arg.sinh(),
Unary::Cosh => arg.cosh(),
Unary::Tanh => arg.tanh(),
}
}
pub fn eval_deriv<A: Scalar>(&self, arg: A, deriv: A) -> A {
match self {
Unary::Neg => -deriv,
Unary::Square => A::from_f64(2.0).unwrap() * arg.conj() * deriv,
Unary::Exp => arg.exp() * deriv,
Unary::Ln => deriv / arg,
Unary::Sin => arg.cos() * deriv,
Unary::Cos => -arg.sin() * deriv,
Unary::Tan => -deriv / (arg.cos() * arg.cos()),
Unary::Sinh => arg.cosh() * deriv,
Unary::Cosh => arg.sinh() * deriv,
Unary::Tanh => deriv / (arg.cosh() * arg.cosh()),
}
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Binary {
Add,
Mul,
Div,
Pow,
}
impl Binary {
pub fn eval_value<A: Scalar>(&self, lhs: A, rhs: A) -> A {
match self {
Binary::Add => lhs + rhs,
Binary::Mul => lhs * rhs,
Binary::Div => lhs / rhs,
Binary::Pow => lhs.pow(rhs),
}
}
pub fn eval_deriv<A: Scalar>(&self, lhs: A, rhs: A, deriv: A) -> (A, A) {
match self {
Binary::Add => (deriv, deriv),
Binary::Mul => (rhs * deriv, lhs * deriv),
Binary::Div => (deriv / rhs, -lhs * deriv / (rhs * rhs)),
Binary::Pow => (
rhs * lhs.pow(rhs - A::from_f64(1.0).unwrap()),
rhs.ln() * lhs.pow(rhs),
),
}
}
}