use crate::models::{ModelParamsBase, ShallowParamsBase};
use crate::ModelFeatures;
use crate::activate::{ReLUActivation, SigmoidActivation};
use crate::models::traits::ShallowModelRepr;
use concision_params::ParamsBase;
use ndarray::{
Array1, ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, RemoveAxis, ScalarOperand,
};
use num_traits::Float;
impl<S, D, H, A> ModelParamsBase<S, D, H, A>
where
D: Dimension,
S: RawData<Elem = A>,
H: ShallowModelRepr<S, D>,
{
pub const fn shallow(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
Self {
input,
hidden,
output,
}
}
}
impl<S, D, A> ShallowParamsBase<S, D, A>
where
S: RawData<Elem = A>,
D: Dimension,
{
#[allow(clippy::should_implement_trait)]
pub fn default(input: D, hidden: D, output: D) -> Self
where
A: Clone + Default,
S: DataOwned,
D: RemoveAxis,
{
Self {
hidden: ParamsBase::default(hidden),
input: ParamsBase::default(input),
output: ParamsBase::default(output),
}
}
#[inline]
pub fn size(&self) -> usize {
let mut size = self.input().count_weights();
size += self.hidden().count_weights();
size + self.output().count_weights()
}
pub const fn hidden_weights(&self) -> &ArrayBase<S, D, A> {
self.hidden().weights()
}
pub const fn hidden_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
self.hidden_mut().weights_mut()
}
}
impl<S, A> ShallowParamsBase<S, Ix2, A>
where
S: RawData<Elem = A>,
{
pub fn from_features(features: ModelFeatures) -> Self
where
A: Clone + Default,
S: DataOwned,
{
Self {
hidden: ParamsBase::default(features.dim_hidden()),
input: ParamsBase::default(features.dim_input()),
output: ParamsBase::default(features.dim_output()),
}
}
pub fn forward(&self, input: &Array1<A>) -> Array1<A>
where
A: Float + ScalarOperand,
S: Data,
{
use concision_traits::Forward;
let mut output = self.input().forward_then(input, |x| x.relu());
output = self.hidden().forward_then(&output, |x| x.relu());
self.output().forward_then(&output, |x| x.sigmoid())
}
}
impl<A, S> Default for ShallowParamsBase<S, Ix2, A>
where
S: DataOwned<Elem = A>,
A: Clone + Default,
{
fn default() -> Self {
Self::from_features(ModelFeatures::default())
}
}