concision_neural/params/impls/
impl_params_deep.rs1use crate::{DeepParamsBase, ModelParamsBase};
6
7use crate::ModelFeatures;
8use crate::traits::DeepModelRepr;
9use cnc::params::ParamsBase;
10use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
11use num_traits::{One, Zero};
12
13impl<S, D, H, A> ModelParamsBase<S, D, H>
14where
15    D: Dimension,
16    S: RawData<Elem = A>,
17    H: DeepModelRepr<S, D>,
18{
19    pub const fn deep(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
21        Self {
22            input,
23            hidden,
24            output,
25        }
26    }
27}
28
29impl<A, S, D> DeepParamsBase<S, D>
30where
31    D: Dimension,
32    S: RawData<Elem = A>,
33{
34    #[inline]
36    pub fn size(&self) -> usize {
37        let mut size = self.input().count_weight();
38        for layer in self.hidden() {
39            size += layer.count_weight();
40        }
41        size + self.output().count_weight()
42    }
43
44    #[inline]
51    pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
52        if layer.dim() != self.dim_hidden() {
53            panic!(
54                "the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
55                layer.dim(),
56                self.dim_hidden()
57            );
58        }
59        self.hidden_mut()[idx] = layer;
60        self
61    }
62    #[inline]
64    pub fn dim_input(&self) -> <D as Dimension>::Pattern {
65        self.input().dim()
66    }
67    #[inline]
69    pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
70        assert!(
72            self.hidden()
73                .iter()
74                .all(|p| p.dim() == self.hidden()[0].dim())
75        );
76        self.hidden()[0].dim()
79    }
80    #[inline]
82    pub fn dim_output(&self) -> <D as Dimension>::Pattern {
83        self.output().dim()
84    }
85    #[inline]
87    pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
88    where
89        I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
90    {
91        self.hidden().get(idx)
92    }
93    #[inline]
95    pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
96    where
97        I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
98    {
99        self.hidden_mut().get_mut(idx)
100    }
101    #[inline]
104    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
105    where
106        A: Clone,
107        S: Data,
108        ParamsBase<S, D>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
109    {
110        let mut output = self.input().forward(input)?;
112        for layer in self.hidden() {
114            output = layer.forward(&output)?;
115        }
116        self.output().forward(&output)
118    }
119}
120
121impl<A, S> DeepParamsBase<S, Ix2>
122where
123    S: RawData<Elem = A>,
124{
125    pub fn default(features: ModelFeatures) -> Self
128    where
129        A: Clone + Default,
130        S: DataOwned,
131    {
132        let input = ParamsBase::default(features.dim_input());
133        let hidden = (0..features.layers())
134            .map(|_| ParamsBase::default(features.dim_hidden()))
135            .collect::<Vec<_>>();
136        let output = ParamsBase::default(features.dim_output());
137        Self::new(input, hidden, output)
138    }
139    pub fn ones(features: ModelFeatures) -> Self
142    where
143        A: Clone + One,
144        S: DataOwned,
145    {
146        let input = ParamsBase::ones(features.dim_input());
147        let hidden = (0..features.layers())
148            .map(|_| ParamsBase::ones(features.dim_hidden()))
149            .collect::<Vec<_>>();
150        let output = ParamsBase::ones(features.dim_output());
151        Self::new(input, hidden, output)
152    }
153    pub fn zeros(features: ModelFeatures) -> Self
156    where
157        A: Clone + Zero,
158        S: DataOwned,
159    {
160        let input = ParamsBase::zeros(features.dim_input());
161        let hidden = (0..features.layers())
162            .map(|_| ParamsBase::zeros(features.dim_hidden()))
163            .collect::<Vec<_>>();
164        let output = ParamsBase::zeros(features.dim_output());
165        Self::new(input, hidden, output)
166    }
167}