1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::fmt::Debug;
use ndarray::*;
use ndarray_linalg::*;
use error::*;
pub mod neg;
pub mod add;
#[derive(Debug, Clone, IntoEnum)]
pub enum Value<A: Scalar> {
Scalar(A),
Vector(Array1<A>),
Matrix(Array2<A>),
}
impl<A: Scalar> Value<A> {
pub fn as_scalar(&self) -> Result<A> {
match *self {
Value::Scalar(a) => Ok(a),
_ => Err(CastError {}.into()),
}
}
pub fn identity() -> Self {
Value::Scalar(A::from_f64(1.0))
}
}
pub trait UnaryOperator<A: Scalar>: Clone + Debug {
fn eval_value(&self, arg: &Value<A>) -> Result<Value<A>>;
fn eval_deriv(&self, arg: &Value<A>, deriv: &Value<A>) -> Result<Value<A>>;
}
#[derive(Debug, Clone, Copy, IntoEnum)]
pub enum UnaryOperatorAny {
Neg(neg::Neg),
}
pub fn neg() -> UnaryOperatorAny {
neg::Neg {}.into()
}
impl<A: Scalar> UnaryOperator<A> for UnaryOperatorAny {
fn eval_value(&self, arg: &Value<A>) -> Result<Value<A>> {
match self {
&UnaryOperatorAny::Neg(op) => op.eval_value(arg),
}
}
fn eval_deriv(&self, arg: &Value<A>, deriv: &Value<A>) -> Result<Value<A>> {
match self {
&UnaryOperatorAny::Neg(op) => op.eval_deriv(arg, deriv),
}
}
}
pub trait BinaryOperator<A: Scalar>: Clone + Debug {
fn eval_value(&self, lhs: &Value<A>, rhs: &Value<A>) -> Result<Value<A>>;
fn eval_deriv(
&self,
lhs: &Value<A>,
rhs: &Value<A>,
deriv: &Value<A>,
) -> Result<(Value<A>, Value<A>)>;
}
#[derive(Debug, Clone, Copy, IntoEnum)]
pub enum BinaryOperatorAny {
Add(add::Add),
}
pub fn add() -> BinaryOperatorAny {
add::Add {}.into()
}
impl<A: Scalar> BinaryOperator<A> for BinaryOperatorAny {
fn eval_value(&self, lhs: &Value<A>, rhs: &Value<A>) -> Result<Value<A>> {
match self {
&BinaryOperatorAny::Add(op) => op.eval_value(lhs, rhs),
}
}
fn eval_deriv(
&self,
lhs: &Value<A>,
rhs: &Value<A>,
deriv: &Value<A>,
) -> Result<(Value<A>, Value<A>)> {
match self {
&BinaryOperatorAny::Add(op) => op.eval_deriv(lhs, rhs, deriv),
}
}
}