use super::ModelFeatures;
use crate::models::{ModelFormat, RawModelLayout, RawModelLayoutMut};
fn _verify_input_and_hidden_shape<D>(input: D, hidden: D) -> bool
where
D: ndarray::Dimension,
{
let lhs = input.as_array_view();
let rhs = hidden.as_array_view();
input.ndim() != hidden.ndim()
&& lhs[input.ndim() - 1] != rhs[0]
&& rhs.iter().all(|&v| v == rhs[0])
}
impl ModelFeatures {
pub fn from_shape_and_size(shape: &[usize], size: usize) -> Self {
let input = shape[0];
let output = *shape.last().unwrap();
let hidden = if shape.len() > 2 {
shape[1]
} else {
(size - (input * output)) / (input + output)
};
let layers = if shape.len() > 2 { shape.len() - 2 } else { 1 };
Self::new(input, hidden, output, layers)
}
pub const fn new(input: usize, hidden: usize, output: usize, layers: usize) -> Self {
let inner = ModelFormat::new(hidden, layers);
Self {
input,
output,
inner,
}
}
pub const fn deep(input: usize, hidden: usize, output: usize, layers: usize) -> Self {
Self {
input,
output,
inner: ModelFormat::Deep { hidden, layers },
}
}
pub const fn shallow(input: usize, hidden: usize, output: usize) -> Self {
Self {
input,
output,
inner: ModelFormat::Shallow { hidden },
}
}
pub fn from_layout<L>(layout: L) -> Self
where
L: RawModelLayout,
{
Self {
input: layout.input(),
inner: ModelFormat::new(layout.hidden(), layout.depth()),
output: layout.output(),
}
}
pub const fn input(&self) -> usize {
self.input
}
pub const fn input_mut(&mut self) -> &mut usize {
&mut self.input
}
pub const fn inner(&self) -> ModelFormat {
self.inner
}
pub const fn inner_mut(&mut self) -> &mut ModelFormat {
&mut self.inner
}
pub const fn hidden(&self) -> usize {
self.inner().hidden()
}
pub const fn hidden_mut(&mut self) -> &mut usize {
self.inner_mut().hidden_mut()
}
pub const fn layers(&self) -> usize {
self.inner().layers()
}
pub const fn layers_mut(&mut self) -> &mut usize {
self.inner_mut().layers_mut()
}
pub const fn output(&self) -> usize {
self.output
}
pub const fn output_mut(&mut self) -> &mut usize {
&mut self.output
}
#[inline]
pub fn set_input(&mut self, input: usize) -> &mut Self {
self.input = input;
self
}
#[inline]
pub fn set_hidden(&mut self, hidden: usize) -> &mut Self {
self.inner_mut().set_hidden(hidden);
self
}
#[inline]
pub fn set_layers(&mut self, layers: usize) -> &mut Self {
self.inner_mut().set_layers(layers);
self
}
#[inline]
pub fn set_output(&mut self, output: usize) -> &mut Self {
self.output = output;
self
}
pub fn with_input(self, input: usize) -> Self {
Self { input, ..self }
}
pub fn with_hidden(self, hidden: usize) -> Self {
Self {
inner: self.inner.with_hidden(hidden),
..self
}
}
pub fn with_layers(self, layers: usize) -> Self {
Self {
inner: self.inner.with_layers(layers),
..self
}
}
pub fn with_output(self, output: usize) -> Self {
Self { output, ..self }
}
pub fn dim_input(&self) -> (usize, usize) {
(self.input(), self.hidden())
}
pub fn dim_hidden(&self) -> (usize, usize) {
(self.hidden(), self.hidden())
}
pub fn dim_output(&self) -> (usize, usize) {
(self.hidden(), self.output())
}
pub fn size(&self) -> usize {
self.size_input() + self.size_hidden() + self.size_output()
}
pub fn size_input(&self) -> usize {
self.input() * self.hidden()
}
pub fn size_hidden(&self) -> usize {
self.hidden() * self.hidden() * self.layers()
}
pub fn size_output(&self) -> usize {
self.hidden() * self.output()
}
}
impl RawModelLayout for ModelFeatures {
fn input(&self) -> usize {
self.input()
}
fn hidden(&self) -> usize {
self.hidden()
}
fn depth(&self) -> usize {
self.layers()
}
fn output(&self) -> usize {
self.output()
}
}
impl RawModelLayoutMut for ModelFeatures {
fn input_mut(&mut self) -> &mut usize {
self.input_mut()
}
fn hidden_mut(&mut self) -> &mut usize {
self.hidden_mut()
}
fn layers_mut(&mut self) -> &mut usize {
self.layers_mut()
}
fn output_mut(&mut self) -> &mut usize {
self.output_mut()
}
}
impl Default for ModelFeatures {
fn default() -> Self {
Self {
input: 16,
inner: ModelFormat::Deep {
hidden: 16,
layers: 1,
},
output: 16,
}
}
}
impl core::fmt::Display for ModelFeatures {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&format!(
"{{ input: {}, hidden: {}, layers: {}, output: {} }}",
self.input(),
self.hidden(),
self.layers(),
self.output()
))
}
}