Skip to main content

trueno/vector/ops/activations/advanced/
smooth.rs

1//! Smooth self-gated activation functions: gelu, swish, hardswish, mish
2//!
3//! These activations are smooth, non-monotonic, and use self-gating mechanisms.
4//! They are the preferred activations in modern transformer and vision architectures.
5
6use crate::backends::scalar::ScalarBackend;
7use crate::backends::VectorBackend;
8use crate::vector::Vector;
9use crate::{Backend, Result, TruenoError};
10
11use super::super::dispatch_unary_op;
12
13impl Vector<f32> {
14    /// GELU (Gaussian Error Linear Unit) activation function
15    ///
16    /// Computes the element-wise GELU activation using the tanh approximation.
17    /// GELU is the activation function used in transformers (BERT, GPT, etc.).
18    ///
19    /// # Formula
20    ///
21    /// ```text
22    /// gelu(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
23    /// ```
24    ///
25    /// This is the tanh approximation which is faster than the exact form
26    /// involving the error function (erf).
27    ///
28    /// # Properties
29    ///
30    /// - **Smooth**: Infinitely differentiable everywhere
31    /// - **Non-monotonic**: Unlike ReLU variants, has slight non-monotonicity near zero
32    /// - **Stochastic regularizer**: Can be viewed as adaptive dropout
33    /// - **Zero-centered**: Mean activation close to zero
34    /// - **Bounded below**: Approaches 0 as x → -∞
35    /// - **Unbounded above**: Linear growth for large positive x
36    ///
37    /// # Applications
38    ///
39    /// - **Transformers**: BERT, GPT-2, GPT-3, GPT-4 (default activation)
40    /// - **Vision transformers**: ViT, DINO, MAE
41    /// - **Modern architectures**: State-of-the-art NLP and vision models
42    /// - **Better than ReLU**: Empirically outperforms ReLU in many tasks
43    ///
44    /// # Performance
45    ///
46    /// This operation is compute-intensive (tanh, x³ calculations).
47    /// More expensive than ReLU but comparable to ELU.
48    ///
49    /// # Errors
50    ///
51    /// Returns `EmptyVector` if the input vector is empty.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use trueno::Vector;
57    ///
58    /// let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
59    /// let result = v.gelu()?;
60    ///
61    /// // GELU is smooth and non-monotonic near zero
62    /// assert!(result.as_slice()[0] < 0.0); // Negative inputs → small negative outputs
63    /// assert_eq!(result.as_slice()[2], 0.0); // gelu(0) = 0
64    /// assert!(result.as_slice()[4] > 1.5); // Large positive → ~linear
65    /// # Ok::<(), trueno::TruenoError>(())
66    /// ```
67    pub fn gelu(&self) -> Result<Self> {
68        if self.data.is_empty() {
69            return Err(TruenoError::EmptyVector);
70        }
71
72        // OpComplexity::Low - GPU threshold: >100K elements
73        #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
74        const GPU_THRESHOLD: usize = usize::MAX; // GPU DISABLED - 2-800x slower, see docs/performance-analysis.md
75
76        // Try GPU first for large vectors
77        #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
78        {
79            if self.data.len() >= GPU_THRESHOLD {
80                use crate::backends::gpu::GpuDevice;
81                if GpuDevice::is_available() {
82                    let gpu = GpuDevice::new().map_err(TruenoError::InvalidInput)?;
83                    let mut result = vec![0.0; self.data.len()];
84                    if gpu.gelu(&self.data, &mut result).is_ok() {
85                        return Ok(Vector::from_vec(result));
86                    }
87                }
88            }
89        }
90
91        // Uninit: dispatch_unary_op writes every element before any read.
92        let n = self.len();
93        let mut result: Vec<f32> = Vec::with_capacity(n);
94        // SAFETY: Backend activation writes all elements before any read.
95        unsafe {
96            result.set_len(n);
97        }
98
99        // Dispatch to appropriate backend
100        dispatch_unary_op!(self.backend, gelu, &self.data, &mut result);
101
102        Ok(Vector::from_vec(result))
103    }
104
105    /// Swish activation function (also known as SiLU - Sigmoid Linear Unit)
106    ///
107    /// Applies the Swish activation element-wise: swish(x) = x * sigmoid(x) = x / (1 + e^(-x)).
108    ///
109    /// Swish is a smooth, non-monotonic activation function that consistently matches or
110    /// outperforms ReLU in deep networks. It's used in EfficientNet, MobileNet v3, and
111    /// many modern architectures. The function is self-gated: it adaptively gates the
112    /// input based on its value.
113    ///
114    /// Properties:
115    /// - Smooth and differentiable everywhere
116    /// - Non-monotonic: has a slight "dip" for negative values
117    /// - swish(0) = 0
118    /// - swish(x) ≈ x for large positive x (linear)
119    /// - swish(x) ≈ 0 for large negative x
120    /// - Unbounded above, bounded below by ≈ -0.278 at x ≈ -1.278
121    ///
122    /// # Performance
123    ///
124    /// Compute-bound operation requiring exponential and division.
125    /// Future SIMD optimizations planned for Phase 9 (GPU backend).
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use trueno::Vector;
131    ///
132    /// let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
133    /// let result = v.swish()?;
134    ///
135    /// // Swish is smooth and self-gated
136    /// assert!(result.as_slice()[0] < 0.0); // Negative inputs → small negative outputs
137    /// assert_eq!(result.as_slice()[2], 0.0); // swish(0) = 0
138    /// assert!(result.as_slice()[4] > 1.5); // Large positive → ~linear
139    /// # Ok::<(), trueno::TruenoError>(())
140    /// ```
141    ///
142    /// # Errors
143    ///
144    /// Returns `EmptyVector` if the input vector is empty.
145    ///
146    /// # References
147    ///
148    /// - Ramachandran et al. (2017): "Searching for Activation Functions"
149    /// - Also known as SiLU (Sigmoid Linear Unit): Elfwing et al. (2018)
150    pub fn swish(&self) -> Result<Self> {
151        if self.data.is_empty() {
152            return Err(TruenoError::EmptyVector);
153        }
154
155        // OpComplexity::Low - GPU threshold: >100K elements
156        #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
157        const GPU_THRESHOLD: usize = usize::MAX; // GPU DISABLED - 2-800x slower, see docs/performance-analysis.md
158
159        // Try GPU first for large vectors
160        #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
161        {
162            if self.data.len() >= GPU_THRESHOLD {
163                use crate::backends::gpu::GpuDevice;
164                if GpuDevice::is_available() {
165                    let gpu = GpuDevice::new().map_err(TruenoError::InvalidInput)?;
166                    let mut result = vec![0.0; self.data.len()];
167                    if gpu.swish(&self.data, &mut result).is_ok() {
168                        return Ok(Vector::from_vec(result));
169                    }
170                }
171            }
172        }
173
174        // Uninit: dispatch_unary_op writes every element before any read.
175        let n = self.len();
176        let mut result: Vec<f32> = Vec::with_capacity(n);
177        // SAFETY: Backend activation writes all elements before any read.
178        unsafe {
179            result.set_len(n);
180        }
181
182        // Dispatch to appropriate SIMD backend
183        dispatch_unary_op!(self.backend, swish, &self.data, &mut result);
184
185        Ok(Vector::from_vec(result))
186    }
187
188    /// Hard Swish activation function
189    ///
190    /// Applies the hardswish activation element-wise: hardswish(x) = x * relu6(x + 3) / 6
191    ///
192    /// Hardswish is a piece-wise linear approximation to swish, designed for efficient
193    /// computation in mobile neural networks. It's used in MobileNetV3 and avoids the
194    /// expensive sigmoid computation of standard swish.
195    ///
196    /// Properties:
197    /// - Piece-wise linear: efficient to compute
198    /// - hardswish(x) = 0 for x ≤ -3
199    /// - hardswish(x) = x for x ≥ 3
200    /// - hardswish(x) = x * (x + 3) / 6 for -3 < x < 3
201    /// - hardswish(0) = 0
202    /// - Smooth transitions at boundaries
203    ///
204    /// # Performance
205    ///
206    /// More efficient than swish as it uses only multiply/divide operations
207    /// instead of expensive exponential functions. Ideal for inference on
208    /// resource-constrained devices.
209    ///
210    /// # Examples
211    ///
212    /// ```
213    /// use trueno::Vector;
214    ///
215    /// let v = Vector::from_slice(&[-4.0, -3.0, 0.0, 3.0, 4.0]);
216    /// let result = v.hardswish()?;
217    ///
218    /// // Piece-wise linear behavior
219    /// assert_eq!(result.as_slice()[0], 0.0); // x ≤ -3 → 0
220    /// assert_eq!(result.as_slice()[1], 0.0); // x = -3 → 0
221    /// assert_eq!(result.as_slice()[2], 0.0); // x = 0 → 0
222    /// assert_eq!(result.as_slice()[3], 3.0); // x = 3 → x
223    /// assert_eq!(result.as_slice()[4], 4.0); // x ≥ 3 → x
224    /// # Ok::<(), trueno::TruenoError>(())
225    /// ```
226    ///
227    /// # Errors
228    ///
229    /// Returns `EmptyVector` if the input vector is empty.
230    ///
231    /// # References
232    ///
233    /// - Howard et al. (2019): "Searching for MobileNetV3"
234    pub fn hardswish(&self) -> Result<Self> {
235        if self.data.is_empty() {
236            return Err(TruenoError::EmptyVector);
237        }
238
239        // Scalar implementation: hardswish(x) = x * relu6(x + 3) / 6
240        // Simplified piece-wise:
241        // - x <= -3: 0
242        // - x >= 3: x
243        // - else: x * (x + 3) / 6
244        let data: Vec<f32> = self
245            .data
246            .iter()
247            .map(|&x| {
248                if x <= -3.0 {
249                    0.0
250                } else if x >= 3.0 {
251                    x
252                } else {
253                    x * (x + 3.0) / 6.0
254                }
255            })
256            .collect();
257
258        Ok(Vector::from_vec(data))
259    }
260
261    /// Mish activation function
262    ///
263    /// Applies the mish activation element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
264    ///
265    /// Mish is a self-regularizing non-monotonic activation function that often outperforms
266    /// ReLU and swish in computer vision tasks. It's used in YOLOv4 and many modern architectures.
267    ///
268    /// Properties:
269    /// - Smooth and non-monotonic (similar to swish)
270    /// - Self-regularizing: prevents dying neurons
271    /// - mish(0) ≈ 0 (small positive value)
272    /// - mish(x) ≈ x for large positive x (nearly linear)
273    /// - mish(x) ≈ 0 for large negative x
274    /// - Bounded below by ≈ -0.31 at x ≈ -1.19
275    ///
276    /// # Performance
277    ///
278    /// Compute-bound operation requiring exponential, logarithm, and tanh.
279    /// More expensive than ReLU/swish but often provides better accuracy.
280    ///
281    /// # Examples
282    ///
283    /// ```
284    /// use trueno::Vector;
285    ///
286    /// let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
287    /// let result = v.mish()?;
288    ///
289    /// // Mish is smooth and self-gated
290    /// assert!(result.as_slice()[0] < 0.0); // Small negative output for negative inputs
291    /// assert!(result.as_slice()[2].abs() < 1e-5); // mish(0) = 0
292    /// assert!(result.as_slice()[4] > 1.5); // Large positive → near linear
293    /// # Ok::<(), trueno::TruenoError>(())
294    /// ```
295    ///
296    /// # Errors
297    ///
298    /// Returns `EmptyVector` if the input vector is empty.
299    ///
300    /// # References
301    ///
302    /// - Misra (2019): "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
303    pub fn mish(&self) -> Result<Self> {
304        if self.data.is_empty() {
305            return Err(TruenoError::EmptyVector);
306        }
307
308        // Scalar implementation: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
309        let data: Vec<f32> = self
310            .data
311            .iter()
312            .map(|&x| {
313                // Handle extreme values for numerical stability
314                if x < -20.0 {
315                    // For very negative x: softplus ≈ 0, tanh(0) ≈ 0, so mish ≈ 0
316                    0.0
317                } else if x > 20.0 {
318                    // For very positive x: softplus ≈ x, tanh(x) ≈ 1, so mish ≈ x
319                    x
320                } else {
321                    // Normal case: x * tanh(ln(1 + e^x))
322                    let softplus = (1.0 + x.exp()).ln();
323                    x * softplus.tanh()
324                }
325            })
326            .collect();
327
328        Ok(Vector::from_vec(data))
329    }
330}