concision_neural/model/params/
impl_params_deep.rs1use crate::model::{DeepParamsBase, ModelParamsBase};
6
7use crate::model::ModelFeatures;
8use crate::traits::DeepNeuralStore;
9use cnc::params::ParamsBase;
10use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
11use num_traits::{One, Zero};
12
13impl<S, D, H, A> ModelParamsBase<S, D, H>
14where
15 D: Dimension,
16 S: RawData<Elem = A>,
17 H: DeepNeuralStore<S, D>,
18{
19 pub const fn deep(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
21 Self {
22 input,
23 hidden,
24 output,
25 }
26 }
27}
28
29impl<A, S, D> DeepParamsBase<S, D>
30where
31 D: Dimension,
32 S: RawData<Elem = A>,
33{
34 #[inline]
36 pub fn size(&self) -> usize {
37 let mut size = self.input().count_weight();
38 for layer in self.hidden() {
39 size += layer.count_weight();
40 }
41 size + self.output().count_weight()
42 }
43
44 #[inline]
51 pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
52 if layer.dim() != self.dim_hidden() {
53 panic!(
54 "the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
55 layer.dim(),
56 self.dim_hidden()
57 );
58 }
59 self.hidden_mut()[idx] = layer;
60 self
61 }
62 #[inline]
64 pub fn dim_input(&self) -> <D as Dimension>::Pattern {
65 self.input().dim()
66 }
67 #[inline]
69 pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
70 assert!(
72 self.hidden()
73 .iter()
74 .all(|p| p.dim() == self.hidden()[0].dim())
75 );
76 self.hidden()[0].dim()
79 }
80 #[inline]
82 pub fn dim_output(&self) -> <D as Dimension>::Pattern {
83 self.output().dim()
84 }
85 #[inline]
87 pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
88 where
89 I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
90 {
91 self.hidden().get(idx)
92 }
93 #[inline]
95 pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
96 where
97 I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
98 {
99 self.hidden_mut().get_mut(idx)
100 }
101 #[inline]
104 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
105 where
106 A: Clone,
107 S: Data,
108 ParamsBase<S, D>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
109 {
110 let mut output = self.input().forward(input)?;
112 for layer in self.hidden() {
114 output = layer.forward(&output)?;
115 }
116 self.output().forward(&output)
118 }
119}
120
121impl<A, S> DeepParamsBase<S, Ix2>
122where
123 S: RawData<Elem = A>,
124{
125 pub fn default(features: ModelFeatures) -> Self
128 where
129 A: Clone + Default,
130 S: DataOwned,
131 {
132 let input = ParamsBase::default(features.dim_input());
133 let hidden = (0..features.layers())
134 .map(|_| ParamsBase::default(features.dim_hidden()))
135 .collect::<Vec<_>>();
136 let output = ParamsBase::default(features.dim_output());
137 Self::new(input, hidden, output)
138 }
139 pub fn ones(features: ModelFeatures) -> Self
142 where
143 A: Clone + One,
144 S: DataOwned,
145 {
146 let input = ParamsBase::ones(features.dim_input());
147 let hidden = (0..features.layers())
148 .map(|_| ParamsBase::ones(features.dim_hidden()))
149 .collect::<Vec<_>>();
150 let output = ParamsBase::ones(features.dim_output());
151 Self::new(input, hidden, output)
152 }
153 pub fn zeros(features: ModelFeatures) -> Self
156 where
157 A: Clone + Zero,
158 S: DataOwned,
159 {
160 let input = ParamsBase::zeros(features.dim_input());
161 let hidden = (0..features.layers())
162 .map(|_| ParamsBase::zeros(features.dim_hidden()))
163 .collect::<Vec<_>>();
164 let output = ParamsBase::zeros(features.dim_output());
165 Self::new(input, hidden, output)
166 }
167}