mod impl_layer;
mod impl_layer_ext;
mod impl_layer_repr;
#[doc(inline)]
pub use self::types::*;
#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct LayerBase<F, P> {
pub rho: F,
pub params: P,
}
mod types {
use super::LayerBase;
use crate::activate::{HeavySide, HyperbolicTangent, Linear, ReLU, Sigmoid};
#[cfg(feature = "alloc")]
use alloc::boxed::Box;
use concision_params::{Params, ParamsBase};
pub type LayerParamsBase<F, S, D = ndarray::Ix2, A = f32> = LayerBase<F, ParamsBase<S, D, A>>;
pub type LayerParams<F, A = f32, D = ndarray::Ix2> = LayerBase<F, Params<A, D>>;
pub type LinearLayer<T> = LayerBase<Linear, T>;
pub type SigmoidLayer<T> = LayerBase<Sigmoid, T>;
pub type TanhLayer<T> = LayerBase<HyperbolicTangent, T>;
pub type ReluLayer<T> = LayerBase<ReLU, T>;
pub type HeavySideLayer<T> = LayerBase<HeavySide, T>;
#[cfg(feature = "alloc")]
pub type LayerDyn<'a, T> =
LayerBase<Box<dyn crate::activate::Activator<T, Output = T> + 'a>, T>;
#[cfg(feature = "alloc")]
pub type FnLayer<'a, T> = LayerBase<Box<dyn Fn(T) -> T + 'a>, T>;
}
#[cfg(test)]
mod tests {
use super::*;
use concision_params::Params;
use ndarray::Array1;
#[test]
#[ignore = "need to fix the test"]
fn test_func_layer() {
let params = Params::<f32>::from_elem((3, 2), 0.5);
let layer = LayerBase::new(|x: Array1<f32>| x.pow2(), params);
let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
assert_eq!(layer.params().shape(), &[3, 2]);
assert_eq!(layer.forward(&inputs), Array1::from_elem(2, 7.5625).pow2());
}
#[test]
fn test_linear_layer() {
let params = Params::from_elem((3, 2), 0.5_f32);
let layer = LayerBase::linear(params);
assert_eq!(layer.params().shape(), &[3, 2]);
let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
assert_eq!(layer.forward(&inputs), Array1::from_elem(2, 2.75));
}
#[test]
fn test_relu_layer() {
let params = Params::from_elem((3, 2), 0.5_f32);
let layer = LayerBase::relu(params);
let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
assert_eq!(layer.params().shape(), &[3, 2]);
assert_eq!(layer.forward(&inputs), Array1::from_elem(2, 2.75));
}
#[test]
#[ignore = "need to fix the test"]
fn test_tanh_layer() {
let params = Params::from_elem((3, 2), 0.5_f32);
let layer = LayerBase::tanh(params);
let inputs = Array1::<f32>::linspace(1.0, 2.0, 3);
assert_eq!(layer.params().shape(), &[3, 2]);
let y = layer.forward(&inputs);
let exp = Array1::from_elem(2, 0.99185973).tanh();
assert!((y - exp).abs().iter().all(|&i| i < 1e-4));
}
}