concision_neural/layers/traits/
layers.rs

1/*
2    appellation: layers <module>
3    authors: @FL03
4*/
5use super::{Activator, ActivatorGradient};
6
7use cnc::params::ParamsBase;
8use cnc::tensor::NdTensor;
9use cnc::{Backward, Forward};
10use ndarray::{Data, Dimension, RawData};
11
12/// A generic trait defining the composition of a _layer_ within a neural network.
13pub trait Layer<S, D>
14where
15    D: Dimension,
16    S: RawData<Elem = Self::Elem>,
17{
18    /// the type of element used within the layer; typically a floating-point variant like
19    /// [`f32`] or [`f64`].
20    type Elem;
21    /// The type of activator used by the layer; the type must implement [`ActivatorGradient`]
22    type Rho: Activator<Self::Elem>;
23
24    fn rho(&self) -> &Self::Rho;
25    /// returns an immutable reference to the parameters of the layer
26    fn params(&self) -> &ParamsBase<S, D>;
27    /// returns a mutable reference to the parameters of the layer
28    fn params_mut(&mut self) -> &mut ParamsBase<S, D>;
29}
30/// The [`LayerExt`] trait extends the base [`Layer`] trait with additional methods that
31/// are commonly used in neural network layers. It provides methods for setting parameters,
32/// performing backward propagation of errors, and completing a forward pass through the layer.
33pub trait LayerExt<S, D>: Layer<S, D>
34where
35    D: Dimension,
36    S: RawData<Elem = Self::Elem>,
37{
38    /// update the layer parameters
39    fn set_params(&mut self, params: ParamsBase<S, D>) {
40        *self.params_mut() = params;
41    }
42    /// backward propagate error through the layer
43    fn backward<X, Y, Z, Dt>(&mut self, input: X, error: Y, gamma: Self::Elem) -> cnc::Result<Z>
44    where
45        S: Data,
46        Self: ActivatorGradient<X, Input = Y, Output = Z, Delta = Dt>,
47        Self::Elem: Clone,
48        ParamsBase<S, D>: Backward<X, Dt, Elem = Self::Elem, Output = Z>,
49    {
50        // compute the delta using the activation function
51        let delta = self.activate_gradient(error);
52        // apply the backward function of the inherited layer
53        self.params_mut().backward(&input, &delta, gamma)
54    }
55    /// complete a forward pass through the layer
56    fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
57    where
58        Y: NdTensor<S::Elem, D, Repr = S>,
59        ParamsBase<S, D>: Forward<X, Output = Y>,
60        Self: Activator<Y, Output = Y>,
61    {
62        self.params().forward_then(input, |y| self.activate(y))
63    }
64}