use super::*;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, BinRead)]
pub struct Layer<
W: BinRead<Args = ()> + Debug,
B: BinRead<Args = ()>,
const NUM_INPUTS: usize,
const NUM_OUTPUTS: usize,
> {
biases: Box<MathVec<B, NUM_OUTPUTS>>,
#[br(count = NUM_OUTPUTS, map = |v: Vec<MathVec<W, NUM_INPUTS>>| SerdeWrapper::from_boxed_value(v.try_into().unwrap()))]
weights_transpose: Box<SerdeWrapper<[MathVec<W, NUM_INPUTS>; NUM_OUTPUTS]>>,
}
impl<
W: BinRead<Args = ()> + Debug,
B: BinRead<Args = ()>,
const NUM_INPUTS: usize,
const NUM_OUTPUTS: usize,
> Layer<W, B, NUM_INPUTS, NUM_OUTPUTS>
{
#[inline]
pub fn get_weights_transpose(&self) -> &[MathVec<W, NUM_INPUTS>; NUM_OUTPUTS] {
&self.weights_transpose
}
#[inline]
pub fn get_biases(&self) -> &MathVec<B, NUM_OUTPUTS> {
&self.biases
}
}
impl<
W: BinRead<Args = ()> + Clone + Debug,
B: BinRead<Args = ()> + Clone + AddAssign + From<W> + Mul + Sum<<B as Mul>::Output>,
const NUM_INPUTS: usize,
const NUM_OUTPUTS: usize,
> Layer<W, B, NUM_INPUTS, NUM_OUTPUTS>
{
pub fn forward(&self, inputs: MathVec<W, NUM_INPUTS>) -> MathVec<B, NUM_OUTPUTS> {
let mut outputs = self.get_biases().clone();
outputs
.iter_mut()
.zip(self.weights_transpose.iter())
.for_each(|(o, w)| *o += inputs.dot(w));
outputs
}
}