acme_tensor/impls/
grad.rs

1/*
2    Appellation: grad <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::actions::grad::TensorGrad;
6use crate::prelude::{ScalarExt, TensorExpr, TensorId, TensorResult};
7use crate::TensorBase;
8use acme::prelude::{Arithmetic, BinaryOp, Store, UnaryOp};
9
10pub(crate) type Visited<K = TensorId> = std::collections::HashMap<K, bool>;
11
12macro_rules! entry {
13    ($ctx:expr, $entry:expr) => {
14        entry!($ctx, $entry, $entry.zeros_like())
15    };
16    ($ctx:expr, $entry:expr, $default:expr) => {
17        $ctx.entry($entry.id()).or_insert($default)
18    };
19}
20
21impl<T> TensorBase<T>
22where
23    T: ScalarExt,
24{
25    /// toposort is a function which sorts the nodes of the op graph in topological order.
26    fn toposort(&self, reverse: bool) -> Vec<&TensorBase<T>> {
27        // Here, the sorted nodes are passed as an owned value rather than as a mutable reference to workaround some lifetime limitations.
28        fn walk<'a, T>(
29            scope: &'a TensorBase<T>,
30            nodes: Vec<&'a TensorBase<T>>,
31            visited: &mut Visited<TensorId>,
32        ) -> (bool, Vec<&'a TensorBase<T>>) {
33            if let Some(&tg) = visited.get(&scope.id()) {
34                return (tg, nodes);
35            }
36            // track the gradient of the current node
37            let mut track = false;
38            // recursively call on the children nodes
39            let mut nodes = if scope.is_variable() {
40                // Do not call recursively on the "leaf" nodes.
41                track = true;
42                nodes
43            } else if let Some(op) = scope.op().op() {
44                match op {
45                    TensorExpr::Binary(lhs, rhs, _kind) => {
46                        let (tg, nodes) = walk(lhs, nodes, visited);
47                        track |= tg;
48                        let (tg, nodes) = walk(rhs, nodes, visited);
49                        track |= tg;
50                        nodes
51                    }
52                    TensorExpr::Unary(a, _kind) => {
53                        let (tg, nodes) = walk(a, nodes, visited);
54                        track |= tg;
55                        nodes
56                    }
57                    _ => nodes,
58                }
59            } else {
60                nodes
61            };
62            visited.insert(scope.id(), track);
63            if track {
64                nodes.push(scope);
65            }
66            (track, nodes)
67        }
68        // walk through the dag
69        let (_tg, mut nodes) = walk(self, Vec::new(), &mut Visited::new());
70        // reverse the nodes; if needed
71        if reverse {
72            nodes.reverse();
73        }
74        // return the sorted nodes
75        nodes
76    }
77    /// Compute the gradient of the tensor
78    pub fn grad(&self) -> TensorResult<TensorGrad<T>> {
79        // get the sorted nodes
80        let sorted = self.toposort(true);
81        // initialize a new gradient store
82        let mut store = TensorGrad::new();
83        // insert the gradient w.r.t. the current node
84        store.insert(self.id(), self.ones_like());
85
86        for node in sorted.iter() {
87            if node.is_variable() {
88                continue;
89            }
90            // get the gradient of the node
91            let grad = store.remove(&node.id()).expect("Gradient not found");
92            // detach the gradient
93            let grad = grad.detach();
94            // handle the different types of operations
95            if let Some(op) = &*node.op {
96                match op {
97                    TensorExpr::Binary(lhs, rhs, kind) => match kind {
98                        BinaryOp::Arith(inner) => match inner {
99                            Arithmetic::Add(_) => {
100                                *entry!(store, lhs) += &grad;
101                                *entry!(store, rhs) += &grad;
102                            }
103                            Arithmetic::Div(_) => {
104                                *entry!(store, lhs) += &grad / rhs.as_ref();
105                                *entry!(store, rhs) -= &grad * lhs.as_ref() / rhs.sqr();
106                            }
107                            Arithmetic::Mul(_) => {
108                                *entry!(store, lhs) += &grad * rhs.as_ref();
109                                *entry!(store, rhs) += &grad * lhs.as_ref();
110                            }
111                            Arithmetic::Sub(_) => {
112                                *entry!(store, lhs) += &grad;
113                                *entry!(store, rhs) -= &grad;
114                            }
115                            _ => todo!(),
116                        },
117                        _ => todo!(),
118                    },
119                    TensorExpr::BinaryScalar(lhs, rhs, kind) => match kind {
120                        BinaryOp::Arith(inner) => match inner {
121                            Arithmetic::Add(_) => {
122                                *entry!(store, lhs) += &grad;
123                            }
124                            Arithmetic::Div(_) => {
125                                *entry!(store, lhs) += &grad / *rhs;
126                            }
127                            Arithmetic::Mul(_) => {
128                                *entry!(store, lhs) += &grad * *rhs;
129                            }
130                            Arithmetic::Pow(_) => {
131                                *entry!(store, lhs) += &grad * *rhs * lhs.pow(*rhs - T::one());
132                            }
133                            Arithmetic::Sub(_) => {
134                                *entry!(store, lhs) += &grad;
135                            }
136                            _ => todo!(),
137                        },
138                        _ => todo!(),
139                    },
140                    TensorExpr::Unary(val, kind) => match kind {
141                        UnaryOp::Cos => {
142                            *entry!(store, val) -= &grad * val.sin();
143                        }
144                        UnaryOp::Cosh => {
145                            *entry!(store, val) += &grad * val.sinh();
146                        }
147                        UnaryOp::Exp => {
148                            *entry!(store, val) += &grad * val.exp();
149                        }
150                        UnaryOp::Neg => {
151                            *entry!(store, val) -= &grad;
152                        }
153                        UnaryOp::Recip => {
154                            *entry!(store, val) -= &grad / val.sqr();
155                        }
156
157                        UnaryOp::Sin => {
158                            *entry!(store, val) += &grad * val.cos();
159                        }
160                        UnaryOp::Sinh => {
161                            *entry!(store, val) += &grad * val.cosh();
162                        }
163                        UnaryOp::Sqrt => {
164                            *entry!(store, val) +=
165                                &grad / (val.clone().sqrt() * T::from(2).unwrap());
166                        }
167                        UnaryOp::Tan => {
168                            *entry!(store, val) += &grad / val.clone().cos().sqr();
169                        }
170
171                        _ => {}
172                    },
173                    TensorExpr::Sigmoid(val) => {
174                        let tmp = val.detach();
175                        *entry!(store, val) +=
176                            &grad * tmp.sigmoid() * (tmp.ones_like() - tmp.sigmoid());
177                    }
178                    _ => {}
179                }
180            }
181        }
182
183        Ok(store)
184    }
185}