concision_neural/model/layout/
features.rs1use super::ModelLayout;
6
7pub fn _verify_input_and_hidden_shape<D>(input: D, hidden: D) -> bool
13where
14 D: ndarray::Dimension,
15{
16 let mut valid = true;
17 if input.ndim() != hidden.ndim() {
23 valid = false;
24 }
25 valid
26}
27
28#[derive(
29 Clone,
30 Copy,
31 Debug,
32 Eq,
33 Hash,
34 Ord,
35 PartialEq,
36 PartialOrd,
37 scsys::VariantConstructors,
38 strum::EnumCount,
39 strum::EnumIs,
40)]
41#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
42pub enum ModelFormat {
43 Shallow { hidden: usize },
44 Deep { hidden: usize, layers: usize },
45}
46
47#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
51#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
52pub struct ModelFeatures {
53 pub(crate) input: usize,
55 pub(crate) inner: ModelFormat,
57 pub(crate) output: usize,
59}
60
61impl ModelFormat {
62 pub const fn hidden(&self) -> usize {
64 match self {
65 ModelFormat::Shallow { hidden } => *hidden,
66 ModelFormat::Deep { hidden, .. } => *hidden,
67 }
68 }
69 pub const fn hidden_mut(&mut self) -> &mut usize {
71 match self {
72 ModelFormat::Shallow { hidden } => hidden,
73 ModelFormat::Deep { hidden, .. } => hidden,
74 }
75 }
76 pub const fn layers(&self) -> usize {
80 match self {
81 ModelFormat::Shallow { .. } => 1,
82 ModelFormat::Deep { layers, .. } => *layers,
83 }
84 }
85 pub const fn layers_mut(&mut self) -> &mut usize {
88 match self {
89 ModelFormat::Shallow { .. } => panic!("Cannot mutate layers of a shallow model"),
90 ModelFormat::Deep { layers, .. } => layers,
91 }
92 }
93 pub fn set_hidden(&mut self, value: usize) -> &mut Self {
95 match self {
96 ModelFormat::Shallow { hidden } => {
97 *hidden = value;
98 }
99 ModelFormat::Deep { hidden, .. } => {
100 *hidden = value;
101 }
102 }
103 self
104 }
105 pub fn set_layers(&mut self, value: usize) -> &mut Self {
111 match self {
112 ModelFormat::Shallow { hidden } => {
113 if value > 1 {
114 *self = ModelFormat::Deep {
115 hidden: *hidden,
116 layers: value,
117 };
118 }
119 }
121 ModelFormat::Deep { layers, .. } => {
122 *layers = value;
123 }
124 }
125 self
126 }
127 pub fn with_hidden(self, hidden: usize) -> Self {
130 match self {
131 ModelFormat::Shallow { .. } => ModelFormat::Shallow { hidden },
132 ModelFormat::Deep { layers, .. } => ModelFormat::Deep { hidden, layers },
133 }
134 }
135 pub fn with_layers(self, layers: usize) -> Self {
142 match self {
143 ModelFormat::Shallow { hidden } => {
144 if layers > 1 {
145 ModelFormat::Deep { hidden, layers }
146 } else {
147 ModelFormat::Shallow { hidden }
148 }
149 }
150 ModelFormat::Deep { hidden, .. } => ModelFormat::Deep { hidden, layers },
151 }
152 }
153}
154
155impl ModelFeatures {
156 pub const fn deep(input: usize, hidden: usize, layers: usize, output: usize) -> Self {
157 Self {
158 input,
159 output,
160 inner: ModelFormat::Deep { hidden, layers },
161 }
162 }
163 pub const fn input(&self) -> usize {
165 self.input
166 }
167 pub const fn input_mut(&mut self) -> &mut usize {
169 &mut self.input
170 }
171 pub const fn inner(&self) -> ModelFormat {
173 self.inner
174 }
175 pub const fn inner_mut(&mut self) -> &mut ModelFormat {
177 &mut self.inner
178 }
179 pub const fn hidden(&self) -> usize {
181 self.inner().hidden()
182 }
183 pub const fn hidden_mut(&mut self) -> &mut usize {
185 self.inner_mut().hidden_mut()
186 }
187 pub const fn layers(&self) -> usize {
189 self.inner().layers()
190 }
191 pub const fn layers_mut(&mut self) -> &mut usize {
193 self.inner_mut().layers_mut()
194 }
195 pub const fn output(&self) -> usize {
197 self.output
198 }
199 pub const fn output_mut(&mut self) -> &mut usize {
201 &mut self.output
202 }
203 #[inline]
204 pub fn set_input(&mut self, input: usize) -> &mut Self {
206 self.input = input;
207 self
208 }
209 #[inline]
210 pub fn set_hidden(&mut self, hidden: usize) -> &mut Self {
212 self.inner_mut().set_hidden(hidden);
213 self
214 }
215 #[inline]
216 pub fn set_layers(&mut self, layers: usize) -> &mut Self {
218 self.inner_mut().set_layers(layers);
219 self
220 }
221 #[inline]
222 pub fn set_output(&mut self, output: usize) -> &mut Self {
224 self.output = output;
225 self
226 }
227 pub fn with_input(self, input: usize) -> Self {
229 Self { input, ..self }
230 }
231 pub fn with_hidden(self, hidden: usize) -> Self {
234 Self {
235 inner: self.inner.with_hidden(hidden),
236 ..self
237 }
238 }
239 pub fn with_layers(self, layers: usize) -> Self {
242 Self {
243 inner: self.inner.with_layers(layers),
244 ..self
245 }
246 }
247 pub fn with_output(self, output: usize) -> Self {
250 Self { output, ..self }
251 }
252 pub fn dim_input(&self) -> (usize, usize) {
254 (self.input(), self.hidden())
255 }
256 pub fn dim_hidden(&self) -> (usize, usize) {
258 (self.hidden(), self.hidden())
259 }
260 pub fn dim_output(&self) -> (usize, usize) {
262 (self.hidden(), self.output())
263 }
264 pub fn size(&self) -> usize {
266 self.size_input() + self.size_hidden() + self.size_output()
267 }
268 pub fn size_input(&self) -> usize {
270 self.input() * self.hidden()
271 }
272 pub fn size_hidden(&self) -> usize {
274 self.hidden() * self.hidden() * self.layers()
275 }
276 pub fn size_output(&self) -> usize {
278 self.hidden() * self.output()
279 }
280}
281
282impl ModelLayout for ModelFeatures {
283 fn input(&self) -> usize {
284 self.input()
285 }
286 fn input_mut(&mut self) -> &mut usize {
287 self.input_mut()
288 }
289 fn hidden(&self) -> usize {
290 self.hidden()
291 }
292 fn hidden_mut(&mut self) -> &mut usize {
293 self.hidden_mut()
294 }
295 fn layers(&self) -> usize {
296 self.layers()
297 }
298 fn layers_mut(&mut self) -> &mut usize {
299 self.layers_mut()
300 }
301 fn output(&self) -> usize {
302 self.output()
303 }
304 fn output_mut(&mut self) -> &mut usize {
305 self.output_mut()
306 }
307}
308
309impl Default for ModelFormat {
310 fn default() -> Self {
311 Self::Deep {
312 hidden: 16,
313 layers: 1,
314 }
315 }
316}
317
318impl Default for ModelFeatures {
319 fn default() -> Self {
320 Self {
321 input: 16,
322 inner: ModelFormat::Deep {
323 hidden: 16,
324 layers: 1,
325 },
326 output: 16,
327 }
328 }
329}
330
331impl core::fmt::Display for ModelFormat {
332 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
333 write!(
334 f,
335 "{{ hidden: {}, layers: {} }}",
336 self.hidden(),
337 self.layers()
338 )
339 }
340}
341
342impl core::fmt::Display for ModelFeatures {
343 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
344 write!(
345 f,
346 "{{ input: {i}, hidden: {h}, output: {o}, layers: {l} }}",
347 i = self.input(),
348 h = self.hidden(),
349 l = self.layers(),
350 o = self.output()
351 )
352 }
353}