use crate::{DeepParamsBase, ModelParamsBase};
use crate::ModelFeatures;
use crate::models::traits::DeepModelRepr;
use concision_params::ParamsBase;
use concision_traits::Forward;
use ndarray::{Data, DataOwned, Dimension, Ix2, RawData};
use num_traits::{One, Zero};
impl<S, D, H, A> ModelParamsBase<S, D, H, A>
where
D: Dimension,
S: RawData<Elem = A>,
H: DeepModelRepr<S, D>,
{
pub const fn deep(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
Self {
input,
hidden,
output,
}
}
}
impl<A, S, D> DeepParamsBase<S, D, A>
where
D: Dimension,
S: RawData<Elem = A>,
{
#[inline]
pub fn size(&self) -> usize {
let mut size = self.input().count_weights();
for layer in self.hidden() {
size += layer.count_weights();
}
size + self.output().count_weights()
}
#[inline]
pub fn set_hidden_layer(&mut self, idx: usize, layer: ParamsBase<S, D>) -> &mut Self {
if layer.dim() != self.dim_hidden() {
panic!(
"the dimension of the layer ({:?}) does not match the dimension of the hidden layers ({:?})",
layer.dim(),
self.dim_hidden()
);
}
self.hidden_mut()[idx] = layer;
self
}
#[inline]
pub fn dim_input(&self) -> <D as Dimension>::Pattern {
self.input().dim()
}
#[inline]
pub fn dim_hidden(&self) -> <D as Dimension>::Pattern {
assert!(
self.hidden()
.iter()
.all(|p| p.dim() == self.hidden()[0].dim())
);
self.hidden()[0].dim()
}
#[inline]
pub fn dim_output(&self) -> <D as Dimension>::Pattern {
self.output().dim()
}
#[inline]
pub fn get_hidden_layer<I>(&self, idx: I) -> Option<&I::Output>
where
I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
{
self.hidden().get(idx)
}
#[inline]
pub fn get_hidden_layer_mut<I>(&mut self, idx: I) -> Option<&mut I::Output>
where
I: core::slice::SliceIndex<[ParamsBase<S, D>]>,
{
self.hidden_mut().get_mut(idx)
}
#[inline]
pub fn forward<X, Y>(&self, input: &X) -> Y
where
A: Clone,
S: Data,
ParamsBase<S, D>: Forward<X, Output = Y> + Forward<Y, Output = Y>,
{
let mut output = self.input().forward(input);
self.hidden().iter().for_each(|layer| {
output = layer.forward(&output);
});
self.output().forward(&output)
}
}
impl<A, S> DeepParamsBase<S, Ix2, A>
where
S: RawData<Elem = A>,
{
#[allow(clippy::should_implement_trait)]
pub fn default(features: ModelFeatures) -> Self
where
A: Clone + Default,
S: DataOwned,
{
let input = ParamsBase::default(features.dim_input());
let hidden = (0..features.layers())
.map(|_| ParamsBase::default(features.dim_hidden()))
.collect::<Vec<_>>();
let output = ParamsBase::default(features.dim_output());
Self::new(input, hidden, output)
}
pub fn ones(features: ModelFeatures) -> Self
where
A: Clone + One,
S: DataOwned,
{
let input = ParamsBase::ones(features.dim_input());
let hidden = (0..features.layers())
.map(|_| ParamsBase::ones(features.dim_hidden()))
.collect::<Vec<_>>();
let output = ParamsBase::ones(features.dim_output());
Self::new(input, hidden, output)
}
pub fn zeros(features: ModelFeatures) -> Self
where
A: Clone + Zero,
S: DataOwned,
{
let input = ParamsBase::zeros(features.dim_input());
let hidden = (0..features.layers())
.map(|_| ParamsBase::zeros(features.dim_hidden()))
.collect::<Vec<_>>();
let output = ParamsBase::zeros(features.dim_output());
Self::new(input, hidden, output)
}
}