1use crate::op::{GradientContext, InputArray};
3use crate::tensor::Tensor;
4use crate::Float;
5use crate::FxHashMap;
6use crate::Graph;
7use std::cmp::Ordering;
8use std::collections::binary_heap::BinaryHeap;
9
10struct GradInfo<'graph, T: Float> {
12 has_gradient: bool,
13 grad_called: bool,
14 computed_grads: InputArray<Tensor<'graph, T>>,
15 accumulated_grad: Option<Tensor<'graph, T>>,
16 default_grad: Option<usize>, }
18
19impl<'g, T: Float> GradInfo<'g, T> {
20 #[inline]
21 fn new(has_gradient: bool) -> GradInfo<'g, T> {
22 GradInfo {
23 has_gradient,
24 computed_grads: InputArray::new(),
25 grad_called: false,
26 accumulated_grad: None,
27 default_grad: None,
28 }
29 }
30
31 #[inline]
32 fn push_grad(&mut self, g: Tensor<'g, T>) {
33 self.computed_grads.push(g);
34 }
35
36 #[inline]
37 fn accumulate_then_get(&mut self, g: &'g Graph<T>) -> Tensor<'g, T> {
38 if let Some(acc) = self.accumulated_grad {
39 return acc;
40 }
41 if self.computed_grads.len() == 1 {
42 self.computed_grads[0]
43 } else {
44 let accumulated = g.add_n(self.computed_grads.as_slice());
46 self.accumulated_grad = Some(accumulated);
47 accumulated
48 }
49 }
50
51 #[inline]
52 fn get_grad(&mut self, g: &'g Graph<T>) -> Tensor<'g, T> {
53 if let Some(def) = self.default_grad {
54 g.tensor(def)
55 } else {
56 self.accumulate_then_get(g)
57 }
58 }
59}
60
61#[inline]
62fn has_marked_child<T: Float>(parent: Tensor<T>, path: &FxHashMap<usize, GradInfo<T>>) -> bool {
63 for i in 0..parent.num_backprop_inputs() {
64 let child = parent.get_backprop_input(i);
65 if path.get(&child.id).unwrap().has_gradient {
66 return true;
67 }
68 }
69 false
70}
71
72#[inline]
73fn is_wrt(node: usize, wrt: &[usize]) -> bool {
74 wrt.contains(&node)
75}
76
77fn get_between_nodes<'t, 'g, T: Float>(
83 g: &'g Graph<T>,
84 ys: &[usize],
85 wrt: &[usize],
86) -> FxHashMap<usize, GradInfo<'g, T>> {
87 let mut ret = FxHashMap::<usize, GradInfo<T>>::default();
89
90 let mut dfs_stack: Vec<(usize, bool)> = ys.iter().map(|&y| (y, false)).collect();
95 while let Some((node_id, should_visit)) = dfs_stack.pop() {
96 let node = g.tensor(node_id);
97 if should_visit {
98 let has_gradient =
99 node.is_differentiable() && (is_wrt(node_id, wrt) || has_marked_child(node, &ret));
100 ret.insert(node_id, GradInfo::new(has_gradient));
101 } else {
102 dfs_stack.push((node_id, true));
104 for i in 0..node.num_backprop_inputs() {
106 let child = node.get_backprop_input(i).as_tensor(g);
107 if ret.get(&node_id).is_none() {
108 if child.is_source() || !child.is_differentiable() {
109 ret.insert(
112 child.id,
113 GradInfo::new(child.is_differentiable() && is_wrt(child.id, wrt)),
114 );
115 } else {
116 dfs_stack.push((child.id, false));
118 }
119 }
120 }
121 }
122 }
123 ret
124}
125
126pub(crate) fn symbolic_gradients<'t, 'g, T: Float>(
136 ys: &[usize],
137 wrt: &[usize],
138 gys: &[usize],
139 g: &'g Graph<T>,
140) -> Vec<Tensor<'g, T>> {
141 assert_eq!(ys.len(), gys.len(), "`ys.len()` must match `gys.len()`");
142
143 let mut between_nodes = get_between_nodes(g, ys, wrt);
146
147 for (y, gy) in ys.iter().zip(gys) {
149 between_nodes.get_mut(y).unwrap().default_grad = Some(*gy);
150 }
151
152 let mut heap = ys
154 .iter()
155 .map(|&y| g.tensor(y).to_node())
156 .collect::<BinaryHeap<Node>>();
157
158 while let Some(y) = heap.pop() {
161 let gxs = {
162 let info = between_nodes.get_mut(&y.id).unwrap();
163
164 let gy = info.get_grad(g);
165
166 let y_tensor = g.tensor(y.id);
168 let gxs = GradientContext::new(gy, y_tensor, g).extract_input_grads();
169 debug_assert_eq!(y_tensor.num_backprop_inputs(), gxs.len());
170 gxs
171 };
172
173 let y = g.tensor(y.id);
175 for i in 0..gxs.len() {
176 let x = y.get_backprop_input(i).as_tensor(g);
177 let mut x_info = between_nodes.get_mut(&x.id).unwrap();
178 if x_info.has_gradient {
179 if let Some(gx) = gxs[i] {
180 x_info.push_grad(gx);
181 if !x.is_source() && !x_info.grad_called {
183 x_info.grad_called = true;
184 heap.push(x.to_node());
185 }
186 }
187 }
188 }
189 }
190
191 let mut ret = Vec::with_capacity(wrt.len());
193 for x in wrt {
194 let msg1: &str = "Not differentiable with given tensor(s).";
195 let info = between_nodes.get_mut(x).expect(msg1);
196 if !info.has_gradient {
197 panic!("{}", msg1);
198 }
199 assert!(
200 info.default_grad.is_none(),
201 "Can't differentiate with objective itself"
202 );
203 ret.push(info.accumulate_then_get(g));
204 }
205 ret
206}
207
208struct Node {
209 id: usize,
210 rank: usize,
211}
212
213impl Ord for Node {
214 fn cmp(&self, other: &Self) -> Ordering {
216 self.rank.cmp(&other.rank)
217 }
218}
219
220impl PartialOrd for Node {
221 #[inline]
222 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
224 Some(self.rank.cmp(&other.rank))
225 }
226}
227
228impl Eq for Node {}
229
230impl PartialEq for Node {
231 #[inline]
232 fn eq(&self, other: &Node) -> bool {
233 self.id == other.id
234 }
235}
236
237impl<'t, T: Float> Tensor<'t, T> {
238 #[inline]
239 fn to_node(&'t self) -> Node {
240 Node {
241 id: self.id,
242 rank: unsafe { self.inner().top_rank },
243 }
244 }
245}