concision_neural/layers/
layer.rs

1/*
2    Appellation: layer <module>
3    Contrib: @FL03
4*/
5//! this module defines the [`LayerBase`] struct, a generic representation of a neural network
6//! layer essentially wrapping a [`ParamsBase`] with some _activation function_, `F`.
7//!
8
9mod impl_layer;
10mod impl_layer_repr;
11
12#[allow(deprecated)]
13mod impl_layer_deprecated;
14
15use super::Activator;
16use cnc::{Forward, ParamsBase};
17use ndarray::{DataOwned, Dimension, Ix2, RawData, RemoveAxis, ShapeBuilder};
18
19/// The [`LayerBase`] struct is a base representation of a neural network layer, essentially
20/// binding an activation function, `F`, to a set of parameters, `ParamsBase<S, D>`.
21pub struct LayerBase<F, S, D = Ix2>
22where
23    D: Dimension,
24    S: RawData,
25{
26    /// the activation function of the layer
27    pub(crate) rho: F,
28    /// the parameters of the layer is an object consisting of both a weight and a bias tensor.
29    pub(crate) params: ParamsBase<S, D>,
30}
31
32impl<F, S, A, D> LayerBase<F, S, D>
33where
34    D: Dimension,
35    S: RawData<Elem = A>,
36{
37    /// create a new [`LayerBase`] from the given activation function and parameters.
38    pub const fn new(rho: F, params: ParamsBase<S, D>) -> Self {
39        Self { rho, params }
40    }
41    /// create a new [`LayerBase`] from the given parameters assuming the logical default for
42    /// the activation of type `F`.
43    pub fn from_params(params: ParamsBase<S, D>) -> Self
44    where
45        F: Default,
46    {
47        Self {
48            rho: F::default(),
49            params,
50        }
51    }
52    /// create a new [`LayerBase`] from the given activation function and shape.
53    pub fn from_rho<Sh>(rho: F, shape: Sh) -> Self
54    where
55        A: Clone + Default,
56        S: DataOwned,
57        D: RemoveAxis,
58        Sh: ShapeBuilder<Dim = D>,
59    {
60        Self {
61            rho,
62            params: ParamsBase::default(shape),
63        }
64    }
65    /// returns an immutable reference to the layer's parameters
66    pub const fn params(&self) -> &ParamsBase<S, D> {
67        &self.params
68    }
69    /// returns a mutable reference to the layer's parameters
70    pub const fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
71        &mut self.params
72    }
73    /// returns an immutable reference to the activation function of the layer
74    pub const fn rho(&self) -> &F {
75        &self.rho
76    }
77    /// returns a mutable reference to the activation function of the layer
78    pub const fn rho_mut(&mut self) -> &mut F {
79        &mut self.rho
80    }
81    /// consumes the current instance and returns another with the given parameters.
82    pub fn with_params<S2, D2>(self, params: ParamsBase<S2, D2>) -> LayerBase<F, S2, D2>
83    where
84        S2: RawData<Elem = S::Elem>,
85        D2: Dimension,
86    {
87        LayerBase {
88            rho: self.rho,
89            params,
90        }
91    }
92    /// consumes the current instance and returns another with the given activation function.
93    /// This is useful during the creation of the model, when the activation function is not known yet.
94    pub fn with_rho<G>(self, rho: G) -> LayerBase<G, S, D>
95    where
96        G: Activator<S::Elem>,
97        F: Activator<S::Elem>,
98        S: RawData<Elem = A>,
99    {
100        LayerBase {
101            rho,
102            params: self.params,
103        }
104    }
105    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
106    where
107        F: Activator<<ParamsBase<S, D> as Forward<X>>::Output, Output = Y>,
108        ParamsBase<S, D>: Forward<X, Output = Y>,
109        X: Clone,
110        Y: Clone,
111    {
112        Forward::forward(&self.params, input).map(|x| self.rho.activate(x))
113    }
114}