concision_core/models/impls/
impl_model_params.rs1use crate::models::ModelParamsBase;
6
7use crate::{DeepModelRepr, RawHidden};
8use concision_params::ParamsBase;
9use ndarray::{ArrayBase, Data, Dimension, RawData, RawDataClone};
10
11impl<S, D, H, A> ModelParamsBase<S, D, H, A>
16where
17 D: Dimension,
18 S: RawData<Elem = A>,
19 H: RawHidden<S, D>,
20{
21 pub const fn new(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
23 Self {
24 input,
25 hidden,
26 output,
27 }
28 }
29 pub const fn input(&self) -> &ParamsBase<S, D> {
31 &self.input
32 }
33 pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
35 &mut self.input
36 }
37 pub const fn hidden(&self) -> &H {
39 &self.hidden
40 }
41 pub const fn hidden_mut(&mut self) -> &mut H {
43 &mut self.hidden
44 }
45 pub const fn output(&self) -> &ParamsBase<S, D> {
47 &self.output
48 }
49 pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
51 &mut self.output
52 }
53 #[inline]
55 pub fn set_input(&mut self, input: ParamsBase<S, D>) {
56 *self.input_mut() = input
57 }
58 #[inline]
60 pub fn set_hidden(&mut self, hidden: H) {
61 *self.hidden_mut() = hidden
62 }
63 #[inline]
65 pub fn set_output(&mut self, output: ParamsBase<S, D>) {
66 *self.output_mut() = output
67 }
68 #[inline]
70 pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
71 Self { input, ..self }
72 }
73 #[inline]
76 pub fn with_hidden(self, hidden: H) -> Self {
77 Self { hidden, ..self }
78 }
79 #[inline]
81 pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
82 Self { output, ..self }
83 }
84 #[inline]
86 pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>]
87 where
88 H: DeepModelRepr<S, D>,
89 {
90 self.hidden().as_slice()
91 }
92 pub const fn input_bias(&self) -> &ArrayBase<S, D::Smaller, A> {
94 self.input().bias()
95 }
96 pub const fn input_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
98 self.input_mut().bias_mut()
99 }
100 pub const fn input_weights(&self) -> &ArrayBase<S, D, A> {
102 self.input().weights()
103 }
104 pub const fn input_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
106 self.input_mut().weights_mut()
107 }
108 pub const fn output_bias(&self) -> &ArrayBase<S, D::Smaller, A> {
110 self.output().bias()
111 }
112 pub const fn output_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
114 self.output_mut().bias_mut()
115 }
116 pub const fn output_weights(&self) -> &ArrayBase<S, D, A> {
118 self.output().weights()
119 }
120 pub const fn output_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
122 self.output_mut().weights_mut()
123 }
124 pub fn layers(&self) -> usize {
126 2 + self.count_hidden()
127 }
128 pub fn count_hidden(&self) -> usize {
130 self.hidden().count()
131 }
132 #[inline]
135 pub fn is_shallow(&self) -> bool {
136 self.count_hidden() <= 1
137 }
138 #[inline]
141 pub fn is_deep(&self) -> bool {
142 self.count_hidden() > 1
143 }
144}
145
146impl<A, S, D, H> Clone for ModelParamsBase<S, D, H, A>
147where
148 D: Dimension,
149 H: RawHidden<S, D> + Clone,
150 S: RawDataClone<Elem = A>,
151 A: Clone,
152{
153 fn clone(&self) -> Self {
154 Self {
155 input: self.input().clone(),
156 hidden: self.hidden().clone(),
157 output: self.output().clone(),
158 }
159 }
160}
161
162impl<A, S, D, H> core::fmt::Debug for ModelParamsBase<S, D, H, A>
163where
164 D: Dimension,
165 H: RawHidden<S, D> + core::fmt::Debug,
166 S: Data<Elem = A>,
167 A: core::fmt::Debug,
168{
169 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170 f.debug_struct("ModelParams")
171 .field("input", self.input())
172 .field("hidden", self.hidden())
173 .field("output", self.output())
174 .finish()
175 }
176}
177
178impl<A, S, D, H> core::fmt::Display for ModelParamsBase<S, D, H, A>
179where
180 D: Dimension,
181 H: RawHidden<S, D> + core::fmt::Debug,
182 S: Data<Elem = A>,
183 A: core::fmt::Display,
184{
185 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
186 write!(
187 f,
188 "{{ input: {i}, hidden: {h:?}, output: {o} }}",
189 i = self.input(),
190 h = self.hidden(),
191 o = self.output()
192 )
193 }
194}
195
196impl<A, S, D, H> core::ops::Index<usize> for ModelParamsBase<S, D, H, A>
197where
198 D: Dimension,
199 S: Data<Elem = A>,
200 H: RawHidden<S, D> + core::ops::Index<usize, Output = ParamsBase<S, D>>,
201 A: Clone,
202{
203 type Output = ParamsBase<S, D>;
204
205 fn index(&self, index: usize) -> &Self::Output {
206 match index % self.layers() {
207 0 => self.input(),
208 i if i == self.count_hidden() + 1 => self.output(),
209 _ => &self.hidden()[index - 1],
210 }
211 }
212}
213
214impl<A, S, D, H> core::ops::IndexMut<usize> for ModelParamsBase<S, D, H, A>
215where
216 D: Dimension,
217 S: Data<Elem = A>,
218 H: RawHidden<S, D> + core::ops::IndexMut<usize, Output = ParamsBase<S, D>>,
219 A: Clone,
220{
221 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
222 match index % self.layers() {
223 0 => self.input_mut(),
224 i if i == self.count_hidden() + 1 => self.output_mut(),
225 _ => &mut self.hidden_mut()[index - 1],
226 }
227 }
228}