ghostflow_nn/
activation.rs

1//! Activation function modules
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5
6/// ReLU activation module
7pub struct ReLU;
8
9impl ReLU {
10    pub fn new() -> Self { ReLU }
11}
12
13impl Default for ReLU {
14    fn default() -> Self { Self::new() }
15}
16
17impl Module for ReLU {
18    fn forward(&self, input: &Tensor) -> Tensor {
19        input.relu()
20    }
21    fn parameters(&self) -> Vec<Tensor> { vec![] }
22    fn train(&mut self) {}
23    fn eval(&mut self) {}
24    fn is_training(&self) -> bool { false }
25}
26
27/// Leaky ReLU activation module
28pub struct LeakyReLU {
29    alpha: f32,
30}
31
32impl LeakyReLU {
33    pub fn new(alpha: f32) -> Self {
34        LeakyReLU { alpha }
35    }
36}
37
38impl Default for LeakyReLU {
39    fn default() -> Self { Self::new(0.01) }
40}
41
42impl Module for LeakyReLU {
43    fn forward(&self, input: &Tensor) -> Tensor {
44        input.leaky_relu(self.alpha)
45    }
46    fn parameters(&self) -> Vec<Tensor> { vec![] }
47    fn train(&mut self) {}
48    fn eval(&mut self) {}
49    fn is_training(&self) -> bool { false }
50}
51
52/// GELU activation module
53pub struct GELU;
54
55impl GELU {
56    pub fn new() -> Self { GELU }
57}
58
59impl Default for GELU {
60    fn default() -> Self { Self::new() }
61}
62
63impl Module for GELU {
64    fn forward(&self, input: &Tensor) -> Tensor {
65        input.gelu()
66    }
67    fn parameters(&self) -> Vec<Tensor> { vec![] }
68    fn train(&mut self) {}
69    fn eval(&mut self) {}
70    fn is_training(&self) -> bool { false }
71}
72
73/// Sigmoid activation module
74pub struct Sigmoid;
75
76impl Sigmoid {
77    pub fn new() -> Self { Sigmoid }
78}
79
80impl Default for Sigmoid {
81    fn default() -> Self { Self::new() }
82}
83
84impl Module for Sigmoid {
85    fn forward(&self, input: &Tensor) -> Tensor {
86        input.sigmoid()
87    }
88    fn parameters(&self) -> Vec<Tensor> { vec![] }
89    fn train(&mut self) {}
90    fn eval(&mut self) {}
91    fn is_training(&self) -> bool { false }
92}
93
94/// Tanh activation module
95pub struct Tanh;
96
97impl Tanh {
98    pub fn new() -> Self { Tanh }
99}
100
101impl Default for Tanh {
102    fn default() -> Self { Self::new() }
103}
104
105impl Module for Tanh {
106    fn forward(&self, input: &Tensor) -> Tensor {
107        input.tanh()
108    }
109    fn parameters(&self) -> Vec<Tensor> { vec![] }
110    fn train(&mut self) {}
111    fn eval(&mut self) {}
112    fn is_training(&self) -> bool { false }
113}
114
115/// SiLU/Swish activation module
116pub struct SiLU;
117
118impl SiLU {
119    pub fn new() -> Self { SiLU }
120}
121
122impl Default for SiLU {
123    fn default() -> Self { Self::new() }
124}
125
126impl Module for SiLU {
127    fn forward(&self, input: &Tensor) -> Tensor {
128        input.silu()
129    }
130    fn parameters(&self) -> Vec<Tensor> { vec![] }
131    fn train(&mut self) {}
132    fn eval(&mut self) {}
133    fn is_training(&self) -> bool { false }
134}
135
136/// Softmax activation module
137pub struct Softmax {
138    dim: i32,
139}
140
141impl Softmax {
142    pub fn new(dim: i32) -> Self {
143        Softmax { dim }
144    }
145}
146
147impl Default for Softmax {
148    fn default() -> Self { Self::new(-1) }
149}
150
151impl Module for Softmax {
152    fn forward(&self, input: &Tensor) -> Tensor {
153        input.softmax(self.dim)
154    }
155    fn parameters(&self) -> Vec<Tensor> { vec![] }
156    fn train(&mut self) {}
157    fn eval(&mut self) {}
158    fn is_training(&self) -> bool { false }
159}
160
161/// Swish activation module (parameterized version of SiLU)
162pub struct Swish {
163    beta: f32,
164}
165
166impl Swish {
167    pub fn new(beta: f32) -> Self {
168        Swish { beta }
169    }
170}
171
172impl Default for Swish {
173    fn default() -> Self { Self::new(1.0) }
174}
175
176impl Module for Swish {
177    fn forward(&self, input: &Tensor) -> Tensor {
178        let data = input.data_f32();
179        let result: Vec<f32> = data.iter()
180            .map(|&x| x / (1.0 + (-self.beta * x).exp()))
181            .collect();
182        Tensor::from_slice(&result, input.dims()).unwrap()
183    }
184    fn parameters(&self) -> Vec<Tensor> { vec![] }
185    fn train(&mut self) {}
186    fn eval(&mut self) {}
187    fn is_training(&self) -> bool { false }
188}
189
190/// Mish activation module
191/// f(x) = x * tanh(softplus(x))
192pub struct Mish;
193
194impl Mish {
195    pub fn new() -> Self { Mish }
196}
197
198impl Default for Mish {
199    fn default() -> Self { Self::new() }
200}
201
202impl Module for Mish {
203    fn forward(&self, input: &Tensor) -> Tensor {
204        let data = input.data_f32();
205        let result: Vec<f32> = data.iter()
206            .map(|&x| {
207                let softplus = (1.0 + x.exp()).ln();
208                x * softplus.tanh()
209            })
210            .collect();
211        Tensor::from_slice(&result, input.dims()).unwrap()
212    }
213    fn parameters(&self) -> Vec<Tensor> { vec![] }
214    fn train(&mut self) {}
215    fn eval(&mut self) {}
216    fn is_training(&self) -> bool { false }
217}
218
219/// ELU (Exponential Linear Unit) activation module
220pub struct ELU {
221    alpha: f32,
222}
223
224impl ELU {
225    pub fn new(alpha: f32) -> Self {
226        ELU { alpha }
227    }
228}
229
230impl Default for ELU {
231    fn default() -> Self { Self::new(1.0) }
232}
233
234impl Module for ELU {
235    fn forward(&self, input: &Tensor) -> Tensor {
236        let data = input.data_f32();
237        let result: Vec<f32> = data.iter()
238            .map(|&x| {
239                if x > 0.0 {
240                    x
241                } else {
242                    self.alpha * (x.exp() - 1.0)
243                }
244            })
245            .collect();
246        Tensor::from_slice(&result, input.dims()).unwrap()
247    }
248    fn parameters(&self) -> Vec<Tensor> { vec![] }
249    fn train(&mut self) {}
250    fn eval(&mut self) {}
251    fn is_training(&self) -> bool { false }
252}
253
254/// SELU (Scaled Exponential Linear Unit) activation module
255pub struct SELU {
256    alpha: f32,
257    scale: f32,
258}
259
260impl SELU {
261    pub fn new() -> Self {
262        // Standard SELU parameters
263        SELU {
264            alpha: 1.6732632423543772848170429916717,
265            scale: 1.0507009873554804934193349852946,
266        }
267    }
268    
269    pub fn with_params(alpha: f32, scale: f32) -> Self {
270        SELU { alpha, scale }
271    }
272}
273
274impl Default for SELU {
275    fn default() -> Self { Self::new() }
276}
277
278impl Module for SELU {
279    fn forward(&self, input: &Tensor) -> Tensor {
280        let data = input.data_f32();
281        let result: Vec<f32> = data.iter()
282            .map(|&x| {
283                if x > 0.0 {
284                    self.scale * x
285                } else {
286                    self.scale * self.alpha * (x.exp() - 1.0)
287                }
288            })
289            .collect();
290        Tensor::from_slice(&result, input.dims()).unwrap()
291    }
292    fn parameters(&self) -> Vec<Tensor> { vec![] }
293    fn train(&mut self) {}
294    fn eval(&mut self) {}
295    fn is_training(&self) -> bool { false }
296}
297
298/// Softplus activation module
299/// f(x) = ln(1 + exp(x))
300pub struct Softplus {
301    beta: f32,
302    threshold: f32,
303}
304
305impl Softplus {
306    pub fn new(beta: f32, threshold: f32) -> Self {
307        Softplus { beta, threshold }
308    }
309}
310
311impl Default for Softplus {
312    fn default() -> Self { Self::new(1.0, 20.0) }
313}
314
315impl Module for Softplus {
316    fn forward(&self, input: &Tensor) -> Tensor {
317        let data = input.data_f32();
318        let result: Vec<f32> = data.iter()
319            .map(|&x| {
320                let beta_x = self.beta * x;
321                if beta_x > self.threshold {
322                    // For large values, use linear approximation to avoid overflow
323                    x
324                } else {
325                    (1.0 + beta_x.exp()).ln() / self.beta
326                }
327            })
328            .collect();
329        Tensor::from_slice(&result, input.dims()).unwrap()
330    }
331    fn parameters(&self) -> Vec<Tensor> { vec![] }
332    fn train(&mut self) {}
333    fn eval(&mut self) {}
334    fn is_training(&self) -> bool { false }
335}