nanograd/
value.rs

1use std::{
2    cell::{ Ref, RefCell },
3    collections::HashSet,
4    fmt::Debug,
5    hash::Hash,
6    iter::Sum,
7    ops::{ Add, Deref, Mul, Neg, Sub },
8    rc::Rc,
9};
10
11#[derive(Copy, Clone)]
12pub enum Operation {
13    Add,
14    Mul,
15    Tanh,
16    Exp,
17    None,
18}
19
20// implement hash trait for operation enum
21impl Hash for Operation {
22    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
23        match self {
24            Operation::Add => (0).hash(state),
25            Operation::Mul => (1).hash(state),
26            Operation::Tanh => (2).hash(state),
27            Operation::Exp => (3).hash(state),
28            Operation::None => (4).hash(state),
29        }
30    }
31}
32
33// test whether two operation enum vals are equal
34impl PartialEq for Operation {
35    fn eq(&self, other: &Self) -> bool {
36        match (self, other) {
37            (Operation::Add, Operation::Add) => true,
38            (Operation::Mul, Operation::Mul) => true,
39            (Operation::Tanh, Operation::Tanh) => true,
40            (Operation::None, Operation::None) => true,
41            _ => false,
42        }
43    }
44}
45
46impl Debug for Operation {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Operation::Add => write!(f, "Add"),
50            Operation::Mul => write!(f, "Mul"),
51            Operation::Tanh => write!(f, "Tanh"),
52            Operation::Exp => write!(f, "Exp"),
53            Operation::None => write!(f, "None"),
54        }
55    }
56}
57
58// wrapped struct code adopted from https://github.com/danielway/micrograd-rs
59#[derive(Clone, Eq, PartialEq, Debug)]
60pub struct Value(Rc<RefCell<ValueInternal>>);
61
62impl Value {
63    pub fn from<T>(t: T) -> Value where T: Into<Value> {
64        t.into()
65    }
66
67    fn new(value: ValueInternal) -> Value {
68        Value(Rc::new(RefCell::new(value)))
69    }
70
71    pub fn with_label(self, label: &str) -> Value {
72        self.borrow_mut().label = Some(label.to_string());
73        self
74    }
75
76    pub fn data(&self) -> f64 {
77        self.borrow().data
78    }
79
80    pub fn operation(&self) -> Operation {
81        self.borrow().operation
82    }
83
84    pub fn gradient(&self) -> f64 {
85        self.borrow().gradient
86    }
87
88    pub fn clear_gradient(&self) {
89        self.borrow_mut().gradient = 0.0;
90    }
91
92    pub fn adjust(&self, factor: f64) {
93        let mut value = self.borrow_mut();
94        value.data += factor * value.gradient;
95    }
96
97    pub fn pow(&self, other: &Value) -> Value {
98        let result = self.borrow().data.powf(other.borrow().data);
99
100        Value::new(
101            ValueInternal::new(result, None, Operation::Exp, vec![self.clone(), other.clone()])
102        )
103    }
104
105    pub fn tanh(&self) -> Value {
106        let result = self.borrow().data.tanh();
107
108        Value::new(ValueInternal::new(result, None, Operation::Tanh, vec![self.clone()]))
109    }
110
111    pub fn backward(&self) {
112        let mut visited: HashSet<Value> = HashSet::new();
113
114        self.borrow_mut().gradient = 1.0;
115        self.backward_internal(&mut visited, self);
116    }
117
118    fn backward_internal(&self, visited: &mut HashSet<Value>, value: &Value) {
119        if !visited.contains(&value) {
120            visited.insert(value.clone());
121
122            let borrowed_value = value.borrow();
123            if value.operation() != Operation::None {
124                backward_by_operation(&borrowed_value);
125            }
126
127            for child_id in &value.borrow().previous {
128                self.backward_internal(visited, child_id);
129            }
130        }
131    }
132
133    pub fn trace(&self) {
134        let mut visited: HashSet<Value> = HashSet::new();
135        println!("Tracing value...");
136        println!();
137        trace_internal(&mut visited, self);
138        println!();
139        println!("Done tracing value!")
140    }
141}
142
143impl Hash for Value {
144    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
145        self.0.borrow().hash(state);
146    }
147}
148
149impl Deref for Value {
150    type Target = Rc<RefCell<ValueInternal>>;
151
152    fn deref(&self) -> &Self::Target {
153        &self.0
154    }
155}
156
157impl<T: Into<f64>> From<T> for Value {
158    fn from(t: T) -> Value {
159        Value::new(ValueInternal::new(t.into(), None, Operation::None, Vec::new()))
160    }
161}
162
163impl Add<Value> for Value {
164    type Output = Value;
165
166    fn add(self, other: Value) -> Self::Output {
167        add(&self, &other)
168    }
169}
170
171impl<'a, 'b> Add<&'b Value> for &'a Value {
172    type Output = Value;
173
174    fn add(self, other: &'b Value) -> Self::Output {
175        add(self, other)
176    }
177}
178
179fn add(a: &Value, b: &Value) -> Value {
180    let result = a.borrow().data + b.borrow().data;
181
182    Value::new(ValueInternal::new(result, None, Operation::Add, vec![a.clone(), b.clone()]))
183}
184
185// defines addition for f64 and Value
186// in this case the f64 is on the right side of the addition
187impl Add<f64> for Value {
188    type Output = Value;
189
190    fn add(self, other: f64) -> Self::Output {
191        add(&self, &Value::from(other))
192    }
193}
194
195// in this case the f64 is on the right side of the addition
196impl Add<Value> for f64 {
197    type Output = Value;
198
199    fn add(self, other: Value) -> Self::Output {
200        add(&Value::from(self), &other)
201    }
202}
203
204impl Sub<Value> for Value {
205    type Output = Value;
206
207    fn sub(self, other: Value) -> Self::Output {
208        add(&self, &-other)
209    }
210}
211
212impl Mul<Value> for Value {
213    type Output = Value;
214
215    fn mul(self, other: Value) -> Self::Output {
216        mul(&self, &other)
217    }
218}
219
220impl<'a, 'b> Mul<&'b Value> for &'a Value {
221    type Output = Value;
222
223    fn mul(self, other: &'b Value) -> Self::Output {
224        mul(self, other)
225    }
226}
227
228// defines multiplication for f64 and Value
229// in this case the f64 is on the right side of the multiplication
230impl Mul<f64> for Value {
231    type Output = Value;
232
233    fn mul(self, other: f64) -> Self::Output {
234        mul(&self, &Value::from(other))
235    }
236}
237// in this case the f64 is on the right side of the multiplication
238impl Mul<Value> for f64 {
239    type Output = Value;
240
241    fn mul(self, other: Value) -> Self::Output {
242        mul(&Value::from(self), &other)
243    }
244}
245
246fn mul(a: &Value, b: &Value) -> Value {
247    let result = a.borrow().data * b.borrow().data;
248
249    Value::new(ValueInternal::new(result, None, Operation::Mul, vec![a.clone(), b.clone()]))
250}
251
252impl Neg for Value {
253    type Output = Value;
254
255    fn neg(self) -> Self::Output {
256        mul(&self, &Value::from(-1))
257    }
258}
259
260impl<'a> Neg for &'a Value {
261    type Output = Value;
262
263    fn neg(self) -> Self::Output {
264        mul(self, &Value::from(-1))
265    }
266}
267
268impl Sum for Value {
269    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
270        let mut sum = Value::from(0.0);
271        loop {
272            let val = iter.next();
273            if val.is_none() {
274                break;
275            }
276
277            sum = sum + val.unwrap();
278        }
279        sum
280    }
281}
282
283///
284/// Internal representation of a Value
285///
286/// # Fields
287/// * `data` - The data of the value
288/// * `gradient` - The gradient of the value
289/// * `label` - The label of the value
290/// * `operation` - The operation of the value
291/// * `previous` - The previous values of the value
292/// * `propagate` - The propagate function of the value
293///
294/// # Methods
295/// * `new` - Returns a new ValueInternal
296///
297pub struct ValueInternal {
298    data: f64,
299    gradient: f64,
300    label: Option<String>,
301    operation: Operation,
302    previous: Vec<Value>,
303}
304
305impl ValueInternal {
306    fn new(data: f64, label: Option<String>, op: Operation, prev: Vec<Value>) -> ValueInternal {
307        ValueInternal {
308            data,
309            gradient: 0.0,
310            label,
311            operation: op,
312            previous: prev,
313        }
314    }
315}
316
317impl PartialEq for ValueInternal {
318    fn eq(&self, other: &Self) -> bool {
319        self.data == other.data &&
320            self.gradient == other.gradient &&
321            self.label == other.label &&
322            self.operation == other.operation &&
323            self.previous == other.previous
324    }
325}
326
327impl Eq for ValueInternal {}
328
329impl Hash for ValueInternal {
330    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
331        self.data.to_bits().hash(state);
332        self.gradient.to_bits().hash(state);
333        self.label.hash(state);
334        self.operation.hash(state);
335        self.previous.hash(state);
336    }
337}
338
339impl Debug for ValueInternal {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.debug_struct("ValueInternal")
342            .field("data", &self.data)
343            .field("gradient", &self.gradient)
344            .field("label", &self.label)
345            .field("operation", &self.operation)
346            .field("previous", &self.previous)
347            .finish()
348    }
349}
350
351///
352/// Backward by operation
353///
354/// # Arguments
355/// * `val` - The value to apply the backward method to
356fn backward_by_operation(val: &Ref<ValueInternal>) {
357    match val.operation {
358        Operation::Tanh => {
359            let mut previous = val.previous[0].borrow_mut();
360            previous.gradient += (1.0 - val.data.powf(2.0)) * val.gradient;
361        }
362        Operation::Exp => {
363            let mut base = val.previous[0].borrow_mut();
364            let power = val.previous[1].borrow();
365            base.gradient += power.data * base.data.powf(power.data - 1.0) * val.gradient;
366        }
367        // gradient flows through plus signs
368        Operation::Add => {
369            let mut previous = val.previous[0].borrow_mut();
370            previous.gradient += 1.0 * val.gradient;
371            let mut previous = val.previous[1].borrow_mut();
372            previous.gradient += 1.0 * val.gradient;
373        }
374        Operation::Mul => {
375            let mut first = val.previous[0].borrow_mut();
376            let mut second = val.previous[1].borrow_mut();
377            first.gradient += second.data * val.gradient;
378            second.gradient += first.data * val.gradient;
379        }
380        Operation::None => {
381            println!("No operation when running backward method.");
382        }
383    }
384}
385
386fn trace_internal(visited: &mut HashSet<Value>, value: &Value) {
387    if !visited.contains(&value) {
388        visited.insert(value.clone());
389
390        let borrowed_value = value.borrow();
391        println!("{:?}", borrowed_value);
392        // separate values with a newline
393        println!();
394        for child_id in &value.borrow().previous {
395            trace_internal(visited, child_id);
396        }
397    }
398}