1use super::{Activator, ActivatorGradient, Layer};
6use cnc::{Forward, ParamsBase};
7use ndarray::{Dimension, Ix2, RawData};
8
9pub type LayerDyn<A, S, D> = LayerBase<Box<dyn Activator<A, Output = A> + 'static>, S, D>;
10
11pub struct LayerBase<F, S, D = Ix2>
12where
13 D: Dimension,
14 S: RawData,
15{
16 pub(crate) rho: F,
18 pub(crate) params: ParamsBase<S, D>,
19}
20
21impl<S, D> LayerBase<super::Linear, S, D>
22where
23 D: Dimension,
24 S: RawData<Elem = f32>,
25{
26 pub fn linear(params: ParamsBase<S, D>) -> Self {
27 Self {
28 rho: super::Linear,
29 params,
30 }
31 }
32}
33
34impl<S, D> LayerBase<super::Sigmoid, S, D>
35where
36 D: Dimension,
37 S: RawData<Elem = f32>,
38{
39 pub fn sigmoid(params: ParamsBase<S, D>) -> Self {
40 Self {
41 rho: super::Sigmoid,
42 params,
43 }
44 }
45}
46
47impl<S, D> LayerBase<super::Tanh, S, D>
48where
49 D: Dimension,
50 S: RawData<Elem = f32>,
51{
52 pub fn tanh(params: ParamsBase<S, D>) -> Self {
53 Self {
54 rho: super::Tanh,
55 params,
56 }
57 }
58}
59
60impl<S, D> LayerBase<super::ReLU, S, D>
61where
62 D: Dimension,
63 S: RawData<Elem = f32>,
64{
65 pub fn relu(params: ParamsBase<S, D>) -> Self {
66 Self {
67 rho: super::ReLU,
68 params,
69 }
70 }
71}
72
73impl<F, S, A, D> LayerBase<F, S, D>
74where
75 D: Dimension,
76 S: RawData<Elem = A>,
77{
78 pub fn new(rho: F, params: ParamsBase<S, D>) -> Self {
79 Self { rho, params }
80 }
81 pub fn params(&self) -> &ParamsBase<S, D> {
83 &self.params
84 }
85 pub fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
87 &mut self.params
88 }
89 pub fn rho(&self) -> &F {
91 &self.rho
92 }
93 pub fn rho_mut(&mut self) -> &mut F {
95 &mut self.rho
96 }
97 pub fn with_rho<G>(self, rho: G) -> LayerBase<G, S, D>
100 where
101 G: Activator<S::Elem>,
102 F: Activator<S::Elem>,
103 S: RawData<Elem = A>,
104 {
105 LayerBase {
106 rho,
107 params: self.params,
108 }
109 }
110 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
111 where
112 F: Activator<<ParamsBase<S, D> as Forward<X>>::Output, Output = Y>,
113 ParamsBase<S, D>: Forward<X, Output = Y>,
114 X: Clone,
115 Y: Clone,
116 {
117 Forward::forward(&self.params, input).map(|x| self.rho.activate(x))
118 }
119}
120
121impl<F, S, D> core::ops::Deref for LayerBase<F, S, D>
122where
123 D: Dimension,
124 S: RawData,
125{
126 type Target = ParamsBase<S, D>;
127
128 fn deref(&self) -> &Self::Target {
129 &self.params
130 }
131}
132
133impl<F, S, D> core::ops::DerefMut for LayerBase<F, S, D>
134where
135 D: Dimension,
136 S: RawData,
137{
138 fn deref_mut(&mut self) -> &mut Self::Target {
139 &mut self.params
140 }
141}
142
143impl<U, V, F, S, D> Activator<U> for LayerBase<F, S, D>
144where
145 F: Activator<U, Output = V>,
146 D: Dimension,
147 S: RawData,
148{
149 type Output = V;
150
151 fn activate(&self, x: U) -> Self::Output {
152 self.rho().activate(x)
153 }
154}
155
156impl<U, F, S, D> ActivatorGradient<U> for LayerBase<F, S, D>
157where
158 F: ActivatorGradient<U>,
159 D: Dimension,
160 S: RawData,
161{
162 type Input = F::Input;
163 type Delta = F::Delta;
164
165 fn activate_gradient(&self, inputs: F::Input) -> F::Delta {
166 self.rho().activate_gradient(inputs)
167 }
168}
169
170impl<A, F, S, D> Layer<S, D> for LayerBase<F, S, D>
171where
172 F: ActivatorGradient<A>,
173 D: Dimension,
174 S: RawData<Elem = A>,
175{
176 type Scalar = A;
177
178 fn params(&self) -> &ParamsBase<S, D> {
179 &self.params
180 }
181
182 fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
183 &mut self.params
184 }
185}