concision_neural/model/
model_params.rs

1/*
2    Appellation: store <module>
3    Contrib: @FL03
4*/
5use cnc::params::ParamsBase;
6use ndarray::{Data, Dimension, Ix2, RawData};
7
8pub type ModelParams<A = f64, D = Ix2> = ModelParamsBase<ndarray::OwnedRepr<A>, D>;
9
10/// This object is an abstraction over the parameters of a deep neural network model. This is
11/// done to isolate the necessary parameters from the specific logic within a model allowing us
12/// to easily create additional stores for tracking velocities, gradients, and other metrics
13/// we may need.
14///
15/// Additionally, this provides us with a way to introduce common creation routines for
16/// initializing neural networks.
17pub struct ModelParamsBase<S, D = Ix2>
18where
19    D: Dimension,
20    S: RawData,
21{
22    /// the input layer of the model
23    pub(crate) input: ParamsBase<S, D>,
24    /// a sequential stack of params for the model's hidden layers
25    pub(crate) hidden: Vec<ParamsBase<S, D>>,
26    /// the output layer of the model
27    pub(crate) output: ParamsBase<S, D>,
28}
29
30impl<A, S, D> ModelParamsBase<S, D>
31where
32    D: Dimension,
33    S: RawData<Elem = A>,
34{
35    /// returns a new instance of the [`ModelParamsBase`] with the specified input, hidden, and
36    /// output layers.
37    pub const fn new(
38        input: ParamsBase<S, D>,
39        hidden: Vec<ParamsBase<S, D>>,
40        output: ParamsBase<S, D>,
41    ) -> Self {
42        Self {
43            input,
44            hidden,
45            output,
46        }
47    }
48    /// returns an immutable reference to the input layer of the model
49    pub const fn input(&self) -> &ParamsBase<S, D> {
50        &self.input
51    }
52    /// returns a mutable reference to the input layer of the model
53    pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
54        &mut self.input
55    }
56    /// returns an immutable reference to the hidden layers of the model
57    pub const fn hidden(&self) -> &Vec<ParamsBase<S, D>> {
58        &self.hidden
59    }
60    /// returns an immutable reference to the hidden layers of the model as a slice
61    #[inline]
62    pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>] {
63        self.hidden.as_slice()
64    }
65    /// returns a mutable reference to the hidden layers of the model
66    pub const fn hidden_mut(&mut self) -> &mut Vec<ParamsBase<S, D>> {
67        &mut self.hidden
68    }
69    /// returns an immutable reference to the output layer of the model
70    pub const fn output(&self) -> &ParamsBase<S, D> {
71        &self.output
72    }
73    /// returns a mutable reference to the output layer of the model
74    pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
75        &mut self.output
76    }
77    /// set the input layer of the model
78    #[inline]
79    pub fn set_input(&mut self, input: ParamsBase<S, D>) -> &mut Self {
80        *self.input_mut() = input;
81        self
82    }
83    /// set the hidden layers of the model
84    #[inline]
85    pub fn set_hidden(&mut self, hidden: Vec<ParamsBase<S, D>>) -> &mut Self {
86        *self.hidden_mut() = hidden;
87        self
88    }
89    /// set the layer at the specified index in the hidden layers of the model
90    ///
91    /// ## Panics
92    ///
93    /// Panics if the index is out of bounds or if the dimension of the provided layer is
94    /// inconsistent with the others in the stack.
95    #[inline]
96    pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
97        if layer.dim() != self.dim_hidden() {
98            panic!(
99                "the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
100                layer.dim(),
101                self.dim_hidden()
102            );
103        }
104        self.hidden_mut()[idx] = layer;
105        self
106    }
107    /// set the output layer of the model
108    #[inline]
109    pub fn set_output(&mut self, output: ParamsBase<S, D>) -> &mut Self {
110        *self.output_mut() = output;
111        self
112    }
113    /// consumes the current instance and returns another with the specified input layer
114    #[inline]
115    pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
116        Self { input, ..self }
117    }
118    /// consumes the current instance and returns another with the specified hidden layers
119    #[inline]
120    pub fn with_hidden<I>(self, iter: I) -> Self
121    where
122        I: IntoIterator<Item = ParamsBase<S, D>>,
123    {
124        Self {
125            hidden: Vec::from_iter(iter),
126            ..self
127        }
128    }
129    /// consumes the current instance and returns another with the specified output layer
130    #[inline]
131    pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
132        Self { output, ..self }
133    }
134    /// returns the dimension of the input layer
135    #[inline]
136    pub fn dim_input(&self) -> <D as Dimension>::Pattern {
137        self.input().dim()
138    }
139    /// returns the dimension of the hidden layers
140    #[inline]
141    pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
142        // verify that all hidden layers have the same dimension
143        assert!(
144            self.hidden()
145                .iter()
146                .all(|p| p.dim() == self.hidden()[0].dim())
147        );
148        // use the first hidden layer's dimension as the representative
149        // dimension for all hidden layers
150        self.hidden()[0].dim()
151    }
152    /// returns the dimension of the output layer
153    #[inline]
154    pub fn dim_output(&self) -> <D as Dimension>::Pattern {
155        self.output().dim()
156    }
157    /// returns the hidden layer associated with the given index
158    #[inline]
159    pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
160    where
161        I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
162    {
163        self.hidden().get(idx)
164    }
165    /// returns a mutable reference to the hidden layer associated with the given index
166    #[inline]
167    pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
168    where
169        I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
170    {
171        self.hidden_mut().get_mut(idx)
172    }
173    /// sequentially forwards the input through the model without any activations or other
174    /// complexities in-between. not overly usefuly, but it is here for completeness
175    #[inline]
176    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
177    where
178        A: Clone,
179        S: Data,
180        ParamsBase<S, D>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
181    {
182        // forward the input through the input layer
183        let mut output = self.input().forward(input)?;
184        // forward the input through each of the hidden layers
185        for layer in self.hidden() {
186            output = layer.forward(&output)?;
187        }
188        // finally, forward the output through the output layer
189        self.output().forward(&output)
190    }
191    /// returns true if the stack is shallow; a neural network is considered to be _shallow_ if
192    /// it has at most one hidden layer (`n <= 1`).
193    #[inline]
194    pub fn is_shallow(&self) -> bool {
195        self.count_hidden() <= 1 || self.hidden().is_empty()
196    }
197    /// returns true if the model stack of parameters is considered to be _deep_, meaning that
198    /// there the number of hidden layers is greater than one.
199    #[inline]
200    pub fn is_deep(&self) -> bool {
201        self.count_hidden() > 1
202    }
203    /// returns the total number of hidden layers within the model
204    #[inline]
205    pub fn count_hidden(&self) -> usize {
206        self.hidden().len()
207    }
208    /// returns the total number of layers within the model, including the input and output layers
209    #[inline]
210    pub fn len(&self) -> usize {
211        self.count_hidden() + 2 // +2 for input and output layers
212    }
213    /// returns the total number parameters within the model, including the input and output layers
214    #[inline]
215    pub fn size(&self) -> usize {
216        let mut size = self.input().count_weight();
217        for layer in self.hidden() {
218            size += layer.count_weight();
219        }
220        size + self.output().count_weight()
221    }
222}
223
224impl<A, S, D> Clone for ModelParamsBase<S, D>
225where
226    A: Clone,
227    D: Dimension,
228    S: ndarray::RawDataClone<Elem = A>,
229{
230    fn clone(&self) -> Self {
231        Self {
232            input: self.input().clone(),
233            hidden: self.hidden().to_vec(),
234            output: self.output().clone(),
235        }
236    }
237}
238
239impl<A, S, D> core::fmt::Debug for ModelParamsBase<S, D>
240where
241    A: core::fmt::Debug,
242    D: Dimension,
243    S: ndarray::Data<Elem = A>,
244{
245    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246        f.debug_struct("ModelParams")
247            .field("input", &self.input)
248            .field("hidden", &self.hidden)
249            .field("output", &self.output)
250            .finish()
251    }
252}
253
254impl<A, S, D> core::fmt::Display for ModelParamsBase<S, D>
255where
256    A: core::fmt::Debug,
257    D: Dimension,
258    S: ndarray::Data<Elem = A>,
259{
260    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
261        write!(
262            f,
263            "{{ input: {:?}, hidden: {:?}, output: {:?} }}",
264            self.input, self.hidden, self.output
265        )
266    }
267}