1use cnc::params::ParamsBase;
6use ndarray::{Data, Dimension, Ix2, RawData};
7
8pub type ModelParams<A = f64, D = Ix2> = ModelParamsBase<ndarray::OwnedRepr<A>, D>;
9
10pub struct ModelParamsBase<S, D = Ix2>
18where
19 D: Dimension,
20 S: RawData,
21{
22 pub(crate) input: ParamsBase<S, D>,
24 pub(crate) hidden: Vec<ParamsBase<S, D>>,
26 pub(crate) output: ParamsBase<S, D>,
28}
29
30impl<A, S, D> ModelParamsBase<S, D>
31where
32 D: Dimension,
33 S: RawData<Elem = A>,
34{
35 pub const fn new(
38 input: ParamsBase<S, D>,
39 hidden: Vec<ParamsBase<S, D>>,
40 output: ParamsBase<S, D>,
41 ) -> Self {
42 Self {
43 input,
44 hidden,
45 output,
46 }
47 }
48 pub const fn input(&self) -> &ParamsBase<S, D> {
50 &self.input
51 }
52 pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
54 &mut self.input
55 }
56 pub const fn hidden(&self) -> &Vec<ParamsBase<S, D>> {
58 &self.hidden
59 }
60 #[inline]
62 pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>] {
63 self.hidden.as_slice()
64 }
65 pub const fn hidden_mut(&mut self) -> &mut Vec<ParamsBase<S, D>> {
67 &mut self.hidden
68 }
69 pub const fn output(&self) -> &ParamsBase<S, D> {
71 &self.output
72 }
73 pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
75 &mut self.output
76 }
77 #[inline]
79 pub fn set_input(&mut self, input: ParamsBase<S, D>) -> &mut Self {
80 *self.input_mut() = input;
81 self
82 }
83 #[inline]
85 pub fn set_hidden(&mut self, hidden: Vec<ParamsBase<S, D>>) -> &mut Self {
86 *self.hidden_mut() = hidden;
87 self
88 }
89 #[inline]
96 pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
97 if layer.dim() != self.dim_hidden() {
98 panic!(
99 "the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
100 layer.dim(),
101 self.dim_hidden()
102 );
103 }
104 self.hidden_mut()[idx] = layer;
105 self
106 }
107 #[inline]
109 pub fn set_output(&mut self, output: ParamsBase<S, D>) -> &mut Self {
110 *self.output_mut() = output;
111 self
112 }
113 #[inline]
115 pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
116 Self { input, ..self }
117 }
118 #[inline]
120 pub fn with_hidden<I>(self, iter: I) -> Self
121 where
122 I: IntoIterator<Item = ParamsBase<S, D>>,
123 {
124 Self {
125 hidden: Vec::from_iter(iter),
126 ..self
127 }
128 }
129 #[inline]
131 pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
132 Self { output, ..self }
133 }
134 #[inline]
136 pub fn dim_input(&self) -> <D as Dimension>::Pattern {
137 self.input().dim()
138 }
139 #[inline]
141 pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
142 assert!(
144 self.hidden()
145 .iter()
146 .all(|p| p.dim() == self.hidden()[0].dim())
147 );
148 self.hidden()[0].dim()
151 }
152 #[inline]
154 pub fn dim_output(&self) -> <D as Dimension>::Pattern {
155 self.output().dim()
156 }
157 #[inline]
159 pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
160 where
161 I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
162 {
163 self.hidden().get(idx)
164 }
165 #[inline]
167 pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
168 where
169 I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
170 {
171 self.hidden_mut().get_mut(idx)
172 }
173 #[inline]
176 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
177 where
178 A: Clone,
179 S: Data,
180 ParamsBase<S, D>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
181 {
182 let mut output = self.input().forward(input)?;
184 for layer in self.hidden() {
186 output = layer.forward(&output)?;
187 }
188 self.output().forward(&output)
190 }
191 #[inline]
194 pub fn is_shallow(&self) -> bool {
195 self.count_hidden() <= 1 || self.hidden().is_empty()
196 }
197 #[inline]
200 pub fn is_deep(&self) -> bool {
201 self.count_hidden() > 1
202 }
203 #[inline]
205 pub fn count_hidden(&self) -> usize {
206 self.hidden().len()
207 }
208 #[inline]
210 pub fn len(&self) -> usize {
211 self.count_hidden() + 2 }
213 #[inline]
215 pub fn size(&self) -> usize {
216 let mut size = self.input().count_weight();
217 for layer in self.hidden() {
218 size += layer.count_weight();
219 }
220 size + self.output().count_weight()
221 }
222}
223
224impl<A, S, D> Clone for ModelParamsBase<S, D>
225where
226 A: Clone,
227 D: Dimension,
228 S: ndarray::RawDataClone<Elem = A>,
229{
230 fn clone(&self) -> Self {
231 Self {
232 input: self.input().clone(),
233 hidden: self.hidden().to_vec(),
234 output: self.output().clone(),
235 }
236 }
237}
238
239impl<A, S, D> core::fmt::Debug for ModelParamsBase<S, D>
240where
241 A: core::fmt::Debug,
242 D: Dimension,
243 S: ndarray::Data<Elem = A>,
244{
245 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246 f.debug_struct("ModelParams")
247 .field("input", &self.input)
248 .field("hidden", &self.hidden)
249 .field("output", &self.output)
250 .finish()
251 }
252}
253
254impl<A, S, D> core::fmt::Display for ModelParamsBase<S, D>
255where
256 A: core::fmt::Debug,
257 D: Dimension,
258 S: ndarray::Data<Elem = A>,
259{
260 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
261 write!(
262 f,
263 "{{ input: {:?}, hidden: {:?}, output: {:?} }}",
264 self.input, self.hidden, self.output
265 )
266 }
267}