concision_neural/layout/
format.rs

1/*
2    appellation: format <module>
3    authors: @FL03
4*/
5
6/// The [`ModelFormat`] type enumerates the various formats a neural network may take, either
7/// shallow or deep, providing a unified interface for accessing the number of hidden features
8/// and layers in the model. This is done largely for simplicity, as it eliminates the need to
9/// define a particular _type_ of network as its composition has little impact on the actual
10/// requirements / algorithms used to train or evaluate the model (that is, outside of the
11/// obvious need to account for additional hidden layers in deep configurations). In other
12/// words, both shallow and deep networks are requried to implement the same traits and
13/// fulfill the same requirements, so it makes sense to treat them as a single type with
14/// different configurations. The differences between the networks are largely left to the
15/// developer and their choice of activation functions, optimizers, and other considerations.
16#[derive(
17    Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, strum::EnumCount, strum::EnumIs,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
20pub enum ModelFormat {
21    Shallow { hidden: usize },
22    Deep { hidden: usize, layers: usize },
23}
24
25impl ModelFormat {
26    /// initialize a new [`Deep`](ModelFormat::Deep) variant for a deep neural network with the
27    /// given number of hidden features and layers
28    pub const fn deep(hidden: usize, layers: usize) -> Self {
29        ModelFormat::Deep { hidden, layers }
30    }
31    /// create a new instance of [`ModelFormat`] for a shallow neural network, using the given
32    /// number of hidden features
33    pub const fn shallow(hidden: usize) -> Self {
34        ModelFormat::Shallow { hidden }
35    }
36    /// returns a copy of the number of hidden features
37    pub const fn hidden(&self) -> usize {
38        match self {
39            ModelFormat::Shallow { hidden } => *hidden,
40            ModelFormat::Deep { hidden, .. } => *hidden,
41        }
42    }
43    /// returns a mutable reference to the hidden features for the model
44    pub const fn hidden_mut(&mut self) -> &mut usize {
45        match self {
46            ModelFormat::Shallow { hidden } => hidden,
47            ModelFormat::Deep { hidden, .. } => hidden,
48        }
49    }
50    /// returns a copy of the number of layers for the model; if the variant is
51    /// [`Shallow`](ModelFormat::Shallow), it returns 1
52    /// returns `n` if the variant is [`Deep`](ModelFormat::Deep)
53    pub const fn layers(&self) -> usize {
54        match self {
55            ModelFormat::Shallow { .. } => 1,
56            ModelFormat::Deep { layers, .. } => *layers,
57        }
58    }
59    /// returns a mutable reference to the number of layers for the model; this will panic on
60    /// [`Shallow`](ModelFormat::Shallow) variants
61    pub const fn layers_mut(&mut self) -> &mut usize {
62        match self {
63            ModelFormat::Shallow { .. } => panic!("Cannot mutate layers of a shallow model"),
64            ModelFormat::Deep { layers, .. } => layers,
65        }
66    }
67    /// update the number of hidden features for the model
68    pub fn set_hidden(&mut self, value: usize) -> &mut Self {
69        match self {
70            ModelFormat::Shallow { hidden } => {
71                *hidden = value;
72            }
73            ModelFormat::Deep { hidden, .. } => {
74                *hidden = value;
75            }
76        }
77        self
78    }
79    /// update the number of layers for the model;
80    ///
81    /// **note:** this method will automatically convert the model to a [`Deep`](ModelFormat::Deep)
82    /// variant if it is currently a [`Shallow`](ModelFormat::Shallow) variant and the number
83    /// of layers becomes greater than 1
84    pub fn set_layers(&mut self, value: usize) -> &mut Self {
85        match self {
86            ModelFormat::Shallow { hidden } => {
87                if value > 1 {
88                    *self = ModelFormat::Deep {
89                        hidden: *hidden,
90                        layers: value,
91                    };
92                }
93                // if the value is 1, we do not change the model format
94            }
95            ModelFormat::Deep { layers, .. } => {
96                *layers = value;
97            }
98        }
99        self
100    }
101    /// consumes the current instance and returns a new instance with the given hidden
102    /// features
103    pub fn with_hidden(self, hidden: usize) -> Self {
104        match self {
105            ModelFormat::Shallow { .. } => ModelFormat::Shallow { hidden },
106            ModelFormat::Deep { layers, .. } => ModelFormat::Deep { hidden, layers },
107        }
108    }
109    /// consumes the current instance and returns a new instance with the given number of
110    /// hidden layers
111    ///
112    /// **note:** this method will automatically convert the model to a [`Deep`](ModelFormat::Deep)
113    /// variant if it is currently a [`Shallow`](ModelFormat::Shallow) variant and the number
114    /// of layers becomes greater than 1
115    pub fn with_layers(self, layers: usize) -> Self {
116        match self {
117            ModelFormat::Shallow { hidden } => {
118                if layers > 1 {
119                    ModelFormat::Deep { hidden, layers }
120                } else {
121                    ModelFormat::Shallow { hidden }
122                }
123            }
124            ModelFormat::Deep { hidden, .. } => ModelFormat::Deep { hidden, layers },
125        }
126    }
127}
128
129impl Default for ModelFormat {
130    fn default() -> Self {
131        Self::Deep {
132            hidden: 16,
133            layers: 1,
134        }
135    }
136}
137
138impl core::fmt::Display for ModelFormat {
139    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
140        write!(
141            f,
142            "{{ hidden: {}, layers: {} }}",
143            self.hidden(),
144            self.layers()
145        )
146    }
147}