concision_neural/layers/traits/
layers.rs

1/*
2    appellation: layers <module>
3    authors: @FL03
4*/
5use super::{Activator, ActivatorGradient};
6
7use cnc::params::ParamsBase;
8use cnc::tensor::NdTensor;
9use cnc::{Backward, Forward};
10use ndarray::{Data, Dimension, RawData};
11/// A layer within a neural-network containing a set of parameters and an activation function.
12/// Here, this manifests as a wrapper around the parameters of the layer with a generic
13/// activation function and corresponding traits to denote desired behaviors.
14///
15pub trait Layer<S, D>
16where
17    D: Dimension,
18    S: RawData<Elem = Self::Scalar>,
19{
20    type Scalar;
21
22    /// returns an immutable reference to the parameters of the layer
23    fn params(&self) -> &ParamsBase<S, D>;
24    /// returns a mutable reference to the parameters of the layer
25    fn params_mut(&mut self) -> &mut ParamsBase<S, D>;
26    /// update the layer parameters
27    fn set_params(&mut self, params: ParamsBase<S, D>) {
28        *self.params_mut() = params;
29    }
30    /// backward propagate error through the layer
31    fn backward<X, Y, Z, Delta>(
32        &mut self,
33        input: X,
34        error: Y,
35        gamma: Self::Scalar,
36    ) -> cnc::Result<Z>
37    where
38        S: Data,
39        Self: ActivatorGradient<X, Input = Y, Delta = Delta>,
40        Self::Scalar: Clone,
41        ParamsBase<S, D>: Backward<X, Delta, Elem = Self::Scalar, Output = Z>,
42    {
43        // compute the delta using the activation function
44        let delta = self.activate_gradient(error);
45        // apply the backward function of the inherited layer
46        self.params_mut().backward(&input, &delta, gamma)
47    }
48    /// complete a forward pass through the layer
49    fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
50    where
51        Y: NdTensor<S::Elem, D, Repr = S>,
52        ParamsBase<S, D>: Forward<X, Output = Y>,
53        Self: Activator<Y, Output = Y>,
54    {
55        self.params().forward_then(input, |y| self.activate(y))
56    }
57}