ghostflow_core/ops/
activation.rs

1//! Activation functions
2
3use crate::tensor::Tensor;
4use rayon::prelude::*;
5
6impl Tensor {
7    /// ReLU activation: max(0, x)
8    pub fn relu(&self) -> Tensor {
9        #[cfg(feature = "simd")]
10        {
11            use crate::ops::simd::relu_simd;
12            let data = self.data_f32();
13            let result = relu_simd(&data);
14            return Tensor::from_slice(&result, self.dims()).unwrap();
15        }
16        
17        #[cfg(not(feature = "simd"))]
18        {
19            let data: Vec<f32> = self.data_f32()
20                .par_iter()
21                .map(|&x| x.max(0.0))
22                .collect();
23            Tensor::from_slice(&data, self.dims()).unwrap()
24        }
25    }
26
27    /// Leaky ReLU: max(alpha * x, x)
28    pub fn leaky_relu(&self, alpha: f32) -> Tensor {
29        let data: Vec<f32> = self.data_f32()
30            .par_iter()
31            .map(|&x| if x > 0.0 { x } else { alpha * x })
32            .collect();
33        Tensor::from_slice(&data, self.dims()).unwrap()
34    }
35
36    /// ELU activation
37    pub fn elu(&self, alpha: f32) -> Tensor {
38        let data: Vec<f32> = self.data_f32()
39            .par_iter()
40            .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
41            .collect();
42        Tensor::from_slice(&data, self.dims()).unwrap()
43    }
44
45    /// SELU activation (self-normalizing)
46    pub fn selu(&self) -> Tensor {
47        const ALPHA: f32 = 1.6732632423543772;
48        const SCALE: f32 = 1.0507009873554805;
49        
50        let data: Vec<f32> = self.data_f32()
51            .par_iter()
52            .map(|&x| {
53                SCALE * if x > 0.0 { x } else { ALPHA * (x.exp() - 1.0) }
54            })
55            .collect();
56        Tensor::from_slice(&data, self.dims()).unwrap()
57    }
58
59    /// Sigmoid activation: 1 / (1 + exp(-x))
60    pub fn sigmoid(&self) -> Tensor {
61        #[cfg(feature = "simd")]
62        {
63            use crate::ops::simd::sigmoid_simd;
64            let data = self.data_f32();
65            let result = sigmoid_simd(&data);
66            return Tensor::from_slice(&result, self.dims()).unwrap();
67        }
68        
69        #[cfg(not(feature = "simd"))]
70        {
71            let data: Vec<f32> = self.data_f32()
72                .par_iter()
73                .map(|&x| 1.0 / (1.0 + (-x).exp()))
74                .collect();
75            Tensor::from_slice(&data, self.dims()).unwrap()
76        }
77    }
78
79    /// Tanh activation
80    pub fn tanh(&self) -> Tensor {
81        let data: Vec<f32> = self.data_f32()
82            .par_iter()
83            .map(|&x| x.tanh())
84            .collect();
85        Tensor::from_slice(&data, self.dims()).unwrap()
86    }
87
88    /// GELU activation (Gaussian Error Linear Unit)
89    pub fn gelu(&self) -> Tensor {
90        #[cfg(feature = "simd")]
91        {
92            use crate::ops::simd::gelu_simd;
93            let data = self.data_f32();
94            let result = gelu_simd(&data);
95            return Tensor::from_slice(&result, self.dims()).unwrap();
96        }
97        
98        #[cfg(not(feature = "simd"))]
99        {
100            const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
101            const COEFF: f32 = 0.044715;
102            
103            let data: Vec<f32> = self.data_f32()
104                .par_iter()
105                .map(|&x| {
106                    // Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
107                    let inner = SQRT_2_OVER_PI * (x + COEFF * x.powi(3));
108                    0.5 * x * (1.0 + inner.tanh())
109                })
110                .collect();
111            Tensor::from_slice(&data, self.dims()).unwrap()
112        }
113    }
114
115    /// SiLU / Swish activation: x * sigmoid(x)
116    pub fn silu(&self) -> Tensor {
117        let data: Vec<f32> = self.data_f32()
118            .par_iter()
119            .map(|&x| x / (1.0 + (-x).exp()))
120            .collect();
121        Tensor::from_slice(&data, self.dims()).unwrap()
122    }
123
124    /// Mish activation: x * tanh(softplus(x))
125    pub fn mish(&self) -> Tensor {
126        let data: Vec<f32> = self.data_f32()
127            .par_iter()
128            .map(|&x| {
129                let softplus = (1.0 + x.exp()).ln();
130                x * softplus.tanh()
131            })
132            .collect();
133        Tensor::from_slice(&data, self.dims()).unwrap()
134    }
135
136    /// Softplus: log(1 + exp(x))
137    pub fn softplus(&self) -> Tensor {
138        let data: Vec<f32> = self.data_f32()
139            .par_iter()
140            .map(|&x| {
141                // Numerically stable version
142                if x > 20.0 {
143                    x
144                } else if x < -20.0 {
145                    x.exp()
146                } else {
147                    (1.0 + x.exp()).ln()
148                }
149            })
150            .collect();
151        Tensor::from_slice(&data, self.dims()).unwrap()
152    }
153
154    /// Softsign: x / (1 + |x|)
155    pub fn softsign(&self) -> Tensor {
156        let data: Vec<f32> = self.data_f32()
157            .par_iter()
158            .map(|&x| x / (1.0 + x.abs()))
159            .collect();
160        Tensor::from_slice(&data, self.dims()).unwrap()
161    }
162
163    /// Softmax along last dimension
164    pub fn softmax(&self, dim: i32) -> Tensor {
165        let dims = self.dims();
166        let ndim = dims.len();
167        let dim = if dim < 0 { (ndim as i32 + dim) as usize } else { dim as usize };
168        
169        let data = self.data_f32();
170        let dim_size = dims[dim];
171        
172        // Compute stride for the softmax dimension
173        let inner_size: usize = dims[dim + 1..].iter().product();
174        let _outer_size: usize = dims[..dim].iter().product();
175        
176        let mut result = vec![0.0f32; data.len()];
177        
178        // Parallelize over outer dimension for better performance
179        result.par_chunks_mut(dim_size * inner_size)
180            .enumerate()
181            .for_each(|(outer, outer_chunk)| {
182                for inner in 0..inner_size {
183                    // Find max for numerical stability
184                    let mut max_val = f32::NEG_INFINITY;
185                    for d in 0..dim_size {
186                        let idx = d * inner_size + inner;
187                        let val = data[outer * dim_size * inner_size + idx];
188                        max_val = max_val.max(val);
189                    }
190                    
191                    // Compute exp and sum
192                    let mut sum = 0.0f32;
193                    for d in 0..dim_size {
194                        let idx = d * inner_size + inner;
195                        let data_idx = outer * dim_size * inner_size + idx;
196                        let exp_val = (data[data_idx] - max_val).exp();
197                        outer_chunk[idx] = exp_val;
198                        sum += exp_val;
199                    }
200                    
201                    // Normalize
202                    for d in 0..dim_size {
203                        let idx = d * inner_size + inner;
204                        outer_chunk[idx] /= sum;
205                    }
206                }
207            });
208        
209        Tensor::from_slice(&result, dims).unwrap()
210    }
211
212    /// Log softmax (numerically stable)
213    pub fn log_softmax(&self, dim: i32) -> Tensor {
214        let softmax = self.softmax(dim);
215        softmax.log()
216    }
217
218    /// Hardtanh: clamp(x, min, max)
219    pub fn hardtanh(&self, min_val: f32, max_val: f32) -> Tensor {
220        self.clamp(min_val, max_val)
221    }
222
223    /// Hard sigmoid: clamp((x + 3) / 6, 0, 1)
224    pub fn hardsigmoid(&self) -> Tensor {
225        let data: Vec<f32> = self.data_f32()
226            .par_iter()
227            .map(|&x| ((x + 3.0) / 6.0).clamp(0.0, 1.0))
228            .collect();
229        Tensor::from_slice(&data, self.dims()).unwrap()
230    }
231
232    /// Hard swish: x * hardsigmoid(x)
233    pub fn hardswish(&self) -> Tensor {
234        let data: Vec<f32> = self.data_f32()
235            .par_iter()
236            .map(|&x| x * ((x + 3.0) / 6.0).clamp(0.0, 1.0))
237            .collect();
238        Tensor::from_slice(&data, self.dims()).unwrap()
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_relu() {
248        let t = Tensor::from_slice(&[-1.0f32, 0.0, 1.0, 2.0], &[4]).unwrap();
249        let r = t.relu();
250        assert_eq!(r.data_f32(), vec![0.0, 0.0, 1.0, 2.0]);
251    }
252
253    #[test]
254    fn test_sigmoid() {
255        let t = Tensor::from_slice(&[0.0f32], &[1]).unwrap();
256        let s = t.sigmoid();
257        assert!((s.data_f32()[0] - 0.5).abs() < 0.001);
258    }
259
260    #[test]
261    fn test_softmax() {
262        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
263        let s = t.softmax(-1);
264        let sum: f32 = s.data_f32().iter().sum();
265        assert!((sum - 1.0).abs() < 0.001);
266    }
267
268    #[test]
269    fn test_gelu() {
270        let t = Tensor::from_slice(&[0.0f32, 1.0, -1.0], &[3]).unwrap();
271        let g = t.gelu();
272        // GELU(0) ≈ 0
273        assert!(g.data_f32()[0].abs() < 0.001);
274    }
275}