use ndarray::{ArrayD, ArrayViewD, IxDyn};
use serde::{Deserialize, Serialize};
use crate::{
core::{NNMode, NNResult},
layers::{Layer, TrainLayer},
utils::{ActivationFunction, MSGPackFormatting, Optimizer},
};
use mininn_derive::Layer;
#[derive(Layer, Serialize, Deserialize, Clone, Debug)]
pub struct Activation {
input: ArrayD<f32>,
activation: Box<dyn ActivationFunction>,
}
impl Activation {
#[inline]
pub fn new(activation: impl ActivationFunction + 'static) -> Self {
Self {
input: ArrayD::zeros(IxDyn(&[0])),
activation: Box::new(activation),
}
}
#[inline]
pub fn activation(&self) -> &str {
self.activation.name()
}
#[inline]
pub fn set_activation(&mut self, activation: impl ActivationFunction + 'static) {
self.activation = Box::new(activation);
}
}
impl TrainLayer for Activation {
fn forward(&mut self, input: ArrayViewD<f32>, _mode: &NNMode) -> NNResult<ArrayD<f32>> {
self.input = input.to_owned();
Ok(self.activation.function(&self.input.view()))
}
#[inline]
fn backward(
&mut self,
output_gradient: ArrayViewD<f32>,
_learning_rate: f32,
_optimizer: &Optimizer,
_mode: &NNMode,
) -> NNResult<ArrayD<f32>> {
Ok(output_gradient.to_owned() * self.activation.derivate(&self.input.view()))
}
}
#[cfg(test)]
mod tests {
use mininn_derive::ActivationFunction;
use ndarray::ArrayViewD;
use crate::utils::{Act, ActCore, NNUtil};
use super::*;
#[test]
fn test_activation_creation() {
let activation = Activation::new(Act::Tanh);
assert_eq!(activation.activation(), "Tanh");
}
#[test]
fn test_forward_pass() {
let mut activation = Activation::new(Act::ReLU);
let input = vec![0.5, -0.3, 0.8];
let input = ArrayD::from_shape_vec(IxDyn(&[input.len()]), input).unwrap();
let output = activation.forward(input.view(), &NNMode::Test).unwrap();
let expected_output = vec![0.5, 0.0, 0.8];
let expected_output =
ArrayD::from_shape_vec(IxDyn(&[expected_output.len()]), expected_output).unwrap();
assert_eq!(output, expected_output);
}
#[test]
fn test_backward_pass() {
let mut activation = Activation::new(Act::ReLU);
let input = vec![0.5, -0.3, 0.8];
let input = ArrayD::from_shape_vec(IxDyn(&[input.len()]), input).unwrap();
activation.forward(input.view(), &NNMode::Test).unwrap();
let output_gradient = vec![1.0, 1.0, 1.0];
let output_gradient =
ArrayD::from_shape_vec(IxDyn(&[output_gradient.len()]), output_gradient).unwrap();
let result = activation
.backward(output_gradient.view(), 0.1, &Optimizer::GD, &NNMode::Test)
.unwrap();
let expected_result = vec![1.0, 0.0, 1.0];
let expected_result =
ArrayD::from_shape_vec(IxDyn(&[expected_result.len()]), expected_result).unwrap();
assert_eq!(result, expected_result);
}
#[test]
fn test_activation_msg_pack() {
let activation = Activation::new(Act::ReLU);
let bytes = activation.to_msgpack().unwrap();
assert!(!bytes.is_empty());
let deserialized: Box<dyn Layer> = Activation::from_msgpack(&bytes).unwrap();
assert_eq!(activation.layer_type(), deserialized.layer_type());
}
#[test]
fn test_activation_layer_custom_activation() {
#[derive(ActivationFunction, Debug, Clone)]
struct CustomActivation;
impl ActCore for CustomActivation {
fn function(&self, z: &ArrayViewD<f32>) -> ArrayD<f32> {
z.mapv(|x| x.powi(2))
}
fn derivate(&self, z: &ArrayViewD<f32>) -> ArrayD<f32> {
z.mapv(|x| 2. * x)
}
}
let activation = Activation::new(CustomActivation);
assert_eq!(activation.activation(), "CustomActivation");
}
}