pub trait ZeroGrads<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
// Provided methods
fn alloc_grads(&self) -> Gradients<E, D> { ... }
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err> { ... }
fn zero_grads(&self, gradients: &mut Gradients<E, D>) { ... }
fn try_zero_grads(
&self,
gradients: &mut Gradients<E, D>
) -> Result<(), D::Err> { ... }
}
Expand description
Zero’s any gradients associated with self
.
let model = dev.build_module::<Linear<2, 5>, f32>();
let mut grads: Gradients<f32, _> = model.alloc_grads();
model.zero_grads(&mut grads);
Provided Methods§
sourcefn alloc_grads(&self) -> Gradients<E, D>
fn alloc_grads(&self) -> Gradients<E, D>
Allocates gradients for this tensor collection. This marks all other gradients as temporary, so they are dropped after .backward()
sourcefn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err>
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err>
Allocates gradients for this tensor collection. This marks all other gradients as temporary, so they are dropped after .backward()
sourcefn zero_grads(&self, gradients: &mut Gradients<E, D>)
fn zero_grads(&self, gradients: &mut Gradients<E, D>)
Zero’s any gradients associated with self
.