1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
use tch::nn::VarStore; pub trait ZeroGrad { /// Zero out gradients. fn zero_grad(&self); } impl ZeroGrad for VarStore { fn zero_grad(&self) { for (_, mut tensor) in self.variables() { if tensor.requires_grad() { tensor.zero_grad() } } } }