use super::{Deep, NetworkDepth, RawModelLayout};
mod impl_model_features;
mod impl_model_format;
mod impl_model_layout;
pub trait IntoModelFeatures {
fn into_model_features(self) -> ModelFeatures;
}
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, strum::EnumCount, strum::EnumIs,
)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum ModelFormat {
Layer,
Shallow { hidden: usize },
Deep { hidden: usize, layers: usize },
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ModelFeatures {
pub(crate) input: usize,
pub(crate) inner: ModelFormat,
pub(crate) output: usize,
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ModelLayout<F, D = Deep>
where
D: NetworkDepth,
F: RawModelLayout,
{
pub(crate) features: F,
pub(crate) _marker: core::marker::PhantomData<D>,
}
impl IntoModelFeatures for (usize, usize, usize) {
fn into_model_features(self) -> ModelFeatures {
ModelFeatures {
input: self.0,
inner: ModelFormat::Shallow { hidden: self.1 },
output: self.2,
}
}
}
impl IntoModelFeatures for (usize, usize, usize, usize) {
fn into_model_features(self) -> ModelFeatures {
ModelFeatures {
input: self.0,
inner: ModelFormat::Deep {
hidden: self.1,
layers: self.3,
},
output: self.2,
}
}
}
impl IntoModelFeatures for [usize; 3] {
fn into_model_features(self) -> ModelFeatures {
ModelFeatures {
input: self[0],
inner: ModelFormat::Shallow { hidden: self[1] },
output: self[2],
}
}
}
impl IntoModelFeatures for [usize; 4] {
fn into_model_features(self) -> ModelFeatures {
ModelFeatures {
input: self[0],
inner: ModelFormat::Deep {
hidden: self[1],
layers: self[3],
},
output: self[2],
}
}
}