concision_neural/model/layout/
features.rs

1/*
2    Appellation: layout <module>
3    Contrib: @FL03
4*/
5use super::ModelLayout;
6
7/// The [`ModelFeatures`] provides a common way of defining the layout of a model. This is
8/// used to define the number of input features, the number of hidden layers, the number of
9/// hidden features, and the number of output features.
10#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
11#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
12pub struct ModelFeatures {
13    /// the number of input features
14    pub(crate) input: usize,
15    /// the dimension of hidden layers
16    pub(crate) hidden: usize,
17    /// the number of hidden layers
18    pub(crate) layers: usize,
19    /// the number of output features
20    pub(crate) output: usize,
21}
22
23impl ModelFeatures {
24    pub fn new(input: usize, hidden: usize, layers: usize, output: usize) -> Self {
25        Self {
26            input,
27            hidden,
28            layers,
29            output,
30        }
31    }
32    /// returns a copy of the input features for the model
33    pub const fn input(&self) -> usize {
34        self.input
35    }
36    /// returns a mutable reference to the input features for the model
37    #[inline]
38    pub const fn input_mut(&mut self) -> &mut usize {
39        &mut self.input
40    }
41    /// returns a copy of the hidden features for the model
42    pub const fn hidden(&self) -> usize {
43        self.hidden
44    }
45    /// returns a mutable reference to the hidden features for the model
46    #[inline]
47    pub const fn hidden_mut(&mut self) -> &mut usize {
48        &mut self.hidden
49    }
50    /// returns a copy of the number of hidden layers for the model
51    pub const fn layers(&self) -> usize {
52        self.layers
53    }
54    /// returns a mutable reference to the number of hidden layers for the model
55    #[inline]
56    pub const fn layers_mut(&mut self) -> &mut usize {
57        &mut self.layers
58    }
59    /// returns a copy of the output features for the model
60    pub const fn output(&self) -> usize {
61        self.output
62    }
63    /// returns a mutable reference to the output features for the model
64    #[inline]
65    pub const fn output_mut(&mut self) -> &mut usize {
66        &mut self.output
67    }
68    #[inline]
69    /// sets the input features for the model
70    pub fn set_input(&mut self, input: usize) -> &mut Self {
71        self.input = input;
72        self
73    }
74    #[inline]
75    /// sets the hidden features for the model
76    pub fn set_hidden(&mut self, hidden: usize) -> &mut Self {
77        self.hidden = hidden;
78        self
79    }
80    #[inline]
81    /// sets the number of hidden layers for the model
82    pub fn set_layers(&mut self, layers: usize) -> &mut Self {
83        self.layers = layers;
84        self
85    }
86    #[inline]
87    /// sets the output features for the model
88    pub fn set_output(&mut self, output: usize) -> &mut Self {
89        self.output = output;
90        self
91    }
92    /// consumes the current instance and returns a new instance with the given input
93    pub fn with_input(self, input: usize) -> Self {
94        Self { input, ..self }
95    }
96
97    /// consumes the current instance and returns a new instance with the given hidden
98    /// features
99    pub fn with_hidden(self, hidden: usize) -> Self {
100        Self { hidden, ..self }
101    }
102    /// consumes the current instance and returns a new instance with the given number of
103    /// hidden layers
104    pub fn with_layers(self, layers: usize) -> Self {
105        Self { layers, ..self }
106    }
107    /// consumes the current instance and returns a new instance with the given output
108    /// features
109    pub fn with_output(self, output: usize) -> Self {
110        Self { output, ..self }
111    }
112    /// the dimension of the input layer; (input, hidden)
113    pub fn dim_input(&self) -> (usize, usize) {
114        (self.input(), self.hidden())
115    }
116    /// the dimension of the hidden layers; (hidden, hidden)
117    pub fn dim_hidden(&self) -> (usize, usize) {
118        (self.hidden(), self.hidden())
119    }
120    /// the dimension of the output layer; (hidden, output)
121    pub fn dim_output(&self) -> (usize, usize) {
122        (self.hidden(), self.output())
123    }
124    /// the total number of parameters in the model
125    pub fn size(&self) -> usize {
126        self.size_input() + self.size_hidden() + self.size_output()
127    }
128    /// the total number of input parameters in the model
129    pub fn size_input(&self) -> usize {
130        self.input() * self.hidden()
131    }
132    /// the total number of hidden parameters in the model
133    pub fn size_hidden(&self) -> usize {
134        self.hidden() * self.hidden() * self.layers()
135    }
136    /// the total number of output parameters in the model
137    pub fn size_output(&self) -> usize {
138        self.hidden() * self.output()
139    }
140}
141
142impl ModelLayout for ModelFeatures {
143    fn input(&self) -> usize {
144        self.input()
145    }
146    fn input_mut(&mut self) -> &mut usize {
147        self.input_mut()
148    }
149    fn hidden(&self) -> usize {
150        self.hidden()
151    }
152    fn hidden_mut(&mut self) -> &mut usize {
153        self.hidden_mut()
154    }
155    fn layers(&self) -> usize {
156        self.layers()
157    }
158    fn layers_mut(&mut self) -> &mut usize {
159        self.layers_mut()
160    }
161    fn output(&self) -> usize {
162        self.output()
163    }
164    fn output_mut(&mut self) -> &mut usize {
165        self.output_mut()
166    }
167}
168
169impl Default for ModelFeatures {
170    fn default() -> Self {
171        Self {
172            input: 16,
173            hidden: 64,
174            layers: 3,
175            output: 16,
176        }
177    }
178}
179
180impl core::fmt::Display for ModelFeatures {
181    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
182        write!(
183            f,
184            "{{ input: {}, hidden: {}, layers: {}, output: {} }}",
185            self.input, self.hidden, self.layers, self.output
186        )
187    }
188}