concision_core/nn/traits/
layer.rs

1/*
2    Appellation: layer <module>
3    Created At: 2025.12.10:16:50:03
4    Contrib: @FL03
5*/
6use crate::activate::{Activator, ActivatorGradient};
7use concision_params::RawParams;
8use concision_traits::{Backward, Forward};
9
10/// The [`RawLayer`] trait establishes a common interface for all _layers_ within a given
11/// model. Implementors will need to define the type of parameters they utilize, as well as
12/// provide methods to access both the activation function and the parameters of the layer.
13pub trait RawLayer<F, A>
14where
15    F: Activator<A>,
16    Self::Params<A>: RawParams<Elem = A>,
17{
18    type Params<_T>;
19    /// the activation function of the layer
20    fn rho(&self) -> &F;
21    /// returns an immutable reference to the parameters of the layer
22    fn params(&self) -> &Self::Params<A>;
23    /// complete a forward pass through the layer
24    fn forward<X, Y, Z>(&self, input: &X) -> Z
25    where
26        F: Activator<Y, Output = Z>,
27        Self::Params<A>: Forward<X, Output = Y>,
28    {
29        let y = self.params().forward(input);
30        self.rho().activate(y)
31    }
32}
33/// The [`RawLayerMut`] trait extends the [`RawLayer`] trait by providing mutable access to the
34/// layer's parameters and additional methods for training the layer, such as backward
35/// propagation and parameter updates.
36pub trait RawLayerMut<F, A>: RawLayer<F, A>
37where
38    F: Activator<A>,
39    Self::Params<A>: RawParams<Elem = A>,
40{
41    /// returns a mutable reference to the parameters of the layer
42    fn params_mut(&mut self) -> &mut Self::Params<A>;
43    /// backward propagate error through the layer
44    fn backward<X, Y, Z, Dt>(&mut self, input: X, error: Y, gamma: A)
45    where
46        A: Clone,
47        F: ActivatorGradient<Y, Rel = F, Delta = Dt>,
48        Self::Params<A>: Backward<X, Dt, Elem = A>,
49    {
50        let delta = self.rho().activate_gradient(error);
51        self.params_mut().backward(&input, &delta, gamma)
52    }
53    /// update the layer parameters
54    fn set_params(&mut self, params: Self::Params<A>) {
55        *self.params_mut() = params;
56    }
57    /// [`replace`](core::mem::replace) the params of the layer, returning the previous value
58    fn replace_params(&mut self, params: Self::Params<A>) -> Self::Params<A> {
59        core::mem::replace(self.params_mut(), params)
60    }
61    /// [`swap`](core::mem::swap) the params of the layer with another
62    fn swap_params(&mut self, other: &mut Self::Params<A>) {
63        core::mem::swap(self.params_mut(), other);
64    }
65}