candle_nn/
init.rs

1//! Variable initialization.
2// This is based on:
3// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
4use candle::{DType, Device, Result, Shape, Tensor, Var};
5
6/// Number of features as input or output of a layer.
7/// In Kaiming initialization, choosing `FanIn` preserves
8/// the magnitude of the variance of the weights in the
9/// forward pass, choosing `FanOut` preserves this
10/// magnitude in the backward pass.
11#[derive(Debug, Copy, Clone)]
12pub enum FanInOut {
13    FanIn,
14    FanOut,
15}
16
17impl FanInOut {
18    /// Compute the fan-in or fan-out value for a weight tensor of
19    /// the specified dimensions.
20    /// <https://github.com/pytorch/pytorch/blob/dbeacf11820e336e803bb719b7aaaf2125ae4d9c/torch/nn/init.py#L284>
21    pub fn for_shape(&self, shape: &Shape) -> usize {
22        let dims = shape.dims();
23        let receptive_field_size: usize = dims.iter().skip(2).product();
24        match &self {
25            FanInOut::FanIn => {
26                if dims.len() < 2 {
27                    1
28                } else {
29                    dims[1] * receptive_field_size
30                }
31            }
32            FanInOut::FanOut => {
33                if dims.is_empty() {
34                    1
35                } else {
36                    dims[0] * receptive_field_size
37                }
38            }
39        }
40    }
41}
42
43#[derive(Debug, Copy, Clone)]
44pub enum NormalOrUniform {
45    Normal,
46    Uniform,
47}
48
49/// The non-linear function that follows this layer. ReLU is the
50/// recommended value.
51#[derive(Debug, Copy, Clone)]
52pub enum NonLinearity {
53    ReLU,
54    Linear,
55    Sigmoid,
56    Tanh,
57    SELU,
58    ExplicitGain(f64),
59}
60
61impl NonLinearity {
62    // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#L67
63    pub fn gain(&self) -> f64 {
64        match *self {
65            NonLinearity::ReLU => 2f64.sqrt(),
66            NonLinearity::Tanh => 5. / 3.,
67            NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
68            NonLinearity::SELU => 0.75,
69            NonLinearity::ExplicitGain(g) => g,
70        }
71    }
72}
73
74/// Variable initializations.
75#[derive(Debug, Copy, Clone)]
76pub enum Init {
77    /// Constant value.
78    Const(f64),
79
80    /// Random normal with some mean and standard deviation.
81    Randn { mean: f64, stdev: f64 },
82
83    /// Uniform initialization between some lower and upper bounds.
84    Uniform { lo: f64, up: f64 },
85
86    /// Kaiming uniform initialization.
87    /// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification"
88    /// He, K. et al. (2015). This uses a uniform distribution.
89    Kaiming {
90        dist: NormalOrUniform,
91        fan: FanInOut,
92        non_linearity: NonLinearity,
93    },
94}
95
96pub const ZERO: Init = Init::Const(0.);
97pub const ONE: Init = Init::Const(1.);
98
99pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
100    dist: NormalOrUniform::Uniform,
101    fan: FanInOut::FanIn,
102    non_linearity: NonLinearity::ReLU,
103};
104
105pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
106    dist: NormalOrUniform::Normal,
107    fan: FanInOut::FanIn,
108    non_linearity: NonLinearity::ReLU,
109};
110
111impl Init {
112    /// Creates a new tensor with the specified shape, device, and initialization.
113    pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
114        match self {
115            Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
116            Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
117            Self::Const(cst) => {
118                Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
119            }
120            Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
121            Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
122            Self::Kaiming {
123                dist,
124                fan,
125                non_linearity,
126            } => {
127                let s = s.into();
128                let fan = fan.for_shape(&s);
129                let gain = non_linearity.gain();
130                let std = gain / (fan as f64).sqrt();
131                match dist {
132                    NormalOrUniform::Uniform => {
133                        let bound = 3f64.sqrt() * std;
134                        Var::rand_f64(-bound, bound, s, dtype, device)
135                    }
136                    NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
137                }
138            }
139        }
140    }
141}
142
143impl Default for Init {
144    fn default() -> Self {
145        Self::Const(0.)
146    }
147}