use crate::params_base::ParamsBase;
use ndarray::{ArrayBase, DataOwned, Dimension, Ix1, Ix2, RawData};
impl<A, S, D> ParamsBase<S, D, A>
where
D: Dimension,
S: RawData<Elem = A>,
{
}
impl<A, S> ParamsBase<S, Ix1>
where
S: RawData<Elem = A>,
{
pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
where
A: Clone,
S: DataOwned,
{
Self {
bias: ArrayBase::from_elem((), bias),
weights,
}
}
pub fn nrows(&self) -> usize {
self.weights().len()
}
}
impl<A, S> ParamsBase<S, Ix2>
where
S: RawData<Elem = A>,
{
pub fn ncols(&self) -> usize {
self.weights().ncols()
}
pub fn nrows(&self) -> usize {
self.weights().nrows()
}
}