1#[doc(inline)]
7pub use self::layer::LayerBase;
8
9pub(crate) mod layer;
10
11#[cfg(feature = "attention")]
12pub mod attention;
13
14pub(crate) mod prelude {
15 #[cfg(feature = "attention")]
16 pub use super::attention::prelude::*;
17 pub use super::layer::*;
18}
19
20use cnc::params::ParamsBase;
21use cnc::{Activate, ActivateGradient, Backward, Forward, Tensor};
22
23use ndarray::{Data, Dimension, RawData};
24
25pub trait Layer<S, D>
30where
31 D: Dimension,
32 S: RawData<Elem = Self::Scalar>,
33{
34 type Scalar;
35
36 fn params(&self) -> &ParamsBase<S, D>;
38 fn params_mut(&mut self) -> &mut ParamsBase<S, D>;
40 fn set_params(&mut self, params: ParamsBase<S, D>) {
42 *self.params_mut() = params;
43 }
44 fn backward<X, Y, Z, Delta>(
46 &mut self,
47 input: X,
48 error: Y,
49 gamma: Self::Scalar,
50 ) -> cnc::Result<Z>
51 where
52 S: Data,
53 Self: ActivateGradient<Y, Delta = Delta>,
54 Self::Scalar: Clone,
55 ParamsBase<S, D>: Backward<X, Delta, Elem = Self::Scalar, Output = Z>,
56 {
57 let delta = self.activate_gradient(error);
59 self.params_mut().backward(&input, &delta, gamma)
61 }
62 fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
64 where
65 Y: Tensor<S, D, Scalar = Self::Scalar>,
66 ParamsBase<S, D>: Forward<X, Output = Y>,
67 Self: Activate<Y, Output = Y>,
68 {
69 self.params().forward_then(input, |y| self.activate(y))
70 }
71}
72
73#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75pub struct Linear;
76
77impl<U> Activate<U> for Linear {
78 type Output = U;
79
80 fn activate(&self, x: U) -> Self::Output {
81 x
82 }
83}
84
85impl<U> ActivateGradient<U> for Linear
86where
87 U: cnc::LinearActivation,
88{
89 type Input = U;
90 type Delta = U::Output;
91
92 fn activate_gradient(&self, _inputs: U) -> Self::Delta {
93 _inputs.linear_derivative()
94 }
95}
96
97#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99pub struct Sigmoid;
100
101impl<U> Activate<U> for Sigmoid
102where
103 U: cnc::Sigmoid,
104{
105 type Output = U::Output;
106
107 fn activate(&self, x: U) -> Self::Output {
108 cnc::Sigmoid::sigmoid(&x)
109 }
110}
111
112impl<U> ActivateGradient<U> for Sigmoid
113where
114 U: cnc::Sigmoid,
115{
116 type Input = U;
117 type Delta = U::Output;
118
119 fn activate_gradient(&self, x: U) -> Self::Delta {
120 cnc::Sigmoid::sigmoid(&x)
121 }
122}
123
124#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub struct Tanh;
127
128impl<U> Activate<U> for Tanh
129where
130 U: cnc::Tanh,
131{
132 type Output = U::Output;
133
134 fn activate(&self, x: U) -> Self::Output {
135 x.tanh()
136 }
137}
138impl<U> ActivateGradient<U> for Tanh
139where
140 U: cnc::Tanh,
141{
142 type Input = U;
143 type Delta = U::Output;
144
145 fn activate_gradient(&self, inputs: U) -> Self::Delta {
146 inputs.tanh_derivative()
147 }
148}
149
150#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
151#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
152pub struct ReLU;
153
154impl<U> Activate<U> for ReLU
155where
156 U: cnc::ReLU,
157{
158 type Output = U::Output;
159
160 fn activate(&self, x: U) -> Self::Output {
161 x.relu()
162 }
163}
164
165impl<U> ActivateGradient<U> for ReLU
166where
167 U: cnc::ReLU,
168{
169 type Input = U;
170 type Delta = U::Output;
171
172 fn activate_gradient(&self, inputs: U) -> Self::Delta {
173 inputs.relu_derivative()
174 }
175}