#[derive(Clone, Copy, Debug, PartialEq)]
pub enum InitKind {
Zeros,
XavierUniform,
HeUniform,
Constant(f32),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InitError {
InvalidShape,
ShapeMismatch,
NonFinite,
}
pub fn expected_parameter_counts(layers: &[usize]) -> Option<(usize, usize)> {
if layers.len() < 2 {
return None;
}
let mut weights = 0usize;
let mut biases = 0usize;
for i in 0..layers.len() - 1 {
let in_size = layers[i];
let out_size = layers[i + 1];
if in_size == 0 || out_size == 0 {
return None;
}
weights = weights.checked_add(in_size.checked_mul(out_size)?)?;
biases = biases.checked_add(out_size)?;
}
Some((weights, biases))
}
pub fn initialize_dense_parameters(
layers: &[usize],
weights: &mut [f32],
biases: &mut [f32],
init_kind: InitKind,
seed: u64,
) -> Result<(), InitError> {
let (expected_w, expected_b) = expected_parameter_counts(layers).ok_or(InitError::InvalidShape)?;
if weights.len() != expected_w || biases.len() != expected_b {
return Err(InitError::ShapeMismatch);
}
let mut rng = SplitMix64::new(seed);
let mut w_off = 0usize;
let mut b_off = 0usize;
for i in 0..layers.len() - 1 {
let in_size = layers[i];
let out_size = layers[i + 1];
let w_len = in_size * out_size;
let (w_slice, b_slice) = (
&mut weights[w_off..w_off + w_len],
&mut biases[b_off..b_off + out_size],
);
match init_kind {
InitKind::Zeros => {
for w in w_slice {
*w = 0.0;
}
for b in b_slice {
*b = 0.0;
}
}
InitKind::Constant(value) => {
if !value.is_finite() {
return Err(InitError::NonFinite);
}
for w in w_slice {
*w = value;
}
for b in b_slice {
*b = 0.0;
}
}
InitKind::XavierUniform => {
let denom = (in_size + out_size) as f32;
if denom <= 0.0 || !denom.is_finite() {
return Err(InitError::NonFinite);
}
let limit = crate::math::sqrtf(6.0 / denom);
if !limit.is_finite() {
return Err(InitError::NonFinite);
}
for w in w_slice {
*w = random_uniform_symmetric(&mut rng, limit);
}
for b in b_slice {
*b = 0.0;
}
}
InitKind::HeUniform => {
let denom = in_size as f32;
if denom <= 0.0 || !denom.is_finite() {
return Err(InitError::NonFinite);
}
let limit = crate::math::sqrtf(6.0 / denom);
if !limit.is_finite() {
return Err(InitError::NonFinite);
}
for w in w_slice {
*w = random_uniform_symmetric(&mut rng, limit);
}
for b in b_slice {
*b = 0.0;
}
}
}
w_off += w_len;
b_off += out_size;
}
Ok(())
}
fn random_uniform_symmetric(rng: &mut SplitMix64, limit: f32) -> f32 {
let u = rng.next_f32_01();
(u * 2.0 - 1.0) * limit
}
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E3779B97F4A7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^ (z >> 31)
}
fn next_f32_01(&mut self) -> f32 {
let raw = (self.next_u64() >> 40) as u32;
raw as f32 / 16777216.0
}
}