Skip to main content

rnn/initializers/
initializers.rs

1#[derive(Clone, Copy, Debug, PartialEq)]
2pub enum InitKind {
3    Zeros,
4    XavierUniform,
5    HeUniform,
6    Constant(f32),
7}
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum InitError {
11    InvalidShape,
12    ShapeMismatch,
13    NonFinite,
14}
15
16pub fn expected_parameter_counts(layers: &[usize]) -> Option<(usize, usize)> {
17    if layers.len() < 2 {
18        return None;
19    }
20
21    let mut weights = 0usize;
22    let mut biases = 0usize;
23    for i in 0..layers.len() - 1 {
24        let in_size = layers[i];
25        let out_size = layers[i + 1];
26        if in_size == 0 || out_size == 0 {
27            return None;
28        }
29        weights = weights.checked_add(in_size.checked_mul(out_size)?)?;
30        biases = biases.checked_add(out_size)?;
31    }
32    Some((weights, biases))
33}
34
35pub fn initialize_dense_parameters(
36    layers: &[usize],
37    weights: &mut [f32],
38    biases: &mut [f32],
39    init_kind: InitKind,
40    seed: u64,
41) -> Result<(), InitError> {
42    let (expected_w, expected_b) = expected_parameter_counts(layers).ok_or(InitError::InvalidShape)?;
43    if weights.len() != expected_w || biases.len() != expected_b {
44        return Err(InitError::ShapeMismatch);
45    }
46
47    let mut rng = SplitMix64::new(seed);
48    let mut w_off = 0usize;
49    let mut b_off = 0usize;
50
51    for i in 0..layers.len() - 1 {
52        let in_size = layers[i];
53        let out_size = layers[i + 1];
54        let w_len = in_size * out_size;
55        let (w_slice, b_slice) = (
56            &mut weights[w_off..w_off + w_len],
57            &mut biases[b_off..b_off + out_size],
58        );
59
60        match init_kind {
61            InitKind::Zeros => {
62                for w in w_slice {
63                    *w = 0.0;
64                }
65                for b in b_slice {
66                    *b = 0.0;
67                }
68            }
69            InitKind::Constant(value) => {
70                if !value.is_finite() {
71                    return Err(InitError::NonFinite);
72                }
73                for w in w_slice {
74                    *w = value;
75                }
76                for b in b_slice {
77                    *b = 0.0;
78                }
79            }
80            InitKind::XavierUniform => {
81                let denom = (in_size + out_size) as f32;
82                if denom <= 0.0 || !denom.is_finite() {
83                    return Err(InitError::NonFinite);
84                }
85                let limit = crate::math::sqrtf(6.0 / denom);
86                if !limit.is_finite() {
87                    return Err(InitError::NonFinite);
88                }
89                for w in w_slice {
90                    *w = random_uniform_symmetric(&mut rng, limit);
91                }
92                for b in b_slice {
93                    *b = 0.0;
94                }
95            }
96            InitKind::HeUniform => {
97                let denom = in_size as f32;
98                if denom <= 0.0 || !denom.is_finite() {
99                    return Err(InitError::NonFinite);
100                }
101                let limit = crate::math::sqrtf(6.0 / denom);
102                if !limit.is_finite() {
103                    return Err(InitError::NonFinite);
104                }
105                for w in w_slice {
106                    *w = random_uniform_symmetric(&mut rng, limit);
107                }
108                for b in b_slice {
109                    *b = 0.0;
110                }
111            }
112        }
113
114        w_off += w_len;
115        b_off += out_size;
116    }
117
118    Ok(())
119}
120
121fn random_uniform_symmetric(rng: &mut SplitMix64, limit: f32) -> f32 {
122    let u = rng.next_f32_01();
123    (u * 2.0 - 1.0) * limit
124}
125
126struct SplitMix64 {
127    state: u64,
128}
129
130impl SplitMix64 {
131    fn new(seed: u64) -> Self {
132        Self { state: seed }
133    }
134
135    fn next_u64(&mut self) -> u64 {
136        self.state = self.state.wrapping_add(0x9E3779B97F4A7C15);
137        let mut z = self.state;
138        z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
139        z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
140        z ^ (z >> 31)
141    }
142
143    fn next_f32_01(&mut self) -> f32 {
144        let raw = (self.next_u64() >> 40) as u32;
145        raw as f32 / 16777216.0
146    }
147}