concision_core/models/layout/
impl_model_features.rs1use super::ModelFeatures;
6use crate::models::{ModelFormat, RawModelLayout, RawModelLayoutMut};
7
8fn _verify_input_and_hidden_shape<D>(input: D, hidden: D) -> bool
14where
15 D: ndarray::Dimension,
16{
17 let lhs = input.as_array_view();
18 let rhs = hidden.as_array_view();
19 input.ndim() != hidden.ndim()
21 && lhs[input.ndim() - 1] != rhs[0]
22 && rhs.iter().all(|&v| v == rhs[0])
23}
24
25impl ModelFeatures {
26 pub fn from_shape_and_size(shape: &[usize], size: usize) -> Self {
27 let input = shape[0];
28 let output = *shape.last().unwrap();
29 let hidden = if shape.len() > 2 {
30 shape[1]
31 } else {
32 (size - (input * output)) / (input + output)
33 };
34 let layers = if shape.len() > 2 { shape.len() - 2 } else { 1 };
35 Self::new(input, hidden, output, layers)
36 }
37 pub const fn new(input: usize, hidden: usize, output: usize, layers: usize) -> Self {
41 let inner = ModelFormat::new(hidden, layers);
42 Self {
43 input,
44 output,
45 inner,
46 }
47 }
48 pub const fn deep(input: usize, hidden: usize, output: usize, layers: usize) -> Self {
51 Self {
52 input,
53 output,
54 inner: ModelFormat::Deep { hidden, layers },
55 }
56 }
57 pub const fn shallow(input: usize, hidden: usize, output: usize) -> Self {
60 Self {
61 input,
62 output,
63 inner: ModelFormat::Shallow { hidden },
64 }
65 }
66 pub fn from_layout<L>(layout: L) -> Self
67 where
68 L: RawModelLayout,
69 {
70 Self {
71 input: layout.input(),
72 inner: ModelFormat::new(layout.hidden(), layout.depth()),
73 output: layout.output(),
74 }
75 }
76 pub const fn input(&self) -> usize {
78 self.input
79 }
80 pub const fn input_mut(&mut self) -> &mut usize {
82 &mut self.input
83 }
84 pub const fn inner(&self) -> ModelFormat {
86 self.inner
87 }
88 pub const fn inner_mut(&mut self) -> &mut ModelFormat {
90 &mut self.inner
91 }
92 pub const fn hidden(&self) -> usize {
94 self.inner().hidden()
95 }
96 pub const fn hidden_mut(&mut self) -> &mut usize {
98 self.inner_mut().hidden_mut()
99 }
100 pub const fn layers(&self) -> usize {
102 self.inner().layers()
103 }
104 pub const fn layers_mut(&mut self) -> &mut usize {
106 self.inner_mut().layers_mut()
107 }
108 pub const fn output(&self) -> usize {
110 self.output
111 }
112 pub const fn output_mut(&mut self) -> &mut usize {
114 &mut self.output
115 }
116 #[inline]
117 pub fn set_input(&mut self, input: usize) -> &mut Self {
119 self.input = input;
120 self
121 }
122 #[inline]
123 pub fn set_hidden(&mut self, hidden: usize) -> &mut Self {
125 self.inner_mut().set_hidden(hidden);
126 self
127 }
128 #[inline]
129 pub fn set_layers(&mut self, layers: usize) -> &mut Self {
131 self.inner_mut().set_layers(layers);
132 self
133 }
134 #[inline]
135 pub fn set_output(&mut self, output: usize) -> &mut Self {
137 self.output = output;
138 self
139 }
140 pub fn with_input(self, input: usize) -> Self {
142 Self { input, ..self }
143 }
144 pub fn with_hidden(self, hidden: usize) -> Self {
147 Self {
148 inner: self.inner.with_hidden(hidden),
149 ..self
150 }
151 }
152 pub fn with_layers(self, layers: usize) -> Self {
155 Self {
156 inner: self.inner.with_layers(layers),
157 ..self
158 }
159 }
160 pub fn with_output(self, output: usize) -> Self {
163 Self { output, ..self }
164 }
165 pub fn dim_input(&self) -> (usize, usize) {
167 (self.input(), self.hidden())
168 }
169 pub fn dim_hidden(&self) -> (usize, usize) {
171 (self.hidden(), self.hidden())
172 }
173 pub fn dim_output(&self) -> (usize, usize) {
175 (self.hidden(), self.output())
176 }
177 pub fn size(&self) -> usize {
179 self.size_input() + self.size_hidden() + self.size_output()
180 }
181 pub fn size_input(&self) -> usize {
183 self.input() * self.hidden()
184 }
185 pub fn size_hidden(&self) -> usize {
187 self.hidden() * self.hidden() * self.layers()
188 }
189 pub fn size_output(&self) -> usize {
191 self.hidden() * self.output()
192 }
193}
194
195impl RawModelLayout for ModelFeatures {
196 fn input(&self) -> usize {
197 self.input()
198 }
199
200 fn hidden(&self) -> usize {
201 self.hidden()
202 }
203
204 fn depth(&self) -> usize {
205 self.layers()
206 }
207
208 fn output(&self) -> usize {
209 self.output()
210 }
211}
212
213impl RawModelLayoutMut for ModelFeatures {
214 fn input_mut(&mut self) -> &mut usize {
215 self.input_mut()
216 }
217
218 fn hidden_mut(&mut self) -> &mut usize {
219 self.hidden_mut()
220 }
221
222 fn layers_mut(&mut self) -> &mut usize {
223 self.layers_mut()
224 }
225
226 fn output_mut(&mut self) -> &mut usize {
227 self.output_mut()
228 }
229}
230
231impl Default for ModelFeatures {
232 fn default() -> Self {
233 Self {
234 input: 16,
235 inner: ModelFormat::Deep {
236 hidden: 16,
237 layers: 1,
238 },
239 output: 16,
240 }
241 }
242}
243
244impl core::fmt::Display for ModelFeatures {
245 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246 f.write_str(&format!(
247 "{{ input: {}, hidden: {}, layers: {}, output: {} }}",
248 self.input(),
249 self.hidden(),
250 self.layers(),
251 self.output()
252 ))
253 }
254}