1use std::fmt::Display;
2
3use crate::{error::Error, sede::Sede};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
6pub enum Initialization {
7 Xavier,
13 He,
17 Fixed(f32),
19 Random,
21}
22
23pub fn calc_initialization(
24 typ: Initialization,
25 prev_layer_size: usize,
26 next_layer_size: usize,
27) -> f32 {
28 match typ {
29 Initialization::Random => fastrand::f32() * 2. - 1.,
30 Initialization::He => (fastrand::f32() * 2. - 1.) * (6.0 / prev_layer_size as f32).sqrt(),
31 Initialization::Xavier => {
32 (fastrand::f32() * 2. - 1.) * (6.0 / (prev_layer_size + next_layer_size) as f32).sqrt()
33 }
34 Initialization::Fixed(val) => val,
35 }
36}
37
38impl Display for Initialization {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Initialization::Random => write!(f, "Random"),
42 Initialization::He => write!(f, "He"),
43 Initialization::Xavier => write!(f, "Xavier"),
44 Initialization::Fixed(val) => write!(f, "Fixed({})", val),
45 }
46 }
47}
48
49impl Sede for Initialization {
50 fn serialize(&self) -> String {
51 format!("{}", self)
52 }
53
54 fn deserialize(s: &str) -> Result<Self, Error> {
55 if s == "Random" {
56 Ok(Initialization::Random)
57 } else if s == "He" {
58 Ok(Initialization::He)
59 } else if s == "Xavier" {
60 Ok(Initialization::Xavier)
61 } else if let Some(val) = s.strip_prefix("Fixed(").and_then(|s| s.strip_suffix(')')) {
62 val.parse::<f32>()
63 .map(Initialization::Fixed)
64 .map_err(|_| Error::SerializationError(format!("Invalid Fixed value: {}", val)))
65 } else {
66 Err(Error::SerializationError(format!(
67 "Unknown initialization type: {}",
68 s
69 )))
70 }
71 }
72}
73
74#[cfg(test)]
77mod tests {
78 use crate::initialization::Initialization;
79 use crate::nn::NN;
80 use fastrand;
81
82 #[test]
83 fn test_fixed_initialization() {
84 let fixed = 0.5f32;
85 let nn = NN::new(&[4, 3, 2]).with_initialization(Initialization::Fixed(fixed));
86 let weights = nn.get_weights();
87 assert!(weights.iter().all(|&w| (w - fixed).abs() < 1e-6));
88 }
89
90 #[test]
91 fn test_random_distribution_stats() {
92 use std::f32;
93 fastrand::seed(12345);
95 let nn = NN::new(&[100, 50, 20]).with_initialization(Initialization::Random);
96 let vals: Vec<f32> = nn.get_weights();
97 let mean: f32 = vals.iter().copied().sum::<f32>() / vals.len() as f32;
99 let var: f32 = vals.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / vals.len() as f32;
101
102 assert!(mean.abs() < 0.05, "mean too far from 0: {}", mean);
103 assert!((var - 0.3333).abs() < 0.05, "variance off: {}", var);
104
105 assert!(
107 vals.iter()
108 .all(|&v| v >= -1.0 - f32::EPSILON && v <= 1.0 + f32::EPSILON)
109 );
110 }
111}