concision_core/models/traits/
hidden.rs

1/*
2    appellation: hidden <module>
3    authors: @FL03
4*/
5use concision_params::ParamsBase;
6use ndarray::{Data, Dimension, RawData};
7
8/// The [`RawHidden`] trait for compatible representations of hidden layers
9pub trait RawHidden<S, D>
10where
11    S: RawData,
12    D: Dimension,
13{
14    private! {}
15
16    fn count(&self) -> usize;
17}
18
19/// The [`ShallowModelRepr`] trait for shallow neural networks
20pub trait ShallowModelRepr<S, D>: RawHidden<S, D>
21where
22    S: RawData,
23    D: Dimension,
24{
25    private! {}
26}
27/// The [`DeepModelRepr`] trait for deep neural networks
28pub trait DeepModelRepr<S, D>: RawHidden<S, D>
29where
30    S: RawData,
31    D: Dimension,
32    Self:
33        IntoIterator<Item = ParamsBase<S, D>> + core::ops::Index<usize, Output = ParamsBase<S, D>>,
34{
35    private! {}
36
37    /// returns the hidden layers as a slice
38    fn as_slice(&self) -> &[ParamsBase<S, D>];
39
40    /// returns the hidden layers as a mutable slice
41    fn as_mut_slice(&mut self) -> &mut [ParamsBase<S, D>];
42}
43
44/*
45 ************* Implementations *************
46*/
47
48impl<X, A, S, D> DeepModelRepr<S, D> for X
49where
50    S: RawData<Elem = A>,
51    D: Dimension,
52    X: RawHidden<S, D>
53        + IntoIterator<Item = ParamsBase<S, D>>
54        + AsRef<[ParamsBase<S, D>]>
55        + AsMut<[ParamsBase<S, D>]>
56        + core::ops::Index<usize, Output = ParamsBase<S, D>>,
57{
58    seal!();
59
60    fn as_slice(&self) -> &[ParamsBase<S, D>] {
61        self.as_ref()
62    }
63
64    fn as_mut_slice(&mut self) -> &mut [ParamsBase<S, D>] {
65        self.as_mut()
66    }
67}
68
69impl<S, D, T> RawHidden<S, D> for &T
70where
71    D: Dimension,
72    S: RawData,
73    T: RawHidden<S, D>,
74{
75    seal!();
76
77    fn count(&self) -> usize {
78        RawHidden::count(*self)
79    }
80}
81
82impl<S, D, T> RawHidden<S, D> for &mut T
83where
84    D: Dimension,
85    S: RawData,
86    T: RawHidden<S, D>,
87{
88    seal!();
89
90    fn count(&self) -> usize {
91        RawHidden::count(*self)
92    }
93}
94
95impl<A, S, D, const N: usize> RawHidden<S, D> for [ParamsBase<S, D>; N]
96where
97    D: Dimension,
98    S: Data<Elem = A>,
99{
100    seal!();
101
102    fn count(&self) -> usize {
103        N
104    }
105}
106
107macro_rules! impl_raw_hidden_params {
108    (#[count = len] $($rest:tt)*) => {
109        impl<S, D> RawHidden<S, D> for $($rest)*
110        where
111            S: RawData,
112            D: Dimension,
113        {
114            seal!();
115
116            fn count(&self) -> usize {
117                self.len()
118            }
119        }
120    };
121    (#[count = 1] $($rest:tt)*) => {
122        impl<S, D> RawHidden<S, D> for $($rest)*
123        where
124            S: RawData,
125            D: Dimension,
126        {
127            seal!();
128
129            fn count(&self) -> usize {
130                1
131            }
132        }
133
134        impl<S, D> ShallowModelRepr<S, D> for $($rest)*
135        where
136            S: RawData,
137            D: Dimension,
138        {
139            seal!();
140        }
141    };
142}
143
144impl_raw_hidden_params! {
145    #[count = 1]
146    ParamsBase<S, D>
147}
148
149impl_raw_hidden_params! {
150    #[count = len]
151    [ParamsBase<S, D>]
152}
153
154impl_raw_hidden_params! {
155    #[count = len]
156    Vec<ParamsBase<S, D>>
157}
158
159impl_raw_hidden_params! {
160    #[count = len]
161    std::collections::HashSet<ParamsBase<S, D>>
162}