1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11use num_traits::Float;
12
13#[derive(Config, Debug, PartialEq)]
15pub enum Initializer {
16 Constant {
18 value: f64,
20 },
21 Ones,
23 Zeros,
25 Uniform {
27 min: f64,
29
30 max: f64,
32 },
33 Normal {
35 mean: f64,
37
38 std: f64,
40 },
41 KaimingUniform {
43 gain: f64,
45
46 fan_out_only: bool,
48 },
49 KaimingNormal {
51 gain: f64,
53
54 fan_out_only: bool,
56 },
57 XavierUniform {
61 gain: f64,
63 },
64 XavierNormal {
68 gain: f64,
70 },
71}
72
73impl Initializer {
74 pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
80 &self,
81 shape: S,
82 device: &B::Device,
83 ) -> Param<Tensor<B, D>> {
84 self.init_with(shape, None, None, device)
85 }
86
87 pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
93 &self,
94 shape: S,
95 fan_in: Option<usize>,
96 fan_out: Option<usize>,
97 device: &B::Device,
98 ) -> Param<Tensor<B, D>> {
99 let device = device.clone();
100 let shape: Shape = shape.into();
101 let config = self.clone();
102
103 Param::uninitialized(
104 ParamId::new(),
105 move |device, require_grad| {
106 let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
107
108 if require_grad {
109 tensor = tensor.require_grad();
110 }
111
112 tensor
113 },
114 device,
115 true,
116 )
117 }
118
119 fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
120 &self,
121 shape: S,
122 fan_in: Option<usize>,
123 fan_out: Option<usize>,
124 device: &B::Device,
125 ) -> Tensor<B, D> {
126 let shape = shape.into();
127 match self {
128 Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
129 Initializer::Ones => Tensor::<B, D>::ones(shape, device),
130 Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
131 Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
132 Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
133 Initializer::KaimingUniform { gain, fan_out_only } => {
134 let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
135 uniform_draw(shape, -a, a, device)
136 }
137 Initializer::KaimingNormal { gain, fan_out_only } => {
138 let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
139 normal_draw(shape, 0.0, std, device)
140 }
141 Initializer::XavierUniform { gain } => {
142 let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
143 uniform_draw(shape, -a, a, device)
144 }
145 Initializer::XavierNormal { gain } => {
146 let std = *gain * self.xavier_std(fan_in, fan_out);
147 normal_draw(shape, 0.0, std, device)
148 }
149 }
150 }
151
152 fn kaiming_std(
153 &self,
154 fan_out_only: bool,
155 fan_in: Option<usize>,
156 fan_out: Option<usize>,
157 ) -> f64 {
158 let fan = if fan_out_only { fan_out } else { fan_in };
159 let fan = fan.expect(
160 "Can't use Kaiming initialization without specifying fan. Use init_with method.",
161 );
162
163 1.0 / (fan as f64).sqrt()
164 }
165
166 fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
167 let fan_in = fan_in.expect(
168 "Can't use Xavier initialization without specifying fan in. Use init_with method and \
169 provide fan_in.",
170 );
171 let fan_out = fan_out.expect(
172 "Can't use Xavier initialization without specifying fan out. Use init_with method and \
173 provide fan_out.",
174 );
175 (2.0 / (fan_in + fan_out) as f64).sqrt()
176 }
177}
178
179fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
180 shape: S,
181 low: f64,
182 high: f64,
183 device: &B::Device,
184) -> Tensor<B, D> {
185 let distribution = Distribution::Uniform(low, high);
186 Tensor::<B, D>::random(shape, distribution, device)
187}
188
189fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
190 shape: S,
191 mean: f64,
192 std: f64,
193 device: &B::Device,
194) -> Tensor<B, D> {
195 let distribution = Distribution::Normal(mean, std);
196 Tensor::<B, D>::random(shape, distribution, device)
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use crate::tensor::{ElementConversion, TensorData};
204 use num_traits::Pow;
205
206 pub type TB = burn_ndarray::NdArray<f32>;
207 use burn_tensor::{Tolerance, ops::FloatElem};
208 type FT = FloatElem<TB>;
209
210 fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
211 let (actual_vars, actual_means) = tensor.clone().var_mean(0);
212 let actual_vars = actual_vars.to_data();
213 let actual_vars = actual_vars
214 .as_slice::<<TB as Backend>::FloatElem>()
215 .unwrap();
216 let actual_means = actual_means.to_data();
217 let actual_means = actual_means
218 .as_slice::<<TB as Backend>::FloatElem>()
219 .unwrap();
220
221 for i in 0..tensor.shape().dims[0] {
222 let actual_var = actual_vars[i] as f64;
223 let actual_mean = actual_means[i] as f64;
224
225 assert!(
226 (expected_var - actual_var).abs() <= 0.1,
227 "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
228 );
229 assert!(
230 (expected_mean - actual_mean).abs() <= 0.1,
231 "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
232 );
233 }
234 }
235
236 #[test]
237 fn initializer_uniform_init() {
238 TB::seed(0);
239
240 let (min, max) = (0.0, 1.0);
241 let uniform = Initializer::Uniform { min, max };
242 let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
243
244 tensor.into_data().assert_within_range(min..max);
245 }
246
247 #[test]
248 fn initializer_normal_init() {
249 TB::seed(0);
251 let (mean, std) = (0.0, 1.0);
252 let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
253 .init([1000], &Default::default())
254 .into_value();
255 let (var_act, mean_act) = normal.var_mean(0);
256
257 let var_act: f32 = var_act.into_scalar().elem();
258 let mean_act: f32 = mean_act.into_scalar().elem();
259
260 assert!(
261 var_act > 0.9 && var_act < 1.1,
262 "Expected variance to be between 1.0 += 0.1, but got {var_act}"
263 );
264 assert!(
265 mean_act > -0.1 && mean_act < 0.1,
266 "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
267 );
268 }
269
270 #[test]
271 fn initializer_constant_init() {
272 let value = 5.0;
273 let constants: Tensor<TB, 4> = Initializer::Constant { value }
274 .init([2, 2, 2, 2], &Default::default())
275 .into_value();
276 constants.sum().to_data().assert_approx_eq::<FT>(
277 &TensorData::from([value as f32 * 16.0]),
278 Tolerance::default(),
279 );
280 }
281
282 #[test]
283 fn initializer_zeros_init() {
284 let zeros: Tensor<TB, 4> = Initializer::Zeros
285 .init([2, 2, 2, 2], &Default::default())
286 .into_value();
287 zeros
288 .sum()
289 .to_data()
290 .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
291 }
292
293 #[test]
294 fn initializer_ones_init() {
295 let ones: Tensor<TB, 4> = Initializer::Ones
296 .init([2, 2, 2, 2], &Default::default())
297 .into_value();
298 ones.sum()
299 .to_data()
300 .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
301 }
302
303 #[test]
304 fn initializer_kaiming_uniform_init() {
305 TB::seed(0);
306
307 let gain = 2_f64;
308 let (fan_in, fan_out) = (5, 6);
309 let k = gain * (3.0 / fan_in as f64).sqrt();
310
311 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
312 gain,
313 fan_out_only: false,
314 }
315 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
316 .into_value();
317 tensor.into_data().assert_within_range(-k..k);
318 }
319
320 #[test]
321 fn initializer_kaiming_normal_init() {
322 TB::seed(0);
323
324 let gain = 2.;
325 let (fan_in, fan_out) = (1000, 10);
326 let expected_mean = 0_f64;
327
328 let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
329 let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
330 gain,
331 fan_out_only: false,
332 }
333 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
334 .into_value();
335 assert_normal_init(expected_mean, expected_var, &tensor)
336 }
337
338 #[test]
339 fn initializer_kaiming_uniform_init_bias() {
340 TB::seed(0);
341
342 let gain = 2_f64;
343 let shape = [3];
344 let fan_in = 5;
345 let k = gain * (3.0 / fan_in as f64).sqrt();
346
347 let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
348 gain,
349 fan_out_only: false,
350 }
351 .init_with(shape, Some(fan_in), None, &Default::default())
352 .into_value();
353 tensor.into_data().assert_within_range(-k..k);
354 }
355
356 #[test]
357 fn initializer_kaiming_uniform_init_fan_out() {
358 TB::seed(0);
359
360 let gain = 2_f64;
361 let (fan_in, fan_out) = (5, 6);
362 let k = gain * (3.0 / fan_out as f64).sqrt();
363
364 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
365 gain,
366 fan_out_only: true,
367 }
368 .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
369 .into_value();
370 tensor.into_data().assert_within_range(-k..k);
371 }
372
373 #[test]
374 #[should_panic]
375 fn initializer_kaiming_uniform_no_fan() {
376 TB::seed(0);
377
378 let gain = 2_f64;
379 let (fan_in, fan_out) = (5, 6);
380
381 let _: Tensor<TB, 2> = Initializer::KaimingUniform {
382 gain,
383 fan_out_only: false,
384 }
385 .init([fan_out, fan_in], &Default::default())
386 .into_value();
387 }
388
389 #[test]
390 fn initializer_xavier_uniform_init() {
391 TB::seed(0);
392
393 let gain = 2.;
394 let (fan_in, fan_out) = (5, 6);
395 let bound = gain * (6. / (fan_in + fan_out) as f64).sqrt();
396 let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
397 .init_with(
398 [fan_out, fan_in],
399 Some(fan_in),
400 Some(fan_out),
401 &Default::default(),
402 )
403 .into_value();
404
405 tensor.into_data().assert_within_range(-bound..bound);
406 }
407
408 #[test]
409 fn initializer_xavier_normal_init() {
410 TB::seed(0);
411
412 let gain = 2.;
413 let (fan_in, fan_out) = (1000, 10);
414 let expected_mean = 0_f64;
415
416 let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
417 let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
418 .init_with(
419 [fan_out, fan_in],
420 Some(fan_in),
421 Some(fan_out),
422 &Default::default(),
423 )
424 .into_value();
425 assert_normal_init(expected_mean, expected_var, &tensor)
426 }
427
428 #[test]
429 #[should_panic]
430 fn initializer_xavier_uniform_no_fan() {
431 TB::seed(0);
432
433 let gain = 2.;
434 let (fan_in, fan_out) = (5, 6);
435 let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
436 .init([fan_out, fan_in], &Default::default())
437 .into_value();
438 }
439}