use crate::autograd::Variable;
use crate::tensor::Result;
use super::{Module, Parameter};
pub struct Identity;
impl Default for Identity {
fn default() -> Self {
Identity
}
}
impl Identity {
pub fn new() -> Self {
Self
}
}
impl Module for Identity {
fn name(&self) -> &str { "identity" }
fn forward(&self, input: &Variable) -> Result<Variable> {
Ok(input.clone())
}
}
pub struct ReLU;
impl Default for ReLU {
fn default() -> Self {
ReLU
}
}
impl ReLU {
pub fn new() -> Self {
Self
}
}
impl Module for ReLU {
fn name(&self) -> &str { "relu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.relu()
}
}
pub struct Sigmoid;
impl Default for Sigmoid {
fn default() -> Self {
Sigmoid
}
}
impl Sigmoid {
pub fn new() -> Self {
Self
}
}
impl Module for Sigmoid {
fn name(&self) -> &str { "sigmoid" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.sigmoid()
}
}
pub struct Tanh;
impl Default for Tanh {
fn default() -> Self {
Tanh
}
}
impl Tanh {
pub fn new() -> Self {
Self
}
}
impl Module for Tanh {
fn name(&self) -> &str { "tanh" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.tanh()
}
}
pub struct GELU;
impl Default for GELU {
fn default() -> Self {
GELU
}
}
impl GELU {
pub fn new() -> Self {
Self
}
}
impl Module for GELU {
fn name(&self) -> &str { "gelu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.gelu()
}
}
pub struct SiLU;
impl Default for SiLU {
fn default() -> Self {
SiLU
}
}
impl SiLU {
pub fn new() -> Self {
Self
}
}
impl Module for SiLU {
fn name(&self) -> &str { "silu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.silu()
}
}
pub struct LeakyReLU {
negative_slope: f64,
}
impl Default for LeakyReLU {
fn default() -> Self {
LeakyReLU { negative_slope: 0.01 }
}
}
impl LeakyReLU {
pub fn new(negative_slope: f64) -> Self {
Self { negative_slope }
}
}
impl Module for LeakyReLU {
fn name(&self) -> &str { "leaky_relu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.leaky_relu(self.negative_slope)
}
}
pub struct ELU {
alpha: f64,
}
impl Default for ELU {
fn default() -> Self {
ELU { alpha: 1.0 }
}
}
impl ELU {
pub fn new(alpha: f64) -> Self {
Self { alpha }
}
}
impl Module for ELU {
fn name(&self) -> &str { "elu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.elu(self.alpha)
}
}
pub struct Softplus {
beta: f64,
threshold: f64,
}
impl Default for Softplus {
fn default() -> Self {
Softplus { beta: 1.0, threshold: 20.0 }
}
}
impl Softplus {
pub fn new(beta: f64, threshold: f64) -> Self {
Self { beta, threshold }
}
}
impl Module for Softplus {
fn name(&self) -> &str { "softplus" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.softplus(self.beta, self.threshold)
}
}
pub struct Mish;
impl Default for Mish {
fn default() -> Self {
Mish
}
}
impl Mish {
pub fn new() -> Self {
Self
}
}
impl Module for Mish {
fn name(&self) -> &str { "mish" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.mish()
}
}
pub struct Softmax {
dim: i32,
}
impl Softmax {
pub fn new(dim: i32) -> Self {
Self { dim }
}
}
impl Module for Softmax {
fn name(&self) -> &str { "softmax" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.softmax(self.dim)
}
}
pub struct LogSoftmax {
dim: i32,
}
impl LogSoftmax {
pub fn new(dim: i32) -> Self {
Self { dim }
}
}
impl Module for LogSoftmax {
fn name(&self) -> &str { "log_softmax" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.log_softmax(self.dim)
}
}
pub struct Flatten {
start_dim: i32,
end_dim: i32,
}
impl Default for Flatten {
fn default() -> Self {
Flatten { start_dim: 1, end_dim: -1 }
}
}
impl Flatten {
pub fn new(start_dim: i32, end_dim: i32) -> Self {
Self { start_dim, end_dim }
}
}
impl Module for Flatten {
fn name(&self) -> &str { "flatten" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.flatten(self.start_dim, self.end_dim)
}
}
pub struct SELU;
impl Default for SELU {
fn default() -> Self {
SELU
}
}
impl SELU {
pub fn new() -> Self {
Self
}
}
impl Module for SELU {
fn name(&self) -> &str { "selu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.selu()
}
}
pub struct Hardswish;
impl Default for Hardswish {
fn default() -> Self {
Hardswish
}
}
impl Hardswish {
pub fn new() -> Self {
Self
}
}
impl Module for Hardswish {
fn name(&self) -> &str { "hardswish" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.hardswish()
}
}
pub struct Hardsigmoid;
impl Default for Hardsigmoid {
fn default() -> Self {
Hardsigmoid
}
}
impl Hardsigmoid {
pub fn new() -> Self {
Self
}
}
impl Module for Hardsigmoid {
fn name(&self) -> &str { "hardsigmoid" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.hardsigmoid()
}
}
pub struct PReLU {
weight: Parameter,
}
impl PReLU {
pub fn new(num_parameters: i64, device: crate::tensor::Device) -> Result<Self> {
let init = crate::tensor::Tensor::full(&[num_parameters], 0.25, crate::tensor::TensorOptions {
dtype: crate::tensor::DType::Float32,
device,
})?;
Ok(Self {
weight: Parameter::new(init, "weight"),
})
}
pub fn on_device(num_parameters: i64, device: crate::tensor::Device) -> Result<Self> {
Self::new(num_parameters, device)
}
}
impl Module for PReLU {
fn name(&self) -> &str { "prelu" }
fn forward(&self, input: &Variable) -> Result<Variable> {
input.prelu(&self.weight.variable)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone()]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, test_device};
#[test]
fn test_leaky_relu_module() {
let m = LeakyReLU::new(0.2);
let t = Tensor::from_f32(&[-1.0, 0.0, 1.0], &[3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - (-0.2)).abs() < 1e-5);
assert!((y[2] - 1.0).abs() < 1e-5);
}
#[test]
fn test_leaky_relu_default() {
let m = LeakyReLU::default();
let t = Tensor::from_f32(&[-1.0], &[1], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - (-0.01)).abs() < 1e-5);
}
#[test]
fn test_elu_module() {
let m = ELU::default();
let t = Tensor::from_f32(&[-1.0, 0.0, 1.0], &[3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!(y[0] < 0.0); assert!((y[1] - 0.0).abs() < 1e-5);
assert!((y[2] - 1.0).abs() < 1e-5);
}
#[test]
fn test_softplus_module() {
let m = Softplus::default();
let t = Tensor::from_f32(&[0.0], &[1], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - std::f32::consts::LN_2).abs() < 1e-3);
}
#[test]
fn test_mish_module() {
let m = Mish::new();
let t = Tensor::from_f32(&[0.0, 1.0], &[2], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - 0.0).abs() < 1e-5);
assert!((y[1] - 0.8651).abs() < 1e-3);
}
#[test]
fn test_softmax_module() {
let m = Softmax::new(-1);
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
let sum: f32 = y.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_log_softmax_module() {
let m = LogSoftmax::new(-1);
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!(y.iter().all(|&v| v < 0.0));
}
#[test]
fn test_flatten_module() {
let m = Flatten::default();
let t = Tensor::from_f32(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[2, 2, 2], test_device(),
).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap();
assert_eq!(y.data().shape(), vec![2, 4]); }
#[test]
fn test_flatten_all() {
let m = Flatten::new(0, -1);
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap();
assert_eq!(y.data().shape(), vec![4]);
}
#[test]
fn test_selu_module() {
let m = SELU::new();
let t = Tensor::from_f32(&[-1.0, 0.0, 1.0], &[3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[1] - 0.0).abs() < 1e-5);
assert!((y[2] - 1.0507).abs() < 1e-3);
assert!(y[0] < 0.0);
}
#[test]
fn test_hardswish_module() {
let m = Hardswish::new();
let t = Tensor::from_f32(&[-4.0, 0.0, 4.0], &[3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - 0.0).abs() < 1e-5); assert!((y[1] - 0.0).abs() < 1e-5); assert!((y[2] - 4.0).abs() < 1e-5); }
#[test]
fn test_hardsigmoid_module() {
let m = Hardsigmoid::new();
let t = Tensor::from_f32(&[-4.0, 0.0, 4.0], &[3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - 0.0).abs() < 1e-5); assert!((y[1] - 0.5).abs() < 1e-5); assert!((y[2] - 1.0).abs() < 1e-5); }
#[test]
fn test_prelu_module() {
let m = PReLU::new(1, test_device()).unwrap();
let t = Tensor::from_f32(&[-2.0, 0.0, 1.0], &[3], test_device()).unwrap();
let x = Variable::new(t, false);
let y = m.forward(&x).unwrap().data().to_f32_vec().unwrap();
assert!((y[0] - (-0.5)).abs() < 1e-5);
assert!((y[1] - 0.0).abs() < 1e-5);
assert!((y[2] - 1.0).abs() < 1e-5);
assert_eq!(m.parameters().len(), 1);
}
}