use super::{
gelu, relu, sigmoid, swish, tanh, ActivationFn, ActivationKind, Backward, Closure,
Layer, PhantomData, Tensor, TensorFloat, WithGrad,
};
pub struct Activation<const RANK: usize, const SIZE: usize, ActivationFn> {
pub shape: [usize; RANK],
pub data: [TensorFloat; SIZE],
pub _activation: PhantomData<ActivationFn>,
}
impl<const RANK: usize, const SIZE: usize, Activator: ActivationFn>
Activation<RANK, SIZE, Activator>
{
#[must_use]
pub fn build(self) -> ActivationLayer<RANK, SIZE> {
ActivationLayer {
actfn: Activator::kind(),
}
}
}
pub struct ActivationLayer<const RANK: usize, const SIZE: usize> {
actfn: ActivationKind,
}
impl<const RANK: usize, const SIZE: usize> ActivationLayer<RANK, SIZE> {
#[inline]
#[must_use]
pub fn forward<'a>(
&'a self,
input: &'a WithGrad<Tensor<RANK, SIZE>>,
) -> (Tensor<RANK, SIZE>, Backward<'a, RANK, SIZE, SIZE, SIZE>) {
match self.actfn {
ActivationKind::ReLU => {
let (out, back) = relu(input);
(out, Backward::Unary(back))
}
ActivationKind::Sigmoid => {
let (out, back) = sigmoid(input);
(out, Backward::Unary(back))
}
ActivationKind::GELU => {
let (out, back) = gelu(input);
(out, Backward::Unary(back))
}
ActivationKind::Swish => {
let (out, back) = swish(input);
(out, Backward::Unary(back))
}
ActivationKind::Tanh => {
let (out, back) = tanh(input);
(out, Backward::Unary(back))
}
}
}
#[inline]
#[must_use]
#[allow(clippy::unnecessary_wraps)]
pub fn backward<const IN_SIZE: usize, const OUT_SIZE: usize>(
&self,
grad_output: Tensor<RANK, IN_SIZE>,
back: Backward<'_, RANK, SIZE, IN_SIZE, OUT_SIZE>,
) -> (Tensor<RANK, OUT_SIZE>, Option<Tensor<RANK, OUT_SIZE>>) {
match back {
Backward::Unary(f) => {
let grad_in = f.invoke(grad_output);
(grad_in, None)
}
Backward::Binary(_) => {
unreachable!("activation never has a binary closure");
}
}
}
}
impl<const RANK: usize, const IN_SIZE: usize> Layer<RANK, IN_SIZE, 0>
for ActivationLayer<RANK, IN_SIZE>
{
#[inline]
fn weights(&self) -> [&WithGrad<Tensor<RANK, IN_SIZE>>; 0] {
[]
}
#[inline]
fn weights_mut(&mut self) -> [&mut WithGrad<Tensor<RANK, IN_SIZE>>; 0] {
[]
}
}