concision_core/models/
layout.rs1use super::{Deep, NetworkDepth, RawModelLayout};
6
7mod impl_model_features;
8mod impl_model_format;
9mod impl_model_layout;
10
11pub trait IntoModelFeatures {
13 fn into_model_features(self) -> ModelFeatures;
14}
15
16#[derive(
22 Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, strum::EnumCount, strum::EnumIs,
23)]
24#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
25pub enum ModelFormat {
26 Layer,
27 Shallow { hidden: usize },
28 Deep { hidden: usize, layers: usize },
29}
30
31#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
35#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
36pub struct ModelFeatures {
37 pub(crate) input: usize,
39 pub(crate) inner: ModelFormat,
41 pub(crate) output: usize,
43}
44
45#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
51#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
52pub struct ModelLayout<F, D = Deep>
53where
54 D: NetworkDepth,
55 F: RawModelLayout,
56{
57 pub(crate) features: F,
58 pub(crate) _marker: core::marker::PhantomData<D>,
59}
60
61impl IntoModelFeatures for (usize, usize, usize) {
66 fn into_model_features(self) -> ModelFeatures {
67 ModelFeatures {
68 input: self.0,
69 inner: ModelFormat::Shallow { hidden: self.1 },
70 output: self.2,
71 }
72 }
73}
74
75impl IntoModelFeatures for (usize, usize, usize, usize) {
76 fn into_model_features(self) -> ModelFeatures {
77 ModelFeatures {
78 input: self.0,
79 inner: ModelFormat::Deep {
80 hidden: self.1,
81 layers: self.3,
82 },
83 output: self.2,
84 }
85 }
86}
87
88impl IntoModelFeatures for [usize; 3] {
89 fn into_model_features(self) -> ModelFeatures {
90 ModelFeatures {
91 input: self[0],
92 inner: ModelFormat::Shallow { hidden: self[1] },
93 output: self[2],
94 }
95 }
96}
97
98impl IntoModelFeatures for [usize; 4] {
99 fn into_model_features(self) -> ModelFeatures {
100 ModelFeatures {
101 input: self[0],
102 inner: ModelFormat::Deep {
103 hidden: self[1],
104 layers: self[3],
105 },
106 output: self[2],
107 }
108 }
109}