1use super::layout::ModelFeatures;
6use cnc::params::ParamsBase;
7use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
8use num_traits::{One, Zero};
9
10pub type ModelParams<A = f64, D = Ix2> = ModelParamsBase<ndarray::OwnedRepr<A>, D>;
11
12pub struct ModelParamsBase<S, D = Ix2>
20where
21 D: Dimension,
22 S: RawData,
23{
24 pub input: ParamsBase<S, D>,
26 pub hidden: Vec<ParamsBase<S, D>>,
28 pub output: ParamsBase<S, D>,
30}
31
32impl<A, S, D> ModelParamsBase<S, D>
33where
34 D: Dimension,
35 S: RawData<Elem = A>,
36{
37 pub fn new(
38 input: ParamsBase<S, D>,
39 hidden: Vec<ParamsBase<S, D>>,
40 output: ParamsBase<S, D>,
41 ) -> Self {
42 Self {
43 input,
44 hidden,
45 output,
46 }
47 }
48 pub fn is_shallow(&self) -> bool {
50 self.hidden.is_empty() || self.hidden.len() == 1
51 }
52 pub const fn input(&self) -> &ParamsBase<S, D> {
54 &self.input
55 }
56 #[inline]
58 pub fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
59 &mut self.input
60 }
61 pub const fn hidden(&self) -> &Vec<ParamsBase<S, D>> {
63 &self.hidden
64 }
65 #[inline]
67 pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>] {
68 self.hidden.as_slice()
69 }
70 #[inline]
72 pub fn hidden_mut(&mut self) -> &mut Vec<ParamsBase<S, D>> {
73 &mut self.hidden
74 }
75 pub const fn output(&self) -> &ParamsBase<S, D> {
77 &self.output
78 }
79 #[inline]
81 pub fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
82 &mut self.output
83 }
84 pub fn set_input(&mut self, input: ParamsBase<S, D>) {
86 *self.input_mut() = input;
87 }
88 pub fn set_hidden<I>(&mut self, iter: I)
90 where
91 I: IntoIterator<Item = ParamsBase<S, D>>,
92 {
93 *self.hidden_mut() = Vec::from_iter(iter);
94 }
95 pub fn set_output(&mut self, output: ParamsBase<S, D>) {
97 self.output = output;
98 }
99 pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
101 Self { input, ..self }
102 }
103 pub fn with_hidden<I>(self, iter: I) -> Self
105 where
106 I: IntoIterator<Item = ParamsBase<S, D>>,
107 {
108 Self {
109 hidden: Vec::from_iter(iter),
110 ..self
111 }
112 }
113 pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
115 Self { output, ..self }
116 }
117 pub fn dim_input(&self) -> <D as Dimension>::Pattern {
119 self.input().dim()
120 }
121 pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
123 assert!(self.hidden.iter().all(|p| p.dim() == self.hidden[0].dim()));
124 self.hidden()[0].dim()
125 }
126 pub fn dim_output(&self) -> <D as Dimension>::Pattern {
128 self.output.dim()
129 }
130 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
133 where
134 A: Clone,
135 S: Data,
136 ParamsBase<S, D>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
137 {
138 let mut output = self.input().forward(input)?;
139 for layer in self.hidden() {
140 output = layer.forward(&output)?;
141 }
142 self.output().forward(&output)
143 }
144}
145
146impl<A, S> ModelParamsBase<S>
147where
148 S: RawData<Elem = A>,
149{
150 pub fn default(features: ModelFeatures) -> Self
153 where
154 A: Clone + Default,
155 S: DataOwned,
156 {
157 let input = ParamsBase::default(features.dim_input());
158 let hidden = (0..features.layers())
159 .map(|_| ParamsBase::default(features.dim_hidden()))
160 .collect::<Vec<_>>();
161 let output = ParamsBase::default(features.dim_output());
162 Self::new(input, hidden, output)
163 }
164 pub fn ones(features: ModelFeatures) -> Self
167 where
168 A: Clone + One,
169 S: DataOwned,
170 {
171 let input = ParamsBase::ones(features.dim_input());
172 let hidden = (0..features.layers())
173 .map(|_| ParamsBase::ones(features.dim_hidden()))
174 .collect::<Vec<_>>();
175 let output = ParamsBase::ones(features.dim_output());
176 Self::new(input, hidden, output)
177 }
178 pub fn zeros(features: ModelFeatures) -> Self
181 where
182 A: Clone + Zero,
183 S: DataOwned,
184 {
185 let input = ParamsBase::zeros(features.dim_input());
186 let hidden = (0..features.layers())
187 .map(|_| ParamsBase::zeros(features.dim_hidden()))
188 .collect::<Vec<_>>();
189 let output = ParamsBase::zeros(features.dim_output());
190 Self::new(input, hidden, output)
191 }
192
193 #[cfg(feature = "rand")]
194 pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
195 where
196 G: Fn((usize, usize)) -> Ds,
197 Ds: Clone + cnc::init::rand_distr::Distribution<A>,
198 S: DataOwned,
199 {
200 use cnc::init::Initialize;
201 let input = ParamsBase::rand(features.dim_input(), distr(features.dim_input()));
202 let hidden = (0..features.layers())
203 .map(|_| ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())))
204 .collect::<Vec<_>>();
205
206 let output = ParamsBase::rand(features.dim_output(), distr(features.dim_output()));
207
208 Self::new(input, hidden, output)
209 }
210 #[cfg(feature = "rand")]
212 pub fn glorot_normal(features: ModelFeatures) -> Self
213 where
214 S: DataOwned,
215 A: num_traits::Float + num_traits::FromPrimitive,
216 cnc::init::rand_distr::StandardNormal: cnc::init::rand_distr::Distribution<A>,
217 {
218 Self::init_rand(features, |(rows, cols)| {
219 cnc::init::XavierNormal::new(rows, cols)
220 })
221 }
222 #[cfg(feature = "rand")]
224 pub fn glorot_uniform(features: ModelFeatures) -> Self
225 where
226 S: ndarray::DataOwned,
227 A: Clone
228 + num_traits::Float
229 + num_traits::FromPrimitive
230 + cnc::init::rand_distr::uniform::SampleUniform,
231 <S::Elem as cnc::init::rand_distr::uniform::SampleUniform>::Sampler: Clone,
232 cnc::init::rand_distr::Uniform<S::Elem>: cnc::init::rand_distr::Distribution<S::Elem>,
233 {
234 Self::init_rand(features, |(rows, cols)| {
235 cnc::init::XavierUniform::new(rows, cols).expect("failed to create distribution")
236 })
237 }
238}
239
240impl<A, S, D> Clone for ModelParamsBase<S, D>
241where
242 A: Clone,
243 D: Dimension,
244 S: ndarray::RawDataClone<Elem = A>,
245{
246 fn clone(&self) -> Self {
247 Self {
248 input: self.input.clone(),
249 hidden: self.hidden.to_vec(),
250 output: self.output.clone(),
251 }
252 }
253}
254
255impl<A, S, D> core::fmt::Debug for ModelParamsBase<S, D>
256where
257 A: core::fmt::Debug,
258 D: Dimension,
259 S: ndarray::Data<Elem = A>,
260{
261 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
262 f.debug_struct("ModelParams")
263 .field("input", &self.input)
264 .field("hidden", &self.hidden)
265 .field("output", &self.output)
266 .finish()
267 }
268}
269
270impl<A, S, D> core::fmt::Display for ModelParamsBase<S, D>
271where
272 A: core::fmt::Debug,
273 D: Dimension,
274 S: ndarray::Data<Elem = A>,
275{
276 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
277 write!(
278 f,
279 "{{ input: {:?}, hidden: {:?}, output: {:?} }}",
280 self.input, self.hidden, self.output
281 )
282 }
283}
284
285impl<A, S, D> core::ops::Index<usize> for ModelParamsBase<S, D>
286where
287 A: Clone,
288 D: Dimension,
289 S: ndarray::Data<Elem = A>,
290{
291 type Output = ParamsBase<S, D>;
292
293 fn index(&self, index: usize) -> &Self::Output {
294 if index == 0 {
295 &self.input
296 } else if index == self.hidden.len() + 1 {
297 &self.output
298 } else {
299 &self.hidden[index - 1]
300 }
301 }
302}
303
304impl<A, S, D> core::ops::IndexMut<usize> for ModelParamsBase<S, D>
305where
306 A: Clone,
307 D: Dimension,
308 S: ndarray::Data<Elem = A>,
309{
310 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
311 if index == 0 {
312 &mut self.input
313 } else if index == self.hidden.len() + 1 {
314 &mut self.output
315 } else {
316 &mut self.hidden[index - 1]
317 }
318 }
319}