concision_neural/layout/
features.rs1use super::{ModelFormat, ModelLayout};
6
7pub fn _verify_input_and_hidden_shape<D>(input: D, hidden: D) -> bool
13where
14 D: ndarray::Dimension,
15{
16 let mut valid = true;
17 if input.ndim() != hidden.ndim() {
23 valid = false;
24 }
25 valid
26}
27
28#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
32#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
33pub struct ModelFeatures {
34 pub(crate) input: usize,
36 pub(crate) inner: ModelFormat,
38 pub(crate) output: usize,
40}
41
42impl ModelFeatures {
43 pub const fn deep(input: usize, hidden: usize, output: usize, layers: usize) -> Self {
46 Self {
47 input,
48 output,
49 inner: ModelFormat::deep(hidden, layers),
50 }
51 }
52 pub const fn shallow(input: usize, hidden: usize, output: usize) -> Self {
55 Self {
56 input,
57 output,
58 inner: ModelFormat::shallow(hidden),
59 }
60 }
61 pub const fn input(&self) -> usize {
63 self.input
64 }
65 pub const fn input_mut(&mut self) -> &mut usize {
67 &mut self.input
68 }
69 pub const fn inner(&self) -> ModelFormat {
71 self.inner
72 }
73 pub const fn inner_mut(&mut self) -> &mut ModelFormat {
75 &mut self.inner
76 }
77 pub const fn hidden(&self) -> usize {
79 self.inner().hidden()
80 }
81 pub const fn hidden_mut(&mut self) -> &mut usize {
83 self.inner_mut().hidden_mut()
84 }
85 pub const fn layers(&self) -> usize {
87 self.inner().layers()
88 }
89 pub const fn layers_mut(&mut self) -> &mut usize {
91 self.inner_mut().layers_mut()
92 }
93 pub const fn output(&self) -> usize {
95 self.output
96 }
97 pub const fn output_mut(&mut self) -> &mut usize {
99 &mut self.output
100 }
101 #[inline]
102 pub fn set_input(&mut self, input: usize) -> &mut Self {
104 self.input = input;
105 self
106 }
107 #[inline]
108 pub fn set_hidden(&mut self, hidden: usize) -> &mut Self {
110 self.inner_mut().set_hidden(hidden);
111 self
112 }
113 #[inline]
114 pub fn set_layers(&mut self, layers: usize) -> &mut Self {
116 self.inner_mut().set_layers(layers);
117 self
118 }
119 #[inline]
120 pub fn set_output(&mut self, output: usize) -> &mut Self {
122 self.output = output;
123 self
124 }
125 pub fn with_input(self, input: usize) -> Self {
127 Self { input, ..self }
128 }
129 pub fn with_hidden(self, hidden: usize) -> Self {
132 Self {
133 inner: self.inner.with_hidden(hidden),
134 ..self
135 }
136 }
137 pub fn with_layers(self, layers: usize) -> Self {
140 Self {
141 inner: self.inner.with_layers(layers),
142 ..self
143 }
144 }
145 pub fn with_output(self, output: usize) -> Self {
148 Self { output, ..self }
149 }
150 pub fn dim_input(&self) -> (usize, usize) {
152 (self.input(), self.hidden())
153 }
154 pub fn dim_hidden(&self) -> (usize, usize) {
156 (self.hidden(), self.hidden())
157 }
158 pub fn dim_output(&self) -> (usize, usize) {
160 (self.hidden(), self.output())
161 }
162 pub fn size(&self) -> usize {
164 self.size_input() + self.size_hidden() + self.size_output()
165 }
166 pub fn size_input(&self) -> usize {
168 self.input() * self.hidden()
169 }
170 pub fn size_hidden(&self) -> usize {
172 self.hidden() * self.hidden() * self.layers()
173 }
174 pub fn size_output(&self) -> usize {
176 self.hidden() * self.output()
177 }
178}
179
180impl ModelLayout for ModelFeatures {
181 fn input(&self) -> usize {
182 self.input()
183 }
184
185 fn input_mut(&mut self) -> &mut usize {
186 self.input_mut()
187 }
188
189 fn hidden(&self) -> usize {
190 self.hidden()
191 }
192
193 fn hidden_mut(&mut self) -> &mut usize {
194 self.hidden_mut()
195 }
196
197 fn layers(&self) -> usize {
198 self.layers()
199 }
200
201 fn layers_mut(&mut self) -> &mut usize {
202 self.layers_mut()
203 }
204
205 fn output(&self) -> usize {
206 self.output()
207 }
208
209 fn output_mut(&mut self) -> &mut usize {
210 self.output_mut()
211 }
212}
213
214impl Default for ModelFeatures {
215 fn default() -> Self {
216 Self {
217 input: 16,
218 inner: ModelFormat::Deep {
219 hidden: 16,
220 layers: 1,
221 },
222 output: 16,
223 }
224 }
225}
226impl core::fmt::Display for ModelFeatures {
227 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
228 write!(
229 f,
230 "{{ input: {i}, hidden: {h}, output: {o}, layers: {l} }}",
231 i = self.input(),
232 h = self.hidden(),
233 l = self.layers(),
234 o = self.output()
235 )
236 }
237}