1#[doc(inline)]
7pub use self::layer::LayerBase;
8
9pub(crate) mod layer;
10
11#[cfg(feature = "attention")]
12pub mod attention;
13
14pub(crate) mod prelude {
15 #[cfg(feature = "attention")]
16 pub use super::attention::prelude::*;
17 pub use super::layer::*;
18}
19
20use cnc::params::ParamsBase;
21use cnc::{Backward, Forward, Tensor};
22
23use ndarray::{Data, Dimension, RawData};
24
25pub trait Activate<T> {
27 type Output;
28
29 fn activate(&self, input: T) -> Self::Output;
31}
32pub trait ActivateGradient<T>: Activate<T> {
35 type Input;
36 type Delta;
37
38 fn activate_gradient(&self, input: Self::Input) -> Self::Delta;
40}
41
42pub trait Layer<S, D>
47where
48 D: Dimension,
49 S: RawData<Elem = Self::Scalar>,
50{
51 type Scalar;
52
53 fn params(&self) -> &ParamsBase<S, D>;
55 fn params_mut(&mut self) -> &mut ParamsBase<S, D>;
57 fn set_params(&mut self, params: ParamsBase<S, D>) {
59 *self.params_mut() = params;
60 }
61 fn backward<X, Y, Z, Delta>(
63 &mut self,
64 input: X,
65 error: Y,
66 gamma: Self::Scalar,
67 ) -> cnc::Result<Z>
68 where
69 S: Data,
70 Self: ActivateGradient<X, Input = Y, Delta = Delta>,
71 Self::Scalar: Clone,
72 ParamsBase<S, D>: Backward<X, Delta, Elem = Self::Scalar, Output = Z>,
73 {
74 let delta = self.activate_gradient(error);
76 self.params_mut().backward(&input, &delta, gamma)
78 }
79 fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
81 where
82 Y: Tensor<S, D, Scalar = Self::Scalar>,
83 ParamsBase<S, D>: Forward<X, Output = Y>,
84 Self: Activate<Y, Output = Y>,
85 {
86 self.params().forward_then(input, |y| self.activate(y))
87 }
88}
89
90#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
91#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
92pub struct Linear;
93
94impl<U> Activate<U> for Linear {
95 type Output = U;
96
97 fn activate(&self, x: U) -> Self::Output {
98 x
99 }
100}
101
102impl<U> ActivateGradient<U> for Linear
103where
104 U: cnc::LinearActivation,
105{
106 type Input = U;
107 type Delta = U::Output;
108
109 fn activate_gradient(&self, _inputs: U) -> Self::Delta {
110 _inputs.linear_derivative()
111 }
112}
113
114#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
115#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
116pub struct Sigmoid;
117
118impl<U> Activate<U> for Sigmoid
119where
120 U: cnc::Sigmoid,
121{
122 type Output = U::Output;
123
124 fn activate(&self, x: U) -> Self::Output {
125 cnc::Sigmoid::sigmoid(x)
126 }
127}
128
129impl<U> ActivateGradient<U> for Sigmoid
130where
131 U: cnc::Sigmoid,
132{
133 type Input = U;
134 type Delta = U::Output;
135
136 fn activate_gradient(&self, x: U) -> Self::Delta {
137 cnc::Sigmoid::sigmoid_derivative(x)
138 }
139}
140
141#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
142#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
143pub struct Tanh;
144
145impl<U> Activate<U> for Tanh
146where
147 U: cnc::Tanh,
148{
149 type Output = U::Output;
150
151 fn activate(&self, x: U) -> Self::Output {
152 x.tanh()
153 }
154}
155impl<U> ActivateGradient<U> for Tanh
156where
157 U: cnc::Tanh,
158{
159 type Input = U;
160 type Delta = U::Output;
161
162 fn activate_gradient(&self, inputs: U) -> Self::Delta {
163 inputs.tanh_derivative()
164 }
165}
166
167#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
168#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
169pub struct ReLU;
170
171impl<U> Activate<U> for ReLU
172where
173 U: cnc::ReLU,
174{
175 type Output = U::Output;
176
177 fn activate(&self, x: U) -> Self::Output {
178 x.relu()
179 }
180}
181
182impl<U> ActivateGradient<U> for ReLU
183where
184 U: cnc::ReLU,
185{
186 type Input = U;
187 type Delta = U::Output;
188
189 fn activate_gradient(&self, inputs: U) -> Self::Delta {
190 inputs.relu_derivative()
191 }
192}