concision_neural/model/
model_params.rs

1/*
2    Appellation: store <module>
3    Contrib: @FL03
4*/
5use crate::ModelFeatures;
6use cnc::params::ParamsBase;
7use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
8use num_traits::{Float, FromPrimitive, One, Zero};
9
10pub type ModelParams<A = f64, D = Ix2> = ModelParamsBase<ndarray::OwnedRepr<A>, D>;
11
12/// This object is an abstraction over the parameters of a deep neural network model. This is
13/// done to isolate the necessary parameters from the specific logic within a model allowing us
14/// to easily create additional stores for tracking velocities, gradients, and other metrics
15/// we may need.
16///
17/// Additionally, this provides us with a way to introduce common creation routines for
18/// initializing neural networks.
19pub struct ModelParamsBase<S, D = Ix2>
20where
21    D: Dimension,
22    S: RawData,
23{
24    pub(crate) input: ParamsBase<S, D>,
25    pub(crate) hidden: Vec<ParamsBase<S, D>>,
26    pub(crate) output: ParamsBase<S, D>,
27}
28
29impl<A, S, D> ModelParamsBase<S, D>
30where
31    D: Dimension,
32    S: RawData<Elem = A>,
33{
34    pub fn new(
35        input: ParamsBase<S, D>,
36        hidden: Vec<ParamsBase<S, D>>,
37        output: ParamsBase<S, D>,
38    ) -> Self {
39        Self {
40            input,
41            hidden,
42            output,
43        }
44    }
45    /// returns true if the stack is shallow
46    pub fn is_shallow(&self) -> bool {
47        self.hidden.is_empty() || self.hidden.len() == 1
48    }
49    /// returns an immutable reference to the input layer of the model
50    pub const fn input(&self) -> &ParamsBase<S, D> {
51        &self.input
52    }
53    /// returns a mutable reference to the input layer of the model
54    #[inline]
55    pub fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
56        &mut self.input
57    }
58    /// returns an immutable reference to the hidden layers of the model
59    pub const fn hidden(&self) -> &Vec<ParamsBase<S, D>> {
60        &self.hidden
61    }
62    /// returns an immutable reference to the hidden layers of the model as a slice
63    #[inline]
64    pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>] {
65        self.hidden.as_slice()
66    }
67    /// returns a mutable reference to the hidden layers of the model
68    #[inline]
69    pub fn hidden_mut(&mut self) -> &mut Vec<ParamsBase<S, D>> {
70        &mut self.hidden
71    }
72    /// returns an immutable reference to the output layer of the model
73    pub const fn output(&self) -> &ParamsBase<S, D> {
74        &self.output
75    }
76    /// returns a mutable reference to the output layer of the model
77    #[inline]
78    pub fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
79        &mut self.output
80    }
81    /// set the input layer of the model
82    pub fn set_input(&mut self, input: ParamsBase<S, D>) {
83        *self.input_mut() = input;
84    }
85    /// set the hidden layers of the model
86    pub fn set_hidden<I>(&mut self, iter: I)
87    where
88        I: IntoIterator<Item = ParamsBase<S, D>>,
89    {
90        *self.hidden_mut() = Vec::from_iter(iter);
91    }
92    /// set the output layer of the model
93    pub fn set_output(&mut self, output: ParamsBase<S, D>) {
94        self.output = output;
95    }
96    /// consumes the current instance and returns another with the specified input layer
97    pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
98        Self { input, ..self }
99    }
100    /// consumes the current instance and returns another with the specified hidden layers
101    pub fn with_hidden<I>(self, iter: I) -> Self
102    where
103        I: IntoIterator<Item = ParamsBase<S, D>>,
104    {
105        Self {
106            hidden: Vec::from_iter(iter),
107            ..self
108        }
109    }
110    /// consumes the current instance and returns another with the specified output layer
111    pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
112        Self { output, ..self }
113    }
114    /// returns the dimension of the input layer
115    pub fn dim_input(&self) -> <D as Dimension>::Pattern {
116        self.input().dim()
117    }
118    /// returns the dimension of the hidden layers
119    pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
120        assert!(self.hidden.iter().all(|p| p.dim() == self.hidden[0].dim()));
121        self.hidden()[0].dim()
122    }
123    /// returns the dimension of the output layer
124    pub fn dim_output(&self) -> <D as Dimension>::Pattern {
125        self.output.dim()
126    }
127}
128
129impl<A, S> ModelParamsBase<S>
130where
131    S: RawData<Elem = A>,
132{
133    /// create a new instance of the model;
134    /// all parameters are initialized to their defaults (i.e., zero)
135    pub fn default(features: ModelFeatures) -> Self
136    where
137        A: Clone + Default,
138        S: DataOwned,
139    {
140        let input = ParamsBase::default(features.d_input());
141        let hidden = (0..features.layers())
142            .map(|_| ParamsBase::default(features.d_hidden()))
143            .collect::<Vec<_>>();
144        let output = ParamsBase::default(features.d_output());
145        Self::new(input, hidden, output)
146    }
147    /// create a new instance of the model;
148    /// all parameters are initialized to zero
149    pub fn ones(features: ModelFeatures) -> Self
150    where
151        A: Clone + One,
152        S: DataOwned,
153    {
154        let input = ParamsBase::ones(features.d_input());
155        let hidden = (0..features.layers())
156            .map(|_| ParamsBase::ones(features.d_hidden()))
157            .collect::<Vec<_>>();
158        let output = ParamsBase::ones(features.d_output());
159        Self::new(input, hidden, output)
160    }
161    /// create a new instance of the model;
162    /// all parameters are initialized to zero
163    pub fn zeros(features: ModelFeatures) -> Self
164    where
165        A: Clone + Zero,
166        S: DataOwned,
167    {
168        let input = ParamsBase::zeros(features.d_input());
169        let hidden = (0..features.layers())
170            .map(|_| ParamsBase::zeros(features.d_hidden()))
171            .collect::<Vec<_>>();
172        let output = ParamsBase::zeros(features.d_output());
173        Self::new(input, hidden, output)
174    }
175
176    #[cfg(feature = "rand")]
177    pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
178    where
179        G: Fn((usize, usize)) -> Ds,
180        Ds: Clone + cnc::init::rand_distr::Distribution<A>,
181        S: DataOwned,
182    {
183        use cnc::init::Initialize;
184        let input = ParamsBase::rand(features.d_input(), distr(features.d_input()));
185        let hidden = (0..features.layers())
186            .map(|_| ParamsBase::rand(features.d_hidden(), distr(features.d_hidden())))
187            .collect::<Vec<_>>();
188
189        let output = ParamsBase::rand(features.d_output(), distr(features.d_output()));
190
191        Self::new(input, hidden, output)
192    }
193
194    #[cfg(feature = "rand")]
195    pub fn glorot_normal(features: ModelFeatures) -> Self
196    where
197        cnc::init::rand_distr::StandardNormal: cnc::init::rand_distr::Distribution<A>,
198        S: DataOwned,
199        S::Elem: Float + FromPrimitive,
200    {
201        Self::init_rand(features, |(rows, cols)| {
202            cnc::init::XavierNormal::new(rows, cols)
203        })
204    }
205
206    #[cfg(feature = "rand")]
207    pub fn glorot_uniform(features: ModelFeatures) -> Self
208    where
209        S: ndarray::DataOwned,
210        A: Clone
211            + num_traits::Float
212            + num_traits::FromPrimitive
213            + cnc::init::rand_distr::uniform::SampleUniform,
214        <S::Elem as cnc::init::rand_distr::uniform::SampleUniform>::Sampler: Clone,
215        cnc::init::rand_distr::Uniform<S::Elem>: cnc::init::rand_distr::Distribution<S::Elem>,
216    {
217        Self::init_rand(features, |(rows, cols)| {
218            cnc::init::XavierUniform::new(rows, cols).expect("failed to create distribution")
219        })
220    }
221
222    
223
224    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
225    where
226        A: Clone,
227        S: Data,
228        ParamsBase<S, Ix2>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
229    {
230        let mut output = self.input.forward(input)?;
231        for layer in &self.hidden {
232            output = layer.forward(&output)?;
233        }
234        self.output.forward(&output)
235    }
236}
237
238impl<A, S, D> Clone for ModelParamsBase<S, D>
239where
240    A: Clone,
241    D: Dimension,
242    S: ndarray::RawDataClone<Elem = A>,
243{
244    fn clone(&self) -> Self {
245        Self {
246            input: self.input.clone(),
247            hidden: self.hidden.iter().map(|p| p.clone()).collect(),
248            output: self.output.clone(),
249        }
250    }
251}
252
253impl<A, S, D> core::fmt::Debug for ModelParamsBase<S, D>
254where
255    A: core::fmt::Debug,
256    D: Dimension,
257    S: ndarray::Data<Elem = A>,
258{
259    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
260        f.debug_struct("ModelParams")
261            .field("input", &self.input)
262            .field("hidden", &self.hidden)
263            .field("output", &self.output)
264            .finish()
265    }
266}
267
268impl<A, S, D> core::fmt::Display for ModelParamsBase<S, D>
269where
270    A: core::fmt::Debug,
271    D: Dimension,
272    S: ndarray::Data<Elem = A>,
273{
274    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
275        write!(
276            f,
277            "{{ input: {:?}, hidden: {:?}, output: {:?} }}",
278            self.input, self.hidden, self.output
279        )
280    }
281}
282
283impl<A, S, D> core::ops::Index<usize> for ModelParamsBase<S, D>
284where
285    A: Clone,
286    D: Dimension,
287    S: ndarray::Data<Elem = A>,
288{
289    type Output = ParamsBase<S, D>;
290
291    fn index(&self, index: usize) -> &Self::Output {
292        if index == 0 {
293            &self.input
294        } else if index == self.hidden.len() + 1 {
295            &self.output
296        } else {
297            &self.hidden[index - 1]
298        }
299    }
300}
301
302impl<A, S, D> core::ops::IndexMut<usize> for ModelParamsBase<S, D>
303where
304    A: Clone,
305    D: Dimension,
306    S: ndarray::Data<Elem = A>,
307{
308    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
309        if index == 0 {
310            &mut self.input
311        } else if index == self.hidden.len() + 1 {
312            &mut self.output
313        } else {
314            &mut self.hidden[index - 1]
315        }
316    }
317}