1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
use super::tensor_collection::*;
use crate::{shapes::*, tensor::*, tensor_ops::Device};
struct Resetter;
impl<E: Dtype, D: Device<E>> TensorVisitor<E, D> for Resetter {
type Viewer = ViewTensorMut;
type Err = D::Err;
type E2 = E;
type D2 = D;
fn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
t: &mut Tensor<S, E, D>,
) -> Result<Option<Tensor<S, E, D>>, Self::Err> {
(opts.reset)(t)?;
Ok(None)
}
}
pub trait ResetParams<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
fn reset_params(&mut self) {
self.try_reset_params().unwrap();
}
fn try_reset_params(&mut self) -> Result<(), D::Err> {
Self::iter_tensors(&mut RecursiveWalker {
m: self,
f: &mut Resetter,
})?;
Ok(())
}
}
impl<E: Dtype, D: Device<E>, M: TensorCollection<E, D>> ResetParams<E, D> for M {}