1use std::{
2 cell::{ Ref, RefCell },
3 collections::HashSet,
4 fmt::Debug,
5 hash::Hash,
6 iter::Sum,
7 ops::{ Add, Deref, Mul, Neg, Sub },
8 rc::Rc,
9};
10
11#[derive(Copy, Clone)]
12pub enum Operation {
13 Add,
14 Mul,
15 Tanh,
16 Exp,
17 None,
18}
19
20impl Hash for Operation {
22 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
23 match self {
24 Operation::Add => (0).hash(state),
25 Operation::Mul => (1).hash(state),
26 Operation::Tanh => (2).hash(state),
27 Operation::Exp => (3).hash(state),
28 Operation::None => (4).hash(state),
29 }
30 }
31}
32
33impl PartialEq for Operation {
35 fn eq(&self, other: &Self) -> bool {
36 match (self, other) {
37 (Operation::Add, Operation::Add) => true,
38 (Operation::Mul, Operation::Mul) => true,
39 (Operation::Tanh, Operation::Tanh) => true,
40 (Operation::None, Operation::None) => true,
41 _ => false,
42 }
43 }
44}
45
46impl Debug for Operation {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Operation::Add => write!(f, "Add"),
50 Operation::Mul => write!(f, "Mul"),
51 Operation::Tanh => write!(f, "Tanh"),
52 Operation::Exp => write!(f, "Exp"),
53 Operation::None => write!(f, "None"),
54 }
55 }
56}
57
58#[derive(Clone, Eq, PartialEq, Debug)]
60pub struct Value(Rc<RefCell<ValueInternal>>);
61
62impl Value {
63 pub fn from<T>(t: T) -> Value where T: Into<Value> {
64 t.into()
65 }
66
67 fn new(value: ValueInternal) -> Value {
68 Value(Rc::new(RefCell::new(value)))
69 }
70
71 pub fn with_label(self, label: &str) -> Value {
72 self.borrow_mut().label = Some(label.to_string());
73 self
74 }
75
76 pub fn data(&self) -> f64 {
77 self.borrow().data
78 }
79
80 pub fn operation(&self) -> Operation {
81 self.borrow().operation
82 }
83
84 pub fn gradient(&self) -> f64 {
85 self.borrow().gradient
86 }
87
88 pub fn clear_gradient(&self) {
89 self.borrow_mut().gradient = 0.0;
90 }
91
92 pub fn adjust(&self, factor: f64) {
93 let mut value = self.borrow_mut();
94 value.data += factor * value.gradient;
95 }
96
97 pub fn pow(&self, other: &Value) -> Value {
98 let result = self.borrow().data.powf(other.borrow().data);
99
100 Value::new(
101 ValueInternal::new(result, None, Operation::Exp, vec![self.clone(), other.clone()])
102 )
103 }
104
105 pub fn tanh(&self) -> Value {
106 let result = self.borrow().data.tanh();
107
108 Value::new(ValueInternal::new(result, None, Operation::Tanh, vec![self.clone()]))
109 }
110
111 pub fn backward(&self) {
112 let mut visited: HashSet<Value> = HashSet::new();
113
114 self.borrow_mut().gradient = 1.0;
115 self.backward_internal(&mut visited, self);
116 }
117
118 fn backward_internal(&self, visited: &mut HashSet<Value>, value: &Value) {
119 if !visited.contains(&value) {
120 visited.insert(value.clone());
121
122 let borrowed_value = value.borrow();
123 if value.operation() != Operation::None {
124 backward_by_operation(&borrowed_value);
125 }
126
127 for child_id in &value.borrow().previous {
128 self.backward_internal(visited, child_id);
129 }
130 }
131 }
132
133 pub fn trace(&self) {
134 let mut visited: HashSet<Value> = HashSet::new();
135 println!("Tracing value...");
136 println!();
137 trace_internal(&mut visited, self);
138 println!();
139 println!("Done tracing value!")
140 }
141}
142
143impl Hash for Value {
144 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
145 self.0.borrow().hash(state);
146 }
147}
148
149impl Deref for Value {
150 type Target = Rc<RefCell<ValueInternal>>;
151
152 fn deref(&self) -> &Self::Target {
153 &self.0
154 }
155}
156
157impl<T: Into<f64>> From<T> for Value {
158 fn from(t: T) -> Value {
159 Value::new(ValueInternal::new(t.into(), None, Operation::None, Vec::new()))
160 }
161}
162
163impl Add<Value> for Value {
164 type Output = Value;
165
166 fn add(self, other: Value) -> Self::Output {
167 add(&self, &other)
168 }
169}
170
171impl<'a, 'b> Add<&'b Value> for &'a Value {
172 type Output = Value;
173
174 fn add(self, other: &'b Value) -> Self::Output {
175 add(self, other)
176 }
177}
178
179fn add(a: &Value, b: &Value) -> Value {
180 let result = a.borrow().data + b.borrow().data;
181
182 Value::new(ValueInternal::new(result, None, Operation::Add, vec![a.clone(), b.clone()]))
183}
184
185impl Add<f64> for Value {
188 type Output = Value;
189
190 fn add(self, other: f64) -> Self::Output {
191 add(&self, &Value::from(other))
192 }
193}
194
195impl Add<Value> for f64 {
197 type Output = Value;
198
199 fn add(self, other: Value) -> Self::Output {
200 add(&Value::from(self), &other)
201 }
202}
203
204impl Sub<Value> for Value {
205 type Output = Value;
206
207 fn sub(self, other: Value) -> Self::Output {
208 add(&self, &-other)
209 }
210}
211
212impl Mul<Value> for Value {
213 type Output = Value;
214
215 fn mul(self, other: Value) -> Self::Output {
216 mul(&self, &other)
217 }
218}
219
220impl<'a, 'b> Mul<&'b Value> for &'a Value {
221 type Output = Value;
222
223 fn mul(self, other: &'b Value) -> Self::Output {
224 mul(self, other)
225 }
226}
227
228impl Mul<f64> for Value {
231 type Output = Value;
232
233 fn mul(self, other: f64) -> Self::Output {
234 mul(&self, &Value::from(other))
235 }
236}
237impl Mul<Value> for f64 {
239 type Output = Value;
240
241 fn mul(self, other: Value) -> Self::Output {
242 mul(&Value::from(self), &other)
243 }
244}
245
246fn mul(a: &Value, b: &Value) -> Value {
247 let result = a.borrow().data * b.borrow().data;
248
249 Value::new(ValueInternal::new(result, None, Operation::Mul, vec![a.clone(), b.clone()]))
250}
251
252impl Neg for Value {
253 type Output = Value;
254
255 fn neg(self) -> Self::Output {
256 mul(&self, &Value::from(-1))
257 }
258}
259
260impl<'a> Neg for &'a Value {
261 type Output = Value;
262
263 fn neg(self) -> Self::Output {
264 mul(self, &Value::from(-1))
265 }
266}
267
268impl Sum for Value {
269 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
270 let mut sum = Value::from(0.0);
271 loop {
272 let val = iter.next();
273 if val.is_none() {
274 break;
275 }
276
277 sum = sum + val.unwrap();
278 }
279 sum
280 }
281}
282
283pub struct ValueInternal {
298 data: f64,
299 gradient: f64,
300 label: Option<String>,
301 operation: Operation,
302 previous: Vec<Value>,
303}
304
305impl ValueInternal {
306 fn new(data: f64, label: Option<String>, op: Operation, prev: Vec<Value>) -> ValueInternal {
307 ValueInternal {
308 data,
309 gradient: 0.0,
310 label,
311 operation: op,
312 previous: prev,
313 }
314 }
315}
316
317impl PartialEq for ValueInternal {
318 fn eq(&self, other: &Self) -> bool {
319 self.data == other.data &&
320 self.gradient == other.gradient &&
321 self.label == other.label &&
322 self.operation == other.operation &&
323 self.previous == other.previous
324 }
325}
326
327impl Eq for ValueInternal {}
328
329impl Hash for ValueInternal {
330 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
331 self.data.to_bits().hash(state);
332 self.gradient.to_bits().hash(state);
333 self.label.hash(state);
334 self.operation.hash(state);
335 self.previous.hash(state);
336 }
337}
338
339impl Debug for ValueInternal {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.debug_struct("ValueInternal")
342 .field("data", &self.data)
343 .field("gradient", &self.gradient)
344 .field("label", &self.label)
345 .field("operation", &self.operation)
346 .field("previous", &self.previous)
347 .finish()
348 }
349}
350
351fn backward_by_operation(val: &Ref<ValueInternal>) {
357 match val.operation {
358 Operation::Tanh => {
359 let mut previous = val.previous[0].borrow_mut();
360 previous.gradient += (1.0 - val.data.powf(2.0)) * val.gradient;
361 }
362 Operation::Exp => {
363 let mut base = val.previous[0].borrow_mut();
364 let power = val.previous[1].borrow();
365 base.gradient += power.data * base.data.powf(power.data - 1.0) * val.gradient;
366 }
367 Operation::Add => {
369 let mut previous = val.previous[0].borrow_mut();
370 previous.gradient += 1.0 * val.gradient;
371 let mut previous = val.previous[1].borrow_mut();
372 previous.gradient += 1.0 * val.gradient;
373 }
374 Operation::Mul => {
375 let mut first = val.previous[0].borrow_mut();
376 let mut second = val.previous[1].borrow_mut();
377 first.gradient += second.data * val.gradient;
378 second.gradient += first.data * val.gradient;
379 }
380 Operation::None => {
381 println!("No operation when running backward method.");
382 }
383 }
384}
385
386fn trace_internal(visited: &mut HashSet<Value>, value: &Value) {
387 if !visited.contains(&value) {
388 visited.insert(value.clone());
389
390 let borrowed_value = value.borrow();
391 println!("{:?}", borrowed_value);
392 println!();
394 for child_id in &value.borrow().previous {
395 trace_internal(visited, child_id);
396 }
397 }
398}