concision_neural/model/
model_params.rs

1/*
2    Appellation: store <module>
3    Contrib: @FL03
4*/
5use super::layout::ModelFeatures;
6use cnc::params::ParamsBase;
7use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
8use num_traits::{One, Zero};
9
10pub type ModelParams<A = f64, D = Ix2> = ModelParamsBase<ndarray::OwnedRepr<A>, D>;
11
12/// This object is an abstraction over the parameters of a deep neural network model. This is
13/// done to isolate the necessary parameters from the specific logic within a model allowing us
14/// to easily create additional stores for tracking velocities, gradients, and other metrics
15/// we may need.
16///
17/// Additionally, this provides us with a way to introduce common creation routines for
18/// initializing neural networks.
19pub struct ModelParamsBase<S, D = Ix2>
20where
21    D: Dimension,
22    S: RawData,
23{
24    /// the input layer of the model
25    pub input: ParamsBase<S, D>,
26    /// a sequential stack of params for the model's hidden layers
27    pub hidden: Vec<ParamsBase<S, D>>,
28    /// the output layer of the model
29    pub output: ParamsBase<S, D>,
30}
31
32impl<A, S, D> ModelParamsBase<S, D>
33where
34    D: Dimension,
35    S: RawData<Elem = A>,
36{
37    pub fn new(
38        input: ParamsBase<S, D>,
39        hidden: Vec<ParamsBase<S, D>>,
40        output: ParamsBase<S, D>,
41    ) -> Self {
42        Self {
43            input,
44            hidden,
45            output,
46        }
47    }
48    /// returns true if the stack is shallow
49    pub fn is_shallow(&self) -> bool {
50        self.hidden.is_empty() || self.hidden.len() == 1
51    }
52    /// returns an immutable reference to the input layer of the model
53    pub const fn input(&self) -> &ParamsBase<S, D> {
54        &self.input
55    }
56    /// returns a mutable reference to the input layer of the model
57    #[inline]
58    pub fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
59        &mut self.input
60    }
61    /// returns an immutable reference to the hidden layers of the model
62    pub const fn hidden(&self) -> &Vec<ParamsBase<S, D>> {
63        &self.hidden
64    }
65    /// returns an immutable reference to the hidden layers of the model as a slice
66    #[inline]
67    pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>] {
68        self.hidden.as_slice()
69    }
70    /// returns a mutable reference to the hidden layers of the model
71    #[inline]
72    pub fn hidden_mut(&mut self) -> &mut Vec<ParamsBase<S, D>> {
73        &mut self.hidden
74    }
75    /// returns an immutable reference to the output layer of the model
76    pub const fn output(&self) -> &ParamsBase<S, D> {
77        &self.output
78    }
79    /// returns a mutable reference to the output layer of the model
80    #[inline]
81    pub fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
82        &mut self.output
83    }
84    /// set the input layer of the model
85    pub fn set_input(&mut self, input: ParamsBase<S, D>) {
86        *self.input_mut() = input;
87    }
88    /// set the hidden layers of the model
89    pub fn set_hidden<I>(&mut self, iter: I)
90    where
91        I: IntoIterator<Item = ParamsBase<S, D>>,
92    {
93        *self.hidden_mut() = Vec::from_iter(iter);
94    }
95    /// set the output layer of the model
96    pub fn set_output(&mut self, output: ParamsBase<S, D>) {
97        self.output = output;
98    }
99    /// consumes the current instance and returns another with the specified input layer
100    pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
101        Self { input, ..self }
102    }
103    /// consumes the current instance and returns another with the specified hidden layers
104    pub fn with_hidden<I>(self, iter: I) -> Self
105    where
106        I: IntoIterator<Item = ParamsBase<S, D>>,
107    {
108        Self {
109            hidden: Vec::from_iter(iter),
110            ..self
111        }
112    }
113    /// consumes the current instance and returns another with the specified output layer
114    pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
115        Self { output, ..self }
116    }
117    /// returns the dimension of the input layer
118    pub fn dim_input(&self) -> <D as Dimension>::Pattern {
119        self.input().dim()
120    }
121    /// returns the dimension of the hidden layers
122    pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
123        assert!(self.hidden.iter().all(|p| p.dim() == self.hidden[0].dim()));
124        self.hidden()[0].dim()
125    }
126    /// returns the dimension of the output layer
127    pub fn dim_output(&self) -> <D as Dimension>::Pattern {
128        self.output.dim()
129    }
130    /// sequentially forwards the input through the model without any activations or other
131    /// complexities in-between. not overly usefuly, but it is here for completeness
132    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
133    where
134        A: Clone,
135        S: Data,
136        ParamsBase<S, D>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
137    {
138        let mut output = self.input().forward(input)?;
139        for layer in self.hidden() {
140            output = layer.forward(&output)?;
141        }
142        self.output().forward(&output)
143    }
144}
145
146impl<A, S> ModelParamsBase<S>
147where
148    S: RawData<Elem = A>,
149{
150    /// create a new instance of the model;
151    /// all parameters are initialized to their defaults (i.e., zero)
152    pub fn default(features: ModelFeatures) -> Self
153    where
154        A: Clone + Default,
155        S: DataOwned,
156    {
157        let input = ParamsBase::default(features.dim_input());
158        let hidden = (0..features.layers())
159            .map(|_| ParamsBase::default(features.dim_hidden()))
160            .collect::<Vec<_>>();
161        let output = ParamsBase::default(features.dim_output());
162        Self::new(input, hidden, output)
163    }
164    /// create a new instance of the model;
165    /// all parameters are initialized to zero
166    pub fn ones(features: ModelFeatures) -> Self
167    where
168        A: Clone + One,
169        S: DataOwned,
170    {
171        let input = ParamsBase::ones(features.dim_input());
172        let hidden = (0..features.layers())
173            .map(|_| ParamsBase::ones(features.dim_hidden()))
174            .collect::<Vec<_>>();
175        let output = ParamsBase::ones(features.dim_output());
176        Self::new(input, hidden, output)
177    }
178    /// create a new instance of the model;
179    /// all parameters are initialized to zero
180    pub fn zeros(features: ModelFeatures) -> Self
181    where
182        A: Clone + Zero,
183        S: DataOwned,
184    {
185        let input = ParamsBase::zeros(features.dim_input());
186        let hidden = (0..features.layers())
187            .map(|_| ParamsBase::zeros(features.dim_hidden()))
188            .collect::<Vec<_>>();
189        let output = ParamsBase::zeros(features.dim_output());
190        Self::new(input, hidden, output)
191    }
192
193    #[cfg(feature = "rand")]
194    pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
195    where
196        G: Fn((usize, usize)) -> Ds,
197        Ds: Clone + cnc::init::rand_distr::Distribution<A>,
198        S: DataOwned,
199    {
200        use cnc::init::Initialize;
201        let input = ParamsBase::rand(features.dim_input(), distr(features.dim_input()));
202        let hidden = (0..features.layers())
203            .map(|_| ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())))
204            .collect::<Vec<_>>();
205
206        let output = ParamsBase::rand(features.dim_output(), distr(features.dim_output()));
207
208        Self::new(input, hidden, output)
209    }
210    /// initialize the model parameters using a glorot normal distribution
211    #[cfg(feature = "rand")]
212    pub fn glorot_normal(features: ModelFeatures) -> Self
213    where
214        S: DataOwned,
215        A: num_traits::Float + num_traits::FromPrimitive,
216        cnc::init::rand_distr::StandardNormal: cnc::init::rand_distr::Distribution<A>,
217    {
218        Self::init_rand(features, |(rows, cols)| {
219            cnc::init::XavierNormal::new(rows, cols)
220        })
221    }
222    /// initialize the model parameters using a glorot uniform distribution
223    #[cfg(feature = "rand")]
224    pub fn glorot_uniform(features: ModelFeatures) -> Self
225    where
226        S: ndarray::DataOwned,
227        A: Clone
228            + num_traits::Float
229            + num_traits::FromPrimitive
230            + cnc::init::rand_distr::uniform::SampleUniform,
231        <S::Elem as cnc::init::rand_distr::uniform::SampleUniform>::Sampler: Clone,
232        cnc::init::rand_distr::Uniform<S::Elem>: cnc::init::rand_distr::Distribution<S::Elem>,
233    {
234        Self::init_rand(features, |(rows, cols)| {
235            cnc::init::XavierUniform::new(rows, cols).expect("failed to create distribution")
236        })
237    }
238}
239
240impl<A, S, D> Clone for ModelParamsBase<S, D>
241where
242    A: Clone,
243    D: Dimension,
244    S: ndarray::RawDataClone<Elem = A>,
245{
246    fn clone(&self) -> Self {
247        Self {
248            input: self.input.clone(),
249            hidden: self.hidden.to_vec(),
250            output: self.output.clone(),
251        }
252    }
253}
254
255impl<A, S, D> core::fmt::Debug for ModelParamsBase<S, D>
256where
257    A: core::fmt::Debug,
258    D: Dimension,
259    S: ndarray::Data<Elem = A>,
260{
261    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
262        f.debug_struct("ModelParams")
263            .field("input", &self.input)
264            .field("hidden", &self.hidden)
265            .field("output", &self.output)
266            .finish()
267    }
268}
269
270impl<A, S, D> core::fmt::Display for ModelParamsBase<S, D>
271where
272    A: core::fmt::Debug,
273    D: Dimension,
274    S: ndarray::Data<Elem = A>,
275{
276    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
277        write!(
278            f,
279            "{{ input: {:?}, hidden: {:?}, output: {:?} }}",
280            self.input, self.hidden, self.output
281        )
282    }
283}
284
285impl<A, S, D> core::ops::Index<usize> for ModelParamsBase<S, D>
286where
287    A: Clone,
288    D: Dimension,
289    S: ndarray::Data<Elem = A>,
290{
291    type Output = ParamsBase<S, D>;
292
293    fn index(&self, index: usize) -> &Self::Output {
294        if index == 0 {
295            &self.input
296        } else if index == self.hidden.len() + 1 {
297            &self.output
298        } else {
299            &self.hidden[index - 1]
300        }
301    }
302}
303
304impl<A, S, D> core::ops::IndexMut<usize> for ModelParamsBase<S, D>
305where
306    A: Clone,
307    D: Dimension,
308    S: ndarray::Data<Elem = A>,
309{
310    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
311        if index == 0 {
312            &mut self.input
313        } else if index == self.hidden.len() + 1 {
314            &mut self.output
315        } else {
316            &mut self.hidden[index - 1]
317        }
318    }
319}