ferrite/autograd/scalar/
graph.rs

1// scalar.rs
2
3use std::cell::{Cell, RefCell};
4use std::fmt;
5
6#[derive(Debug, Clone, Copy)]
7pub struct Value<'a> {
8  pub idx: usize,
9  pub graph: &'a Graph,
10}
11
12impl<'a> Value<'a> {
13  pub fn new(idx: usize, graph: &'a Graph) -> Self {
14    Value { 
15      idx,
16      graph,
17    }
18  }
19
20  pub fn grad(&self) -> f32 {
21    self.graph.scalars.borrow()[self.idx].grad.get()
22  }
23
24  pub fn value(&self) -> f32 {
25    self.graph.scalars.borrow()[self.idx].data
26  }
27}
28
29#[derive(Debug)]
30pub struct Scalar {
31  pub data: f32,
32  pub idx: usize,
33  pub prev: Vec<usize>,
34  pub op: String,
35  pub grad: Cell<f32>,
36  pub requires_grad: bool
37}
38
39impl Scalar {
40  pub fn new(data: f32, idx: usize, prev: &[usize], op: &str, requires_grad: bool) -> Self {
41    Scalar {
42      data,
43      idx,
44      prev: prev.to_vec(),
45      op: String::from(op),
46      grad: Cell::new(0.),
47      requires_grad: requires_grad,
48    }
49  }
50}
51
52impl fmt::Display for Scalar {
53  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54    write!(f, "Scalar(data={}, grad={})", self.data, self.grad.get())
55  }
56}
57
58
59#[derive(Debug)]
60pub struct Graph {
61  pub scalars: RefCell<Vec<Scalar>>,
62}
63
64impl Graph {
65  pub fn new() -> Self {
66    Graph {
67      scalars: RefCell::new(Vec::new()),
68    }
69  }
70
71  pub fn scalar(&self, data: f32, requires_grad: bool) -> Value {
72    let mut scalars = self.scalars.borrow_mut();
73    let idx = scalars.len();
74    scalars.push(Scalar::new(data, idx, &[], "", requires_grad));
75    
76    Value {
77      idx: idx,
78      graph: self,
79    }
80  }
81
82  pub fn get_value(&self, value: &Value) -> f32 {
83    let id = value.idx;
84    self.scalars.borrow()[id].data
85  }
86
87  pub fn get_op(&self, value: &Value) -> String {
88    let id = value.idx;
89    self.scalars.borrow()[id].op.clone()
90  }
91
92  pub fn get_grad(&self, value: &Value) -> f32 {
93    let id = value.idx;
94    self.scalars.borrow()[id].grad.get()
95  }
96}