concision_neural/layer/
layer.rs1#![allow(unused)]
6use super::{Activate, Layer};
7use cnc::{Forward, ParamsBase, activate};
8use ndarray::{Dimension, Ix2, RawData};
9
10pub struct LayerBase<F, S, D = Ix2>
11where
12 D: Dimension,
13 S: RawData,
14{
15 pub(crate) rho: F,
16 pub(crate) params: ParamsBase<S, D>,
17}
18
19pub struct Linear;
20
21impl<U> Activate<U> for Linear {
22 type Output = U;
23 fn activate(&self, x: U) -> Self::Output {
24 x
25 }
26}
27
28impl<S, D> LayerBase<Linear, S, D>
29where
30 D: Dimension,
31 S: RawData<Elem = f32>,
32{
33 pub fn linear(params: ParamsBase<S, D>) -> Self {
34 Self {
35 rho: Linear,
36 params,
37 }
38 }
39}
40
41impl<F, S, A, D> LayerBase<F, S, D>
42where
43 D: Dimension,
44 S: RawData<Elem = A>,
45{
46 pub fn new(rho: F, params: ParamsBase<S, D>) -> Self {
47 Self { rho, params }
48 }
49 pub fn params(&self) -> &ParamsBase<S, D> {
51 &self.params
52 }
53 pub fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
55 &mut self.params
56 }
57 pub fn rho(&self) -> &F {
59 &self.rho
60 }
61 pub fn rho_mut(&mut self) -> &mut F {
63 &mut self.rho
64 }
65 pub fn with_rho<G>(self, rho: G) -> LayerBase<G, S, D>
68 where
69 G: Activate<S::Elem>,
70 F: Activate<S::Elem>,
71 S: RawData<Elem = A>,
72 {
73 LayerBase {
74 rho,
75 params: self.params,
76 }
77 }
78 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
79 where
80 F: Activate<<ParamsBase<S, D> as Forward<X>>::Output, Output = Y>,
81 ParamsBase<S, D>: Forward<X, Output = Y>,
82 X: Clone,
83 Y: Clone,
84 {
85 Forward::forward(&self.params, input).map(|x| self.rho.activate(x))
86 }
87}
88
89impl<F, S, D> core::ops::Deref for LayerBase<F, S, D>
90where
91 D: Dimension,
92 S: RawData,
93{
94 type Target = ParamsBase<S, D>;
95
96 fn deref(&self) -> &Self::Target {
97 &self.params
98 }
99}
100
101impl<F, S, D> core::ops::DerefMut for LayerBase<F, S, D>
102where
103 D: Dimension,
104 S: RawData,
105{
106 fn deref_mut(&mut self) -> &mut Self::Target {
107 &mut self.params
108 }
109}