concision_core/models/impls/
impl_params_shallow.rs

1/*
2    Appellation: controller <module>
3    Contrib: @FL03
4*/
5use crate::models::{ModelParamsBase, ShallowParamsBase};
6
7use crate::ModelFeatures;
8use crate::activate::{ReLUActivation, SigmoidActivation};
9use crate::models::traits::ShallowModelRepr;
10use concision_params::ParamsBase;
11use ndarray::{
12    Array1, ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, RemoveAxis, ScalarOperand,
13};
14use num_traits::Float;
15
16impl<S, D, H, A> ModelParamsBase<S, D, H, A>
17where
18    D: Dimension,
19    S: RawData<Elem = A>,
20    H: ShallowModelRepr<S, D>,
21{
22    /// create a new instance of the [`ModelParamsBase`] instance
23    pub const fn shallow(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
24        Self {
25            input,
26            hidden,
27            output,
28        }
29    }
30}
31
32impl<S, D, A> ShallowParamsBase<S, D, A>
33where
34    S: RawData<Elem = A>,
35    D: Dimension,
36{
37    #[allow(clippy::should_implement_trait)]
38    /// initialize a new instance of the [`ShallowParamsBase`] with the given input, hidden,
39    /// and output dimensions using the default values for the parameters
40    pub fn default(input: D, hidden: D, output: D) -> Self
41    where
42        A: Clone + Default,
43        S: DataOwned,
44        D: RemoveAxis,
45    {
46        Self {
47            hidden: ParamsBase::default(hidden),
48            input: ParamsBase::default(input),
49            output: ParamsBase::default(output),
50        }
51    }
52    /// returns the total number parameters within the model, including the input and output layers
53    #[inline]
54    pub fn size(&self) -> usize {
55        let mut size = self.input().count_weights();
56        size += self.hidden().count_weights();
57        size + self.output().count_weights()
58    }
59    /// returns an immutable reference to the hidden weights
60    pub const fn hidden_weights(&self) -> &ArrayBase<S, D, A> {
61        self.hidden().weights()
62    }
63    /// returns an mutable reference to the hidden weights
64    pub const fn hidden_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
65        self.hidden_mut().weights_mut()
66    }
67}
68
69impl<S, A> ShallowParamsBase<S, Ix2, A>
70where
71    S: RawData<Elem = A>,
72{
73    pub fn from_features(features: ModelFeatures) -> Self
74    where
75        A: Clone + Default,
76        S: DataOwned,
77    {
78        Self {
79            hidden: ParamsBase::default(features.dim_hidden()),
80            input: ParamsBase::default(features.dim_input()),
81            output: ParamsBase::default(features.dim_output()),
82        }
83    }
84    /// forward input through the controller network
85    pub fn forward(&self, input: &Array1<A>) -> Array1<A>
86    where
87        A: Float + ScalarOperand,
88        S: Data,
89    {
90        use concision_traits::Forward;
91        let mut output = self.input().forward_then(input, |x| x.relu());
92        output = self.hidden().forward_then(&output, |x| x.relu());
93        self.output().forward_then(&output, |x| x.sigmoid())
94    }
95}
96
97impl<A, S> Default for ShallowParamsBase<S, Ix2, A>
98where
99    S: DataOwned<Elem = A>,
100    A: Clone + Default,
101{
102    fn default() -> Self {
103        Self::from_features(ModelFeatures::default())
104    }
105}