runnt/
initialization.rs

1use std::fmt::Display;
2
3use crate::{error::Error, sede::Sede};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
6pub enum Initialization {
7    /// Uniform initialization of weights between `+-sqrt(6/(in+out))`
8    ///
9    /// Best for Tanh, Sigmoid
10    ///
11    /// This is the default initialization
12    Xavier,
13    /// Uniform initialization of weights between `+-sqrt(6/in)`
14    ///
15    /// Best for Relu, Swish
16    He,
17    /// set all weights the same
18    Fixed(f32),
19    ///-1 to 1
20    Random,
21}
22
23pub fn calc_initialization(
24    typ: Initialization,
25    prev_layer_size: usize,
26    next_layer_size: usize,
27) -> f32 {
28    match typ {
29        Initialization::Random => fastrand::f32() * 2. - 1.,
30        Initialization::He => (fastrand::f32() * 2. - 1.) * (6.0 / prev_layer_size as f32).sqrt(),
31        Initialization::Xavier => {
32            (fastrand::f32() * 2. - 1.) * (6.0 / (prev_layer_size + next_layer_size) as f32).sqrt()
33        }
34        Initialization::Fixed(val) => val,
35    }
36}
37
38impl Display for Initialization {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Initialization::Random => write!(f, "Random"),
42            Initialization::He => write!(f, "He"),
43            Initialization::Xavier => write!(f, "Xavier"),
44            Initialization::Fixed(val) => write!(f, "Fixed({})", val),
45        }
46    }
47}
48
49impl Sede for Initialization {
50    fn serialize(&self) -> String {
51        format!("{}", self)
52    }
53
54    fn deserialize(s: &str) -> Result<Self, Error> {
55        if s == "Random" {
56            Ok(Initialization::Random)
57        } else if s == "He" {
58            Ok(Initialization::He)
59        } else if s == "Xavier" {
60            Ok(Initialization::Xavier)
61        } else if let Some(val) = s.strip_prefix("Fixed(").and_then(|s| s.strip_suffix(')')) {
62            val.parse::<f32>()
63                .map(Initialization::Fixed)
64                .map_err(|_| Error::SerializationError(format!("Invalid Fixed value: {}", val)))
65        } else {
66            Err(Error::SerializationError(format!(
67                "Unknown initialization type: {}",
68                s
69            )))
70        }
71    }
72}
73
74// -------------------- Tests for initialization --------------------
75
76#[cfg(test)]
77mod tests {
78    use crate::initialization::Initialization;
79    use crate::nn::NN;
80    use fastrand;
81
82    #[test]
83    fn test_fixed_initialization() {
84        let fixed = 0.5f32;
85        let nn = NN::new(&[4, 3, 2]).with_initialization(Initialization::Fixed(fixed));
86        let weights = nn.get_weights();
87        assert!(weights.iter().all(|&w| (w - fixed).abs() < 1e-6));
88    }
89
90    #[test]
91    fn test_random_distribution_stats() {
92        use std::f32;
93        // create reasonably large network to sample many weights
94        fastrand::seed(12345);
95        let nn = NN::new(&[100, 50, 20]).with_initialization(Initialization::Random);
96        let vals: Vec<f32> = nn.get_weights();
97        // mean should be near 0 for symmetric Random in [-1,1]
98        let mean: f32 = vals.iter().copied().sum::<f32>() / vals.len() as f32;
99        // variance approx 1/3 for uniform[-1,1]
100        let var: f32 = vals.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / vals.len() as f32;
101
102        assert!(mean.abs() < 0.05, "mean too far from 0: {}", mean);
103        assert!((var - 0.3333).abs() < 0.05, "variance off: {}", var);
104
105        // ensure values fall within [-1,1]
106        assert!(
107            vals.iter()
108                .all(|&v| v >= -1.0 - f32::EPSILON && v <= 1.0 + f32::EPSILON)
109        );
110    }
111}