concision_core/models/
layout.rs

1/*
2    appellation: layout <module>
3    authors: @FL03
4*/
5use super::{Deep, NetworkDepth, RawModelLayout};
6
7mod impl_model_features;
8mod impl_model_format;
9mod impl_model_layout;
10
11/// A trait that consumes the caller to create a new instance of [`ModelFeatures`] object.
12pub trait IntoModelFeatures {
13    fn into_model_features(self) -> ModelFeatures;
14}
15
16/// The [`ModelFormat`] type enumerates the various formats a neural network may take, either
17/// shallow or deep, providing a unified interface for accessing the number of hidden features
18/// and layers in the model. This is primarily used to generalize the allowed formats of a
19/// neural network without introducing any additional complexity with typing or other
20/// constructs.
21#[derive(
22    Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, strum::EnumCount, strum::EnumIs,
23)]
24#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
25pub enum ModelFormat {
26    Layer,
27    Shallow { hidden: usize },
28    Deep { hidden: usize, layers: usize },
29}
30
31/// The [`ModelFeatures`] provides a common way of defining the layout of a model. This is
32/// used to define the number of input features, the number of hidden layers, the number of
33/// hidden features, and the number of output features.
34#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
35#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
36pub struct ModelFeatures {
37    /// the number of input features
38    pub(crate) input: usize,
39    /// the features of the "inner" layers
40    pub(crate) inner: ModelFormat,
41    /// the number of output features
42    pub(crate) output: usize,
43}
44
45/// In contrast to the [`ModelFeatures`] type, the [`ModelLayout`] implementation aims to
46/// provide a generic foundation for using type-based features / layouts within neural network.
47/// Our goal with this struct is to eventually push the implementation to the point of being
48/// able to sufficiently describe everything about a model's layout (similar to what the
49/// [`ndarray`] developers have attained with the [`LayoutRef`](ndarray::LayoutRef)).
50#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
51#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
52pub struct ModelLayout<F, D = Deep>
53where
54    D: NetworkDepth,
55    F: RawModelLayout,
56{
57    pub(crate) features: F,
58    pub(crate) _marker: core::marker::PhantomData<D>,
59}
60
61/*
62 ************* Implementations *************
63*/
64
65impl IntoModelFeatures for (usize, usize, usize) {
66    fn into_model_features(self) -> ModelFeatures {
67        ModelFeatures {
68            input: self.0,
69            inner: ModelFormat::Shallow { hidden: self.1 },
70            output: self.2,
71        }
72    }
73}
74
75impl IntoModelFeatures for (usize, usize, usize, usize) {
76    fn into_model_features(self) -> ModelFeatures {
77        ModelFeatures {
78            input: self.0,
79            inner: ModelFormat::Deep {
80                hidden: self.1,
81                layers: self.3,
82            },
83            output: self.2,
84        }
85    }
86}
87
88impl IntoModelFeatures for [usize; 3] {
89    fn into_model_features(self) -> ModelFeatures {
90        ModelFeatures {
91            input: self[0],
92            inner: ModelFormat::Shallow { hidden: self[1] },
93            output: self[2],
94        }
95    }
96}
97
98impl IntoModelFeatures for [usize; 4] {
99    fn into_model_features(self) -> ModelFeatures {
100        ModelFeatures {
101            input: self[0],
102            inner: ModelFormat::Deep {
103                hidden: self[1],
104                layers: self[3],
105            },
106            output: self[2],
107        }
108    }
109}