concision_neural/model/params/
impl_model_params.rs1use crate::model::params::ModelParamsBase;
6
7use crate::{DeepNeuralStore, RawHidden};
8use cnc::params::ParamsBase;
9use ndarray::{ArrayBase, Data, Dimension, RawData, RawDataClone};
10
11impl<S, D, H, A> ModelParamsBase<S, D, H>
12where
13 D: Dimension,
14 S: RawData<Elem = A>,
15 H: RawHidden<S, D>,
16{
17 pub const fn new(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
19 Self {
20 input,
21 hidden,
22 output,
23 }
24 }
25 pub const fn input(&self) -> &ParamsBase<S, D> {
27 &self.input
28 }
29 pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
31 &mut self.input
32 }
33 pub const fn hidden(&self) -> &H {
35 &self.hidden
36 }
37 pub const fn hidden_mut(&mut self) -> &mut H {
39 &mut self.hidden
40 }
41 pub const fn output(&self) -> &ParamsBase<S, D> {
43 &self.output
44 }
45 pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
47 &mut self.output
48 }
49 #[inline]
51 pub fn set_input(&mut self, input: ParamsBase<S, D>) -> &mut Self {
52 *self.input_mut() = input;
53 self
54 }
55 #[inline]
57 pub fn set_hidden(&mut self, hidden: H) -> &mut Self {
58 *self.hidden_mut() = hidden;
59 self
60 }
61 #[inline]
63 pub fn set_output(&mut self, output: ParamsBase<S, D>) -> &mut Self {
64 *self.output_mut() = output;
65 self
66 }
67 #[inline]
69 pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
70 Self { input, ..self }
71 }
72 #[inline]
75 pub fn with_hidden(self, hidden: H) -> Self {
76 Self { hidden, ..self }
77 }
78 #[inline]
80 pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
81 Self { output, ..self }
82 }
83 #[inline]
85 pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>]
86 where
87 H: DeepNeuralStore<S, D>,
88 {
89 self.hidden().as_slice()
90 }
91 pub const fn input_bias(&self) -> &ArrayBase<S, D::Smaller> {
93 self.input().bias()
94 }
95 pub const fn input_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
97 self.input_mut().bias_mut()
98 }
99 pub const fn input_weights(&self) -> &ArrayBase<S, D> {
101 self.input().weights()
102 }
103 pub const fn input_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
105 self.input_mut().weights_mut()
106 }
107 pub const fn output_bias(&self) -> &ArrayBase<S, D::Smaller> {
109 self.output().bias()
110 }
111 pub const fn output_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
113 self.output_mut().bias_mut()
114 }
115 pub const fn output_weights(&self) -> &ArrayBase<S, D> {
117 self.output().weights()
118 }
119 pub const fn output_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
121 self.output_mut().weights_mut()
122 }
123 pub fn count_hidden(&self) -> usize {
125 self.hidden().count()
126 }
127 #[inline]
130 pub fn is_shallow(&self) -> bool {
131 self.count_hidden() <= 1
132 }
133 #[inline]
136 pub fn is_deep(&self) -> bool {
137 self.count_hidden() > 1
138 }
139 #[inline]
141 pub fn len(&self) -> usize {
142 self.count_hidden() + 2 }
144}
145
146impl<A, S, D, H> Clone for ModelParamsBase<S, D, H>
147where
148 D: Dimension,
149 H: RawHidden<S, D> + Clone,
150 S: RawDataClone<Elem = A>,
151 A: Clone,
152{
153 fn clone(&self) -> Self {
154 Self {
155 input: self.input().clone(),
156 hidden: self.hidden().clone(),
157 output: self.output().clone(),
158 }
159 }
160}
161
162impl<A, S, D, H> core::fmt::Debug for ModelParamsBase<S, D, H>
163where
164 D: Dimension,
165 H: RawHidden<S, D> + core::fmt::Debug,
166 S: Data<Elem = A>,
167 A: core::fmt::Debug,
168{
169 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170 f.debug_struct("ModelParams")
171 .field("input", self.input())
172 .field("hidden", self.hidden())
173 .field("output", self.output())
174 .finish()
175 }
176}
177
178impl<A, S, D, H> core::fmt::Display for ModelParamsBase<S, D, H>
179where
180 D: Dimension,
181 H: RawHidden<S, D> + core::fmt::Debug,
182 S: Data<Elem = A>,
183 A: core::fmt::Display,
184{
185 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
186 write!(
187 f,
188 "{{ input: {i}, hidden: {h:?}, output: {o} }}",
189 i = self.input(),
190 h = self.hidden(),
191 o = self.output()
192 )
193 }
194}
195
196impl<A, S, D, H> core::ops::Index<usize> for ModelParamsBase<S, D, H>
197where
198 D: Dimension,
199 S: Data<Elem = A>,
200 H: DeepNeuralStore<S, D>,
201 A: Clone,
202{
203 type Output = ParamsBase<S, D>;
204
205 fn index(&self, index: usize) -> &Self::Output {
206 if index == 0 {
207 self.input()
208 } else if index == self.count_hidden() + 1 {
209 self.output()
210 } else {
211 &self.hidden().as_slice()[index - 1]
212 }
213 }
214}
215
216impl<A, S, D, H> core::ops::IndexMut<usize> for ModelParamsBase<S, D, H>
217where
218 D: Dimension,
219 S: Data<Elem = A>,
220 H: DeepNeuralStore<S, D>,
221 A: Clone,
222{
223 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
224 if index == 0 {
225 self.input_mut()
226 } else if index == self.count_hidden() + 1 {
227 self.output_mut()
228 } else {
229 &mut self.hidden_mut().as_mut_slice()[index - 1]
230 }
231 }
232}