autograd/
gradient.rs

1//! Defining things related to gradient computation.
2use crate::op::{GradientContext, InputArray};
3use crate::tensor::Tensor;
4use crate::Float;
5use crate::FxHashMap;
6use crate::Graph;
7use std::cmp::Ordering;
8use std::collections::binary_heap::BinaryHeap;
9
10// Info of gradient of a `Tensor`.
11struct GradInfo<'graph, T: Float> {
12    has_gradient: bool,
13    grad_called: bool,
14    computed_grads: InputArray<Tensor<'graph, T>>,
15    accumulated_grad: Option<Tensor<'graph, T>>,
16    default_grad: Option<usize>, // id
17}
18
19impl<'g, T: Float> GradInfo<'g, T> {
20    #[inline]
21    fn new(has_gradient: bool) -> GradInfo<'g, T> {
22        GradInfo {
23            has_gradient,
24            computed_grads: InputArray::new(),
25            grad_called: false,
26            accumulated_grad: None,
27            default_grad: None,
28        }
29    }
30
31    #[inline]
32    fn push_grad(&mut self, g: Tensor<'g, T>) {
33        self.computed_grads.push(g);
34    }
35
36    #[inline]
37    fn accumulate_then_get(&mut self, g: &'g Graph<T>) -> Tensor<'g, T> {
38        if let Some(acc) = self.accumulated_grad {
39            return acc;
40        }
41        if self.computed_grads.len() == 1 {
42            self.computed_grads[0]
43        } else {
44            // accumulation is required
45            let accumulated = g.add_n(self.computed_grads.as_slice());
46            self.accumulated_grad = Some(accumulated);
47            accumulated
48        }
49    }
50
51    #[inline]
52    fn get_grad(&mut self, g: &'g Graph<T>) -> Tensor<'g, T> {
53        if let Some(def) = self.default_grad {
54            g.tensor(def)
55        } else {
56            self.accumulate_then_get(g)
57        }
58    }
59}
60
61#[inline]
62fn has_marked_child<T: Float>(parent: Tensor<T>, path: &FxHashMap<usize, GradInfo<T>>) -> bool {
63    for i in 0..parent.num_backprop_inputs() {
64        let child = parent.get_backprop_input(i);
65        if path.get(&child.id).unwrap().has_gradient {
66            return true;
67        }
68    }
69    false
70}
71
72#[inline]
73fn is_wrt(node: usize, wrt: &[usize]) -> bool {
74    wrt.contains(&node)
75}
76
77// Go backward from `ys` and collect nodes until reach `wrt` for backprop.
78//
79// Strategy
80//   1. Record all nodes that are reachable from `ys` into `ret`.
81//   2. Mark the path between `ys` and `xs` as `has_gradient`.
82fn get_between_nodes<'t, 'g, T: Float>(
83    g: &'g Graph<T>,
84    ys: &[usize],
85    wrt: &[usize],
86) -> FxHashMap<usize, GradInfo<'g, T>> {
87    // Randomly accessible by use of each node's id.
88    let mut ret = FxHashMap::<usize, GradInfo<T>>::default();
89
90    // Builds GradInfo while performing depth-first-search.
91    // `has_gradient` properties are filled at the same time.
92
93    // dfs_stack: (node, should_visit)
94    let mut dfs_stack: Vec<(usize, bool)> = ys.iter().map(|&y| (y, false)).collect();
95    while let Some((node_id, should_visit)) = dfs_stack.pop() {
96        let node = g.tensor(node_id);
97        if should_visit {
98            let has_gradient =
99                node.is_differentiable() && (is_wrt(node_id, wrt) || has_marked_child(node, &ret));
100            ret.insert(node_id, GradInfo::new(has_gradient));
101        } else {
102            // Put self on the stack top (should visit next time)
103            dfs_stack.push((node_id, true));
104            // Push children as necessary
105            for i in 0..node.num_backprop_inputs() {
106                let child = node.get_backprop_input(i).as_tensor(g);
107                if ret.get(&node_id).is_none() {
108                    if child.is_source() || !child.is_differentiable() {
109                        // Add to result, but don't allow any more recursive search
110                        // because there will be no `wrt` nodes in this direction....
111                        ret.insert(
112                            child.id,
113                            GradInfo::new(child.is_differentiable() && is_wrt(child.id, wrt)),
114                        );
115                    } else {
116                        // Recurse
117                        dfs_stack.push((child.id, false));
118                    }
119                }
120            }
121        }
122    }
123    ret
124}
125
126/// Returns symbolic gradient tensors of `xs`.
127///
128/// This computes partial derivatives of `ys` with `xs` and returns the
129/// gradients. This is achieved by building a subgraph between `ys` and
130/// `xs` in reverse order from user's graph definition.
131/// `gys` are already known gradients of `ys`'s outputs.
132///
133/// NOTE: Nodes that do not have gradients won't be included in the subgraph to avoid
134/// unnecessary computation.
135pub(crate) fn symbolic_gradients<'t, 'g, T: Float>(
136    ys: &[usize],
137    wrt: &[usize],
138    gys: &[usize],
139    g: &'g Graph<T>,
140) -> Vec<Tensor<'g, T>> {
141    assert_eq!(ys.len(), gys.len(), "`ys.len()` must match `gys.len()`");
142
143    // Setup gradient path.
144    // We lookup this with tensor id.
145    let mut between_nodes = get_between_nodes(g, ys, wrt);
146
147    // Set default grads.
148    for (y, gy) in ys.iter().zip(gys) {
149        between_nodes.get_mut(y).unwrap().default_grad = Some(*gy);
150    }
151
152    // Prepare a heap with given ys.
153    let mut heap = ys
154        .iter()
155        .map(|&y| g.tensor(y).to_node())
156        .collect::<BinaryHeap<Node>>();
157
158    // Backprop.
159    // Starts with `ys`.
160    while let Some(y) = heap.pop() {
161        let gxs = {
162            let info = between_nodes.get_mut(&y.id).unwrap();
163
164            let gy = info.get_grad(g);
165
166            // Call Op::grad (mutate the graph)
167            let y_tensor = g.tensor(y.id);
168            let gxs = GradientContext::new(gy, y_tensor, g).extract_input_grads();
169            debug_assert_eq!(y_tensor.num_backprop_inputs(), gxs.len());
170            gxs
171        };
172
173        // Register computed gradients
174        let y = g.tensor(y.id);
175        for i in 0..gxs.len() {
176            let x = y.get_backprop_input(i).as_tensor(g);
177            let mut x_info = between_nodes.get_mut(&x.id).unwrap();
178            if x_info.has_gradient {
179                if let Some(gx) = gxs[i] {
180                    x_info.push_grad(gx);
181                    // update heap
182                    if !x.is_source() && !x_info.grad_called {
183                        x_info.grad_called = true;
184                        heap.push(x.to_node());
185                    }
186                }
187            }
188        }
189    }
190
191    // Aggregate and return xs's gradients
192    let mut ret = Vec::with_capacity(wrt.len());
193    for x in wrt {
194        let msg1: &str = "Not differentiable with given tensor(s).";
195        let info = between_nodes.get_mut(x).expect(msg1);
196        if !info.has_gradient {
197            panic!("{}", msg1);
198        }
199        assert!(
200            info.default_grad.is_none(),
201            "Can't differentiate with objective itself"
202        );
203        ret.push(info.accumulate_then_get(g));
204    }
205    ret
206}
207
208struct Node {
209    id: usize,
210    rank: usize,
211}
212
213impl Ord for Node {
214    // Compares the ranks in topological ordering
215    fn cmp(&self, other: &Self) -> Ordering {
216        self.rank.cmp(&other.rank)
217    }
218}
219
220impl PartialOrd for Node {
221    #[inline]
222    // Compares the ranks in topological ordering
223    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
224        Some(self.rank.cmp(&other.rank))
225    }
226}
227
228impl Eq for Node {}
229
230impl PartialEq for Node {
231    #[inline]
232    fn eq(&self, other: &Node) -> bool {
233        self.id == other.id
234    }
235}
236
237impl<'t, T: Float> Tensor<'t, T> {
238    #[inline]
239    fn to_node(&'t self) -> Node {
240        Node {
241            id: self.id,
242            rank: unsafe { self.inner().top_rank },
243        }
244    }
245}