concision_neural/model/layout/
features.rs

1/*
2    Appellation: layout <module>
3    Contrib: @FL03
4*/
5use super::ModelLayout;
6
7/// verify if the input and hidden dimensions are compatible by checking:
8///
9/// 1. they have the same dimensionality
10/// 2. if the the number of dimensions is greater than one, the hidden layer should be square
11/// 3. the finaly dimension of the input is equal to one hidden dimension
12pub fn _verify_input_and_hidden_shape<D>(input: D, hidden: D) -> bool
13where
14    D: ndarray::Dimension,
15{
16    let mut valid = true;
17    // // check that the hidden dimension is square
18    // if hidden.ndim() > 1 && hidden.shape().iter().any(|&d| d != hidden.shape()[0]) {
19    //     valid = false;
20    // }
21    // check that the input and hidden dimensions are compatible
22    if input.ndim() != hidden.ndim() {
23        valid = false;
24    }
25    valid
26}
27
28#[derive(
29    Clone,
30    Copy,
31    Debug,
32    Eq,
33    Hash,
34    Ord,
35    PartialEq,
36    PartialOrd,
37    scsys::VariantConstructors,
38    strum::EnumCount,
39    strum::EnumIs,
40)]
41#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
42pub enum ModelFormat {
43    Shallow { hidden: usize },
44    Deep { hidden: usize, layers: usize },
45}
46
47/// The [`ModelFeatures`] provides a common way of defining the layout of a model. This is
48/// used to define the number of input features, the number of hidden layers, the number of
49/// hidden features, and the number of output features.
50#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
51#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
52pub struct ModelFeatures {
53    /// the number of input features
54    pub(crate) input: usize,
55    /// the features of the "inner" layers
56    pub(crate) inner: ModelFormat,
57    /// the number of output features
58    pub(crate) output: usize,
59}
60
61impl ModelFormat {
62    /// returns a copy of the number of hidden features
63    pub const fn hidden(&self) -> usize {
64        match self {
65            ModelFormat::Shallow { hidden } => *hidden,
66            ModelFormat::Deep { hidden, .. } => *hidden,
67        }
68    }
69    /// returns a mutable reference to the hidden features for the model
70    pub const fn hidden_mut(&mut self) -> &mut usize {
71        match self {
72            ModelFormat::Shallow { hidden } => hidden,
73            ModelFormat::Deep { hidden, .. } => hidden,
74        }
75    }
76    /// returns a copy of the number of layers for the model; if the variant is
77    /// [`Shallow`](ModelFormat::Shallow), it returns 1
78    /// returns `n` if the variant is [`Deep`](ModelFormat::Deep)
79    pub const fn layers(&self) -> usize {
80        match self {
81            ModelFormat::Shallow { .. } => 1,
82            ModelFormat::Deep { layers, .. } => *layers,
83        }
84    }
85    /// returns a mutable reference to the number of layers for the model; this will panic on
86    /// [`Shallow`](ModelFormat::Shallow) variants
87    pub const fn layers_mut(&mut self) -> &mut usize {
88        match self {
89            ModelFormat::Shallow { .. } => panic!("Cannot mutate layers of a shallow model"),
90            ModelFormat::Deep { layers, .. } => layers,
91        }
92    }
93    /// update the number of hidden features for the model
94    pub fn set_hidden(&mut self, value: usize) -> &mut Self {
95        match self {
96            ModelFormat::Shallow { hidden } => {
97                *hidden = value;
98            }
99            ModelFormat::Deep { hidden, .. } => {
100                *hidden = value;
101            }
102        }
103        self
104    }
105    /// update the number of layers for the model;
106    ///
107    /// **note:** this method will automatically convert the model to a [`Deep`](ModelFormat::Deep)
108    /// variant if it is currently a [`Shallow`](ModelFormat::Shallow) variant and the number
109    /// of layers becomes greater than 1
110    pub fn set_layers(&mut self, value: usize) -> &mut Self {
111        match self {
112            ModelFormat::Shallow { hidden } => {
113                if value > 1 {
114                    *self = ModelFormat::Deep {
115                        hidden: *hidden,
116                        layers: value,
117                    };
118                }
119                // if the value is 1, we do not change the model format
120            }
121            ModelFormat::Deep { layers, .. } => {
122                *layers = value;
123            }
124        }
125        self
126    }
127    /// consumes the current instance and returns a new instance with the given hidden
128    /// features
129    pub fn with_hidden(self, hidden: usize) -> Self {
130        match self {
131            ModelFormat::Shallow { .. } => ModelFormat::Shallow { hidden },
132            ModelFormat::Deep { layers, .. } => ModelFormat::Deep { hidden, layers },
133        }
134    }
135    /// consumes the current instance and returns a new instance with the given number of
136    /// hidden layers
137    ///
138    /// **note:** this method will automatically convert the model to a [`Deep`](ModelFormat::Deep)
139    /// variant if it is currently a [`Shallow`](ModelFormat::Shallow) variant and the number
140    /// of layers becomes greater than 1
141    pub fn with_layers(self, layers: usize) -> Self {
142        match self {
143            ModelFormat::Shallow { hidden } => {
144                if layers > 1 {
145                    ModelFormat::Deep { hidden, layers }
146                } else {
147                    ModelFormat::Shallow { hidden }
148                }
149            }
150            ModelFormat::Deep { hidden, .. } => ModelFormat::Deep { hidden, layers },
151        }
152    }
153}
154
155impl ModelFeatures {
156    pub const fn deep(input: usize, hidden: usize, layers: usize, output: usize) -> Self {
157        Self {
158            input,
159            output,
160            inner: ModelFormat::Deep { hidden, layers },
161        }
162    }
163    /// returns a copy of the input features for the model
164    pub const fn input(&self) -> usize {
165        self.input
166    }
167    /// returns a mutable reference to the input features for the model
168    pub const fn input_mut(&mut self) -> &mut usize {
169        &mut self.input
170    }
171    /// returns a copy of the inner format for the model
172    pub const fn inner(&self) -> ModelFormat {
173        self.inner
174    }
175    /// returns a mutable reference to the inner format for the model
176    pub const fn inner_mut(&mut self) -> &mut ModelFormat {
177        &mut self.inner
178    }
179    /// returns a copy of the hidden features for the model
180    pub const fn hidden(&self) -> usize {
181        self.inner().hidden()
182    }
183    /// returns a mutable reference to the hidden features for the model
184    pub const fn hidden_mut(&mut self) -> &mut usize {
185        self.inner_mut().hidden_mut()
186    }
187    /// returns a copy of the number of hidden layers for the model
188    pub const fn layers(&self) -> usize {
189        self.inner().layers()
190    }
191    /// returns a mutable reference to the number of hidden layers for the model
192    pub const fn layers_mut(&mut self) -> &mut usize {
193        self.inner_mut().layers_mut()
194    }
195    /// returns a copy of the output features for the model
196    pub const fn output(&self) -> usize {
197        self.output
198    }
199    /// returns a mutable reference to the output features for the model
200    pub const fn output_mut(&mut self) -> &mut usize {
201        &mut self.output
202    }
203    #[inline]
204    /// sets the input features for the model
205    pub fn set_input(&mut self, input: usize) -> &mut Self {
206        self.input = input;
207        self
208    }
209    #[inline]
210    /// sets the hidden features for the model
211    pub fn set_hidden(&mut self, hidden: usize) -> &mut Self {
212        self.inner_mut().set_hidden(hidden);
213        self
214    }
215    #[inline]
216    /// sets the number of hidden layers for the model
217    pub fn set_layers(&mut self, layers: usize) -> &mut Self {
218        self.inner_mut().set_layers(layers);
219        self
220    }
221    #[inline]
222    /// sets the output features for the model
223    pub fn set_output(&mut self, output: usize) -> &mut Self {
224        self.output = output;
225        self
226    }
227    /// consumes the current instance and returns a new instance with the given input
228    pub fn with_input(self, input: usize) -> Self {
229        Self { input, ..self }
230    }
231    /// consumes the current instance and returns a new instance with the given hidden
232    /// features
233    pub fn with_hidden(self, hidden: usize) -> Self {
234        Self {
235            inner: self.inner.with_hidden(hidden),
236            ..self
237        }
238    }
239    /// consumes the current instance and returns a new instance with the given number of
240    /// hidden layers
241    pub fn with_layers(self, layers: usize) -> Self {
242        Self {
243            inner: self.inner.with_layers(layers),
244            ..self
245        }
246    }
247    /// consumes the current instance and returns a new instance with the given output
248    /// features
249    pub fn with_output(self, output: usize) -> Self {
250        Self { output, ..self }
251    }
252    /// the dimension of the input layer; (input, hidden)
253    pub fn dim_input(&self) -> (usize, usize) {
254        (self.input(), self.hidden())
255    }
256    /// the dimension of the hidden layers; (hidden, hidden)
257    pub fn dim_hidden(&self) -> (usize, usize) {
258        (self.hidden(), self.hidden())
259    }
260    /// the dimension of the output layer; (hidden, output)
261    pub fn dim_output(&self) -> (usize, usize) {
262        (self.hidden(), self.output())
263    }
264    /// the total number of parameters in the model
265    pub fn size(&self) -> usize {
266        self.size_input() + self.size_hidden() + self.size_output()
267    }
268    /// the total number of input parameters in the model
269    pub fn size_input(&self) -> usize {
270        self.input() * self.hidden()
271    }
272    /// the total number of hidden parameters in the model
273    pub fn size_hidden(&self) -> usize {
274        self.hidden() * self.hidden() * self.layers()
275    }
276    /// the total number of output parameters in the model
277    pub fn size_output(&self) -> usize {
278        self.hidden() * self.output()
279    }
280}
281
282impl ModelLayout for ModelFeatures {
283    fn input(&self) -> usize {
284        self.input()
285    }
286    fn input_mut(&mut self) -> &mut usize {
287        self.input_mut()
288    }
289    fn hidden(&self) -> usize {
290        self.hidden()
291    }
292    fn hidden_mut(&mut self) -> &mut usize {
293        self.hidden_mut()
294    }
295    fn layers(&self) -> usize {
296        self.layers()
297    }
298    fn layers_mut(&mut self) -> &mut usize {
299        self.layers_mut()
300    }
301    fn output(&self) -> usize {
302        self.output()
303    }
304    fn output_mut(&mut self) -> &mut usize {
305        self.output_mut()
306    }
307}
308
309impl Default for ModelFormat {
310    fn default() -> Self {
311        Self::Deep {
312            hidden: 16,
313            layers: 1,
314        }
315    }
316}
317
318impl Default for ModelFeatures {
319    fn default() -> Self {
320        Self {
321            input: 16,
322            inner: ModelFormat::Deep {
323                hidden: 16,
324                layers: 1,
325            },
326            output: 16,
327        }
328    }
329}
330
331impl core::fmt::Display for ModelFormat {
332    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
333        write!(
334            f,
335            "{{ hidden: {}, layers: {} }}",
336            self.hidden(),
337            self.layers()
338        )
339    }
340}
341
342impl core::fmt::Display for ModelFeatures {
343    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
344        write!(
345            f,
346            "{{ input: {i}, hidden: {h}, output: {o}, layers: {l} }}",
347            i = self.input(),
348            h = self.hidden(),
349            l = self.layers(),
350            o = self.output()
351        )
352    }
353}