use super::{
softmax, Backward, Closure, Layer, Tensor, TensorFloat, WithGrad,
};
pub struct Softmax<const RANK: usize, const SIZE: usize> {
pub shape: [usize; RANK],
pub data: [TensorFloat; SIZE],
pub temp: f32,
}
impl<const RANK: usize, const SIZE: usize>
Softmax<RANK, SIZE>
{
#[must_use]
pub const fn build(self) -> SoftmaxLayer<RANK, SIZE> {
SoftmaxLayer {
temp: self.temp,
}
}
}
pub struct SoftmaxLayer<const RANK: usize, const SIZE: usize> {
temp: f32,
}
impl<const RANK: usize, const SIZE: usize> SoftmaxLayer<RANK, SIZE> {
#[must_use]
pub const fn get_temp_relative(&self, _: f32) -> f32 {
self.temp
}
pub fn update_temp<F>(&mut self, f: F)
where
F: Fn(f32) -> f32,
{
self.temp = f(self.temp);
}
}
#[cfg(feature = "dyntensor")]
impl<const RANK: usize, const SIZE: usize> SoftmaxLayer<RANK, SIZE> {
#[inline]
#[must_use]
pub fn forward<'a>(
&'a self,
input: &'a WithGrad<Tensor<RANK, 0>>,
) -> (Tensor<RANK, 0>, Backward<'a, RANK, 0, 0, 0>) {
let (out, f) = softmax(input);
(out, Backward::Unary(f))
}
#[inline]
#[must_use]
#[allow(clippy::unnecessary_wraps)]
pub fn backward(
&self,
grad_output: Tensor<RANK, 0>,
back: Backward<'_, RANK, 0, 0, 0>,
) -> (Tensor<RANK, 0>, Option<Tensor<RANK, 0>>) {
match back {
Backward::Unary(f) => {
let grad_in = f.invoke(grad_output);
(grad_in, None)
}
Backward::Binary(_) => {
unreachable!("Softmax never has a binary closure");
}
}
}
}
#[cfg(not(feature = "dyntensor"))]
impl<const RANK: usize, const SIZE: usize> SoftmaxLayer<RANK, SIZE> {
#[inline]
#[must_use]
pub fn forward<'a, const OUT_SIZE: usize>(
&'a self,
input: &'a WithGrad<Tensor<RANK, SIZE>>,
) -> (Tensor<RANK, OUT_SIZE>, Backward<'a, RANK, OUT_SIZE, OUT_SIZE, SIZE>) {
let (out, f) = softmax(input);
(out, Backward::Unary(f))
}
#[inline]
#[must_use]
#[allow(clippy::unnecessary_wraps)]
pub fn backward<const IN_SIZE: usize, const OUT_SIZE: usize>(
&self,
grad_output: Tensor<RANK, IN_SIZE>,
back: Backward<'_, RANK, SIZE, IN_SIZE, OUT_SIZE>,
) -> (Tensor<RANK, OUT_SIZE>, Option<Tensor<RANK, OUT_SIZE>>) {
match back {
Backward::Unary(f) => {
let grad_in = f.invoke(grad_output);
(grad_in, None)
}
Backward::Binary(_) => {
unreachable!("Softmax never has a binary closure");
}
}
}
}
impl<const RANK: usize, const IN_SIZE: usize> Layer<RANK, IN_SIZE, 0>
for SoftmaxLayer<RANK, IN_SIZE>
{
#[inline]
fn weights(&self) -> [&WithGrad<Tensor<RANK, IN_SIZE>>; 0] {
[]
}
#[inline]
fn weights_mut(&mut self) -> [&mut WithGrad<Tensor<RANK, IN_SIZE>>; 0] {
[]
}
}