ferrite/autograd/scalar/
graph.rs1use std::cell::{Cell, RefCell};
4use std::fmt;
5
6#[derive(Debug, Clone, Copy)]
7pub struct Value<'a> {
8 pub idx: usize,
9 pub graph: &'a Graph,
10}
11
12impl<'a> Value<'a> {
13 pub fn new(idx: usize, graph: &'a Graph) -> Self {
14 Value {
15 idx,
16 graph,
17 }
18 }
19
20 pub fn grad(&self) -> f32 {
21 self.graph.scalars.borrow()[self.idx].grad.get()
22 }
23
24 pub fn value(&self) -> f32 {
25 self.graph.scalars.borrow()[self.idx].data
26 }
27}
28
29#[derive(Debug)]
30pub struct Scalar {
31 pub data: f32,
32 pub idx: usize,
33 pub prev: Vec<usize>,
34 pub op: String,
35 pub grad: Cell<f32>,
36 pub requires_grad: bool
37}
38
39impl Scalar {
40 pub fn new(data: f32, idx: usize, prev: &[usize], op: &str, requires_grad: bool) -> Self {
41 Scalar {
42 data,
43 idx,
44 prev: prev.to_vec(),
45 op: String::from(op),
46 grad: Cell::new(0.),
47 requires_grad: requires_grad,
48 }
49 }
50}
51
52impl fmt::Display for Scalar {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "Scalar(data={}, grad={})", self.data, self.grad.get())
55 }
56}
57
58
59#[derive(Debug)]
60pub struct Graph {
61 pub scalars: RefCell<Vec<Scalar>>,
62}
63
64impl Graph {
65 pub fn new() -> Self {
66 Graph {
67 scalars: RefCell::new(Vec::new()),
68 }
69 }
70
71 pub fn scalar(&self, data: f32, requires_grad: bool) -> Value {
72 let mut scalars = self.scalars.borrow_mut();
73 let idx = scalars.len();
74 scalars.push(Scalar::new(data, idx, &[], "", requires_grad));
75
76 Value {
77 idx: idx,
78 graph: self,
79 }
80 }
81
82 pub fn get_value(&self, value: &Value) -> f32 {
83 let id = value.idx;
84 self.scalars.borrow()[id].data
85 }
86
87 pub fn get_op(&self, value: &Value) -> String {
88 let id = value.idx;
89 self.scalars.borrow()[id].op.clone()
90 }
91
92 pub fn get_grad(&self, value: &Value) -> f32 {
93 let id = value.idx;
94 self.scalars.borrow()[id].grad.get()
95 }
96}