1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use crate::Tensor;
use std::borrow::Borrow;
#[derive(Debug, Clone, Copy)]
pub struct BatchNormConfig {
pub cudnn_enabled: bool,
pub eps: f64,
pub momentum: f64,
pub ws_init: super::Init,
pub bs_init: super::Init,
}
impl Default for BatchNormConfig {
fn default() -> Self {
BatchNormConfig {
cudnn_enabled: true,
eps: 1e-5,
momentum: 0.1,
ws_init: super::Init::Uniform { lo: 0., up: 1. },
bs_init: super::Init::Const(0.),
}
}
}
#[derive(Debug)]
pub struct BatchNorm {
config: BatchNormConfig,
pub running_mean: Tensor,
pub running_var: Tensor,
pub ws: Tensor,
pub bs: Tensor,
pub nd: usize,
}
fn batch_norm<'a, T: Borrow<super::Path<'a>>>(
vs: T,
nd: usize,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
let vs = vs.borrow();
BatchNorm {
config,
running_mean: vs.zeros_no_train("running_mean", &[out_dim]),
running_var: vs.ones_no_train("running_var", &[out_dim]),
ws: vs.var("weight", &[out_dim], config.ws_init),
bs: vs.var("bias", &[out_dim], config.bs_init),
nd,
}
}
pub fn batch_norm1d<'a, T: Borrow<super::Path<'a>>>(
vs: T,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
batch_norm(vs, 1, out_dim, config)
}
pub fn batch_norm2d<'a, T: Borrow<super::Path<'a>>>(
vs: T,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
batch_norm(vs, 2, out_dim, config)
}
pub fn batch_norm3d<'a, T: Borrow<super::Path<'a>>>(
vs: T,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
batch_norm(vs, 3, out_dim, config)
}
impl super::module::ModuleT for BatchNorm {
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
let dim = xs.dim();
if self.nd == 1 && dim != 2 && dim != 3 {
panic!(
"expected an input tensor with 2 or 3 dims, got {:?}",
xs.size()
)
}
if self.nd > 1 && xs.dim() != self.nd + 2 {
panic!(
"expected an input tensor with {} dims, got {:?}",
self.nd + 2,
xs.size()
)
};
Tensor::batch_norm(
xs,
Some(&self.ws),
Some(&self.bs),
Some(&self.running_mean),
Some(&self.running_var),
train,
self.config.momentum,
self.config.eps,
self.config.cudnn_enabled,
)
}
}