concision_core/nn/layer/
impl_layer_repr.rs1use crate::activate::{Activator, HyperbolicTangent, Linear, ReLU, Sigmoid};
6use crate::nn::layer::LayerBase;
7use concision_params::{ParamsBase, RawParams};
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
9
10impl<F, S, D, A> LayerBase<F, ArrayBase<S, D, A>>
11where
12 F: Activator<A, Output = A>,
13 D: Dimension,
14 S: RawData<Elem = A>,
15{
16 pub fn from_rho_with_shape<Sh>(rho: F, shape: Sh) -> Self
18 where
19 A: Clone + Default,
20 S: DataOwned,
21 D: RemoveAxis,
22 Sh: ShapeBuilder<Dim = D>,
23 {
24 Self {
25 rho,
26 params: ArrayBase::default(shape),
27 }
28 }
29
30 pub fn dim(&self) -> D::Pattern {
31 self.params().dim()
32 }
33
34 pub fn raw_dim(&self) -> D {
35 self.params().raw_dim()
36 }
37
38 pub fn shape(&self) -> &[usize] {
39 self.params().shape()
40 }
41}
42
43impl<F, S, D, E, A> LayerBase<F, ParamsBase<S, D, A>>
44where
45 F: Activator<A, Output = A>,
46 D: Dimension<Smaller = E>,
47 E: Dimension<Larger = D>,
48 S: RawData<Elem = A>,
49{
50 pub fn from_rho_with_shape<Sh>(rho: F, shape: Sh) -> Self
52 where
53 A: Clone + Default,
54 S: DataOwned,
55 D: RemoveAxis,
56 Sh: ShapeBuilder<Dim = D>,
57 {
58 Self {
59 rho,
60 params: ParamsBase::default(shape),
61 }
62 }
63
64 pub const fn bias(&self) -> &ArrayBase<S, E, A> {
65 self.params().bias()
66 }
67
68 pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, E, A> {
69 self.params_mut().bias_mut()
70 }
71
72 pub const fn weights(&self) -> &ArrayBase<S, D, A> {
73 self.params().weights()
74 }
75
76 pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
77 self.params_mut().weights_mut()
78 }
79
80 pub fn dim(&self) -> D::Pattern {
81 self.params().dim()
82 }
83
84 pub fn raw_dim(&self) -> D {
85 self.params().raw_dim()
86 }
87
88 pub fn shape(&self) -> &[usize] {
89 self.params().shape()
90 }
91}
92
93impl<F, P, A> LayerBase<F, P>
94where
95 F: Fn(A) -> A,
96 P: RawParams<Elem = A>,
97{
98}
99
100impl<A, P> LayerBase<Linear, P>
101where
102 P: RawParams<Elem = A>,
103{
104 pub const fn linear(params: P) -> Self {
106 Self {
107 rho: Linear,
108 params,
109 }
110 }
111}
112
113impl<A, P> LayerBase<Sigmoid, P>
114where
115 P: RawParams<Elem = A>,
116{
117 pub const fn sigmoid(params: P) -> Self {
119 Self {
120 rho: Sigmoid,
121 params,
122 }
123 }
124}
125
126impl<A, P> LayerBase<HyperbolicTangent, P>
127where
128 P: RawParams<Elem = A>,
129{
130 pub const fn tanh(params: P) -> Self {
133 Self {
134 rho: HyperbolicTangent,
135 params,
136 }
137 }
138}
139
140impl<A, P> LayerBase<ReLU, P>
141where
142 P: RawParams<Elem = A>,
143{
144 pub const fn relu(params: P) -> Self {
146 Self { rho: ReLU, params }
147 }
148}