ferrite/autograd/scalar/
operations.rs

1use super::{Graph, Scalar, Value};  // Import Graph from parent module
2use std::ops::{Add, Sub, Mul, Div};
3
4
5pub trait Operations<'a> {
6  fn add(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
7  fn sub(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
8  fn mul(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
9  fn div(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
10  fn exp(&self, lhs_val: Value) -> Value<'_>;
11  fn pow(&self, lhs_val: Value, rhs_val: Value) -> Value<'_>;
12
13}
14
15impl<'a> Operations<'a> for Graph {
16  fn add(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
17    let lhs = lhs_val.idx;
18    let rhs = rhs_val.idx;
19
20    let mut scalars = self.scalars.borrow_mut();
21
22    let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
23    let data = scalars[lhs].data + scalars[rhs].data;
24
25    let idx = scalars.len();
26    let scalar = Scalar::new(data, idx, &[lhs, rhs], "+", requires_grad);
27    scalars.push(scalar);
28    
29    Value {
30      idx,
31      graph: self,
32    }
33  }
34
35  fn sub(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
36    let lhs = lhs_val.idx;
37    let rhs = rhs_val.idx;
38
39    let mut scalars = self.scalars.borrow_mut();
40
41    let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
42    let data = scalars[lhs].data - scalars[rhs].data;
43
44    let idx = scalars.len();
45    let scalar = Scalar::new(data, idx, &[lhs, rhs], "-", requires_grad);
46    scalars.push(scalar);
47    
48    Value {
49      idx,
50      graph: self,
51    }
52  }
53
54
55  fn mul(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
56    let lhs = lhs_val.idx;
57    let rhs = rhs_val.idx;
58
59    let mut scalars = self.scalars.borrow_mut();
60
61    let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
62    let data = scalars[lhs].data * scalars[rhs].data;
63    let idx = scalars.len();
64    scalars.push(Scalar::new(data, idx, &[lhs, rhs], "*", requires_grad));
65    Value {
66      idx,
67      graph: self,
68    }
69  }
70
71  fn div(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
72    let lhs = lhs_val.idx;
73    let rhs = rhs_val.idx;
74    let mut scalars = self.scalars.borrow_mut();
75
76    let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
77    let data = scalars[lhs].data / scalars[rhs].data;
78
79    let idx = scalars.len();
80    scalars.push(Scalar::new(data, idx, &[lhs, rhs], "/", requires_grad));
81    Value {
82      idx,
83      graph: self,
84    }
85  }
86
87  fn exp(&self, lhs_val: Value) -> Value<'_> {
88    let lhs = lhs_val.idx;
89
90    let mut scalars = self.scalars.borrow_mut();
91
92    let requires_grad = scalars[lhs].requires_grad;
93    let data = scalars[lhs].data.exp();
94    let idx = scalars.len();
95    scalars.push(Scalar::new(data, idx, &[lhs], "exp", requires_grad));
96    Value {
97      idx,
98      graph: self,
99    }
100  }
101
102  fn pow(&self, lhs_val: Value, rhs_val: Value) -> Value<'_> {
103    let lhs = lhs_val.idx;
104    let rhs = rhs_val.idx;
105    let mut scalars = self.scalars.borrow_mut();
106
107    let requires_grad = scalars[lhs].requires_grad || scalars[rhs].requires_grad;
108    let data = f32::powf(scalars[lhs].data, scalars[rhs].data);
109
110    let idx = scalars.len();
111    scalars.push(Scalar::new(data, idx, &[lhs, rhs], "pow", requires_grad));
112    Value {
113      idx,
114      graph: self,
115    }
116  }
117}
118
119
120// ========================
121// Add Implementation
122// ========================
123
124impl<'a> Add<Value<'a>> for Value<'a> {
125  type Output = Value<'a>;
126
127  fn add(self, rhs: Self) -> Self::Output {
128    self.graph.add(self, rhs)
129  }
130}
131
132impl<'a> Add<f32> for Value<'a> {
133  type Output = Value<'a>;
134
135  fn add(self, rhs: f32) -> Self::Output {
136    let scalar = self.graph.scalar(rhs, false);
137    self.graph.add(self, scalar)
138  }
139}
140
141impl<'a> Add<Value<'a>> for f32 {
142  type Output = Value<'a>;
143
144  fn add(self, rhs: Value<'a>) -> Self::Output {
145    let scalar = rhs.graph.scalar(self, false);
146    rhs.graph.add(scalar, rhs)
147  }
148}
149
150// ========================
151// Sub Implementation
152// ========================
153
154impl<'a> Sub<Value<'a>> for Value<'a> {
155  type Output = Value<'a>;
156
157  fn sub(self, rhs: Self) -> Self::Output {
158    self.graph.sub(self, rhs)
159  }
160}
161
162impl<'a> Sub<f32> for Value<'a> {
163  type Output = Value<'a>;
164
165  fn sub(self, rhs: f32) -> Self::Output {
166    let scalar = self.graph.scalar(rhs, false);
167    self.graph.sub(self, scalar)
168  }
169}
170
171impl<'a> Sub<Value<'a>> for f32 {
172  type Output = Value<'a>;
173
174  fn sub(self, rhs: Value<'a>) -> Self::Output {
175    let scalar = rhs.graph.scalar(self, false);
176    rhs.graph.sub(scalar, rhs)
177  }
178}
179
180
181// ========================
182// Mul Implementation
183// ========================
184
185impl<'a> Mul<Value<'a>> for Value<'a> {
186  type Output = Value<'a>;
187
188  fn mul(self, rhs: Self) -> Self::Output {
189    self.graph.mul(self, rhs)
190  }
191}
192
193impl<'a> Mul<f32> for Value<'a> {
194  type Output = Value<'a>;
195
196  fn mul(self, rhs: f32) -> Self::Output {
197    let scalar = self.graph.scalar(rhs, false);
198    self.graph.mul(self, scalar)
199  }
200}
201
202impl<'a> Mul<Value<'a>> for f32 {
203  type Output = Value<'a>;
204
205  fn mul(self, rhs: Value<'a>) -> Self::Output {
206    let scalar = rhs.graph.scalar(self, false);
207    rhs.graph.mul(scalar, rhs)
208  }
209}
210
211
212
213// ========================
214// Div Implementation
215// ========================
216
217impl<'a> Div<Value<'a>> for Value<'a> {
218  type Output = Value<'a>;
219
220  fn div(self, rhs: Self) -> Self::Output {
221    self.graph.div(self, rhs)
222  }
223}
224
225impl<'a> Div<f32> for Value<'a> {
226  type Output = Value<'a>;
227
228  fn div(self, rhs: f32) -> Self::Output {
229    let scalar = self.graph.scalar(rhs, false);
230    self.graph.div(self, scalar)
231  }
232}
233
234impl<'a> Div<Value<'a>> for f32 {
235  type Output = Value<'a>;
236
237  fn div(self, rhs: Value<'a>) -> Self::Output {
238    let scalar = rhs.graph.scalar(self, false);
239    rhs.graph.div(scalar, rhs)
240  }
241}