concision_core/models/impls/
impl_model_params.rs

1/*
2    appellation: impl_model_params <module>
3    authors: @FL03
4*/
5use crate::models::ModelParamsBase;
6
7use crate::{DeepModelRepr, RawHidden};
8use concision_params::ParamsBase;
9use ndarray::{ArrayBase, Data, Dimension, RawData, RawDataClone};
10
11/// The base implementation for the [`ModelParamsBase`] type, which is generic over the
12/// storage type `S`, the dimension `D`, and the hidden layer type `H`. This implementation
13/// focuses on providing basic initialization routines and accessors for the various layers
14/// within the model.
15impl<S, D, H, A> ModelParamsBase<S, D, H, A>
16where
17    D: Dimension,
18    S: RawData<Elem = A>,
19    H: RawHidden<S, D>,
20{
21    /// create a new instance of the [`ModelParamsBase`] instance
22    pub const fn new(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
23        Self {
24            input,
25            hidden,
26            output,
27        }
28    }
29    /// returns an immutable reference to the input layer of the model
30    pub const fn input(&self) -> &ParamsBase<S, D> {
31        &self.input
32    }
33    /// returns a mutable reference to the input layer of the model
34    pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
35        &mut self.input
36    }
37    /// returns an immutable reference to the hidden layers of the model
38    pub const fn hidden(&self) -> &H {
39        &self.hidden
40    }
41    /// returns a mutable reference to the hidden layers of the model
42    pub const fn hidden_mut(&mut self) -> &mut H {
43        &mut self.hidden
44    }
45    /// returns an immutable reference to the output layer of the model
46    pub const fn output(&self) -> &ParamsBase<S, D> {
47        &self.output
48    }
49    /// returns a mutable reference to the output layer of the model
50    pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
51        &mut self.output
52    }
53    /// set the input layer of the model
54    #[inline]
55    pub fn set_input(&mut self, input: ParamsBase<S, D>) {
56        *self.input_mut() = input
57    }
58    /// set the hidden layers of the model
59    #[inline]
60    pub fn set_hidden(&mut self, hidden: H) {
61        *self.hidden_mut() = hidden
62    }
63    /// set the output layer of the model
64    #[inline]
65    pub fn set_output(&mut self, output: ParamsBase<S, D>) {
66        *self.output_mut() = output
67    }
68    /// consumes the current instance and returns another with the specified input layer
69    #[inline]
70    pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
71        Self { input, ..self }
72    }
73    /// consumes the current instance and returns another with the specified hidden
74    /// layer(s)
75    #[inline]
76    pub fn with_hidden(self, hidden: H) -> Self {
77        Self { hidden, ..self }
78    }
79    /// consumes the current instance and returns another with the specified output layer
80    #[inline]
81    pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
82        Self { output, ..self }
83    }
84    /// returns an immutable reference to the hidden layers of the model as a slice
85    #[inline]
86    pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>]
87    where
88        H: DeepModelRepr<S, D>,
89    {
90        self.hidden().as_slice()
91    }
92    /// returns an immutable reference to the input bias
93    pub const fn input_bias(&self) -> &ArrayBase<S, D::Smaller, A> {
94        self.input().bias()
95    }
96    /// returns a mutable reference to the input bias
97    pub const fn input_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
98        self.input_mut().bias_mut()
99    }
100    /// returns an immutable reference to the input weights
101    pub const fn input_weights(&self) -> &ArrayBase<S, D, A> {
102        self.input().weights()
103    }
104    /// returns an mutable reference to the input weights
105    pub const fn input_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
106        self.input_mut().weights_mut()
107    }
108    /// returns an immutable reference to the output bias
109    pub const fn output_bias(&self) -> &ArrayBase<S, D::Smaller, A> {
110        self.output().bias()
111    }
112    /// returns a mutable reference to the output bias
113    pub const fn output_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
114        self.output_mut().bias_mut()
115    }
116    /// returns an immutable reference to the output weights
117    pub const fn output_weights(&self) -> &ArrayBase<S, D, A> {
118        self.output().weights()
119    }
120    /// returns an mutable reference to the output weights
121    pub const fn output_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
122        self.output_mut().weights_mut()
123    }
124    /// returns the total number of layers in the model, including input, hidden, and output
125    pub fn layers(&self) -> usize {
126        2 + self.count_hidden()
127    }
128    /// returns the number of hidden layers in the model
129    pub fn count_hidden(&self) -> usize {
130        self.hidden().count()
131    }
132    /// returns true if the stack is shallow; a neural network is considered to be _shallow_ if
133    /// it has at most one hidden layer (`n <= 1`).
134    #[inline]
135    pub fn is_shallow(&self) -> bool {
136        self.count_hidden() <= 1
137    }
138    /// returns true if the model stack of parameters is considered to be _deep_, meaning that
139    /// there the number of hidden layers is greater than one.
140    #[inline]
141    pub fn is_deep(&self) -> bool {
142        self.count_hidden() > 1
143    }
144}
145
146impl<A, S, D, H> Clone for ModelParamsBase<S, D, H, A>
147where
148    D: Dimension,
149    H: RawHidden<S, D> + Clone,
150    S: RawDataClone<Elem = A>,
151    A: Clone,
152{
153    fn clone(&self) -> Self {
154        Self {
155            input: self.input().clone(),
156            hidden: self.hidden().clone(),
157            output: self.output().clone(),
158        }
159    }
160}
161
162impl<A, S, D, H> core::fmt::Debug for ModelParamsBase<S, D, H, A>
163where
164    D: Dimension,
165    H: RawHidden<S, D> + core::fmt::Debug,
166    S: Data<Elem = A>,
167    A: core::fmt::Debug,
168{
169    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170        f.debug_struct("ModelParams")
171            .field("input", self.input())
172            .field("hidden", self.hidden())
173            .field("output", self.output())
174            .finish()
175    }
176}
177
178impl<A, S, D, H> core::fmt::Display for ModelParamsBase<S, D, H, A>
179where
180    D: Dimension,
181    H: RawHidden<S, D> + core::fmt::Debug,
182    S: Data<Elem = A>,
183    A: core::fmt::Display,
184{
185    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
186        write!(
187            f,
188            "{{ input: {i}, hidden: {h:?}, output: {o} }}",
189            i = self.input(),
190            h = self.hidden(),
191            o = self.output()
192        )
193    }
194}
195
196impl<A, S, D, H> core::ops::Index<usize> for ModelParamsBase<S, D, H, A>
197where
198    D: Dimension,
199    S: Data<Elem = A>,
200    H: RawHidden<S, D> + core::ops::Index<usize, Output = ParamsBase<S, D>>,
201    A: Clone,
202{
203    type Output = ParamsBase<S, D>;
204
205    fn index(&self, index: usize) -> &Self::Output {
206        match index % self.layers() {
207            0 => self.input(),
208            i if i == self.count_hidden() + 1 => self.output(),
209            _ => &self.hidden()[index - 1],
210        }
211    }
212}
213
214impl<A, S, D, H> core::ops::IndexMut<usize> for ModelParamsBase<S, D, H, A>
215where
216    D: Dimension,
217    S: Data<Elem = A>,
218    H: RawHidden<S, D> + core::ops::IndexMut<usize, Output = ParamsBase<S, D>>,
219    A: Clone,
220{
221    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
222        match index % self.layers() {
223            0 => self.input_mut(),
224            i if i == self.count_hidden() + 1 => self.output_mut(),
225            _ => &mut self.hidden_mut()[index - 1],
226        }
227    }
228}