concision_neural/layers/
layer.rs1mod impl_layer;
10mod impl_layer_repr;
11
12#[allow(deprecated)]
13mod impl_layer_deprecated;
14
15use super::Activator;
16use cnc::{Forward, ParamsBase};
17use ndarray::{DataOwned, Dimension, Ix2, RawData, RemoveAxis, ShapeBuilder};
18
19pub struct LayerBase<F, S, D = Ix2>
22where
23 D: Dimension,
24 S: RawData,
25{
26 pub(crate) rho: F,
28 pub(crate) params: ParamsBase<S, D>,
30}
31
32impl<F, S, A, D> LayerBase<F, S, D>
33where
34 D: Dimension,
35 S: RawData<Elem = A>,
36{
37 pub const fn new(rho: F, params: ParamsBase<S, D>) -> Self {
39 Self { rho, params }
40 }
41 pub fn from_params(params: ParamsBase<S, D>) -> Self
44 where
45 F: Default,
46 {
47 Self {
48 rho: F::default(),
49 params,
50 }
51 }
52 pub fn from_rho<Sh>(rho: F, shape: Sh) -> Self
54 where
55 A: Clone + Default,
56 S: DataOwned,
57 D: RemoveAxis,
58 Sh: ShapeBuilder<Dim = D>,
59 {
60 Self {
61 rho,
62 params: ParamsBase::default(shape),
63 }
64 }
65 pub const fn params(&self) -> &ParamsBase<S, D> {
67 &self.params
68 }
69 pub const fn params_mut(&mut self) -> &mut ParamsBase<S, D> {
71 &mut self.params
72 }
73 pub const fn rho(&self) -> &F {
75 &self.rho
76 }
77 pub const fn rho_mut(&mut self) -> &mut F {
79 &mut self.rho
80 }
81 pub fn with_params<S2, D2>(self, params: ParamsBase<S2, D2>) -> LayerBase<F, S2, D2>
83 where
84 S2: RawData<Elem = S::Elem>,
85 D2: Dimension,
86 {
87 LayerBase {
88 rho: self.rho,
89 params,
90 }
91 }
92 pub fn with_rho<G>(self, rho: G) -> LayerBase<G, S, D>
95 where
96 G: Activator<S::Elem>,
97 F: Activator<S::Elem>,
98 S: RawData<Elem = A>,
99 {
100 LayerBase {
101 rho,
102 params: self.params,
103 }
104 }
105 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
106 where
107 F: Activator<<ParamsBase<S, D> as Forward<X>>::Output, Output = Y>,
108 ParamsBase<S, D>: Forward<X, Output = Y>,
109 X: Clone,
110 Y: Clone,
111 {
112 Forward::forward(&self.params, input).map(|x| self.rho.activate(x))
113 }
114}