concision_neural/layout/format.rs
1/*
2 appellation: format <module>
3 authors: @FL03
4*/
5
6/// The [`ModelFormat`] type enumerates the various formats a neural network may take, either
7/// shallow or deep, providing a unified interface for accessing the number of hidden features
8/// and layers in the model. This is done largely for simplicity, as it eliminates the need to
9/// define a particular _type_ of network as its composition has little impact on the actual
10/// requirements / algorithms used to train or evaluate the model (that is, outside of the
11/// obvious need to account for additional hidden layers in deep configurations). In other
12/// words, both shallow and deep networks are requried to implement the same traits and
13/// fulfill the same requirements, so it makes sense to treat them as a single type with
14/// different configurations. The differences between the networks are largely left to the
15/// developer and their choice of activation functions, optimizers, and other considerations.
16#[derive(
17 Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, strum::EnumCount, strum::EnumIs,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
20pub enum ModelFormat {
21 Shallow { hidden: usize },
22 Deep { hidden: usize, layers: usize },
23}
24
25impl ModelFormat {
26 /// initialize a new [`Deep`](ModelFormat::Deep) variant for a deep neural network with the
27 /// given number of hidden features and layers
28 pub const fn deep(hidden: usize, layers: usize) -> Self {
29 ModelFormat::Deep { hidden, layers }
30 }
31 /// create a new instance of [`ModelFormat`] for a shallow neural network, using the given
32 /// number of hidden features
33 pub const fn shallow(hidden: usize) -> Self {
34 ModelFormat::Shallow { hidden }
35 }
36 /// returns a copy of the number of hidden features
37 pub const fn hidden(&self) -> usize {
38 match self {
39 ModelFormat::Shallow { hidden } => *hidden,
40 ModelFormat::Deep { hidden, .. } => *hidden,
41 }
42 }
43 /// returns a mutable reference to the hidden features for the model
44 pub const fn hidden_mut(&mut self) -> &mut usize {
45 match self {
46 ModelFormat::Shallow { hidden } => hidden,
47 ModelFormat::Deep { hidden, .. } => hidden,
48 }
49 }
50 /// returns a copy of the number of layers for the model; if the variant is
51 /// [`Shallow`](ModelFormat::Shallow), it returns 1
52 /// returns `n` if the variant is [`Deep`](ModelFormat::Deep)
53 pub const fn layers(&self) -> usize {
54 match self {
55 ModelFormat::Shallow { .. } => 1,
56 ModelFormat::Deep { layers, .. } => *layers,
57 }
58 }
59 /// returns a mutable reference to the number of layers for the model; this will panic on
60 /// [`Shallow`](ModelFormat::Shallow) variants
61 pub const fn layers_mut(&mut self) -> &mut usize {
62 match self {
63 ModelFormat::Shallow { .. } => panic!("Cannot mutate layers of a shallow model"),
64 ModelFormat::Deep { layers, .. } => layers,
65 }
66 }
67 /// update the number of hidden features for the model
68 pub fn set_hidden(&mut self, value: usize) -> &mut Self {
69 match self {
70 ModelFormat::Shallow { hidden } => {
71 *hidden = value;
72 }
73 ModelFormat::Deep { hidden, .. } => {
74 *hidden = value;
75 }
76 }
77 self
78 }
79 /// update the number of layers for the model;
80 ///
81 /// **note:** this method will automatically convert the model to a [`Deep`](ModelFormat::Deep)
82 /// variant if it is currently a [`Shallow`](ModelFormat::Shallow) variant and the number
83 /// of layers becomes greater than 1
84 pub fn set_layers(&mut self, value: usize) -> &mut Self {
85 match self {
86 ModelFormat::Shallow { hidden } => {
87 if value > 1 {
88 *self = ModelFormat::Deep {
89 hidden: *hidden,
90 layers: value,
91 };
92 }
93 // if the value is 1, we do not change the model format
94 }
95 ModelFormat::Deep { layers, .. } => {
96 *layers = value;
97 }
98 }
99 self
100 }
101 /// consumes the current instance and returns a new instance with the given hidden
102 /// features
103 pub fn with_hidden(self, hidden: usize) -> Self {
104 match self {
105 ModelFormat::Shallow { .. } => ModelFormat::Shallow { hidden },
106 ModelFormat::Deep { layers, .. } => ModelFormat::Deep { hidden, layers },
107 }
108 }
109 /// consumes the current instance and returns a new instance with the given number of
110 /// hidden layers
111 ///
112 /// **note:** this method will automatically convert the model to a [`Deep`](ModelFormat::Deep)
113 /// variant if it is currently a [`Shallow`](ModelFormat::Shallow) variant and the number
114 /// of layers becomes greater than 1
115 pub fn with_layers(self, layers: usize) -> Self {
116 match self {
117 ModelFormat::Shallow { hidden } => {
118 if layers > 1 {
119 ModelFormat::Deep { hidden, layers }
120 } else {
121 ModelFormat::Shallow { hidden }
122 }
123 }
124 ModelFormat::Deep { hidden, .. } => ModelFormat::Deep { hidden, layers },
125 }
126 }
127}
128
129impl Default for ModelFormat {
130 fn default() -> Self {
131 Self::Deep {
132 hidden: 16,
133 layers: 1,
134 }
135 }
136}
137
138impl core::fmt::Display for ModelFormat {
139 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
140 write!(
141 f,
142 "{{ hidden: {}, layers: {} }}",
143 self.hidden(),
144 self.layers()
145 )
146 }
147}