1use super::{Graph, Scalar, Value}; use std::ops::{Add, Sub, Mul, Div};
3
4
5pub trait Operations<'a> {
6 fn add(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
7 fn sub(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
8 fn mul(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
9 fn div(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
10 fn exp(&self, lhs_val: Value) -> Value<'_>;
11 fn pow(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
12
13}
14
15impl<'a> Operations<'a> for Graph {
16 fn add(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
17 let lhs = lhs_val.idx;
18 let rhs = rhs_val.idx;
19
20 let mut scalars = self.scalars.borrow_mut();
21
22 let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
23 let data = scalars[lhs].data + scalars[rhs].data;
24
25 let idx = scalars.len();
26 let scalar = Scalar::new(data, idx, &[lhs, rhs], "+", requires_grad);
27 scalars.push(scalar);
28
29 Value {
30 idx,
31 graph: self,
32 }
33 }
34
35 fn sub(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
36 let lhs = lhs_val.idx;
37 let rhs = rhs_val.idx;
38
39 let mut scalars = self.scalars.borrow_mut();
40
41 let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
42 let data = scalars[lhs].data - scalars[rhs].data;
43
44 let idx = scalars.len();
45 let scalar = Scalar::new(data, idx, &[lhs, rhs], "-", requires_grad);
46 scalars.push(scalar);
47
48 Value {
49 idx,
50 graph: self,
51 }
52 }
53
54
55 fn mul(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
56 let lhs = lhs_val.idx;
57 let rhs = rhs_val.idx;
58
59 let mut scalars = self.scalars.borrow_mut();
60
61 let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
62 let data = scalars[lhs].data * scalars[rhs].data;
63 let idx = scalars.len();
64 scalars.push(Scalar::new(data, idx, &[lhs, rhs], "*", requires_grad));
65 Value {
66 idx,
67 graph: self,
68 }
69 }
70
71 fn div(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
72 let lhs = lhs_val.idx;
73 let rhs = rhs_val.idx;
74 let mut scalars = self.scalars.borrow_mut();
75
76 let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
77 let data = scalars[lhs].data / scalars[rhs].data;
78
79 let idx = scalars.len();
80 scalars.push(Scalar::new(data, idx, &[lhs, rhs], "/", requires_grad));
81 Value {
82 idx,
83 graph: self,
84 }
85 }
86
87 fn exp(&self, lhs_val: Value) -> Value<'_> {
88 let lhs = lhs_val.idx;
89
90 let mut scalars = self.scalars.borrow_mut();
91
92 let requires_grad = scalars[lhs].requires_grad;
93 let data = scalars[lhs].data.exp();
94 let idx = scalars.len();
95 scalars.push(Scalar::new(data, idx, &[lhs], "exp", requires_grad));
96 Value {
97 idx,
98 graph: self,
99 }
100 }
101
102 fn pow(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
103 let lhs = lhs_val.idx;
104 let rhs = rhs_val.idx;
105 let mut scalars = self.scalars.borrow_mut();
106
107 let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
108 let data = f32::powf(scalars[lhs].data, scalars[rhs].data);
109
110 let idx = scalars.len();
111 scalars.push(Scalar::new(data, idx, &[lhs, rhs], "pow", requires_grad));
112 Value {
113 idx,
114 graph: self,
115 }
116 }
117}
118
119
120impl<'a> Add<Value<'a>> for Value<'a> {
125 type Output = Value<'a>;
126
127 fn add(self, rhs: Self) -> Self::Output {
128 self.graph.add(self, rhs)
129 }
130}
131
132impl<'a> Add<f32> for Value<'a> {
133 type Output = Value<'a>;
134
135 fn add(self, rhs: f32) -> Self::Output {
136 let scalar = self.graph.scalar(rhs, false);
137 self.graph.add(self, scalar)
138 }
139}
140
141impl<'a> Add<Value<'a>> for f32 {
142 type Output = Value<'a>;
143
144 fn add(self, rhs: Value<'a>) -> Self::Output {
145 let scalar = rhs.graph.scalar(self, false);
146 rhs.graph.add(scalar, rhs)
147 }
148}
149
150impl<'a> Sub<Value<'a>> for Value<'a> {
155 type Output = Value<'a>;
156
157 fn sub(self, rhs: Self) -> Self::Output {
158 self.graph.sub(self, rhs)
159 }
160}
161
162impl<'a> Sub<f32> for Value<'a> {
163 type Output = Value<'a>;
164
165 fn sub(self, rhs: f32) -> Self::Output {
166 let scalar = self.graph.scalar(rhs, false);
167 self.graph.sub(self, scalar)
168 }
169}
170
171impl<'a> Sub<Value<'a>> for f32 {
172 type Output = Value<'a>;
173
174 fn sub(self, rhs: Value<'a>) -> Self::Output {
175 let scalar = rhs.graph.scalar(self, false);
176 rhs.graph.sub(scalar, rhs)
177 }
178}
179
180
181impl<'a> Mul<Value<'a>> for Value<'a> {
186 type Output = Value<'a>;
187
188 fn mul(self, rhs: Self) -> Self::Output {
189 self.graph.mul(self, rhs)
190 }
191}
192
193impl<'a> Mul<f32> for Value<'a> {
194 type Output = Value<'a>;
195
196 fn mul(self, rhs: f32) -> Self::Output {
197 let scalar = self.graph.scalar(rhs, false);
198 self.graph.mul(self, scalar)
199 }
200}
201
202impl<'a> Mul<Value<'a>> for f32 {
203 type Output = Value<'a>;
204
205 fn mul(self, rhs: Value<'a>) -> Self::Output {
206 let scalar = rhs.graph.scalar(self, false);
207 rhs.graph.mul(scalar, rhs)
208 }
209}
210
211
212
213impl<'a> Div<Value<'a>> for Value<'a> {
218 type Output = Value<'a>;
219
220 fn div(self, rhs: Self) -> Self::Output {
221 self.graph.div(self, rhs)
222 }
223}
224
225impl<'a> Div<f32> for Value<'a> {
226 type Output = Value<'a>;
227
228 fn div(self, rhs: f32) -> Self::Output {
229 let scalar = self.graph.scalar(rhs, false);
230 self.graph.div(self, scalar)
231 }
232}
233
234impl<'a> Div<Value<'a>> for f32 {
235 type Output = Value<'a>;
236
237 fn div(self, rhs: Value<'a>) -> Self::Output {
238 let scalar = rhs.graph.scalar(self, false);
239 rhs.graph.div(scalar, rhs)
240 }
241}