concision_neural/traits/
hidden.rs

1/*
2    appellation: hidden <module>
3    authors: @FL03
4*/
5use cnc::ParamsBase;
6use ndarray::{Data, Dimension, RawData};
7
8/// The [`RawHidden`] trait for compatible representations of hidden layers
9pub trait RawHidden<S, D>
10where
11    S: RawData,
12    D: Dimension,
13{
14    private!();
15
16    fn count(&self) -> usize;
17}
18
19/// The [`ShallowNeuralStore`] trait for shallow neural networks
20pub trait ShallowNeuralStore<S, D>: RawHidden<S, D>
21where
22    S: RawData,
23    D: Dimension,
24{
25    private!();
26}
27/// The [`DeepNeuralStore`] trait for deep neural networks
28pub trait DeepNeuralStore<S, D>: RawHidden<S, D>
29where
30    S: RawData,
31    D: Dimension,
32    Self: IntoIterator<Item = ParamsBase<S, D>>,
33{
34    private!();
35
36    /// returns the hidden layers as a slice
37    fn as_slice(&self) -> &[ParamsBase<S, D>];
38
39    /// returns the hidden layers as a mutable slice
40    fn as_mut_slice(&mut self) -> &mut [ParamsBase<S, D>];
41}
42
43/*
44 ************* Implementations *************
45*/
46
47impl<X, A, S, D> DeepNeuralStore<S, D> for X
48where
49    S: RawData<Elem = A>,
50    D: Dimension,
51    X: RawHidden<S, D>
52        + IntoIterator<Item = ParamsBase<S, D>>
53        + AsRef<[ParamsBase<S, D>]>
54        + AsMut<[ParamsBase<S, D>]>,
55{
56    seal!();
57
58    fn as_slice(&self) -> &[ParamsBase<S, D>] {
59        self.as_ref()
60    }
61
62    fn as_mut_slice(&mut self) -> &mut [ParamsBase<S, D>] {
63        self.as_mut()
64    }
65}
66
67impl<S, D, T> RawHidden<S, D> for &T
68where
69    D: Dimension,
70    S: RawData,
71    T: RawHidden<S, D>,
72{
73    seal!();
74
75    fn count(&self) -> usize {
76        RawHidden::count(*self)
77    }
78}
79
80impl<S, D, T> RawHidden<S, D> for &mut T
81where
82    D: Dimension,
83    S: RawData,
84    T: RawHidden<S, D>,
85{
86    seal!();
87
88    fn count(&self) -> usize {
89        RawHidden::count(*self)
90    }
91}
92
93impl<A, S, D, const N: usize> RawHidden<S, D> for [ParamsBase<S, D>; N]
94where
95    D: Dimension,
96    S: Data<Elem = A>,
97{
98    seal!();
99
100    fn count(&self) -> usize {
101        N
102    }
103}
104
105macro_rules! impl_raw_hidden_params {
106    (#[count = len] $($rest:tt)*) => {
107        impl<S, D> RawHidden<S, D> for $($rest)*
108        where
109            S: RawData,
110            D: Dimension,
111        {
112            seal!();
113
114            fn count(&self) -> usize {
115                self.len()
116            }
117        }
118    };
119    (#[count = 1] $($rest:tt)*) => {
120        impl<S, D> RawHidden<S, D> for $($rest)*
121        where
122            S: RawData,
123            D: Dimension,
124        {
125            seal!();
126
127            fn count(&self) -> usize {
128                1
129            }
130        }
131
132        impl<S, D> ShallowNeuralStore<S, D> for $($rest)*
133        where
134            S: RawData,
135            D: Dimension,
136        {
137            seal!();
138        }
139    };
140}
141
142impl_raw_hidden_params! {
143    #[count = 1]
144    ParamsBase<S, D>
145}
146
147impl_raw_hidden_params! {
148    #[count = len]
149    [ParamsBase<S, D>]
150}
151
152impl_raw_hidden_params! {
153    #[count = len]
154    Vec<ParamsBase<S, D>>
155}
156
157impl_raw_hidden_params! {
158    #[count = len]
159    std::collections::HashSet<ParamsBase<S, D>>
160}