use crate::{simd, CnnResult, Tensor};
use super::Layer;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ActivationType {
ReLU,
ReLU6,
Swish,
HardSwish,
Sigmoid,
Identity,
}
#[derive(Clone, Debug)]
pub struct Activation {
activation_type: ActivationType,
}
impl Activation {
pub fn new(activation_type: ActivationType) -> Self {
Self { activation_type }
}
pub fn activation_type(&self) -> ActivationType {
self.activation_type
}
pub fn apply_inplace(&self, data: &mut [f32]) {
match self.activation_type {
ActivationType::ReLU => {
for x in data.iter_mut() {
*x = x.max(0.0);
}
}
ActivationType::ReLU6 => {
for x in data.iter_mut() {
*x = x.max(0.0).min(6.0);
}
}
ActivationType::Swish => {
for x in data.iter_mut() {
let sigmoid = 1.0 / (1.0 + (-*x).exp());
*x *= sigmoid;
}
}
ActivationType::HardSwish => {
for x in data.iter_mut() {
*x *= (*x + 3.0).max(0.0).min(6.0) / 6.0;
}
}
ActivationType::Sigmoid => {
for x in data.iter_mut() {
*x = 1.0 / (1.0 + (-*x).exp());
}
}
ActivationType::Identity => {}
}
}
}
impl Layer for Activation {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let mut output = input.clone();
self.apply_inplace(output.data_mut());
Ok(output)
}
fn name(&self) -> &'static str {
match self.activation_type {
ActivationType::ReLU => "ReLU",
ActivationType::ReLU6 => "ReLU6",
ActivationType::Swish => "Swish",
ActivationType::HardSwish => "HardSwish",
ActivationType::Sigmoid => "Sigmoid",
ActivationType::Identity => "Identity",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ReLU;
impl ReLU {
pub fn new() -> Self {
Self
}
}
impl Layer for ReLU {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let mut output = Tensor::zeros(input.shape());
simd::relu_simd(input.data(), output.data_mut());
Ok(output)
}
fn name(&self) -> &'static str {
"ReLU"
}
}
#[derive(Debug, Clone, Default)]
pub struct ReLU6;
impl ReLU6 {
pub fn new() -> Self {
Self
}
}
impl Layer for ReLU6 {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let mut output = Tensor::zeros(input.shape());
simd::relu6_simd(input.data(), output.data_mut());
Ok(output)
}
fn name(&self) -> &'static str {
"ReLU6"
}
}
#[derive(Debug, Clone, Default)]
pub struct Swish;
impl Swish {
pub fn new() -> Self {
Self
}
}
impl Layer for Swish {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let mut output = Tensor::zeros(input.shape());
simd::scalar::swish_scalar(input.data(), output.data_mut());
Ok(output)
}
fn name(&self) -> &'static str {
"Swish"
}
}
#[derive(Debug, Clone, Default)]
pub struct HardSwish;
impl HardSwish {
pub fn new() -> Self {
Self
}
}
impl Layer for HardSwish {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let mut output = Tensor::zeros(input.shape());
simd::scalar::hard_swish_scalar(input.data(), output.data_mut());
Ok(output)
}
fn name(&self) -> &'static str {
"HardSwish"
}
}
#[derive(Debug, Clone, Default)]
pub struct Sigmoid;
impl Sigmoid {
pub fn new() -> Self {
Self
}
}
impl Layer for Sigmoid {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let mut output = Tensor::zeros(input.shape());
simd::scalar::sigmoid_scalar(input.data(), output.data_mut());
Ok(output)
}
fn name(&self) -> &'static str {
"Sigmoid"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relu() {
let relu = ReLU::new();
let input = Tensor::from_data(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
let output = relu.forward(&input).unwrap();
assert_eq!(output.data(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_relu6() {
let relu6 = ReLU6::new();
let input = Tensor::from_data(vec![-2.0, 0.0, 3.0, 6.0, 10.0], &[5]).unwrap();
let output = relu6.forward(&input).unwrap();
assert_eq!(output.data(), &[0.0, 0.0, 3.0, 6.0, 6.0]);
}
#[test]
fn test_sigmoid() {
let sigmoid = Sigmoid::new();
let input = Tensor::from_data(vec![0.0], &[1]).unwrap();
let output = sigmoid.forward(&input).unwrap();
assert!((output.data()[0] - 0.5).abs() < 0.001);
}
#[test]
fn test_swish() {
let swish = Swish::new();
let input = Tensor::from_data(vec![0.0, 1.0, -1.0], &[3]).unwrap();
let output = swish.forward(&input).unwrap();
assert!(output.data()[0].abs() < 0.001);
assert!((output.data()[1] - 0.731).abs() < 0.01);
}
#[test]
fn test_hard_swish() {
let hs = HardSwish::new();
let input = Tensor::from_data(vec![-4.0, -3.0, 0.0, 3.0, 4.0], &[5]).unwrap();
let output = hs.forward(&input).unwrap();
assert!(output.data()[0].abs() < 0.001);
assert!(output.data()[1].abs() < 0.001);
assert!(output.data()[2].abs() < 0.001);
assert!((output.data()[3] - 3.0).abs() < 0.001);
}
}