Struct dfdx::gradients::GradientTape
source · [−]pub struct GradientTape { /* private fields */ }
Expand description
Records gradient computations to execute later.
The only two things you can do with this are:
- Adding an operation (an operation is a FnOnce that acts on &mut Gradients)
- Executing all the operations to produce Gradients
The reason for this design, which forces users to specify gradient computations, as opposed to having a fixed set of kinds of computations are these:
- Different tensor sizes. The tensors size information would have to be stored inside the operation somehow. Instead, the operation themselves must query with a sized tensor, so sizes are still known at compile time instead of dynamically.
- Slightly different operations. It’d have to support broadcasting operations, etc which can get needlessly complex.
- Optimizations are harder. With operations having control over everything, they can be optimized by hand separately.
An example for how these two are used is the following from the negate operation (ie. multiply all values by -1).
ⓘ
tape.add_backward_op(move |grads| {
let (t_grad, result_grad) = grads.mut_and_ref(&t, &_result);
// addmul_assign is equivalent to: t_grad += t.data() * result_grad;
T::Device::addmul(t_grad, t.data(), result_grad);
});
This is implementing the chain rule, which is normally defined as gradient(t) += deriv * gradient(result)
with
the following optimizations:
- instead of allocating new data for the derivative (which is just -1 everywhere), we can reuse the
t
tensor since the negate function owns it. - We can combine computing the derivative and multiplying by the
gradient(result)
by just settingt
to-gradient(result)
This would not be possible if these chain rule operations were inside of GradientTape!
Implementations
sourceimpl GradientTape
impl GradientTape
Trait Implementations
sourceimpl Debug for GradientTape
impl Debug for GradientTape
sourceimpl Default for GradientTape
impl Default for GradientTape
sourcefn default() -> GradientTape
fn default() -> GradientTape
Returns the “default value” for a type. Read more
Auto Trait Implementations
impl !RefUnwindSafe for GradientTape
impl !Send for GradientTape
impl !Sync for GradientTape
impl Unpin for GradientTape
impl !UnwindSafe for GradientTape
Blanket Implementations
sourceimpl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more