1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use crate::{NdArray, TensorDataType};
use std::cell::RefCell;
use std::rc::Rc;
pub(crate) trait GradientFuncTrait<T: TensorDataType> {
/// Computes the gradient of this function with respect to its sources using the chain rule.
///
/// # Parameters
///
/// - `grad`: the gradient of the function being differentiated with respect to `self`.
fn backward(&mut self, grad: &NdArray<T>);
/// Returns the gradient of the function being differentiated with respect to `self`
/// if this function is a leaf. Otherwise, returns `None`.
fn gradient(&self) -> Option<NdArray<T>> {
None
}
/// Sets the gradient of this tensor to zero.
fn zero_gradient(&mut self) {}
/// Whether the gradient function is NoneBackwards
fn is_none(&self) -> bool {
false
}
}
pub(crate) type GradientFunction<T> = Rc<RefCell<dyn GradientFuncTrait<T>>>;
#[macro_export]
macro_rules! call_next_backward {
($grad:expr, $next:expr) => {
if !$next.borrow().is_none() {
$next.borrow_mut().backward(&$grad);
}
};
($grad:expr, $shape:expr, $next:expr) => {
if !$next.borrow().is_none() {
if $shape == $grad.shape() {
$next.borrow_mut().backward(&$grad);
} else {
let grad = $grad;
let next_grad = reduce_gradient(&grad, $shape);
$next.borrow_mut().backward(&next_grad);
};
}
};
}