ghostflow_core/ops/
activation.rs

1//! Activation functions
2
3use crate::tensor::Tensor;
4#[cfg(feature = "rayon")]
5use rayon::prelude::*;
6
7impl Tensor {
8    /// ReLU activation: max(0, x)
9    pub fn relu(&self) -> Tensor {
10        #[cfg(feature = "simd")]
11        {
12            use crate::ops::simd::relu_simd;
13            let data = self.data_f32();
14            let result = relu_simd(&data);
15            Tensor::from_slice(&result, self.dims()).unwrap()
16        }
17        
18        #[cfg(not(feature = "simd"))]
19        {
20            let data: Vec<f32> = self.data_f32()
21                .iter()
22                .map(|&x| x.max(0.0))
23                .collect();
24            Tensor::from_slice(&data, self.dims()).unwrap()
25        }
26    }
27
28    /// Leaky ReLU: max(alpha * x, x)
29    pub fn leaky_relu(&self, alpha: f32) -> Tensor {
30        let data: Vec<f32> = self.data_f32()
31            .iter()
32            .map(|&x| if x > 0.0 { x } else { alpha * x })
33            .collect();
34        Tensor::from_slice(&data, self.dims()).unwrap()
35    }
36
37    /// ELU activation
38    pub fn elu(&self, alpha: f32) -> Tensor {
39        let data: Vec<f32> = self.data_f32()
40            .iter()
41            .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
42            .collect();
43        Tensor::from_slice(&data, self.dims()).unwrap()
44    }
45
46    /// SELU activation (self-normalizing)
47    pub fn selu(&self) -> Tensor {
48        const ALPHA: f32 = 1.673_263_2;
49        const SCALE: f32 = 1.050_701;
50        
51        let data: Vec<f32> = self.data_f32()
52            .iter()
53            .map(|&x| {
54                SCALE * if x > 0.0 { x } else { ALPHA * (x.exp() - 1.0) }
55            })
56            .collect();
57        Tensor::from_slice(&data, self.dims()).unwrap()
58    }
59
60    /// Sigmoid activation: 1 / (1 + exp(-x))
61    pub fn sigmoid(&self) -> Tensor {
62        #[cfg(feature = "simd")]
63        {
64            use crate::ops::simd::sigmoid_simd;
65            let data = self.data_f32();
66            let result = sigmoid_simd(&data);
67            Tensor::from_slice(&result, self.dims()).unwrap()
68        }
69        
70        #[cfg(not(feature = "simd"))]
71        {
72            let data: Vec<f32> = self.data_f32()
73                .iter()
74                .map(|&x| 1.0 / (1.0 + (-x).exp()))
75                .collect();
76            Tensor::from_slice(&data, self.dims()).unwrap()
77        }
78    }
79
80    /// Tanh activation
81    pub fn tanh(&self) -> Tensor {
82        let data: Vec<f32> = self.data_f32()
83            .iter()
84            .map(|&x| x.tanh())
85            .collect();
86        Tensor::from_slice(&data, self.dims()).unwrap()
87    }
88
89    /// GELU activation (Gaussian Error Linear Unit)
90    pub fn gelu(&self) -> Tensor {
91        #[cfg(feature = "simd")]
92        {
93            use crate::ops::simd::gelu_simd;
94            let data = self.data_f32();
95            let result = gelu_simd(&data);
96            Tensor::from_slice(&result, self.dims()).unwrap()
97        }
98        
99        #[cfg(not(feature = "simd"))]
100        {
101            const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
102            const COEFF: f32 = 0.044715;
103            
104            let data: Vec<f32> = self.data_f32()
105                .iter()
106                .map(|&x| {
107                    // Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
108                    let inner = SQRT_2_OVER_PI * (x + COEFF * x.powi(3));
109                    0.5 * x * (1.0 + inner.tanh())
110                })
111                .collect();
112            Tensor::from_slice(&data, self.dims()).unwrap()
113        }
114    }
115
116    /// SiLU / Swish activation: x * sigmoid(x)
117    pub fn silu(&self) -> Tensor {
118        let data: Vec<f32> = self.data_f32()
119            .iter()
120            .map(|&x| x / (1.0 + (-x).exp()))
121            .collect();
122        Tensor::from_slice(&data, self.dims()).unwrap()
123    }
124
125    /// Mish activation: x * tanh(softplus(x))
126    pub fn mish(&self) -> Tensor {
127        let data: Vec<f32> = self.data_f32()
128            .iter()
129            .map(|&x| {
130                let softplus = (1.0 + x.exp()).ln();
131                x * softplus.tanh()
132            })
133            .collect();
134        Tensor::from_slice(&data, self.dims()).unwrap()
135    }
136
137    /// Softplus: log(1 + exp(x))
138    pub fn softplus(&self) -> Tensor {
139        let data: Vec<f32> = self.data_f32()
140            .iter()
141            .map(|&x| {
142                // Numerically stable version
143                if x > 20.0 {
144                    x
145                } else if x < -20.0 {
146                    x.exp()
147                } else {
148                    (1.0 + x.exp()).ln()
149                }
150            })
151            .collect();
152        Tensor::from_slice(&data, self.dims()).unwrap()
153    }
154
155    /// Softsign: x / (1 + |x|)
156    pub fn softsign(&self) -> Tensor {
157        let data: Vec<f32> = self.data_f32()
158            .iter()
159            .map(|&x| x / (1.0 + x.abs()))
160            .collect();
161        Tensor::from_slice(&data, self.dims()).unwrap()
162    }
163
164    /// Softmax along last dimension
165    pub fn softmax(&self, dim: i32) -> Tensor {
166        let dims = self.dims();
167        let ndim = dims.len();
168        let dim = if dim < 0 { (ndim as i32 + dim) as usize } else { dim as usize };
169        
170        let data = self.data_f32();
171        let dim_size = dims[dim];
172        
173        // Compute stride for the softmax dimension
174        let inner_size: usize = dims[dim + 1..].iter().product();
175        let _outer_size: usize = dims[..dim].iter().product();
176        
177        let mut result = vec![0.0f32; data.len()];
178        
179        // Process each outer chunk
180        #[cfg(feature = "rayon")]
181        let chunks_iter = result.par_chunks_mut(dim_size * inner_size);
182        #[cfg(not(feature = "rayon"))]
183        let chunks_iter = result.chunks_mut(dim_size * inner_size);
184        
185        chunks_iter
186            .enumerate()
187            .for_each(|(outer, outer_chunk)| {
188                for inner in 0..inner_size {
189                    // Find max for numerical stability
190                    let mut max_val = f32::NEG_INFINITY;
191                    for d in 0..dim_size {
192                        let idx = d * inner_size + inner;
193                        let val = data[outer * dim_size * inner_size + idx];
194                        max_val = max_val.max(val);
195                    }
196                    
197                    // Compute exp and sum
198                    let mut sum = 0.0f32;
199                    for d in 0..dim_size {
200                        let idx = d * inner_size + inner;
201                        let data_idx = outer * dim_size * inner_size + idx;
202                        let exp_val = (data[data_idx] - max_val).exp();
203                        outer_chunk[idx] = exp_val;
204                        sum += exp_val;
205                    }
206                    
207                    // Normalize
208                    for d in 0..dim_size {
209                        let idx = d * inner_size + inner;
210                        outer_chunk[idx] /= sum;
211                    }
212                }
213            });
214        
215        Tensor::from_slice(&result, dims).unwrap()
216    }
217
218    /// Log softmax (numerically stable)
219    pub fn log_softmax(&self, dim: i32) -> Tensor {
220        let softmax = self.softmax(dim);
221        softmax.log()
222    }
223
224    /// Hardtanh: clamp(x, min, max)
225    pub fn hardtanh(&self, min_val: f32, max_val: f32) -> Tensor {
226        self.clamp(min_val, max_val)
227    }
228
229    /// Hard sigmoid: clamp((x + 3) / 6, 0, 1)
230    pub fn hardsigmoid(&self) -> Tensor {
231        let data: Vec<f32> = self.data_f32()
232            .iter()
233            .map(|&x| ((x + 3.0) / 6.0).clamp(0.0, 1.0))
234            .collect();
235        Tensor::from_slice(&data, self.dims()).unwrap()
236    }
237
238    /// Hard swish: x * hardsigmoid(x)
239    pub fn hardswish(&self) -> Tensor {
240        let data: Vec<f32> = self.data_f32()
241            .iter()
242            .map(|&x| x * ((x + 3.0) / 6.0).clamp(0.0, 1.0))
243            .collect();
244        Tensor::from_slice(&data, self.dims()).unwrap()
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_relu() {
254        let t = Tensor::from_slice(&[-1.0f32, 0.0, 1.0, 2.0], &[4]).unwrap();
255        let r = t.relu();
256        assert_eq!(r.data_f32(), vec![0.0, 0.0, 1.0, 2.0]);
257    }
258
259    #[test]
260    fn test_sigmoid() {
261        let t = Tensor::from_slice(&[0.0f32], &[1]).unwrap();
262        let s = t.sigmoid();
263        assert!((s.data_f32()[0] - 0.5).abs() < 0.001);
264    }
265
266    #[test]
267    fn test_softmax() {
268        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
269        let s = t.softmax(-1);
270        let sum: f32 = s.data_f32().iter().sum();
271        assert!((sum - 1.0).abs() < 0.001);
272    }
273
274    #[test]
275    fn test_gelu() {
276        let t = Tensor::from_slice(&[0.0f32, 1.0, -1.0], &[3]).unwrap();
277        let g = t.gelu();
278        // GELU(0) ≈ 0
279        assert!(g.data_f32()[0].abs() < 0.001);
280    }
281}