Trait dfdx::nn::ZeroGrads

source ·
pub trait ZeroGrads<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
    // Provided methods
    fn alloc_grads(&self) -> Gradients<E, D> { ... }
    fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err> { ... }
    fn zero_grads(&self, gradients: &mut Gradients<E, D>) { ... }
    fn try_zero_grads(
        &self,
        gradients: &mut Gradients<E, D>
    ) -> Result<(), D::Err> { ... }
}
Expand description

Zero’s any gradients associated with self.

let model = dev.build_module::<Linear<2, 5>, f32>();
let mut grads: Gradients<f32, _> = model.alloc_grads();
model.zero_grads(&mut grads);

Provided Methods§

source

fn alloc_grads(&self) -> Gradients<E, D>

Allocates gradients for this tensor collection. This marks all other gradients as temporary, so they are dropped after .backward()

source

fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err>

Allocates gradients for this tensor collection. This marks all other gradients as temporary, so they are dropped after .backward()

source

fn zero_grads(&self, gradients: &mut Gradients<E, D>)

Zero’s any gradients associated with self.

source

fn try_zero_grads(&self, gradients: &mut Gradients<E, D>) -> Result<(), D::Err>

Zero’s any gradients associated with self.

Implementors§

source§

impl<E: Dtype, D: Device<E>, M: TensorCollection<E, D>> ZeroGrads<E, D> for M