use axonml_autograd::Variable;
use crate::module::Module;
#[derive(Debug, Clone, Copy, Default)]
pub struct ReLU;
impl ReLU {
pub fn new() -> Self {
Self
}
}
impl Module for ReLU {
fn forward(&self, input: &Variable) -> Variable {
input.relu()
}
fn name(&self) -> &'static str {
"ReLU"
}
}
#[derive(Debug, Clone, Copy)]
pub struct LeakyReLU {
negative_slope: f32,
}
impl LeakyReLU {
pub fn new() -> Self {
Self {
negative_slope: 0.01,
}
}
pub fn with_slope(negative_slope: f32) -> Self {
Self { negative_slope }
}
}
impl Default for LeakyReLU {
fn default() -> Self {
Self::new()
}
}
impl Module for LeakyReLU {
fn forward(&self, input: &Variable) -> Variable {
input.leaky_relu(self.negative_slope)
}
fn name(&self) -> &'static str {
"LeakyReLU"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Sigmoid;
impl Sigmoid {
pub fn new() -> Self {
Self
}
}
impl Module for Sigmoid {
fn forward(&self, input: &Variable) -> Variable {
input.sigmoid()
}
fn name(&self) -> &'static str {
"Sigmoid"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Tanh;
impl Tanh {
pub fn new() -> Self {
Self
}
}
impl Module for Tanh {
fn forward(&self, input: &Variable) -> Variable {
input.tanh()
}
fn name(&self) -> &'static str {
"Tanh"
}
}
#[derive(Debug, Clone, Copy)]
pub struct Softmax {
dim: i64,
}
impl Softmax {
pub fn new(dim: i64) -> Self {
Self { dim }
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(-1)
}
}
impl Module for Softmax {
fn forward(&self, input: &Variable) -> Variable {
input.softmax(self.dim as i32)
}
fn name(&self) -> &'static str {
"Softmax"
}
}
#[derive(Debug, Clone, Copy)]
pub struct LogSoftmax {
dim: i64,
}
impl LogSoftmax {
pub fn new(dim: i64) -> Self {
Self { dim }
}
}
impl Default for LogSoftmax {
fn default() -> Self {
Self::new(-1)
}
}
impl Module for LogSoftmax {
fn forward(&self, input: &Variable) -> Variable {
input.log_softmax(self.dim as i32)
}
fn name(&self) -> &'static str {
"LogSoftmax"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GELU;
impl GELU {
pub fn new() -> Self {
Self
}
}
impl Module for GELU {
fn forward(&self, input: &Variable) -> Variable {
input.gelu()
}
fn name(&self) -> &'static str {
"GELU"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SiLU;
impl SiLU {
pub fn new() -> Self {
Self
}
}
impl Module for SiLU {
fn forward(&self, input: &Variable) -> Variable {
let sigmoid = input.sigmoid();
input.mul_var(&sigmoid)
}
fn name(&self) -> &'static str {
"SiLU"
}
}
#[derive(Debug, Clone, Copy)]
pub struct ELU {
alpha: f32,
}
impl ELU {
pub fn new() -> Self {
Self { alpha: 1.0 }
}
pub fn with_alpha(alpha: f32) -> Self {
Self { alpha }
}
}
impl Default for ELU {
fn default() -> Self {
Self::new()
}
}
impl Module for ELU {
fn forward(&self, input: &Variable) -> Variable {
input.elu(self.alpha)
}
fn name(&self) -> &'static str {
"ELU"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Identity;
impl Identity {
pub fn new() -> Self {
Self
}
}
impl Module for Identity {
fn forward(&self, input: &Variable) -> Variable {
input.clone()
}
fn name(&self) -> &'static str {
"Identity"
}
}
#[derive(Debug, Clone, Copy)]
pub struct Flatten {
start_dim: usize,
}
impl Flatten {
pub fn new() -> Self {
Self { start_dim: 1 }
}
pub fn from(start_dim: usize) -> Self {
Self { start_dim }
}
}
impl Default for Flatten {
fn default() -> Self {
Self::new()
}
}
impl Module for Flatten {
fn forward(&self, input: &Variable) -> Variable {
input.flatten(self.start_dim)
}
fn parameters(&self) -> Vec<crate::Parameter> {
Vec::new()
}
fn name(&self) -> &'static str {
"Flatten"
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
#[test]
fn test_relu() {
let relu = ReLU::new();
let input = Variable::new(
Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
false,
);
let output = relu.forward(&input);
assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_sigmoid() {
let sigmoid = Sigmoid::new();
let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
let output = sigmoid.forward(&input);
assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_softmax() {
let softmax = Softmax::new(-1);
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
false,
);
let output = softmax.forward(&input);
let sum: f32 = output.data().to_vec().iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_leaky_relu() {
let leaky = LeakyReLU::with_slope(0.1);
let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
let output = leaky.forward(&input);
assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
}
#[test]
fn test_identity() {
let id = Identity::new();
let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
let output = id.forward(&input);
assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
}
}