concision_neural/layers/
layer.rs

1/*
2    Appellation: layer <module>
3    Contrib: @FL03
4*/
5use super::{Activator, ActivatorGradient, Layer};
6use cnc::{Forward, ParamsBase};
7use ndarray::{Dimension, Ix2, RawData};
8
9pub type LayerDyn<A, S, D> = LayerBase<Box<dyn Activator<A, Output = A> + 'static>, S, D>;
10
11pub struct LayerBase<F, S, D = Ix2>
12where
13    D: Dimension,
14    S: RawData,
15{
16    /// the activation function of the layer
17    pub(crate) rho: F,
18    pub(crate) params: ParamsBase<S, D>,
19}
20
21impl<S, D> LayerBase<super::Linear, S, D>
22where
23    D: Dimension,
24    S: RawData<Elem = f32>,
25{
26    pub fn linear(params: ParamsBase<S, D>) -> Self {
27        Self {
28            rho: super::Linear,
29            params,
30        }
31    }
32}
33
34impl<S, D> LayerBase<super::Sigmoid, S, D>
35where
36    D: Dimension,
37    S: RawData<Elem = f32>,
38{
39    pub fn sigmoid(params: ParamsBase<S, D>) -> Self {
40        Self {
41            rho: super::Sigmoid,
42            params,
43        }
44    }
45}
46
47impl<S, D> LayerBase<super::Tanh, S, D>
48where
49    D: Dimension,
50    S: RawData<Elem = f32>,
51{
52    pub fn tanh(params: ParamsBase<S, D>) -> Self {
53        Self {
54            rho: super::Tanh,
55            params,
56        }
57    }
58}
59
60impl<S, D> LayerBase<super::ReLU, S, D>
61where
62    D: Dimension,
63    S: RawData<Elem = f32>,
64{
65    pub fn relu(params: ParamsBase<S, D>) -> Self {
66        Self {
67            rho: super::ReLU,
68            params,
69        }
70    }
71}
72
73impl<F, S, A, D> LayerBase<F, S, D>
74where
75    D: Dimension,
76    S: RawData<Elem = A>,
77{
78    pub fn new(rho: F, params: ParamsBase<S, D>) -> Self {
79        Self { rho, params }
80    }
81    /// returns an immutable reference to the layer's parameters
82    pub fn params(&self) -> &ParamsBase<S, D> {
83        &self.params
84    }
85    /// returns a mutable reference to the layer's parameters
86    pub fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
87        &mut self.params
88    }
89    /// returns an immutable reference to the activation function of the layer
90    pub fn rho(&self) -> &F {
91        &self.rho
92    }
93    /// returns a mutable reference to the activation function of the layer
94    pub fn rho_mut(&mut self) -> &mut F {
95        &mut self.rho
96    }
97    /// consumes the current instance and returns another with the given activation function.
98    /// This is useful during the creation of the model, when the activation function is not known yet.
99    pub fn with_rho<G>(self, rho: G) -> LayerBase<G, S, D>
100    where
101        G: Activator<S::Elem>,
102        F: Activator<S::Elem>,
103        S: RawData<Elem = A>,
104    {
105        LayerBase {
106            rho,
107            params: self.params,
108        }
109    }
110    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
111    where
112        F: Activator<<ParamsBase<S, D> as Forward<X>>::Output, Output = Y>,
113        ParamsBase<S, D>: Forward<X, Output = Y>,
114        X: Clone,
115        Y: Clone,
116    {
117        Forward::forward(&self.params, input).map(|x| self.rho.activate(x))
118    }
119}
120
121impl<F, S, D> core::ops::Deref for LayerBase<F, S, D>
122where
123    D: Dimension,
124    S: RawData,
125{
126    type Target = ParamsBase<S, D>;
127
128    fn deref(&self) -> &Self::Target {
129        &self.params
130    }
131}
132
133impl<F, S, D> core::ops::DerefMut for LayerBase<F, S, D>
134where
135    D: Dimension,
136    S: RawData,
137{
138    fn deref_mut(&mut self) -> &mut Self::Target {
139        &mut self.params
140    }
141}
142
143impl<U, V, F, S, D> Activator<U> for LayerBase<F, S, D>
144where
145    F: Activator<U, Output = V>,
146    D: Dimension,
147    S: RawData,
148{
149    type Output = V;
150
151    fn activate(&self, x: U) -> Self::Output {
152        self.rho().activate(x)
153    }
154}
155
156impl<U, F, S, D> ActivatorGradient<U> for LayerBase<F, S, D>
157where
158    F: ActivatorGradient<U>,
159    D: Dimension,
160    S: RawData,
161{
162    type Input = F::Input;
163    type Delta = F::Delta;
164
165    fn activate_gradient(&self, inputs: F::Input) -> F::Delta {
166        self.rho().activate_gradient(inputs)
167    }
168}
169
170impl<A, F, S, D> Layer<S, D> for LayerBase<F, S, D>
171where
172    F: ActivatorGradient<A>,
173    D: Dimension,
174    S: RawData<Elem = A>,
175{
176    type Scalar = A;
177
178    fn params(&self) -> &ParamsBase<S, D> {
179        &self.params
180    }
181
182    fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
183        &mut self.params
184    }
185}