eryon_surface/model/impls/
impl_init.rs

1/*
2    Appellation: impl_init <module>
3    Contrib: @FL03
4*/
5
6use crate::model::SurfaceModel;
7use cnc::prelude::{Init, InitInplace, ModelParams, Params};
8#[cfg(feature = "rand")]
9use cnc::{init::Initialize, rand_distr};
10use ndarray::ScalarOperand;
11use num_traits::{Float, FromPrimitive};
12
13impl<A> SurfaceModel<A> {
14    /// consumes the current instance to initialize the model params
15    pub fn init(self) -> Self
16    where
17        Self: Init,
18    {
19        <Self as Init>::init(self)
20    }
21    /// initialize the model params in place
22    pub fn init_inplace(&mut self) -> &mut Self
23    where
24        Self: InitInplace,
25    {
26        <Self as InitInplace>::init_inplace(self)
27    }
28}
29
30#[cfg(feature = "rand")]
31impl<A> Init for SurfaceModel<A>
32where
33    A: Float + FromPrimitive + ScalarOperand,
34    rand_distr::StandardNormal: rand_distr::Distribution<A>,
35{
36    fn init(self) -> Self {
37        // get a copy of the model's features
38        let features = self.features();
39        let input = Params::glorot_normal(features.dim_input());
40        let output = Params::glorot_normal(features.dim_output());
41        let hidden = (0..features.layers())
42            .map(|_| Params::glorot_normal(features.dim_hidden()))
43            .collect::<Vec<_>>();
44        Self {
45            params: ModelParams::new(input, hidden, output),
46            ..self
47        }
48    }
49}
50
51#[cfg(not(feature = "rand"))]
52impl<A> Init for SurfaceModel<A>
53where
54    A: Float + FromPrimitive + ScalarOperand,
55{
56    fn init(self) -> Self {
57        let inputs = self.features().input() as f64;
58        let outputs = self.features().output() as f64;
59
60        // Xavier/Glorot-inspired deterministic initialization
61        let scale_input = A::from_f64(inputs.sqrt().recip()).unwrap_or(A::one());
62        let scale_hidden =
63            A::from_f64(((inputs + outputs) / 2.0).sqrt().recip()).unwrap_or(A::one());
64        let scale_output = A::from_f64(outputs.sqrt().recip()).unwrap_or(A::one());
65
66        // Helper to create alternating fixed patterns
67        let init_params = |dims: (usize, usize), scale: A| -> Params<A, ndarray::Ix2> {
68            let mut params = Params::zeros(dims);
69            let weights = params.weights_mut();
70
71            for i in 0..weights.nrows() {
72                for j in 0..weights.ncols() {
73                    // Alternating pattern based on indices
74                    let value = if (i + j) % 2 == 0 { scale } else { -scale };
75                    weights[[i, j]] = value;
76                }
77            }
78            params
79        };
80
81        // Create params with alternating patterns
82        let input = init_params(self.features().dim_input(), scale_input);
83        let output = init_params(self.features().dim_output(), scale_output);
84
85        let hidden = (0..self.features().layers())
86            .map(|_| init_params(self.features().dim_hidden(), scale_hidden))
87            .collect::<Vec<_>>();
88
89        Self {
90            params: ModelParams::new(input, hidden, output),
91            ..self
92        }
93    }
94}
95
96#[cfg(feature = "rand")]
97impl<A> InitInplace for SurfaceModel<A>
98where
99    A: Float + FromPrimitive + ScalarOperand,
100    rand_distr::StandardNormal: rand_distr::Distribution<A>,
101{
102    fn init_inplace(&mut self) -> &mut Self {
103        // get a copy of the model's features
104        let features = self.features();
105        // initialize the params with glorot normal distribution
106        let input = Params::glorot_normal(features.dim_input());
107        // initialize the output layer
108        let output = Params::glorot_normal(features.dim_output());
109        let hidden = (0..features.layers())
110            .map(|_| Params::glorot_normal(features.dim_hidden()))
111            .collect::<Vec<_>>();
112        self.set_params(ModelParams::new(input, hidden, output));
113        self
114    }
115}
116
117#[cfg(not(feature = "rand"))]
118impl<A> InitInplace for SurfaceModel<A>
119where
120    A: Float + FromPrimitive + ScalarOperand,
121{
122    fn init_inplace(&mut self) -> &mut Self {
123        let inputs = self.features().input() as f64;
124        let outputs = self.features().output() as f64;
125
126        let scale_input = A::from_f64(inputs.sqrt().recip()).unwrap_or(A::one());
127        let scale_hidden =
128            A::from_f64(((inputs + outputs) / 2.0).sqrt().recip()).unwrap_or(A::one());
129        let scale_output = A::from_f64(outputs.sqrt().recip()).unwrap_or(A::one());
130
131        let init_params = |dims: (usize, usize), scale: A| -> Params<A, ndarray::Ix2> {
132            let mut params = Params::zeros(dims);
133            let weights = params.weights_mut();
134
135            for i in 0..weights.nrows() {
136                for j in 0..weights.ncols() {
137                    let value = if (i + j) % 2 == 0 { scale } else { -scale };
138                    weights[[i, j]] = value;
139                }
140            }
141            params
142        };
143        // initialize the input layer
144        let i = init_params(self.features().dim_input(), scale_input);
145        // initialize the output layer
146        let o = init_params(self.features().dim_output(), scale_output);
147        // initialize the hidden layers
148        let h = (0..self.features().layers())
149            .map(|_| init_params(self.features().dim_hidden(), scale_hidden))
150            .collect::<Vec<_>>();
151        // override the current params
152        self.set_params(ModelParams::new(i, h, o));
153        self
154    }
155}