concision_neural/model/impls/
impl_model_params.rs

1/*
2    appellation: impl_model_params <module>
3    authors: @FL03
4*/
5use crate::model::{ModelFeatures, ModelParamsBase};
6
7use cnc::params::ParamsBase;
8use ndarray::{DataOwned, Dimension, RawData};
9use num_traits::{One, Zero};
10
11impl<A, S> ModelParamsBase<S>
12where
13    S: RawData<Elem = A>,
14{
15    /// create a new instance of the model;
16    /// all parameters are initialized to their defaults (i.e., zero)
17    pub fn default(features: ModelFeatures) -> Self
18    where
19        A: Clone + Default,
20        S: DataOwned,
21    {
22        let input = ParamsBase::default(features.dim_input());
23        let hidden = (0..features.layers())
24            .map(|_| ParamsBase::default(features.dim_hidden()))
25            .collect::<Vec<_>>();
26        let output = ParamsBase::default(features.dim_output());
27        Self::new(input, hidden, output)
28    }
29    /// create a new instance of the model;
30    /// all parameters are initialized to zero
31    pub fn ones(features: ModelFeatures) -> Self
32    where
33        A: Clone + One,
34        S: DataOwned,
35    {
36        let input = ParamsBase::ones(features.dim_input());
37        let hidden = (0..features.layers())
38            .map(|_| ParamsBase::ones(features.dim_hidden()))
39            .collect::<Vec<_>>();
40        let output = ParamsBase::ones(features.dim_output());
41        Self::new(input, hidden, output)
42    }
43    /// create a new instance of the model;
44    /// all parameters are initialized to zero
45    pub fn zeros(features: ModelFeatures) -> Self
46    where
47        A: Clone + Zero,
48        S: DataOwned,
49    {
50        let input = ParamsBase::zeros(features.dim_input());
51        let hidden = (0..features.layers())
52            .map(|_| ParamsBase::zeros(features.dim_hidden()))
53            .collect::<Vec<_>>();
54        let output = ParamsBase::zeros(features.dim_output());
55        Self::new(input, hidden, output)
56    }
57}
58
59impl<A, S, D> core::ops::Index<usize> for ModelParamsBase<S, D>
60where
61    A: Clone,
62    D: Dimension,
63    S: ndarray::Data<Elem = A>,
64{
65    type Output = ParamsBase<S, D>;
66
67    fn index(&self, index: usize) -> &Self::Output {
68        if index == 0 {
69            &self.input
70        } else if index == self.count_hidden() + 1 {
71            &self.output
72        } else {
73            &self.hidden[index - 1]
74        }
75    }
76}
77
78impl<A, S, D> core::ops::IndexMut<usize> for ModelParamsBase<S, D>
79where
80    A: Clone,
81    D: Dimension,
82    S: ndarray::Data<Elem = A>,
83{
84    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
85        if index == 0 {
86            &mut self.input
87        } else if index == self.count_hidden() + 1 {
88            &mut self.output
89        } else {
90            &mut self.hidden[index - 1]
91        }
92    }
93}