use crate::config::ModelConfiguration;
use crate::{DeepModelParams, LayoutExt, RawModelLayout};
use concision_params::Params;
use concision_traits::Predict;
pub trait Model<T = f32> {
type Config: ModelConfiguration<T>;
type Layout: LayoutExt;
fn config(&self) -> &Self::Config;
fn config_mut(&mut self) -> &mut Self::Config;
fn layout(&self) -> &Self::Layout;
fn params(&self) -> &DeepModelParams<T>;
fn params_mut(&mut self) -> &mut DeepModelParams<T>;
fn predict<U, V>(&self, inputs: &U) -> V
where
Self: Predict<U, Output = V>,
{
Predict::predict(self, inputs)
}
}
pub trait ModelExt<T>: Model<T> {
fn replace_config(&mut self, config: Self::Config) -> Self::Config {
core::mem::replace(self.config_mut(), config)
}
fn replace_params(&mut self, params: DeepModelParams<T>) -> DeepModelParams<T> {
core::mem::replace(self.params_mut(), params)
}
fn set_config(&mut self, config: Self::Config) -> &mut Self {
*self.config_mut() = config;
self
}
fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
*self.params_mut() = params;
self
}
#[inline]
fn input_layer(&self) -> &Params<T> {
self.params().input()
}
#[inline]
fn input_layer_mut(&mut self) -> &mut Params<T> {
self.params_mut().input_mut()
}
#[inline]
fn hidden_layers(&self) -> &Vec<Params<T>> {
self.params().hidden()
}
#[inline]
fn hidden_layers_mut(&mut self) -> &mut Vec<Params<T>> {
self.params_mut().hidden_mut()
}
#[inline]
fn output_layer(&self) -> &Params<T> {
self.params().output()
}
#[inline]
fn output_layer_mut(&mut self) -> &mut Params<T> {
self.params_mut().output_mut()
}
#[inline]
fn set_input_layer(&mut self, layer: Params<T>) -> &mut Self {
self.params_mut().set_input(layer);
self
}
#[inline]
fn set_hidden_layers(&mut self, layers: Vec<Params<T>>) -> &mut Self {
self.params_mut().set_hidden(layers);
self
}
#[inline]
fn set_output_layer(&mut self, layer: Params<T>) -> &mut Self {
self.params_mut().set_output(layer);
self
}
fn input_dim(&self) -> (usize, usize) {
self.layout().dim_input()
}
fn hidden_dim(&self) -> (usize, usize) {
self.layout().dim_hidden()
}
fn hidden_layers_count(&self) -> usize {
self.layout().depth()
}
fn output_dim(&self) -> (usize, usize) {
self.layout().dim_output()
}
}
impl<M, T> ModelExt<T> for M
where
M: Model<T>,
M::Layout: LayoutExt,
{
}