1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
use tch::Tensor; pub trait ZeroGrad { /// Zero out gradients. fn zero_grad(&mut self); } impl ZeroGrad for Vec<Tensor> { fn zero_grad(&mut self) { for tensor in self { if tensor.requires_grad() { tensor.zero_grad() } } } }