pub struct GradientTape { /* private fields */ }
Expand description

Records gradient computations to execute later.

The only two things you can do with this are:

  1. Adding an operation (an operation is a FnOnce that acts on &mut Gradients)
  2. Executing all the operations to produce Gradients

The reason for this design, which forces users to specify gradient computations, as opposed to having a fixed set of kinds of computations are these:

  1. Different tensor sizes. The tensors size information would have to be stored inside the operation somehow. Instead, the operation themselves must query with a sized tensor, so sizes are still known at compile time instead of dynamically.
  2. Slightly different operations. It’d have to support broadcasting operations, etc which can get needlessly complex.
  3. Optimizations are harder. With operations having control over everything, they can be optimized by hand separately.

An example for how these two are used is the following from the negate operation (ie. multiply all values by -1).

tape.add_backward_op(move |grads| {
    let (t_grad, result_grad) = grads.mut_and_ref(&t, &_result);
    // addmul_assign is equivalent to: t_grad += t.data() * result_grad;
    T::Device::addmul(t_grad, t.data(), result_grad);
});

This is implementing the chain rule, which is normally defined as gradient(t) += deriv * gradient(result) with the following optimizations:

  1. instead of allocating new data for the derivative (which is just -1 everywhere), we can reuse the t tensor since the negate function owns it.
  2. We can combine computing the derivative and multiplying by the gradient(result) by just setting t to -gradient(result)

This would not be possible if these chain rule operations were inside of GradientTape!

Implementations

Compute the Gradients! This just runs all the operations on a new Gradients struct.

Note that this method takes ownership of self, so it can’t be called twice!

Moves all the operations from other into self. Leaves other empty.

Trait Implementations

Formats the value using the given formatter. Read more
Returns the “default value” for a type. Read more

Auto Trait Implementations

Blanket Implementations

Gets the TypeId of self. Read more
Immutably borrows from an owned value. Read more
Mutably borrows from an owned value. Read more

Returns the argument unchanged.

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

The type returned in the event of a conversion error.
Performs the conversion.
The type returned in the event of a conversion error.
Performs the conversion.