rnn/initializers/
initializers.rs1#[derive(Clone, Copy, Debug, PartialEq)]
2pub enum InitKind {
3 Zeros,
4 XavierUniform,
5 HeUniform,
6 Constant(f32),
7}
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum InitError {
11 InvalidShape,
12 ShapeMismatch,
13 NonFinite,
14}
15
16pub fn expected_parameter_counts(layers: &[usize]) -> Option<(usize, usize)> {
17 if layers.len() < 2 {
18 return None;
19 }
20
21 let mut weights = 0usize;
22 let mut biases = 0usize;
23 for i in 0..layers.len() - 1 {
24 let in_size = layers[i];
25 let out_size = layers[i + 1];
26 if in_size == 0 || out_size == 0 {
27 return None;
28 }
29 weights = weights.checked_add(in_size.checked_mul(out_size)?)?;
30 biases = biases.checked_add(out_size)?;
31 }
32 Some((weights, biases))
33}
34
35pub fn initialize_dense_parameters(
36 layers: &[usize],
37 weights: &mut [f32],
38 biases: &mut [f32],
39 init_kind: InitKind,
40 seed: u64,
41) -> Result<(), InitError> {
42 let (expected_w, expected_b) = expected_parameter_counts(layers).ok_or(InitError::InvalidShape)?;
43 if weights.len() != expected_w || biases.len() != expected_b {
44 return Err(InitError::ShapeMismatch);
45 }
46
47 let mut rng = SplitMix64::new(seed);
48 let mut w_off = 0usize;
49 let mut b_off = 0usize;
50
51 for i in 0..layers.len() - 1 {
52 let in_size = layers[i];
53 let out_size = layers[i + 1];
54 let w_len = in_size * out_size;
55 let (w_slice, b_slice) = (
56 &mut weights[w_off..w_off + w_len],
57 &mut biases[b_off..b_off + out_size],
58 );
59
60 match init_kind {
61 InitKind::Zeros => {
62 for w in w_slice {
63 *w = 0.0;
64 }
65 for b in b_slice {
66 *b = 0.0;
67 }
68 }
69 InitKind::Constant(value) => {
70 if !value.is_finite() {
71 return Err(InitError::NonFinite);
72 }
73 for w in w_slice {
74 *w = value;
75 }
76 for b in b_slice {
77 *b = 0.0;
78 }
79 }
80 InitKind::XavierUniform => {
81 let denom = (in_size + out_size) as f32;
82 if denom <= 0.0 || !denom.is_finite() {
83 return Err(InitError::NonFinite);
84 }
85 let limit = crate::math::sqrtf(6.0 / denom);
86 if !limit.is_finite() {
87 return Err(InitError::NonFinite);
88 }
89 for w in w_slice {
90 *w = random_uniform_symmetric(&mut rng, limit);
91 }
92 for b in b_slice {
93 *b = 0.0;
94 }
95 }
96 InitKind::HeUniform => {
97 let denom = in_size as f32;
98 if denom <= 0.0 || !denom.is_finite() {
99 return Err(InitError::NonFinite);
100 }
101 let limit = crate::math::sqrtf(6.0 / denom);
102 if !limit.is_finite() {
103 return Err(InitError::NonFinite);
104 }
105 for w in w_slice {
106 *w = random_uniform_symmetric(&mut rng, limit);
107 }
108 for b in b_slice {
109 *b = 0.0;
110 }
111 }
112 }
113
114 w_off += w_len;
115 b_off += out_size;
116 }
117
118 Ok(())
119}
120
121fn random_uniform_symmetric(rng: &mut SplitMix64, limit: f32) -> f32 {
122 let u = rng.next_f32_01();
123 (u * 2.0 - 1.0) * limit
124}
125
126struct SplitMix64 {
127 state: u64,
128}
129
130impl SplitMix64 {
131 fn new(seed: u64) -> Self {
132 Self { state: seed }
133 }
134
135 fn next_u64(&mut self) -> u64 {
136 self.state = self.state.wrapping_add(0x9E3779B97F4A7C15);
137 let mut z = self.state;
138 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
139 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
140 z ^ (z >> 31)
141 }
142
143 fn next_f32_01(&mut self) -> f32 {
144 let raw = (self.next_u64() >> 40) as u32;
145 raw as f32 / 16777216.0
146 }
147}