pub trait RawModelLayout {
fn input(&self) -> usize;
fn hidden(&self) -> usize;
fn output(&self) -> usize;
fn depth(&self) -> usize;
fn dim_input(&self) -> (usize, usize) {
(self.input(), self.hidden())
}
fn dim_hidden(&self) -> (usize, usize) {
(self.hidden(), self.hidden())
}
fn dim_output(&self) -> (usize, usize) {
(self.hidden(), self.output())
}
fn size(&self) -> usize {
self.size_input() + self.size_hidden() + self.size_output()
}
fn size_input(&self) -> usize {
self.input() * self.hidden()
}
fn size_hidden(&self) -> usize {
self.hidden() * self.hidden() * self.depth()
}
fn size_output(&self) -> usize {
self.hidden() * self.output()
}
}
pub trait RawModelLayoutMut: RawModelLayout {
fn input_mut(&mut self) -> &mut usize;
fn hidden_mut(&mut self) -> &mut usize;
fn layers_mut(&mut self) -> &mut usize;
fn output_mut(&mut self) -> &mut usize;
#[inline]
fn set_input(&mut self, input: usize) -> &mut Self {
*self.input_mut() = input;
self
}
#[inline]
fn set_hidden(&mut self, hidden: usize) -> &mut Self {
*self.hidden_mut() = hidden;
self
}
#[inline]
fn set_layers(&mut self, layers: usize) -> &mut Self {
*self.layers_mut() = layers;
self
}
#[inline]
fn set_output(&mut self, output: usize) -> &mut Self {
*self.output_mut() = output;
self
}
}
pub trait LayoutExt: RawModelLayout + RawModelLayoutMut + Clone + core::fmt::Debug {}
pub trait NetworkDepth {
private!();
fn is_deep(&self) -> bool {
false
}
}
macro_rules! impl_network_depth {
( #[$tgt:ident] $vis:vis $s:ident {$($name:ident $({$($rest:tt)*})?),* $(,)?}) => {
$(
impl_network_depth!(@impl #[$tgt] $vis $s $name $({$($rest)*})?);
)*
};
(@impl #[$tgt:ident] $vis:vis enum $name:ident $({$($rest:tt)*})?) => {
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
$vis enum $name {}
impl $tgt for $name {
seal!();
$($($rest)*)?
}
};
}
impl_network_depth! {
#[NetworkDepth]
pub enum {
Deep {
fn is_deep(&self) -> bool {
true
}
},
Shallow,
}
}
impl<T> RawModelLayout for &T
where
T: RawModelLayout,
{
fn input(&self) -> usize {
<T as RawModelLayout>::input(self)
}
fn hidden(&self) -> usize {
<T as RawModelLayout>::hidden(self)
}
fn depth(&self) -> usize {
<T as RawModelLayout>::depth(self)
}
fn output(&self) -> usize {
<T as RawModelLayout>::output(self)
}
}
impl<T> RawModelLayout for &mut T
where
T: RawModelLayout,
{
fn input(&self) -> usize {
<T as RawModelLayout>::input(self)
}
fn hidden(&self) -> usize {
<T as RawModelLayout>::hidden(self)
}
fn depth(&self) -> usize {
<T as RawModelLayout>::depth(self)
}
fn output(&self) -> usize {
<T as RawModelLayout>::output(self)
}
}
impl<T> LayoutExt for T where T: RawModelLayoutMut + Copy + core::fmt::Debug {}
impl RawModelLayout for (usize, usize, usize) {
fn input(&self) -> usize {
self.0
}
fn hidden(&self) -> usize {
self.1
}
fn depth(&self) -> usize {
1
}
fn output(&self) -> usize {
self.2
}
}
impl RawModelLayout for (usize, usize, usize, usize) {
fn input(&self) -> usize {
self.0
}
fn hidden(&self) -> usize {
self.1
}
fn output(&self) -> usize {
self.2
}
fn depth(&self) -> usize {
self.3
}
}
impl RawModelLayoutMut for (usize, usize, usize, usize) {
fn input_mut(&mut self) -> &mut usize {
&mut self.0
}
fn hidden_mut(&mut self) -> &mut usize {
&mut self.1
}
fn layers_mut(&mut self) -> &mut usize {
&mut self.2
}
fn output_mut(&mut self) -> &mut usize {
&mut self.3
}
}
impl RawModelLayout for [usize; 3] {
fn input(&self) -> usize {
self[0]
}
fn hidden(&self) -> usize {
self[1]
}
fn output(&self) -> usize {
self[2]
}
fn depth(&self) -> usize {
1
}
}
impl RawModelLayout for [usize; 4] {
fn input(&self) -> usize {
self[0]
}
fn hidden(&self) -> usize {
self[1]
}
fn output(&self) -> usize {
self[2]
}
fn depth(&self) -> usize {
self[3]
}
}
impl RawModelLayoutMut for [usize; 4] {
fn input_mut(&mut self) -> &mut usize {
&mut self[0]
}
fn hidden_mut(&mut self) -> &mut usize {
&mut self[1]
}
fn layers_mut(&mut self) -> &mut usize {
&mut self[2]
}
fn output_mut(&mut self) -> &mut usize {
&mut self[3]
}
}