concision_core/models/impls/
impl_params_deep.rs1use crate::{DeepParamsBase, ModelParamsBase};
6
7use crate::ModelFeatures;
8use crate::models::traits::DeepModelRepr;
9use concision_params::ParamsBase;
10use concision_traits::Forward;
11use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
12use num_traits::{One, Zero};
13
14impl<S, D, H, A> ModelParamsBase<S, D, H, A>
15where
16 D: Dimension,
17 S: RawData<Elem = A>,
18 H: DeepModelRepr<S, D>,
19{
20 pub const fn deep(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
22 Self {
23 input,
24 hidden,
25 output,
26 }
27 }
28}
29
30impl<A, S, D> DeepParamsBase<S, D, A>
31where
32 D: Dimension,
33 S: RawData<Elem = A>,
34{
35 #[inline]
37 pub fn size(&self) -> usize {
38 let mut size = self.input().count_weights();
39 for layer in self.hidden() {
40 size += layer.count_weights();
41 }
42 size + self.output().count_weights()
43 }
44
45 #[inline]
52 pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
53 if layer.dim() != self.dim_hidden() {
54 panic!(
55 "the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
56 layer.dim(),
57 self.dim_hidden()
58 );
59 }
60 self.hidden_mut()[idx] = layer;
61 self
62 }
63 #[inline]
65 pub fn dim_input(&self) -> <D as Dimension>::Pattern {
66 self.input().dim()
67 }
68 #[inline]
70 pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
71 assert!(
73 self.hidden()
74 .iter()
75 .all(|p| p.dim() == self.hidden()[0].dim())
76 );
77 self.hidden()[0].dim()
80 }
81 #[inline]
83 pub fn dim_output(&self) -> <D as Dimension>::Pattern {
84 self.output().dim()
85 }
86 #[inline]
88 pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
89 where
90 I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
91 {
92 self.hidden().get(idx)
93 }
94 #[inline]
96 pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
97 where
98 I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
99 {
100 self.hidden_mut().get_mut(idx)
101 }
102 #[inline]
105 pub fn forward<X, Y>(&self, input: &X) -> Y
106 where
107 A: Clone,
108 S: Data,
109 ParamsBase<S, D>: Forward<X, Output = Y> + Forward<Y, Output = Y>,
110 {
111 let mut output = self.input().forward(input);
112 self.hidden().iter().for_each(|layer| {
113 output = layer.forward(&output);
114 });
115 self.output().forward(&output)
116 }
117}
118
119impl<A, S> DeepParamsBase<S, Ix2, A>
120where
121 S: RawData<Elem = A>,
122{
123 #[allow(clippy::should_implement_trait)]
124 pub fn default(features: ModelFeatures) -> Self
127 where
128 A: Clone + Default,
129 S: DataOwned,
130 {
131 let input = ParamsBase::default(features.dim_input());
132 let hidden = (0..features.layers())
133 .map(|_| ParamsBase::default(features.dim_hidden()))
134 .collect::<Vec<_>>();
135 let output = ParamsBase::default(features.dim_output());
136 Self::new(input, hidden, output)
137 }
138 pub fn ones(features: ModelFeatures) -> Self
141 where
142 A: Clone + One,
143 S: DataOwned,
144 {
145 let input = ParamsBase::ones(features.dim_input());
146 let hidden = (0..features.layers())
147 .map(|_| ParamsBase::ones(features.dim_hidden()))
148 .collect::<Vec<_>>();
149 let output = ParamsBase::ones(features.dim_output());
150 Self::new(input, hidden, output)
151 }
152 pub fn zeros(features: ModelFeatures) -> Self
155 where
156 A: Clone + Zero,
157 S: DataOwned,
158 {
159 let input = ParamsBase::zeros(features.dim_input());
160 let hidden = (0..features.layers())
161 .map(|_| ParamsBase::zeros(features.dim_hidden()))
162 .collect::<Vec<_>>();
163 let output = ParamsBase::zeros(features.dim_output());
164 Self::new(input, hidden, output)
165 }
166}