concision_core/nn/layer/
impl_layer_repr.rs

1/*
2    appellation: impl_layer_repr <module>
3    authors: @FL03
4*/
5use crate::activate::{Activator, HyperbolicTangent, Linear, ReLU, Sigmoid};
6use crate::nn::layer::LayerBase;
7use concision_params::{ParamsBase, RawParams};
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
9
10impl<F, S, D, A> LayerBase<F, ArrayBase<S, D, A>>
11where
12    F: Activator<A, Output = A>,
13    D: Dimension,
14    S: RawData<Elem = A>,
15{
16    /// create a new instance from the given activation function and shape.
17    pub fn from_rho_with_shape<Sh>(rho: F, shape: Sh) -> Self
18    where
19        A: Clone + Default,
20        S: DataOwned,
21        D: RemoveAxis,
22        Sh: ShapeBuilder<Dim = D>,
23    {
24        Self {
25            rho,
26            params: ArrayBase::default(shape),
27        }
28    }
29
30    pub fn dim(&self) -> D::Pattern {
31        self.params().dim()
32    }
33
34    pub fn raw_dim(&self) -> D {
35        self.params().raw_dim()
36    }
37
38    pub fn shape(&self) -> &[usize] {
39        self.params().shape()
40    }
41}
42
43impl<F, S, D, E, A> LayerBase<F, ParamsBase<S, D, A>>
44where
45    F: Activator<A, Output = A>,
46    D: Dimension<Smaller = E>,
47    E: Dimension<Larger = D>,
48    S: RawData<Elem = A>,
49{
50    /// create a new layer from the given activation function and shape.
51    pub fn from_rho_with_shape<Sh>(rho: F, shape: Sh) -> Self
52    where
53        A: Clone + Default,
54        S: DataOwned,
55        D: RemoveAxis,
56        Sh: ShapeBuilder<Dim = D>,
57    {
58        Self {
59            rho,
60            params: ParamsBase::default(shape),
61        }
62    }
63
64    pub const fn bias(&self) -> &ArrayBase<S, E, A> {
65        self.params().bias()
66    }
67
68    pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, E, A> {
69        self.params_mut().bias_mut()
70    }
71
72    pub const fn weights(&self) -> &ArrayBase<S, D, A> {
73        self.params().weights()
74    }
75
76    pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
77        self.params_mut().weights_mut()
78    }
79
80    pub fn dim(&self) -> D::Pattern {
81        self.params().dim()
82    }
83
84    pub fn raw_dim(&self) -> D {
85        self.params().raw_dim()
86    }
87
88    pub fn shape(&self) -> &[usize] {
89        self.params().shape()
90    }
91}
92
93impl<F, P, A> LayerBase<F, P>
94where
95    F: Fn(A) -> A,
96    P: RawParams<Elem = A>,
97{
98}
99
100impl<A, P> LayerBase<Linear, P>
101where
102    P: RawParams<Elem = A>,
103{
104    /// initialize a layer using the [`Linear`] activation function and the given params.
105    pub const fn linear(params: P) -> Self {
106        Self {
107            rho: Linear,
108            params,
109        }
110    }
111}
112
113impl<A, P> LayerBase<Sigmoid, P>
114where
115    P: RawParams<Elem = A>,
116{
117    /// initialize a layer using the [`Sigmoid`] activation function and the given params.
118    pub const fn sigmoid(params: P) -> Self {
119        Self {
120            rho: Sigmoid,
121            params,
122        }
123    }
124}
125
126impl<A, P> LayerBase<HyperbolicTangent, P>
127where
128    P: RawParams<Elem = A>,
129{
130    /// initialize a new layer using a [`TanhActivator`] activation function and the given
131    /// parameters.
132    pub const fn tanh(params: P) -> Self {
133        Self {
134            rho: HyperbolicTangent,
135            params,
136        }
137    }
138}
139
140impl<A, P> LayerBase<ReLU, P>
141where
142    P: RawParams<Elem = A>,
143{
144    /// initialize a layer using the [`Sigmoid`] activation function and the given params.
145    pub const fn relu(params: P) -> Self {
146        Self { rho: ReLU, params }
147    }
148}