concision_neural/layers/
mod.rs

1/*
2    Appellation: layers <module>
3    Contrib: @FL03
4*/
5//! This module implments various layers for a neural network
6#[doc(inline)]
7pub use self::layer::LayerBase;
8
9pub(crate) mod layer;
10
11#[cfg(feature = "attention")]
12pub mod attention;
13
14pub(crate) mod prelude {
15    #[cfg(feature = "attention")]
16    pub use super::attention::prelude::*;
17    pub use super::layer::*;
18}
19
20use cnc::params::ParamsBase;
21use cnc::{Activate, ActivateGradient, Backward, Forward, Tensor};
22
23use ndarray::{Data, Dimension, RawData};
24
25/// A layer within a neural-network containing a set of parameters and an activation function.
26/// Here, this manifests as a wrapper around the parameters of the layer with a generic
27/// activation function and corresponding traits to denote desired behaviors.
28///
29pub trait Layer<S, D>
30where
31    D: Dimension,
32    S: RawData<Elem = Self::Scalar>,
33{
34    type Scalar;
35
36    /// returns an immutable reference to the parameters of the layer
37    fn params(&self) -> &ParamsBase<S, D>;
38    /// returns a mutable reference to the parameters of the layer
39    fn params_mut(&mut self) -> &mut ParamsBase<S, D>;
40    /// update the layer parameters
41    fn set_params(&mut self, params: ParamsBase<S, D>) {
42        *self.params_mut() = params;
43    }
44    /// backward propagate error through the layer
45    fn backward<X, Y, Z, Delta>(
46        &mut self,
47        input: X,
48        error: Y,
49        gamma: Self::Scalar,
50    ) -> cnc::Result<Z>
51    where
52        S: Data,
53        Self: ActivateGradient<Y, Delta = Delta>,
54        Self::Scalar: Clone,
55        ParamsBase<S, D>: Backward<X, Delta, Elem = Self::Scalar, Output = Z>,
56    {
57        // compute the delta using the activation function
58        let delta = self.activate_gradient(error);
59        // apply the backward function of the inherited layer
60        self.params_mut().backward(&input, &delta, gamma)
61    }
62    ///
63    fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
64    where
65        Y: Tensor<S, D, Scalar = Self::Scalar>,
66        ParamsBase<S, D>: Forward<X, Output = Y>,
67        Self: Activate<Y, Output = Y>,
68    {
69        self.params().forward_then(input, |y| self.activate(y))
70    }
71}
72
73#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75pub struct Linear;
76
77impl<U> Activate<U> for Linear {
78    type Output = U;
79
80    fn activate(&self, x: U) -> Self::Output {
81        x
82    }
83}
84
85impl<U> ActivateGradient<U> for Linear
86where
87    U: cnc::LinearActivation,
88{
89    type Input = U;
90    type Delta = U::Output;
91
92    fn activate_gradient(&self, _inputs: U) -> Self::Delta {
93        _inputs.linear_derivative()
94    }
95}
96
97#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99pub struct Sigmoid;
100
101impl<U> Activate<U> for Sigmoid
102where
103    U: cnc::Sigmoid,
104{
105    type Output = U::Output;
106
107    fn activate(&self, x: U) -> Self::Output {
108        cnc::Sigmoid::sigmoid(&x)
109    }
110}
111
112impl<U> ActivateGradient<U> for Sigmoid
113where
114    U: cnc::Sigmoid,
115{
116    type Input = U;
117    type Delta = U::Output;
118
119    fn activate_gradient(&self, x: U) -> Self::Delta {
120        cnc::Sigmoid::sigmoid(&x)
121    }
122}
123
124#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub struct Tanh;
127
128impl<U> Activate<U> for Tanh
129where
130    U: cnc::Tanh,
131{
132    type Output = U::Output;
133
134    fn activate(&self, x: U) -> Self::Output {
135        x.tanh()
136    }
137}
138impl<U> ActivateGradient<U> for Tanh
139where
140    U: cnc::Tanh,
141{
142    type Input = U;
143    type Delta = U::Output;
144
145    fn activate_gradient(&self, inputs: U) -> Self::Delta {
146        inputs.tanh_derivative()
147    }
148}
149
150#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
151#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
152pub struct ReLU;
153
154impl<U> Activate<U> for ReLU
155where
156    U: cnc::ReLU,
157{
158    type Output = U::Output;
159
160    fn activate(&self, x: U) -> Self::Output {
161        x.relu()
162    }
163}
164
165impl<U> ActivateGradient<U> for ReLU
166where
167    U: cnc::ReLU,
168{
169    type Input = U;
170    type Delta = U::Output;
171
172    fn activate_gradient(&self, inputs: U) -> Self::Delta {
173        inputs.relu_derivative()
174    }
175}