1use candle::{DType, Device, Result, Shape, Tensor, Var};
5
6#[derive(Debug, Copy, Clone)]
12pub enum FanInOut {
13 FanIn,
14 FanOut,
15}
16
17impl FanInOut {
18 pub fn for_shape(&self, shape: &Shape) -> usize {
22 let dims = shape.dims();
23 let receptive_field_size: usize = dims.iter().skip(2).product();
24 match &self {
25 FanInOut::FanIn => {
26 if dims.len() < 2 {
27 1
28 } else {
29 dims[1] * receptive_field_size
30 }
31 }
32 FanInOut::FanOut => {
33 if dims.is_empty() {
34 1
35 } else {
36 dims[0] * receptive_field_size
37 }
38 }
39 }
40 }
41}
42
43#[derive(Debug, Copy, Clone)]
44pub enum NormalOrUniform {
45 Normal,
46 Uniform,
47}
48
49#[derive(Debug, Copy, Clone)]
52pub enum NonLinearity {
53 ReLU,
54 Linear,
55 Sigmoid,
56 Tanh,
57 SELU,
58 ExplicitGain(f64),
59}
60
61impl NonLinearity {
62 pub fn gain(&self) -> f64 {
64 match *self {
65 NonLinearity::ReLU => 2f64.sqrt(),
66 NonLinearity::Tanh => 5. / 3.,
67 NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
68 NonLinearity::SELU => 0.75,
69 NonLinearity::ExplicitGain(g) => g,
70 }
71 }
72}
73
74#[derive(Debug, Copy, Clone)]
76pub enum Init {
77 Const(f64),
79
80 Randn { mean: f64, stdev: f64 },
82
83 Uniform { lo: f64, up: f64 },
85
86 Kaiming {
90 dist: NormalOrUniform,
91 fan: FanInOut,
92 non_linearity: NonLinearity,
93 },
94}
95
96pub const ZERO: Init = Init::Const(0.);
97pub const ONE: Init = Init::Const(1.);
98
99pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
100 dist: NormalOrUniform::Uniform,
101 fan: FanInOut::FanIn,
102 non_linearity: NonLinearity::ReLU,
103};
104
105pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
106 dist: NormalOrUniform::Normal,
107 fan: FanInOut::FanIn,
108 non_linearity: NonLinearity::ReLU,
109};
110
111impl Init {
112 pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
114 match self {
115 Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
116 Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
117 Self::Const(cst) => {
118 Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
119 }
120 Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
121 Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
122 Self::Kaiming {
123 dist,
124 fan,
125 non_linearity,
126 } => {
127 let s = s.into();
128 let fan = fan.for_shape(&s);
129 let gain = non_linearity.gain();
130 let std = gain / (fan as f64).sqrt();
131 match dist {
132 NormalOrUniform::Uniform => {
133 let bound = 3f64.sqrt() * std;
134 Var::rand_f64(-bound, bound, s, dtype, device)
135 }
136 NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
137 }
138 }
139 }
140 }
141}
142
143impl Default for Init {
144 fn default() -> Self {
145 Self::Const(0.)
146 }
147}