use crate::op::{GradientContext, InputArray};
use crate::tensor::Tensor;
use crate::Float;
use crate::FxHashMap;
use crate::Graph;
use std::cmp::Ordering;
use std::collections::binary_heap::BinaryHeap;
struct GradInfo<'graph, T: Float> {
has_gradient: bool,
grad_called: bool,
computed_grads: InputArray<Tensor<'graph, T>>,
accumulated_grad: Option<Tensor<'graph, T>>,
default_grad: Option<usize>, }
impl<'g, T: Float> GradInfo<'g, T> {
#[inline]
fn new(has_gradient: bool) -> GradInfo<'g, T> {
GradInfo {
has_gradient,
computed_grads: InputArray::new(),
grad_called: false,
accumulated_grad: None,
default_grad: None,
}
}
#[inline]
fn push_grad(&mut self, g: Tensor<'g, T>) {
self.computed_grads.push(g);
}
#[inline]
fn accumulate_then_get(&mut self, g: &'g Graph<T>) -> Tensor<'g, T> {
if let Some(acc) = self.accumulated_grad {
return acc;
}
if self.computed_grads.len() == 1 {
self.computed_grads[0]
} else {
let accumulated = g.add_n(self.computed_grads.as_slice());
self.accumulated_grad = Some(accumulated);
accumulated
}
}
#[inline]
fn get_grad(&mut self, g: &'g Graph<T>) -> Tensor<'g, T> {
if let Some(def) = self.default_grad {
g.tensor(def)
} else {
self.accumulate_then_get(g)
}
}
}
#[inline]
fn has_marked_child<T: Float>(parent: Tensor<T>, path: &FxHashMap<usize, GradInfo<T>>) -> bool {
for i in 0..parent.num_backprop_inputs() {
let child = parent.get_backprop_input(i);
if path.get(&child.id).unwrap().has_gradient {
return true;
}
}
false
}
#[inline]
fn is_wrt(node: usize, wrt: &[usize]) -> bool {
wrt.contains(&node)
}
fn get_between_nodes<'t, 'g, T: Float>(
g: &'g Graph<T>,
ys: &[usize],
wrt: &[usize],
) -> FxHashMap<usize, GradInfo<'g, T>> {
let mut ret = FxHashMap::<usize, GradInfo<T>>::default();
let mut dfs_stack: Vec<(usize, bool)> = ys.iter().map(|&y| (y, false)).collect();
while let Some((node_id, should_visit)) = dfs_stack.pop() {
let node = g.tensor(node_id);
if should_visit {
let has_gradient =
node.is_differentiable() && (is_wrt(node_id, wrt) || has_marked_child(node, &ret));
ret.insert(node_id, GradInfo::new(has_gradient));
} else {
dfs_stack.push((node_id, true));
for i in 0..node.num_backprop_inputs() {
let child = node.get_backprop_input(i).as_tensor(g);
if ret.get(&node_id).is_none() {
if child.is_source() || !child.is_differentiable() {
ret.insert(
child.id,
GradInfo::new(child.is_differentiable() && is_wrt(child.id, wrt)),
);
} else {
dfs_stack.push((child.id, false));
}
}
}
}
}
ret
}
pub(crate) fn symbolic_gradients<'t, 'g, T: Float>(
ys: &[usize],
wrt: &[usize],
gys: &[usize],
g: &'g Graph<T>,
) -> Vec<Tensor<'g, T>> {
assert_eq!(ys.len(), gys.len(), "`ys.len()` must match `gys.len()`");
let mut between_nodes = get_between_nodes(g, ys, wrt);
for (y, gy) in ys.iter().zip(gys) {
between_nodes.get_mut(y).unwrap().default_grad = Some(*gy);
}
let mut heap = ys
.iter()
.map(|&y| g.tensor(y).to_node())
.collect::<BinaryHeap<Node>>();
while let Some(y) = heap.pop() {
let gxs = {
let info = between_nodes.get_mut(&y.id).unwrap();
let gy = info.get_grad(g);
let y_tensor = g.tensor(y.id);
let gxs = GradientContext::new(gy, y_tensor, g).extract_input_grads();
debug_assert_eq!(y_tensor.num_backprop_inputs(), gxs.len());
gxs
};
let y = g.tensor(y.id);
for i in 0..gxs.len() {
let x = y.get_backprop_input(i).as_tensor(g);
let mut x_info = between_nodes.get_mut(&x.id).unwrap();
if x_info.has_gradient {
if let Some(gx) = gxs[i] {
x_info.push_grad(gx);
if !x.is_source() && !x_info.grad_called {
x_info.grad_called = true;
heap.push(x.to_node());
}
}
}
}
}
let mut ret = Vec::with_capacity(wrt.len());
for x in wrt {
let msg1: &str = "Not differentiable with given tensor(s).";
let info = between_nodes.get_mut(x).expect(msg1);
if !info.has_gradient {
panic!("{}", msg1);
}
assert!(
info.default_grad.is_none(),
"Can't differentiate with objective itself"
);
ret.push(info.accumulate_then_get(g));
}
ret
}
struct Node {
id: usize,
rank: usize,
}
impl Ord for Node {
fn cmp(&self, other: &Self) -> Ordering {
self.rank.cmp(&other.rank)
}
}
impl PartialOrd for Node {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.rank.cmp(&other.rank))
}
}
impl Eq for Node {}
impl PartialEq for Node {
#[inline]
fn eq(&self, other: &Node) -> bool {
self.id == other.id
}
}
impl<'t, T: Float> Tensor<'t, T> {
#[inline]
fn to_node(&'t self) -> Node {
Node {
id: self.id,
rank: unsafe { self.inner().top_rank },
}
}
}