concision_neural/params/model_params.rs
1/*
2 Appellation: store <module>
3 Contrib: @FL03
4*/
5
6use cnc::params::ParamsBase;
7use ndarray::{ArrayBase, Dimension, RawData};
8
9use crate::{DeepModelRepr, RawHidden};
10
11/// The [`ModelParamsBase`] object is a generic container for storing the parameters of a
12/// neural network, regardless of the layout (e.g. shallow or deep). This is made possible
13/// through the introduction of a generic hidden layer type, `H`, that allows us to define
14/// aliases and additional traits for contraining the hidden layer type. That being said, we
15/// don't reccoment using this type directly, but rather use the provided type aliases such as
16/// [`DeepModelParams`] or [`ShallowModelParams`] or their owned variants. These provide a much
17/// more straighforward interface for typing the parameters of a neural network. We aren't too
18/// worried about the transmutation between the two since users desiring this ability should
19/// simply stick with a _deep_ representation, initializing only a single layer within the
20/// respective container.
21///
22/// This type also enables us to define a set of common initialization routines and introduce
23/// other standards for dealing with parameters in a neural network.
24pub struct ModelParamsBase<S, D, H>
25where
26 D: Dimension,
27 S: RawData,
28 H: RawHidden<S, D>,
29{
30 /// the input layer of the model
31 pub(crate) input: ParamsBase<S, D>,
32 /// a sequential stack of params for the model's hidden layers
33 pub(crate) hidden: H,
34 /// the output layer of the model
35 pub(crate) output: ParamsBase<S, D>,
36}
37/// The base implementation for the [`ModelParamsBase`] type, which is generic over the
38/// storage type `S`, the dimension `D`, and the hidden layer type `H`. This implementation
39/// focuses on providing basic initialization routines and accessors for the various layers
40/// within the model.
41impl<S, D, H, A> ModelParamsBase<S, D, H>
42where
43 D: Dimension,
44 S: RawData<Elem = A>,
45 H: RawHidden<S, D>,
46{
47 /// create a new instance of the [`ModelParamsBase`] instance
48 pub const fn new(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
49 Self {
50 input,
51 hidden,
52 output,
53 }
54 }
55 /// returns an immutable reference to the input layer of the model
56 pub const fn input(&self) -> &ParamsBase<S, D> {
57 &self.input
58 }
59 /// returns a mutable reference to the input layer of the model
60 pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
61 &mut self.input
62 }
63 /// returns an immutable reference to the hidden layers of the model
64 pub const fn hidden(&self) -> &H {
65 &self.hidden
66 }
67 /// returns a mutable reference to the hidden layers of the model
68 pub const fn hidden_mut(&mut self) -> &mut H {
69 &mut self.hidden
70 }
71 /// returns an immutable reference to the output layer of the model
72 pub const fn output(&self) -> &ParamsBase<S, D> {
73 &self.output
74 }
75 /// returns a mutable reference to the output layer of the model
76 pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
77 &mut self.output
78 }
79 /// set the input layer of the model
80 #[inline]
81 pub fn set_input(&mut self, input: ParamsBase<S, D>) -> &mut Self {
82 *self.input_mut() = input;
83 self
84 }
85 /// set the hidden layers of the model
86 #[inline]
87 pub fn set_hidden(&mut self, hidden: H) -> &mut Self {
88 *self.hidden_mut() = hidden;
89 self
90 }
91 /// set the output layer of the model
92 #[inline]
93 pub fn set_output(&mut self, output: ParamsBase<S, D>) -> &mut Self {
94 *self.output_mut() = output;
95 self
96 }
97 /// consumes the current instance and returns another with the specified input layer
98 #[inline]
99 pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
100 Self { input, ..self }
101 }
102 /// consumes the current instance and returns another with the specified hidden
103 /// layer(s)
104 #[inline]
105 pub fn with_hidden(self, hidden: H) -> Self {
106 Self { hidden, ..self }
107 }
108 /// consumes the current instance and returns another with the specified output layer
109 #[inline]
110 pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
111 Self { output, ..self }
112 }
113 /// returns an immutable reference to the hidden layers of the model as a slice
114 #[inline]
115 pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>]
116 where
117 H: DeepModelRepr<S, D>,
118 {
119 self.hidden().as_slice()
120 }
121 /// returns an immutable reference to the input bias
122 pub const fn input_bias(&self) -> &ArrayBase<S, D::Smaller> {
123 self.input().bias()
124 }
125 /// returns a mutable reference to the input bias
126 pub const fn input_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
127 self.input_mut().bias_mut()
128 }
129 /// returns an immutable reference to the input weights
130 pub const fn input_weights(&self) -> &ArrayBase<S, D> {
131 self.input().weights()
132 }
133 /// returns an mutable reference to the input weights
134 pub const fn input_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
135 self.input_mut().weights_mut()
136 }
137 /// returns an immutable reference to the output bias
138 pub const fn output_bias(&self) -> &ArrayBase<S, D::Smaller> {
139 self.output().bias()
140 }
141 /// returns a mutable reference to the output bias
142 pub const fn output_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
143 self.output_mut().bias_mut()
144 }
145 /// returns an immutable reference to the output weights
146 pub const fn output_weights(&self) -> &ArrayBase<S, D> {
147 self.output().weights()
148 }
149 /// returns an mutable reference to the output weights
150 pub const fn output_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
151 self.output_mut().weights_mut()
152 }
153 /// returns the number of hidden layers in the model
154 pub fn count_hidden(&self) -> usize {
155 self.hidden().count()
156 }
157 /// returns true if the stack is shallow; a neural network is considered to be _shallow_ if
158 /// it has at most one hidden layer (`n <= 1`).
159 #[inline]
160 pub fn is_shallow(&self) -> bool {
161 self.count_hidden() <= 1
162 }
163 /// returns true if the model stack of parameters is considered to be _deep_, meaning that
164 /// there the number of hidden layers is greater than one.
165 #[inline]
166 pub fn is_deep(&self) -> bool {
167 self.count_hidden() > 1
168 }
169 /// returns the total number of layers within the model, including the input and output layers
170 #[inline]
171 pub fn len(&self) -> usize {
172 self.count_hidden() + 2 // +2 for input and output layers
173 }
174}