use crate::models::ModelParamsBase;
use crate::{DeepModelRepr, RawHidden};
use concision_params::ParamsBase;
use ndarray::{ArrayBase, Data, Dimension, RawData, RawDataClone};
impl<S, D, H, A> ModelParamsBase<S, D, H, A>
where
D: Dimension,
S: RawData<Elem = A>,
H: RawHidden<S, D>,
{
pub const fn new(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
Self {
input,
hidden,
output,
}
}
pub const fn input(&self) -> &ParamsBase<S, D> {
&self.input
}
pub const fn input_mut(&mut self) -> &mut ParamsBase<S, D> {
&mut self.input
}
pub const fn hidden(&self) -> &H {
&self.hidden
}
pub const fn hidden_mut(&mut self) -> &mut H {
&mut self.hidden
}
pub const fn output(&self) -> &ParamsBase<S, D> {
&self.output
}
pub const fn output_mut(&mut self) -> &mut ParamsBase<S, D> {
&mut self.output
}
#[inline]
pub fn set_input(&mut self, input: ParamsBase<S, D>) {
*self.input_mut() = input
}
#[inline]
pub fn set_hidden(&mut self, hidden: H) {
*self.hidden_mut() = hidden
}
#[inline]
pub fn set_output(&mut self, output: ParamsBase<S, D>) {
*self.output_mut() = output
}
#[inline]
pub fn with_input(self, input: ParamsBase<S, D>) -> Self {
Self { input, ..self }
}
#[inline]
pub fn with_hidden(self, hidden: H) -> Self {
Self { hidden, ..self }
}
#[inline]
pub fn with_output(self, output: ParamsBase<S, D>) -> Self {
Self { output, ..self }
}
#[inline]
pub fn hidden_as_slice(&self) -> &[ParamsBase<S, D>]
where
H: DeepModelRepr<S, D>,
{
self.hidden().as_slice()
}
pub const fn input_bias(&self) -> &ArrayBase<S, D::Smaller, A> {
self.input().bias()
}
pub const fn input_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
self.input_mut().bias_mut()
}
pub const fn input_weights(&self) -> &ArrayBase<S, D, A> {
self.input().weights()
}
pub const fn input_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
self.input_mut().weights_mut()
}
pub const fn output_bias(&self) -> &ArrayBase<S, D::Smaller, A> {
self.output().bias()
}
pub const fn output_bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
self.output_mut().bias_mut()
}
pub const fn output_weights(&self) -> &ArrayBase<S, D, A> {
self.output().weights()
}
pub const fn output_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
self.output_mut().weights_mut()
}
pub fn layers(&self) -> usize {
2 + self.count_hidden()
}
pub fn count_hidden(&self) -> usize {
self.hidden().count()
}
#[inline]
pub fn is_shallow(&self) -> bool {
self.count_hidden() <= 1
}
#[inline]
pub fn is_deep(&self) -> bool {
self.count_hidden() > 1
}
}
impl<A, S, D, H> Clone for ModelParamsBase<S, D, H, A>
where
D: Dimension,
H: RawHidden<S, D> + Clone,
S: RawDataClone<Elem = A>,
A: Clone,
{
fn clone(&self) -> Self {
Self {
input: self.input().clone(),
hidden: self.hidden().clone(),
output: self.output().clone(),
}
}
}
impl<A, S, D, H> core::fmt::Debug for ModelParamsBase<S, D, H, A>
where
D: Dimension,
H: RawHidden<S, D> + core::fmt::Debug,
S: Data<Elem = A>,
A: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ModelParams")
.field("input", self.input())
.field("hidden", self.hidden())
.field("output", self.output())
.finish()
}
}
impl<A, S, D, H> core::fmt::Display for ModelParamsBase<S, D, H, A>
where
D: Dimension,
H: RawHidden<S, D> + core::fmt::Debug,
S: Data<Elem = A>,
A: core::fmt::Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"{{ input: {i}, hidden: {h:?}, output: {o} }}",
i = self.input(),
h = self.hidden(),
o = self.output()
)
}
}
impl<A, S, D, H> core::ops::Index<usize> for ModelParamsBase<S, D, H, A>
where
D: Dimension,
S: Data<Elem = A>,
H: RawHidden<S, D> + core::ops::Index<usize, Output = ParamsBase<S, D>>,
A: Clone,
{
type Output = ParamsBase<S, D>;
fn index(&self, index: usize) -> &Self::Output {
match index % self.layers() {
0 => self.input(),
i if i == self.count_hidden() + 1 => self.output(),
_ => &self.hidden()[index - 1],
}
}
}
impl<A, S, D, H> core::ops::IndexMut<usize> for ModelParamsBase<S, D, H, A>
where
D: Dimension,
S: Data<Elem = A>,
H: RawHidden<S, D> + core::ops::IndexMut<usize, Output = ParamsBase<S, D>>,
A: Clone,
{
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
match index % self.layers() {
0 => self.input_mut(),
i if i == self.count_hidden() + 1 => self.output_mut(),
_ => &mut self.hidden_mut()[index - 1],
}
}
}