1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only
use zyx::Tensor;
/// Activation
#[derive(Debug, Default)]
pub enum Activation {
//#[serde(alias = "gelu")]
/// Gelu
#[default]
Gelu,
//#[serde(alias = "gelu_new")]
//NewGelu,
/// Relu
Relu,
/// Relu2
Relu2,
/// Relu6
Relu6,
//Silu,
/// Sigmoid
Sigmoid,
/// Hard sigmoid
HardSigmoid,
//Swiglu,
/// Swish
Swish,
//HardSwish,
/// Elu
Elu(f64),
/// Leaky relu
LeakyRelu(f64),
//#[serde(alias = "gelu_pytorch_tanh")]
//GeluPytorchTanh,
}
impl Activation {
/// Activation forward
pub fn forward(&self, xs: impl Into<Tensor>) -> Tensor {
let xs = xs.into();
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
Self::Relu2 => xs.relu().pow(2).unwrap(),
Self::Relu6 => xs.clamp(0f32, 6f32).unwrap(),
//Self::Silu => xs * xs.silu(),
Self::Sigmoid => xs.sigmoid(),
Self::HardSigmoid => xs.hard_sigmoid(),
//Self::Swiglu => xs.swiglu(),
Self::Swish => xs.swish(),
//Self::HardSwish => xs * xs.hard_swish(),
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => xs.leaky_relu(negative_slope),
//Self::GeluPytorchTanh => xs.gelu(),
}
}
}