concision_neural/params/
model_params.rs

1/*
2    Appellation: store <module>
3    Contrib: @FL03
4*/
5
6use cnc::params::ParamsBase;
7use ndarray::{ArrayBase, Dimension, RawData};
8
9use crate::{DeepModelRepr, RawHidden};
10
11/// The [`ModelParamsBase`] object is a generic container for storing the parameters of a
12/// neural network, regardless of the layout (e.g. shallow or deep). This is made possible
13/// through the introduction of a generic hidden layer type, `H`, that allows us to define
14/// aliases and additional traits for contraining the hidden layer type. That being said, we
15/// don't reccoment using this type directly, but rather use the provided type aliases such as
16/// [`DeepModelParams`] or [`ShallowModelParams`] or their owned variants. These provide a much
17/// more straighforward interface for typing the parameters of a neural network. We aren't too
18/// worried about the transmutation between the two since users desiring this ability should
19/// simply stick with a _deep_ representation, initializing only a single layer within the
20/// respective container.
21///
22/// This type also enables us to define a set of common initialization routines and introduce
23/// other standards for dealing with parameters in a neural network.
24pub struct ModelParamsBase<S, D, H>
25where
26    D: Dimension,
27    S: RawData,
28    H: RawHidden<S, D>,
29{
30    /// the input layer of the model
31    pub(crate) input: ParamsBase<S, D>,
32    /// a sequential stack of params for the model's hidden layers
33    pub(crate) hidden: H,
34    /// the output layer of the model
35    pub(crate) output: ParamsBase<S, D>,
36}
37/// The base implementation for the [`ModelParamsBase`] type, which is generic over the
38/// storage type `S`, the dimension `D`, and the hidden layer type `H`. This implementation
39/// focuses on providing basic initialization routines and accessors for the various layers
40/// within the model.
41impl<S, D, H, A> ModelParamsBase<S, D, H>
42where
43    D: Dimension,
44    S: RawData<Elem = A>,
45    H: RawHidden<S, D>,
46{
47    /// create a new instance of the [`ModelParamsBase`] instance
48    pub const fn new(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
49        Self {
50            input,
51            hidden,
52            output,
53        }
54    }
55    /// returns an immutable reference to the input layer of the model
56    pub const fn input(&self) -> &ParamsBase<S, D> {
57        &self.input
58    }
59    /// returns a mutable reference to the input layer of the model
60    pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
61        &mut self.input
62    }
63    /// returns an immutable reference to the hidden layers of the model
64    pub const fn hidden(&self) -> &H {
65        &self.hidden
66    }
67    /// returns a mutable reference to the hidden layers of the model
68    pub const fn hidden_mut(&mut self) -> &mut H {
69        &mut self.hidden
70    }
71    /// returns an immutable reference to the output layer of the model
72    pub const fn output(&self) -> &ParamsBase<S, D> {
73        &self.output
74    }
75    /// returns a mutable reference to the output layer of the model
76    pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
77        &mut self.output
78    }
79    /// set the input layer of the model
80    #[inline]
81    pub fn set_input(&mut self, input: ParamsBase<S, D>) -> &mut Self {
82        *self.input_mut() = input;
83        self
84    }
85    /// set the hidden layers of the model
86    #[inline]
87    pub fn set_hidden(&mut self, hidden: H) -> &mut Self {
88        *self.hidden_mut() = hidden;
89        self
90    }
91    /// set the output layer of the model
92    #[inline]
93    pub fn set_output(&mut self, output: ParamsBase<S, D>) -> &mut Self {
94        *self.output_mut() = output;
95        self
96    }
97    /// consumes the current instance and returns another with the specified input layer
98    #[inline]
99    pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
100        Self { input, ..self }
101    }
102    /// consumes the current instance and returns another with the specified hidden
103    /// layer(s)
104    #[inline]
105    pub fn with_hidden(self, hidden: H) -> Self {
106        Self { hidden, ..self }
107    }
108    /// consumes the current instance and returns another with the specified output layer
109    #[inline]
110    pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
111        Self { output, ..self }
112    }
113    /// returns an immutable reference to the hidden layers of the model as a slice
114    #[inline]
115    pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>]
116    where
117        H: DeepModelRepr<S, D>,
118    {
119        self.hidden().as_slice()
120    }
121    /// returns an immutable reference to the input bias
122    pub const fn input_bias(&self) -> &ArrayBase<S, D::Smaller> {
123        self.input().bias()
124    }
125    /// returns a mutable reference to the input bias
126    pub const fn input_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
127        self.input_mut().bias_mut()
128    }
129    /// returns an immutable reference to the input weights
130    pub const fn input_weights(&self) -> &ArrayBase<S, D> {
131        self.input().weights()
132    }
133    /// returns an mutable reference to the input weights
134    pub const fn input_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
135        self.input_mut().weights_mut()
136    }
137    /// returns an immutable reference to the output bias
138    pub const fn output_bias(&self) -> &ArrayBase<S, D::Smaller> {
139        self.output().bias()
140    }
141    /// returns a mutable reference to the output bias
142    pub const fn output_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
143        self.output_mut().bias_mut()
144    }
145    /// returns an immutable reference to the output weights
146    pub const fn output_weights(&self) -> &ArrayBase<S, D> {
147        self.output().weights()
148    }
149    /// returns an mutable reference to the output weights
150    pub const fn output_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
151        self.output_mut().weights_mut()
152    }
153    /// returns the number of hidden layers in the model
154    pub fn count_hidden(&self) -> usize {
155        self.hidden().count()
156    }
157    /// returns true if the stack is shallow; a neural network is considered to be _shallow_ if
158    /// it has at most one hidden layer (`n <= 1`).
159    #[inline]
160    pub fn is_shallow(&self) -> bool {
161        self.count_hidden() <= 1
162    }
163    /// returns true if the model stack of parameters is considered to be _deep_, meaning that
164    /// there the number of hidden layers is greater than one.
165    #[inline]
166    pub fn is_deep(&self) -> bool {
167        self.count_hidden() > 1
168    }
169    /// returns the total number of layers within the model, including the input and output layers
170    #[inline]
171    pub fn len(&self) -> usize {
172        self.count_hidden() + 2 // +2 for input and output layers
173    }
174}