concision_neural/layer/
layer.rs

1/*
2    Appellation: layer <module>
3    Contrib: @FL03
4*/
5#![allow(unused)]
6use super::{Activate, Layer};
7use cnc::{Forward, ParamsBase, activate};
8use ndarray::{Dimension, Ix2, RawData};
9
10pub struct LayerBase<F, S, D = Ix2>
11where
12    D: Dimension,
13    S: RawData,
14{
15    pub(crate) rho: F,
16    pub(crate) params: ParamsBase<S, D>,
17}
18
19pub struct Linear;
20
21impl<U> Activate<U> for Linear {
22    type Output = U;
23    fn activate(&self, x: U) -> Self::Output {
24        x
25    }
26}
27
28impl<S, D> LayerBase<Linear, S, D>
29where
30    D: Dimension,
31    S: RawData<Elem = f32>,
32{
33    pub fn linear(params: ParamsBase<S, D>) -> Self {
34        Self {
35            rho: Linear,
36            params,
37        }
38    }
39}
40
41impl<F, S, A, D> LayerBase<F, S, D>
42where
43    D: Dimension,
44    S: RawData<Elem = A>,
45{
46    pub fn new(rho: F, params: ParamsBase<S, D>) -> Self {
47        Self { rho, params }
48    }
49    /// returns an immutable reference to the layer's parameters
50    pub fn params(&self) -> &ParamsBase<S, D> {
51        &self.params
52    }
53    /// returns a mutable reference to the layer's parameters
54    pub fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
55        &mut self.params
56    }
57    /// returns an immutable reference to the activation function of the layer
58    pub fn rho(&self) -> &F {
59        &self.rho
60    }
61    /// returns a mutable reference to the activation function of the layer
62    pub fn rho_mut(&mut self) -> &mut F {
63        &mut self.rho
64    }
65    /// consumes the current instance and returns another with the given activation function.
66    /// This is useful during the creation of the model, when the activation function is not known yet.
67    pub fn with_rho<G>(self, rho: G) -> LayerBase<G, S, D>
68    where
69        G: Activate<S::Elem>,
70        F: Activate<S::Elem>,
71        S: RawData<Elem = A>,
72    {
73        LayerBase {
74            rho,
75            params: self.params,
76        }
77    }
78    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
79    where
80        F: Activate<<ParamsBase<S, D> as Forward<X>>::Output, Output = Y>,
81        ParamsBase<S, D>: Forward<X, Output = Y>,
82        X: Clone,
83        Y: Clone,
84    {
85        Forward::forward(&self.params, input).map(|x| self.rho.activate(x))
86    }
87}
88
89impl<F, S, D> core::ops::Deref for LayerBase<F, S, D>
90where
91    D: Dimension,
92    S: RawData,
93{
94    type Target = ParamsBase<S, D>;
95
96    fn deref(&self) -> &Self::Target {
97        &self.params
98    }
99}
100
101impl<F, S, D> core::ops::DerefMut for LayerBase<F, S, D>
102where
103    D: Dimension,
104    S: RawData,
105{
106    fn deref_mut(&mut self) -> &mut Self::Target {
107        &mut self.params
108    }
109}