concision_core/nn/
layer.rs1mod impl_layer;
7mod impl_layer_ext;
8mod impl_layer_repr;
9
10#[doc(inline)]
11pub use self::types::*;
12
13#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
16#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
17pub struct LayerBase<F, P> {
18 pub rho: F,
20 pub params: P,
22}
23
24mod types {
25 use super::LayerBase;
26 use crate::activate::{HeavySide, HyperbolicTangent, Linear, ReLU, Sigmoid};
27 #[cfg(feature = "alloc")]
28 use alloc::boxed::Box;
29 use concision_params::{Params, ParamsBase};
30
31 pub type LayerParamsBase<F, S, D = ndarray::Ix2, A = f32> = LayerBase<F, ParamsBase<S, D, A>>;
33 pub type LayerParams<F, A = f32, D = ndarray::Ix2> = LayerBase<F, Params<A, D>>;
35 pub type LinearLayer<T> = LayerBase<Linear, T>;
37 pub type SigmoidLayer<T> = LayerBase<Sigmoid, T>;
39 pub type TanhLayer<T> = LayerBase<HyperbolicTangent, T>;
41 pub type ReluLayer<T> = LayerBase<ReLU, T>;
43 pub type HeavySideLayer<T> = LayerBase<HeavySide, T>;
45
46 #[cfg(feature = "alloc")]
47 pub type LayerDyn<'a, T> =
49 LayerBase<Box<dyn crate::activate::Activator<T, Output = T> + 'a>, T>;
50 #[cfg(feature = "alloc")]
51 pub type FnLayer<'a, T> = LayerBase<Box<dyn Fn(T) -> T + 'a>, T>;
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use concision_params::Params;
59 use ndarray::Array1;
60
61 #[test]
62 #[ignore = "need to fix the test"]
63 fn test_func_layer() {
64 let params = Params::<f32>::from_elem((3, 2), 0.5);
65 let layer = LayerBase::new(|x: Array1<f32>| x.pow2(), params);
66 let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
68 assert_eq!(layer.params().shape(), &[3, 2]);
70 assert_eq!(layer.forward(&inputs), Array1::from_elem(2, 7.5625).pow2());
72 }
73
74 #[test]
75 fn test_linear_layer() {
76 let params = Params::from_elem((3, 2), 0.5_f32);
77 let layer = LayerBase::linear(params);
78 assert_eq!(layer.params().shape(), &[3, 2]);
80 let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
82 assert_eq!(layer.forward(&inputs), Array1::from_elem(2, 2.75));
84 }
85
86 #[test]
87 fn test_relu_layer() {
88 let params = Params::from_elem((3, 2), 0.5_f32);
89 let layer = LayerBase::relu(params);
90 let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
92 assert_eq!(layer.params().shape(), &[3, 2]);
94 assert_eq!(layer.forward(&inputs), Array1::from_elem(2, 2.75));
96 }
97
98 #[test]
99 #[ignore = "need to fix the test"]
100 fn test_tanh_layer() {
101 let params = Params::from_elem((3, 2), 0.5_f32);
102 let layer = LayerBase::tanh(params);
103 let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
105 assert_eq!(layer.params().shape(), &[3, 2]);
107 let y = layer.forward(&inputs);
109 let exp = Array1::from_elem(2, 0.99185973).tanh();
110 assert!((y - exp).abs().iter().all(|&i| i < 1e-4));
111 }
112}