use crate::activations::Activation;
use crate::error::Result;
use scirs2_core::ndarray::{Array, Zip};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy)]
pub struct Mish;
impl Mish {
pub fn new() -> Self {
Self
}
}
impl Default for Mish {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + NumAssign> Activation<F> for Mish {
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut output = input.clone();
Zip::from(&mut output).for_each(|x| {
let sp = if *x > F::from(20.0).expect("Failed to convert constant to float") {
*x
} else {
(F::one() + x.exp()).ln()
};
*x *= sp.tanh();
});
Ok(output)
}
fn backward(
&self,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut grad_input = Array::zeros(grad_output.raw_dim());
Zip::from(&mut grad_input)
.and(grad_output)
.and(input)
.for_each(|grad_in, &grad_out, &x| {
let sp = if x > F::from(20.0).expect("Failed to convert constant to float") {
x } else {
(F::one() + x.exp()).ln()
};
let tanh_sp = sp.tanh();
let sech_sp_sq = F::one() - tanh_sp * tanh_sp;
let sigmoid_x = F::one() / (F::one() + (-x).exp());
let dmish_dx = tanh_sp + x * sech_sp_sq * sigmoid_x;
*grad_in = grad_out * dmish_dx;
});
Ok(grad_input)
}
}