eryon_surface/model/impls/
impl_init.rs1use crate::model::SurfaceModel;
7use cnc::prelude::{Init, InitInplace, ModelParams, Params};
8#[cfg(feature = "rand")]
9use cnc::{init::Initialize, rand_distr};
10use ndarray::ScalarOperand;
11use num_traits::{Float, FromPrimitive};
12
13impl<A> SurfaceModel<A> {
14 pub fn init(self) -> Self
16 where
17 Self: Init,
18 {
19 <Self as Init>::init(self)
20 }
21 pub fn init_inplace(&mut self) -> &mut Self
23 where
24 Self: InitInplace,
25 {
26 <Self as InitInplace>::init_inplace(self)
27 }
28}
29
30#[cfg(feature = "rand")]
31impl<A> Init for SurfaceModel<A>
32where
33 A: Float + FromPrimitive + ScalarOperand,
34 rand_distr::StandardNormal: rand_distr::Distribution<A>,
35{
36 fn init(self) -> Self {
37 let features = self.features();
39 let input = Params::glorot_normal(features.dim_input());
40 let output = Params::glorot_normal(features.dim_output());
41 let hidden = (0..features.layers())
42 .map(|_| Params::glorot_normal(features.dim_hidden()))
43 .collect::<Vec<_>>();
44 Self {
45 params: ModelParams::new(input, hidden, output),
46 ..self
47 }
48 }
49}
50
51#[cfg(not(feature = "rand"))]
52impl<A> Init for SurfaceModel<A>
53where
54 A: Float + FromPrimitive + ScalarOperand,
55{
56 fn init(self) -> Self {
57 let inputs = self.features().input() as f64;
58 let outputs = self.features().output() as f64;
59
60 let scale_input = A::from_f64(inputs.sqrt().recip()).unwrap_or(A::one());
62 let scale_hidden =
63 A::from_f64(((inputs + outputs) / 2.0).sqrt().recip()).unwrap_or(A::one());
64 let scale_output = A::from_f64(outputs.sqrt().recip()).unwrap_or(A::one());
65
66 let init_params = |dims: (usize, usize), scale: A| -> Params<A, ndarray::Ix2> {
68 let mut params = Params::zeros(dims);
69 let weights = params.weights_mut();
70
71 for i in 0..weights.nrows() {
72 for j in 0..weights.ncols() {
73 let value = if (i + j) % 2 == 0 { scale } else { -scale };
75 weights[[i, j]] = value;
76 }
77 }
78 params
79 };
80
81 let input = init_params(self.features().dim_input(), scale_input);
83 let output = init_params(self.features().dim_output(), scale_output);
84
85 let hidden = (0..self.features().layers())
86 .map(|_| init_params(self.features().dim_hidden(), scale_hidden))
87 .collect::<Vec<_>>();
88
89 Self {
90 params: ModelParams::new(input, hidden, output),
91 ..self
92 }
93 }
94}
95
96#[cfg(feature = "rand")]
97impl<A> InitInplace for SurfaceModel<A>
98where
99 A: Float + FromPrimitive + ScalarOperand,
100 rand_distr::StandardNormal: rand_distr::Distribution<A>,
101{
102 fn init_inplace(&mut self) -> &mut Self {
103 let features = self.features();
105 let input = Params::glorot_normal(features.dim_input());
107 let output = Params::glorot_normal(features.dim_output());
109 let hidden = (0..features.layers())
110 .map(|_| Params::glorot_normal(features.dim_hidden()))
111 .collect::<Vec<_>>();
112 self.set_params(ModelParams::new(input, hidden, output));
113 self
114 }
115}
116
117#[cfg(not(feature = "rand"))]
118impl<A> InitInplace for SurfaceModel<A>
119where
120 A: Float + FromPrimitive + ScalarOperand,
121{
122 fn init_inplace(&mut self) -> &mut Self {
123 let inputs = self.features().input() as f64;
124 let outputs = self.features().output() as f64;
125
126 let scale_input = A::from_f64(inputs.sqrt().recip()).unwrap_or(A::one());
127 let scale_hidden =
128 A::from_f64(((inputs + outputs) / 2.0).sqrt().recip()).unwrap_or(A::one());
129 let scale_output = A::from_f64(outputs.sqrt().recip()).unwrap_or(A::one());
130
131 let init_params = |dims: (usize, usize), scale: A| -> Params<A, ndarray::Ix2> {
132 let mut params = Params::zeros(dims);
133 let weights = params.weights_mut();
134
135 for i in 0..weights.nrows() {
136 for j in 0..weights.ncols() {
137 let value = if (i + j) % 2 == 0 { scale } else { -scale };
138 weights[[i, j]] = value;
139 }
140 }
141 params
142 };
143 let i = init_params(self.features().dim_input(), scale_input);
145 let o = init_params(self.features().dim_output(), scale_output);
147 let h = (0..self.features().layers())
149 .map(|_| init_params(self.features().dim_hidden(), scale_hidden))
150 .collect::<Vec<_>>();
151 self.set_params(ModelParams::new(i, h, o));
153 self
154 }
155}