acme_tensor/impls/
grad.rs1use crate::actions::grad::TensorGrad;
6use crate::prelude::{ScalarExt, TensorExpr, TensorId, TensorResult};
7use crate::TensorBase;
8use acme::prelude::{Arithmetic, BinaryOp, Store, UnaryOp};
9
10pub(crate) type Visited<K = TensorId> = std::collections::HashMap<K, bool>;
11
12macro_rules! entry {
13 ($ctx:expr, $entry:expr) => {
14 entry!($ctx, $entry, $entry.zeros_like())
15 };
16 ($ctx:expr, $entry:expr, $default:expr) => {
17 $ctx.entry($entry.id()).or_insert($default)
18 };
19}
20
21impl<T> TensorBase<T>
22where
23 T: ScalarExt,
24{
25 fn toposort(&self, reverse: bool) -> Vec<&TensorBase<T>> {
27 fn walk<'a, T>(
29 scope: &'a TensorBase<T>,
30 nodes: Vec<&'a TensorBase<T>>,
31 visited: &mut Visited<TensorId>,
32 ) -> (bool, Vec<&'a TensorBase<T>>) {
33 if let Some(&tg) = visited.get(&scope.id()) {
34 return (tg, nodes);
35 }
36 let mut track = false;
38 let mut nodes = if scope.is_variable() {
40 track = true;
42 nodes
43 } else if let Some(op) = scope.op().op() {
44 match op {
45 TensorExpr::Binary(lhs, rhs, _kind) => {
46 let (tg, nodes) = walk(lhs, nodes, visited);
47 track |= tg;
48 let (tg, nodes) = walk(rhs, nodes, visited);
49 track |= tg;
50 nodes
51 }
52 TensorExpr::Unary(a, _kind) => {
53 let (tg, nodes) = walk(a, nodes, visited);
54 track |= tg;
55 nodes
56 }
57 _ => nodes,
58 }
59 } else {
60 nodes
61 };
62 visited.insert(scope.id(), track);
63 if track {
64 nodes.push(scope);
65 }
66 (track, nodes)
67 }
68 let (_tg, mut nodes) = walk(self, Vec::new(), &mut Visited::new());
70 if reverse {
72 nodes.reverse();
73 }
74 nodes
76 }
77 pub fn grad(&self) -> TensorResult<TensorGrad<T>> {
79 let sorted = self.toposort(true);
81 let mut store = TensorGrad::new();
83 store.insert(self.id(), self.ones_like());
85
86 for node in sorted.iter() {
87 if node.is_variable() {
88 continue;
89 }
90 let grad = store.remove(&node.id()).expect("Gradient not found");
92 let grad = grad.detach();
94 if let Some(op) = &*node.op {
96 match op {
97 TensorExpr::Binary(lhs, rhs, kind) => match kind {
98 BinaryOp::Arith(inner) => match inner {
99 Arithmetic::Add(_) => {
100 *entry!(store, lhs) += &grad;
101 *entry!(store, rhs) += &grad;
102 }
103 Arithmetic::Div(_) => {
104 *entry!(store, lhs) += &grad / rhs.as_ref();
105 *entry!(store, rhs) -= &grad * lhs.as_ref() / rhs.sqr();
106 }
107 Arithmetic::Mul(_) => {
108 *entry!(store, lhs) += &grad * rhs.as_ref();
109 *entry!(store, rhs) += &grad * lhs.as_ref();
110 }
111 Arithmetic::Sub(_) => {
112 *entry!(store, lhs) += &grad;
113 *entry!(store, rhs) -= &grad;
114 }
115 _ => todo!(),
116 },
117 _ => todo!(),
118 },
119 TensorExpr::BinaryScalar(lhs, rhs, kind) => match kind {
120 BinaryOp::Arith(inner) => match inner {
121 Arithmetic::Add(_) => {
122 *entry!(store, lhs) += &grad;
123 }
124 Arithmetic::Div(_) => {
125 *entry!(store, lhs) += &grad / *rhs;
126 }
127 Arithmetic::Mul(_) => {
128 *entry!(store, lhs) += &grad * *rhs;
129 }
130 Arithmetic::Pow(_) => {
131 *entry!(store, lhs) += &grad * *rhs * lhs.pow(*rhs - T::one());
132 }
133 Arithmetic::Sub(_) => {
134 *entry!(store, lhs) += &grad;
135 }
136 _ => todo!(),
137 },
138 _ => todo!(),
139 },
140 TensorExpr::Unary(val, kind) => match kind {
141 UnaryOp::Cos => {
142 *entry!(store, val) -= &grad * val.sin();
143 }
144 UnaryOp::Cosh => {
145 *entry!(store, val) += &grad * val.sinh();
146 }
147 UnaryOp::Exp => {
148 *entry!(store, val) += &grad * val.exp();
149 }
150 UnaryOp::Neg => {
151 *entry!(store, val) -= &grad;
152 }
153 UnaryOp::Recip => {
154 *entry!(store, val) -= &grad / val.sqr();
155 }
156
157 UnaryOp::Sin => {
158 *entry!(store, val) += &grad * val.cos();
159 }
160 UnaryOp::Sinh => {
161 *entry!(store, val) += &grad * val.cosh();
162 }
163 UnaryOp::Sqrt => {
164 *entry!(store, val) +=
165 &grad / (val.clone().sqrt() * T::from(2).unwrap());
166 }
167 UnaryOp::Tan => {
168 *entry!(store, val) += &grad / val.clone().cos().sqr();
169 }
170
171 _ => {}
172 },
173 TensorExpr::Sigmoid(val) => {
174 let tmp = val.detach();
175 *entry!(store, val) +=
176 &grad * tmp.sigmoid() * (tmp.ones_like() - tmp.sigmoid());
177 }
178 _ => {}
179 }
180 }
181 }
182
183 Ok(store)
184 }
185}