concision_neural/layers/
mod.rs

1/*
2    Appellation: layers <module>
3    Contrib: @FL03
4*/
5//! This module implments various layers for a neural network
6#[doc(inline)]
7pub use self::layer::LayerBase;
8
9pub(crate) mod layer;
10
11#[cfg(feature = "attention")]
12pub mod attention;
13
14pub(crate) mod prelude {
15    #[cfg(feature = "attention")]
16    pub use super::attention::prelude::*;
17    pub use super::layer::*;
18}
19
20use cnc::params::ParamsBase;
21use cnc::{Backward, Forward, Tensor};
22
23use ndarray::{Data, Dimension, RawData};
24
25/// The [`Activate`] trait defines a method for applying an activation function to an input tensor.
26pub trait Activate<T> {
27    type Output;
28
29    /// Applies the activation function to the input tensor.
30    fn activate(&self, input: T) -> Self::Output;
31}
32/// The [`ActivateGradient`] trait extends the [`Activate`] trait to include a method for 
33/// computing the gradient of the activation function.
34pub trait ActivateGradient<T>: Activate<T> {
35    type Input;
36    type Delta;
37
38    /// compute the gradient of some input
39    fn activate_gradient(&self, input: Self::Input) -> Self::Delta;
40}
41
42/// A layer within a neural-network containing a set of parameters and an activation function.
43/// Here, this manifests as a wrapper around the parameters of the layer with a generic
44/// activation function and corresponding traits to denote desired behaviors.
45///
46pub trait Layer<S, D>
47where
48    D: Dimension,
49    S: RawData<Elem = Self::Scalar>,
50{
51    type Scalar;
52
53    /// returns an immutable reference to the parameters of the layer
54    fn params(&self) -> &ParamsBase<S, D>;
55    /// returns a mutable reference to the parameters of the layer
56    fn params_mut(&mut self) -> &mut ParamsBase<S, D>;
57    /// update the layer parameters
58    fn set_params(&mut self, params: ParamsBase<S, D>) {
59        *self.params_mut() = params;
60    }
61    /// backward propagate error through the layer
62    fn backward<X, Y, Z, Delta>(
63        &mut self,
64        input: X,
65        error: Y,
66        gamma: Self::Scalar,
67    ) -> cnc::Result<Z>
68    where
69        S: Data,
70        Self: ActivateGradient<X, Input = Y, Delta = Delta>,
71        Self::Scalar: Clone,
72        ParamsBase<S, D>: Backward<X, Delta, Elem = Self::Scalar, Output = Z>,
73    {
74        // compute the delta using the activation function
75        let delta = self.activate_gradient(error);
76        // apply the backward function of the inherited layer
77        self.params_mut().backward(&input, &delta, gamma)
78    }
79    /// complete a forward pass through the layer
80    fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
81    where
82        Y: Tensor<S, D, Scalar = Self::Scalar>,
83        ParamsBase<S, D>: Forward<X, Output = Y>,
84        Self: Activate<Y, Output = Y>,
85    {
86        self.params().forward_then(input, |y| self.activate(y))
87    }
88}
89
90#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
91#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
92pub struct Linear;
93
94impl<U> Activate<U> for Linear {
95    type Output = U;
96
97    fn activate(&self, x: U) -> Self::Output {
98        x
99    }
100}
101
102impl<U> ActivateGradient<U> for Linear
103where
104    U: cnc::LinearActivation,
105{
106    type Input = U;
107    type Delta = U::Output;
108
109    fn activate_gradient(&self, _inputs: U) -> Self::Delta {
110        _inputs.linear_derivative()
111    }
112}
113
114#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
115#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
116pub struct Sigmoid;
117
118impl<U> Activate<U> for Sigmoid
119where
120    U: cnc::Sigmoid,
121{
122    type Output = U::Output;
123
124    fn activate(&self, x: U) -> Self::Output {
125        cnc::Sigmoid::sigmoid(x)
126    }
127}
128
129impl<U> ActivateGradient<U> for Sigmoid
130where
131    U: cnc::Sigmoid,
132{
133    type Input = U;
134    type Delta = U::Output;
135
136    fn activate_gradient(&self, x: U) -> Self::Delta {
137        cnc::Sigmoid::sigmoid_derivative(x)
138    }
139}
140
141#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
142#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
143pub struct Tanh;
144
145impl<U> Activate<U> for Tanh
146where
147    U: cnc::Tanh,
148{
149    type Output = U::Output;
150
151    fn activate(&self, x: U) -> Self::Output {
152        x.tanh()
153    }
154}
155impl<U> ActivateGradient<U> for Tanh
156where
157    U: cnc::Tanh,
158{
159    type Input = U;
160    type Delta = U::Output;
161
162    fn activate_gradient(&self, inputs: U) -> Self::Delta {
163        inputs.tanh_derivative()
164    }
165}
166
167#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
168#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
169pub struct ReLU;
170
171impl<U> Activate<U> for ReLU
172where
173    U: cnc::ReLU,
174{
175    type Output = U::Output;
176
177    fn activate(&self, x: U) -> Self::Output {
178        x.relu()
179    }
180}
181
182impl<U> ActivateGradient<U> for ReLU
183where
184    U: cnc::ReLU,
185{
186    type Input = U;
187    type Delta = U::Output;
188
189    fn activate_gradient(&self, inputs: U) -> Self::Delta {
190        inputs.relu_derivative()
191    }
192}