use crate::tensor::Shape;
use crate::config::Config;
use crate::module::{Param, ParamId};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor, s};
use crate as burn;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
#[derive(Config, Debug, PartialEq)]
pub enum Initializer {
Constant {
value: f64,
},
Ones,
Zeros,
Uniform {
min: f64,
max: f64,
},
Normal {
mean: f64,
std: f64,
},
KaimingUniform {
gain: f64,
fan_out_only: bool,
},
KaimingNormal {
gain: f64,
fan_out_only: bool,
},
XavierUniform {
gain: f64,
},
XavierNormal {
gain: f64,
},
Orthogonal {
gain: f64,
},
}
impl Initializer {
pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
&self,
shape: S,
device: &B::Device,
) -> Param<Tensor<B, D>> {
self.init_with(shape, None, None, device)
}
pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
&self,
shape: S,
fan_in: Option<usize>,
fan_out: Option<usize>,
device: &B::Device,
) -> Param<Tensor<B, D>> {
let device = device.clone();
let shape: Shape = shape.into();
let config = self.clone();
let shape_for_closure = shape.clone();
Param::uninitialized(
ParamId::new(),
move |device, require_grad| {
B::memory_persistent_allocations(device, (), move |_| {
let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
if require_grad {
tensor = tensor.require_grad();
}
tensor
})
},
device,
true,
shape_for_closure,
)
}
fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
&self,
shape: S,
fan_in: Option<usize>,
fan_out: Option<usize>,
device: &B::Device,
) -> Tensor<B, D> {
let shape = shape.into();
match self {
Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
Initializer::Ones => Tensor::<B, D>::ones(shape, device),
Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
Initializer::KaimingUniform { gain, fan_out_only } => {
let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
uniform_draw(shape, -a, a, device)
}
Initializer::KaimingNormal { gain, fan_out_only } => {
let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
normal_draw(shape, 0.0, std, device)
}
Initializer::XavierUniform { gain } => {
let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
uniform_draw(shape, -a, a, device)
}
Initializer::XavierNormal { gain } => {
let std = *gain * self.xavier_std(fan_in, fan_out);
normal_draw(shape, 0.0, std, device)
}
Initializer::Orthogonal { gain } => {
assert!(
D >= 2,
"Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
);
let rows: usize = shape.dims::<D>()[0];
let cols: usize = shape.num_elements() / rows;
let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
if rows < cols {
t = t.transpose();
}
let (q, r) = qr_decomposition(t, device);
let [r_rows, r_cols] = r.clone().dims();
let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
.matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
let ph = diag_r.clone().sign();
let mut q = q.mul(ph);
if rows < cols {
q = q.transpose();
}
q.reshape(shape).mul_scalar(*gain)
}
}
}
fn kaiming_std(
&self,
fan_out_only: bool,
fan_in: Option<usize>,
fan_out: Option<usize>,
) -> f64 {
let fan = if fan_out_only { fan_out } else { fan_in };
let fan = fan.expect(
"Can't use Kaiming initialization without specifying fan. Use init_with method.",
);
1.0 / (fan as f64).sqrt()
}
fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
let fan_in = fan_in.expect(
"Can't use Xavier initialization without specifying fan in. Use init_with method and \
provide fan_in.",
);
let fan_out = fan_out.expect(
"Can't use Xavier initialization without specifying fan out. Use init_with method and \
provide fan_out.",
);
(2.0 / (fan_in + fan_out) as f64).sqrt()
}
}
fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
shape: S,
low: f64,
high: f64,
device: &B::Device,
) -> Tensor<B, D> {
let distribution = Distribution::Uniform(low, high);
Tensor::<B, D>::random(shape, distribution, device)
}
fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
shape: S,
mean: f64,
std: f64,
device: &B::Device,
) -> Tensor<B, D> {
let distribution = Distribution::Normal(mean, std);
Tensor::<B, D>::random(shape, distribution, device)
}
fn qr_decomposition<B: Backend>(
a: Tensor<B, 2>,
device: &B::Device,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
let [m, n] = a.clone().dims();
let mut q = Tensor::<B, 2>::zeros([m, n], device);
let mut r = Tensor::<B, 2>::zeros([n, n], device);
for j in 0..n {
let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
for i in 0..j {
let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
let r_ij = q_i.clone().mul(v.clone()).sum();
r = r
.clone()
.slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
v = v - q_i.mul(r_ij);
}
let r_jj = v
.clone()
.powf(Tensor::from_floats([2.0], device))
.sum()
.sqrt();
r = r
.clone()
.slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
let q_j = v / r_jj;
q = q
.clone()
.slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
}
(q, r)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_tensor::{ElementConversion, TensorData};
use num_traits::Pow;
pub type TB = burn_ndarray::NdArray<f32>;
use burn_tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TB>;
fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
let (actual_vars, actual_means) = tensor.clone().var_mean(0);
let actual_vars = actual_vars.to_data();
let actual_vars = actual_vars.as_slice::<FT>().unwrap();
let actual_means = actual_means.to_data();
let actual_means = actual_means.as_slice::<FT>().unwrap();
for i in 0..tensor.shape().dims[0] {
let actual_var = actual_vars[i] as f64;
let actual_mean = actual_means[i] as f64;
assert!(
(expected_var - actual_var).abs() <= 0.1,
"Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
);
assert!(
(expected_mean - actual_mean).abs() <= 0.1,
"Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
);
}
}
#[test]
fn initializer_uniform_init() {
let device = Default::default();
TB::seed(&device, 0);
let (min, max) = (0.0, 1.0);
let uniform = Initializer::Uniform { min, max };
let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
tensor
.into_data()
.assert_within_range::<FT>(min.elem()..max.elem());
}
#[test]
fn initializer_normal_init() {
let device = Default::default();
TB::seed(&device, 0);
let (mean, std) = (0.0, 1.0);
let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
.init([1000], &Default::default())
.into_value();
let (var_act, mean_act) = normal.var_mean(0);
let var_act: f32 = var_act.into_scalar().elem();
let mean_act: f32 = mean_act.into_scalar().elem();
assert!(
var_act > 0.9 && var_act < 1.1,
"Expected variance to be between 1.0 += 0.1, but got {var_act}"
);
assert!(
mean_act > -0.1 && mean_act < 0.1,
"Expected mean to be between 0.0 += 0.1, but got {mean_act}"
);
}
#[test]
fn initializer_constant_init() {
let value = 5.0;
let constants: Tensor<TB, 4> = Initializer::Constant { value }
.init([2, 2, 2, 2], &Default::default())
.into_value();
constants.sum().to_data().assert_approx_eq::<FT>(
&TensorData::from([value as f32 * 16.0]),
Tolerance::default(),
);
}
#[test]
fn initializer_zeros_init() {
let zeros: Tensor<TB, 4> = Initializer::Zeros
.init([2, 2, 2, 2], &Default::default())
.into_value();
zeros
.sum()
.to_data()
.assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
}
#[test]
fn initializer_ones_init() {
let ones: Tensor<TB, 4> = Initializer::Ones
.init([2, 2, 2, 2], &Default::default())
.into_value();
ones.sum()
.to_data()
.assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
}
#[test]
fn initializer_kaiming_uniform_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let (fan_in, fan_out) = (5, 6);
let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
gain,
fan_out_only: false,
}
.init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
.into_value();
tensor.into_data().assert_within_range(-k..k);
}
#[test]
fn initializer_kaiming_normal_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (1000, 10);
let expected_mean = 0_f64;
let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
gain,
fan_out_only: false,
}
.init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
.into_value();
assert_normal_init(expected_mean, expected_var, &tensor)
}
#[test]
fn initializer_kaiming_uniform_init_bias() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let shape = [3];
let fan_in = 5;
let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
gain,
fan_out_only: false,
}
.init_with(shape, Some(fan_in), None, &Default::default())
.into_value();
tensor.into_data().assert_within_range(-k..k);
}
#[test]
fn initializer_kaiming_uniform_init_fan_out() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let (fan_in, fan_out) = (5, 6);
let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
gain,
fan_out_only: true,
}
.init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
.into_value();
tensor.into_data().assert_within_range(-k..k);
}
#[test]
#[should_panic]
fn initializer_kaiming_uniform_no_fan() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2_f64;
let (fan_in, fan_out) = (5, 6);
let _: Tensor<TB, 2> = Initializer::KaimingUniform {
gain,
fan_out_only: false,
}
.init([fan_out, fan_in], &Default::default())
.into_value();
}
#[test]
fn initializer_xavier_uniform_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (5, 6);
let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
.init_with(
[fan_out, fan_in],
Some(fan_in),
Some(fan_out),
&Default::default(),
)
.into_value();
tensor.into_data().assert_within_range(-bound..bound);
}
#[test]
fn initializer_xavier_normal_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (1000, 10);
let expected_mean = 0_f64;
let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
.init_with(
[fan_out, fan_in],
Some(fan_in),
Some(fan_out),
&Default::default(),
)
.into_value();
assert_normal_init(expected_mean, expected_var, &tensor)
}
#[test]
#[should_panic]
fn initializer_xavier_uniform_no_fan() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 2.;
let (fan_in, fan_out) = (5, 6);
let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
.init([fan_out, fan_in], &Default::default())
.into_value();
}
#[test]
fn test_qr_decomposition() {
let device = Default::default();
TB::seed(&device, 0);
let a = Tensor::<TB, 2>::from_floats(
[[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
&Default::default(),
);
let qr = qr_decomposition(a.clone(), &Default::default());
let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
q_matmul_r
.into_data()
.assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
}
#[test]
fn initializer_orthogonal_correct() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 1.;
let size = 10;
let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
.init([size, size], &Default::default())
.into_value();
let eye = Tensor::<TB, 2>::eye(size, &Default::default());
q.clone()
.transpose()
.matmul(q)
.into_data()
.assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
}
#[test]
fn initializer_orthogonal_init() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 1.;
let shape = [25, 30];
let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
.init(shape, &Default::default())
.into_value();
let dims = t.dims();
assert_eq!(
shape, dims,
"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
);
let shape = [24, 6, 85];
let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
.init(shape, &Default::default())
.into_value();
let dims = t.dims();
assert_eq!(
shape, dims,
"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
);
}
#[test]
#[should_panic]
fn initializer_orthogonal_init_1d() {
let device = Default::default();
TB::seed(&device, 0);
let gain = 1.;
let shape = [3];
let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
.init(shape, &Default::default())
.into_value();
}
}