concision_core/models/impls/
impl_params_deep.rs

1/*
2    appellation: impl_model_params <module>
3    authors: @FL03
4*/
5use crate::{DeepParamsBase, ModelParamsBase};
6
7use crate::ModelFeatures;
8use crate::models::traits::DeepModelRepr;
9use concision_params::ParamsBase;
10use concision_traits::Forward;
11use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
12use num_traits::{One, Zero};
13
14impl<S, D, H, A> ModelParamsBase<S, D, H, A>
15where
16    D: Dimension,
17    S: RawData<Elem = A>,
18    H: DeepModelRepr<S, D>,
19{
20    /// create a new instance of the [`ModelParamsBase`] instance
21    pub const fn deep(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
22        Self {
23            input,
24            hidden,
25            output,
26        }
27    }
28}
29
30impl<A, S, D> DeepParamsBase<S, D, A>
31where
32    D: Dimension,
33    S: RawData<Elem = A>,
34{
35    /// returns the total number parameters within the model, including the input and output layers
36    #[inline]
37    pub fn size(&self) -> usize {
38        let mut size = self.input().count_weights();
39        for layer in self.hidden() {
40            size += layer.count_weights();
41        }
42        size + self.output().count_weights()
43    }
44
45    /// set the layer at the specified index in the hidden layers of the model
46    ///
47    /// ## Panics
48    ///
49    /// Panics if the index is out of bounds or if the dimension of the provided layer is
50    /// inconsistent with the others in the stack.
51    #[inline]
52    pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
53        if layer.dim() != self.dim_hidden() {
54            panic!(
55                "the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
56                layer.dim(),
57                self.dim_hidden()
58            );
59        }
60        self.hidden_mut()[idx] = layer;
61        self
62    }
63    /// returns the dimension of the input layer
64    #[inline]
65    pub fn dim_input(&self) -> <D as Dimension>::Pattern {
66        self.input().dim()
67    }
68    /// returns the dimension of the hidden layers
69    #[inline]
70    pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
71        // verify that all hidden layers have the same dimension
72        assert!(
73            self.hidden()
74                .iter()
75                .all(|p| p.dim() == self.hidden()[0].dim())
76        );
77        // use the first hidden layer's dimension as the representative
78        // dimension for all hidden layers
79        self.hidden()[0].dim()
80    }
81    /// returns the dimension of the output layer
82    #[inline]
83    pub fn dim_output(&self) -> <D as Dimension>::Pattern {
84        self.output().dim()
85    }
86    /// returns the hidden layer associated with the given index
87    #[inline]
88    pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
89    where
90        I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
91    {
92        self.hidden().get(idx)
93    }
94    /// returns a mutable reference to the hidden layer associated with the given index
95    #[inline]
96    pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
97    where
98        I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
99    {
100        self.hidden_mut().get_mut(idx)
101    }
102    /// sequentially forwards the input through the model without any activations or other
103    /// complexities in-between. not overly usefuly, but it is here for completeness
104    #[inline]
105    pub fn forward<X, Y>(&self, input: &X) -> Y
106    where
107        A: Clone,
108        S: Data,
109        ParamsBase<S, D>: Forward<X, Output = Y> + Forward<Y, Output = Y>,
110    {
111        let mut output = self.input().forward(input);
112        self.hidden().iter().for_each(|layer| {
113            output = layer.forward(&output);
114        });
115        self.output().forward(&output)
116    }
117}
118
119impl<A, S> DeepParamsBase<S, Ix2, A>
120where
121    S: RawData<Elem = A>,
122{
123    #[allow(clippy::should_implement_trait)]
124    /// create a new instance of the model;
125    /// all parameters are initialized to their defaults (i.e., zero)
126    pub fn default(features: ModelFeatures) -> Self
127    where
128        A: Clone + Default,
129        S: DataOwned,
130    {
131        let input = ParamsBase::default(features.dim_input());
132        let hidden = (0..features.layers())
133            .map(|_| ParamsBase::default(features.dim_hidden()))
134            .collect::<Vec<_>>();
135        let output = ParamsBase::default(features.dim_output());
136        Self::new(input, hidden, output)
137    }
138    /// create a new instance of the model;
139    /// all parameters are initialized to zero
140    pub fn ones(features: ModelFeatures) -> Self
141    where
142        A: Clone + One,
143        S: DataOwned,
144    {
145        let input = ParamsBase::ones(features.dim_input());
146        let hidden = (0..features.layers())
147            .map(|_| ParamsBase::ones(features.dim_hidden()))
148            .collect::<Vec<_>>();
149        let output = ParamsBase::ones(features.dim_output());
150        Self::new(input, hidden, output)
151    }
152    /// create a new instance of the model;
153    /// all parameters are initialized to zero
154    pub fn zeros(features: ModelFeatures) -> Self
155    where
156        A: Clone + Zero,
157        S: DataOwned,
158    {
159        let input = ParamsBase::zeros(features.dim_input());
160        let hidden = (0..features.layers())
161            .map(|_| ParamsBase::zeros(features.dim_hidden()))
162            .collect::<Vec<_>>();
163        let output = ParamsBase::zeros(features.dim_output());
164        Self::new(input, hidden, output)
165    }
166}