use crate::activate::{ReLUActivation, SigmoidActivation};
use crate::config::StandardModelConfig;
use crate::error::Error;
use crate::models::{DeepModelParams, ModelFeatures};
use crate::nn::Model;
#[cfg(feature = "rand")]
use concision_init::{
NdRandom,
rand_distr::{Distribution, StandardNormal},
};
use concision_params::Params;
use concision_traits::{Forward, Norm, Train};
use ndarray::prelude::*;
use ndarray::{Data, ScalarOperand};
use num_traits::{Float, FromPrimitive, NumAssign, Zero};
#[cfg(not(feature = "tracing"))]
use eprintln as error;
#[cfg(feature = "tracing")]
use tracing::error;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct TestModel<T = f64> {
pub config: StandardModelConfig<T>,
pub features: ModelFeatures,
pub store: DeepModelParams<T>,
}
impl<T> TestModel<T> {
pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
where
T: Clone + Zero,
{
let store = DeepModelParams::zeros(features);
TestModel {
config,
features,
store,
}
}
pub const fn config(&self) -> &StandardModelConfig<T> {
&self.config
}
pub const fn features(&self) -> ModelFeatures {
self.features
}
pub const fn store(&self) -> &DeepModelParams<T> {
&self.store
}
pub const fn store_mut(&mut self) -> &mut DeepModelParams<T> {
&mut self.store
}
#[cfg(not(feature = "rand"))]
pub fn init(self) -> Self {
self
}
#[cfg(feature = "rand")]
pub fn init(self) -> Self
where
StandardNormal: Distribution<T>,
T: Float,
{
let TestModel {
mut store,
config,
features,
} = self;
store.set_input(Params::<T>::lecun_normal((
features.input(),
features.hidden(),
)));
for layer in store.hidden_mut() {
*layer = Params::<T>::lecun_normal((features.hidden(), features.hidden()));
}
store.set_output(Params::<T>::lecun_normal((
features.hidden(),
features.output(),
)));
TestModel {
config,
features,
store,
}
}
}
impl<T> Model<T> for TestModel<T> {
type Config = StandardModelConfig<T>;
type Layout = ModelFeatures;
fn config(&self) -> &StandardModelConfig<T> {
&self.config
}
fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
&mut self.config
}
fn layout(&self) -> &ModelFeatures {
&self.features
}
fn params(&self) -> &DeepModelParams<T> {
&self.store
}
fn params_mut(&mut self) -> &mut DeepModelParams<T> {
&mut self.store
}
}
impl<A, S, D> Forward<ArrayBase<S, D, A>> for TestModel<A>
where
A: Float + FromPrimitive + ScalarOperand,
D: Dimension,
S: Data<Elem = A>,
Params<A>: Forward<ArrayBase<S, D, A>, Output = Array<A, D>>
+ Forward<Array<A, D>, Output = Array<A, D>>,
{
type Output = Array<A, D>;
fn forward(&self, input: &ArrayBase<S, D>) -> Self::Output {
let mut output = self.store().input().forward(input).relu();
for layer in self.store().hidden() {
output = layer.forward(&output).relu();
}
self.store().output().forward(&output).sigmoid()
}
}
impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for TestModel<A>
where
A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
S: Data<Elem = A>,
T: Data<Elem = A>,
{
type Error = Error;
type Output = A;
fn train(
&mut self,
input: &ArrayBase<S, Ix1>,
target: &ArrayBase<T, Ix1>,
) -> Result<Self::Output, Error> {
if input.len() != self.layout().input() {
return Err(Error::InvalidInputFeatures(
input.len(),
self.layout().input(),
));
}
if target.len() != self.layout().output() {
return Err(Error::InvalidTargetFeatures(
target.len(),
self.layout().output(),
));
}
let lr = self
.config()
.learning_rate()
.copied()
.unwrap_or(A::from_f32(0.01).unwrap());
let input = input / input.l2_norm();
let target_norm = target.l2_norm();
let target = target / target_norm;
let mut activations = Vec::new();
activations.push(input.to_owned());
let mut output = self.store().input().forward_then(&input, |y| y.relu());
activations.push(output.to_owned());
for layer in self.store().hidden() {
output = layer.forward(&output).relu();
activations.push(output.to_owned());
}
output = self.store().output().forward(&output).sigmoid();
activations.push(output.to_owned());
let error = &target - &output;
let loss = error.pow2().mean().unwrap_or(A::zero());
#[cfg(feature = "tracing")]
tracing::trace!("Training loss: {loss:?}");
let mut delta = error * output.sigmoid_derivative();
delta /= delta.l2_norm();
self.store_mut()
.output_mut()
.backward(activations.last().unwrap(), &delta, lr);
let num_hidden = self.layout().layers();
for i in (0..num_hidden).rev() {
delta = if i == num_hidden - 1 {
self.store().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
} else {
self.store().hidden()[i + 1].weights().t().dot(&delta)
* activations[i + 1].relu_derivative()
};
delta /= delta.l2_norm();
self.store_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr);
}
delta = self.store().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
delta /= delta.l2_norm(); self.store_mut()
.input_mut()
.backward(&activations[1], &delta, lr);
Ok(loss)
}
}
impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for TestModel<A>
where
A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
S: Data<Elem = A>,
T: Data<Elem = A>,
{
type Error = Error;
type Output = A;
fn train(
&mut self,
input: &ArrayBase<S, Ix2>,
target: &ArrayBase<T, Ix2>,
) -> Result<Self::Output, Self::Error> {
if input.nrows() == 0 || target.nrows() == 0 {
return Err(anyhow::anyhow!("Input and target batches must be non-empty").into());
}
if input.ncols() != self.layout().input() {
return Err(Error::InvalidInputFeatures(
input.ncols(),
self.layout().input(),
));
}
if target.ncols() != self.layout().output() || target.nrows() != input.nrows() {
return Err(Error::InvalidTargetFeatures(
target.ncols(),
self.layout().output(),
));
}
let batch_size = input.nrows();
let mut loss = A::zero();
for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
Ok(l) => l,
Err(err) => {
error!(
"Training failed for batch {}/{}: {:?}",
i + 1,
batch_size,
err
);
return Err(err);
}
};
}
Ok(loss)
}
}