concision_neural/model/params/
impl_params_shallow.rs

1/*
2    Appellation: controller <module>
3    Contrib: @FL03
4*/
5use crate::model::params::{ModelParamsBase, ShallowParamsBase};
6
7use crate::model::ModelFeatures;
8use crate::traits::ShallowNeuralStore;
9use cnc::{ParamsBase, ReLU, Sigmoid};
10use ndarray::{
11    Array1, ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, RemoveAxis, ScalarOperand,
12};
13use num_traits::Float;
14
15impl<S, D, H, A> ModelParamsBase<S, D, H>
16where
17    D: Dimension,
18    S: RawData<Elem = A>,
19    H: ShallowNeuralStore<S, D>,
20{
21    /// create a new instance of the [`ModelParamsBase`] instance
22    pub const fn shallow(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
23        Self {
24            input,
25            hidden,
26            output,
27        }
28    }
29}
30
31impl<S, D, A> ShallowParamsBase<S, D>
32where
33    S: RawData<Elem = A>,
34    D: Dimension,
35{
36    #[allow(clippy::should_implement_trait)]
37    /// initialize a new instance of the [`ShallowParamsBase`] with the given input, hidden,
38    /// and output dimensions using the default values for the parameters
39    pub fn default(input: D, hidden: D, output: D) -> Self
40    where
41        A: Clone + Default,
42        S: DataOwned,
43        D: RemoveAxis,
44    {
45        Self {
46            hidden: ParamsBase::default(hidden),
47            input: ParamsBase::default(input),
48            output: ParamsBase::default(output),
49        }
50    }
51    /// returns the total number parameters within the model, including the input and output layers
52    #[inline]
53    pub fn size(&self) -> usize {
54        let mut size = self.input().count_weight();
55        size += self.hidden().count_weight();
56        size + self.output().count_weight()
57    }
58    /// returns an immutable reference to the hidden weights
59    pub const fn hidden_weights(&self) -> &ArrayBase<S, D> {
60        self.hidden().weights()
61    }
62    /// returns an mutable reference to the hidden weights
63    pub const fn hidden_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
64        self.hidden_mut().weights_mut()
65    }
66}
67
68impl<S, A> ShallowParamsBase<S, Ix2>
69where
70    S: RawData<Elem = A>,
71{
72    pub fn from_features(features: ModelFeatures) -> Self
73    where
74        A: Clone + Default,
75        S: DataOwned,
76    {
77        Self {
78            hidden: ParamsBase::default(features.dim_hidden()),
79            input: ParamsBase::default(features.dim_input()),
80            output: ParamsBase::default(features.dim_output()),
81        }
82    }
83    /// forward input through the controller network
84    pub fn forward(&self, input: &Array1<A>) -> cnc::Result<Array1<A>>
85    where
86        A: Float + ScalarOperand,
87        S: Data,
88    {
89        // forward the input through the input layer; activate using relu
90        let mut output = self.input().forward(input)?.relu();
91        // forward the input through the hidden layer(s); activate using relu
92        output = self.hidden().forward(&output)?.relu();
93        // forward the input through the output layer; activate using sigmoid
94        output = self.output().forward(&output)?.sigmoid();
95
96        Ok(output)
97    }
98}
99
100impl<A, S> Default for ShallowParamsBase<S, Ix2>
101where
102    S: DataOwned<Elem = A>,
103    A: Clone + Default,
104{
105    fn default() -> Self {
106        Self::from_features(ModelFeatures::default())
107    }
108}