trueno/vector/ops/activations/advanced/smooth.rs
1//! Smooth self-gated activation functions: gelu, swish, hardswish, mish
2//!
3//! These activations are smooth, non-monotonic, and use self-gating mechanisms.
4//! They are the preferred activations in modern transformer and vision architectures.
5
6use crate::backends::scalar::ScalarBackend;
7use crate::backends::VectorBackend;
8use crate::vector::Vector;
9use crate::{Backend, Result, TruenoError};
10
11use super::super::dispatch_unary_op;
12
13impl Vector<f32> {
14 /// GELU (Gaussian Error Linear Unit) activation function
15 ///
16 /// Computes the element-wise GELU activation using the tanh approximation.
17 /// GELU is the activation function used in transformers (BERT, GPT, etc.).
18 ///
19 /// # Formula
20 ///
21 /// ```text
22 /// gelu(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
23 /// ```
24 ///
25 /// This is the tanh approximation which is faster than the exact form
26 /// involving the error function (erf).
27 ///
28 /// # Properties
29 ///
30 /// - **Smooth**: Infinitely differentiable everywhere
31 /// - **Non-monotonic**: Unlike ReLU variants, has slight non-monotonicity near zero
32 /// - **Stochastic regularizer**: Can be viewed as adaptive dropout
33 /// - **Zero-centered**: Mean activation close to zero
34 /// - **Bounded below**: Approaches 0 as x → -∞
35 /// - **Unbounded above**: Linear growth for large positive x
36 ///
37 /// # Applications
38 ///
39 /// - **Transformers**: BERT, GPT-2, GPT-3, GPT-4 (default activation)
40 /// - **Vision transformers**: ViT, DINO, MAE
41 /// - **Modern architectures**: State-of-the-art NLP and vision models
42 /// - **Better than ReLU**: Empirically outperforms ReLU in many tasks
43 ///
44 /// # Performance
45 ///
46 /// This operation is compute-intensive (tanh, x³ calculations).
47 /// More expensive than ReLU but comparable to ELU.
48 ///
49 /// # Errors
50 ///
51 /// Returns `EmptyVector` if the input vector is empty.
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use trueno::Vector;
57 ///
58 /// let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
59 /// let result = v.gelu()?;
60 ///
61 /// // GELU is smooth and non-monotonic near zero
62 /// assert!(result.as_slice()[0] < 0.0); // Negative inputs → small negative outputs
63 /// assert_eq!(result.as_slice()[2], 0.0); // gelu(0) = 0
64 /// assert!(result.as_slice()[4] > 1.5); // Large positive → ~linear
65 /// # Ok::<(), trueno::TruenoError>(())
66 /// ```
67 pub fn gelu(&self) -> Result<Self> {
68 if self.data.is_empty() {
69 return Err(TruenoError::EmptyVector);
70 }
71
72 // OpComplexity::Low - GPU threshold: >100K elements
73 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
74 const GPU_THRESHOLD: usize = usize::MAX; // GPU DISABLED - 2-800x slower, see docs/performance-analysis.md
75
76 // Try GPU first for large vectors
77 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
78 {
79 if self.data.len() >= GPU_THRESHOLD {
80 use crate::backends::gpu::GpuDevice;
81 if GpuDevice::is_available() {
82 let gpu = GpuDevice::new().map_err(TruenoError::InvalidInput)?;
83 let mut result = vec![0.0; self.data.len()];
84 if gpu.gelu(&self.data, &mut result).is_ok() {
85 return Ok(Vector::from_vec(result));
86 }
87 }
88 }
89 }
90
91 // Uninit: dispatch_unary_op writes every element before any read.
92 let n = self.len();
93 let mut result: Vec<f32> = Vec::with_capacity(n);
94 // SAFETY: Backend activation writes all elements before any read.
95 unsafe {
96 result.set_len(n);
97 }
98
99 // Dispatch to appropriate backend
100 dispatch_unary_op!(self.backend, gelu, &self.data, &mut result);
101
102 Ok(Vector::from_vec(result))
103 }
104
105 /// Swish activation function (also known as SiLU - Sigmoid Linear Unit)
106 ///
107 /// Applies the Swish activation element-wise: swish(x) = x * sigmoid(x) = x / (1 + e^(-x)).
108 ///
109 /// Swish is a smooth, non-monotonic activation function that consistently matches or
110 /// outperforms ReLU in deep networks. It's used in EfficientNet, MobileNet v3, and
111 /// many modern architectures. The function is self-gated: it adaptively gates the
112 /// input based on its value.
113 ///
114 /// Properties:
115 /// - Smooth and differentiable everywhere
116 /// - Non-monotonic: has a slight "dip" for negative values
117 /// - swish(0) = 0
118 /// - swish(x) ≈ x for large positive x (linear)
119 /// - swish(x) ≈ 0 for large negative x
120 /// - Unbounded above, bounded below by ≈ -0.278 at x ≈ -1.278
121 ///
122 /// # Performance
123 ///
124 /// Compute-bound operation requiring exponential and division.
125 /// Future SIMD optimizations planned for Phase 9 (GPU backend).
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// use trueno::Vector;
131 ///
132 /// let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
133 /// let result = v.swish()?;
134 ///
135 /// // Swish is smooth and self-gated
136 /// assert!(result.as_slice()[0] < 0.0); // Negative inputs → small negative outputs
137 /// assert_eq!(result.as_slice()[2], 0.0); // swish(0) = 0
138 /// assert!(result.as_slice()[4] > 1.5); // Large positive → ~linear
139 /// # Ok::<(), trueno::TruenoError>(())
140 /// ```
141 ///
142 /// # Errors
143 ///
144 /// Returns `EmptyVector` if the input vector is empty.
145 ///
146 /// # References
147 ///
148 /// - Ramachandran et al. (2017): "Searching for Activation Functions"
149 /// - Also known as SiLU (Sigmoid Linear Unit): Elfwing et al. (2018)
150 pub fn swish(&self) -> Result<Self> {
151 if self.data.is_empty() {
152 return Err(TruenoError::EmptyVector);
153 }
154
155 // OpComplexity::Low - GPU threshold: >100K elements
156 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
157 const GPU_THRESHOLD: usize = usize::MAX; // GPU DISABLED - 2-800x slower, see docs/performance-analysis.md
158
159 // Try GPU first for large vectors
160 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
161 {
162 if self.data.len() >= GPU_THRESHOLD {
163 use crate::backends::gpu::GpuDevice;
164 if GpuDevice::is_available() {
165 let gpu = GpuDevice::new().map_err(TruenoError::InvalidInput)?;
166 let mut result = vec![0.0; self.data.len()];
167 if gpu.swish(&self.data, &mut result).is_ok() {
168 return Ok(Vector::from_vec(result));
169 }
170 }
171 }
172 }
173
174 // Uninit: dispatch_unary_op writes every element before any read.
175 let n = self.len();
176 let mut result: Vec<f32> = Vec::with_capacity(n);
177 // SAFETY: Backend activation writes all elements before any read.
178 unsafe {
179 result.set_len(n);
180 }
181
182 // Dispatch to appropriate SIMD backend
183 dispatch_unary_op!(self.backend, swish, &self.data, &mut result);
184
185 Ok(Vector::from_vec(result))
186 }
187
188 /// Hard Swish activation function
189 ///
190 /// Applies the hardswish activation element-wise: hardswish(x) = x * relu6(x + 3) / 6
191 ///
192 /// Hardswish is a piece-wise linear approximation to swish, designed for efficient
193 /// computation in mobile neural networks. It's used in MobileNetV3 and avoids the
194 /// expensive sigmoid computation of standard swish.
195 ///
196 /// Properties:
197 /// - Piece-wise linear: efficient to compute
198 /// - hardswish(x) = 0 for x ≤ -3
199 /// - hardswish(x) = x for x ≥ 3
200 /// - hardswish(x) = x * (x + 3) / 6 for -3 < x < 3
201 /// - hardswish(0) = 0
202 /// - Smooth transitions at boundaries
203 ///
204 /// # Performance
205 ///
206 /// More efficient than swish as it uses only multiply/divide operations
207 /// instead of expensive exponential functions. Ideal for inference on
208 /// resource-constrained devices.
209 ///
210 /// # Examples
211 ///
212 /// ```
213 /// use trueno::Vector;
214 ///
215 /// let v = Vector::from_slice(&[-4.0, -3.0, 0.0, 3.0, 4.0]);
216 /// let result = v.hardswish()?;
217 ///
218 /// // Piece-wise linear behavior
219 /// assert_eq!(result.as_slice()[0], 0.0); // x ≤ -3 → 0
220 /// assert_eq!(result.as_slice()[1], 0.0); // x = -3 → 0
221 /// assert_eq!(result.as_slice()[2], 0.0); // x = 0 → 0
222 /// assert_eq!(result.as_slice()[3], 3.0); // x = 3 → x
223 /// assert_eq!(result.as_slice()[4], 4.0); // x ≥ 3 → x
224 /// # Ok::<(), trueno::TruenoError>(())
225 /// ```
226 ///
227 /// # Errors
228 ///
229 /// Returns `EmptyVector` if the input vector is empty.
230 ///
231 /// # References
232 ///
233 /// - Howard et al. (2019): "Searching for MobileNetV3"
234 pub fn hardswish(&self) -> Result<Self> {
235 if self.data.is_empty() {
236 return Err(TruenoError::EmptyVector);
237 }
238
239 // Scalar implementation: hardswish(x) = x * relu6(x + 3) / 6
240 // Simplified piece-wise:
241 // - x <= -3: 0
242 // - x >= 3: x
243 // - else: x * (x + 3) / 6
244 let data: Vec<f32> = self
245 .data
246 .iter()
247 .map(|&x| {
248 if x <= -3.0 {
249 0.0
250 } else if x >= 3.0 {
251 x
252 } else {
253 x * (x + 3.0) / 6.0
254 }
255 })
256 .collect();
257
258 Ok(Vector::from_vec(data))
259 }
260
261 /// Mish activation function
262 ///
263 /// Applies the mish activation element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
264 ///
265 /// Mish is a self-regularizing non-monotonic activation function that often outperforms
266 /// ReLU and swish in computer vision tasks. It's used in YOLOv4 and many modern architectures.
267 ///
268 /// Properties:
269 /// - Smooth and non-monotonic (similar to swish)
270 /// - Self-regularizing: prevents dying neurons
271 /// - mish(0) ≈ 0 (small positive value)
272 /// - mish(x) ≈ x for large positive x (nearly linear)
273 /// - mish(x) ≈ 0 for large negative x
274 /// - Bounded below by ≈ -0.31 at x ≈ -1.19
275 ///
276 /// # Performance
277 ///
278 /// Compute-bound operation requiring exponential, logarithm, and tanh.
279 /// More expensive than ReLU/swish but often provides better accuracy.
280 ///
281 /// # Examples
282 ///
283 /// ```
284 /// use trueno::Vector;
285 ///
286 /// let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
287 /// let result = v.mish()?;
288 ///
289 /// // Mish is smooth and self-gated
290 /// assert!(result.as_slice()[0] < 0.0); // Small negative output for negative inputs
291 /// assert!(result.as_slice()[2].abs() < 1e-5); // mish(0) = 0
292 /// assert!(result.as_slice()[4] > 1.5); // Large positive → near linear
293 /// # Ok::<(), trueno::TruenoError>(())
294 /// ```
295 ///
296 /// # Errors
297 ///
298 /// Returns `EmptyVector` if the input vector is empty.
299 ///
300 /// # References
301 ///
302 /// - Misra (2019): "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
303 pub fn mish(&self) -> Result<Self> {
304 if self.data.is_empty() {
305 return Err(TruenoError::EmptyVector);
306 }
307
308 // Scalar implementation: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
309 let data: Vec<f32> = self
310 .data
311 .iter()
312 .map(|&x| {
313 // Handle extreme values for numerical stability
314 if x < -20.0 {
315 // For very negative x: softplus ≈ 0, tanh(0) ≈ 0, so mish ≈ 0
316 0.0
317 } else if x > 20.0 {
318 // For very positive x: softplus ≈ x, tanh(x) ≈ 1, so mish ≈ x
319 x
320 } else {
321 // Normal case: x * tanh(ln(1 + e^x))
322 let softplus = (1.0 + x.exp()).ln();
323 x * softplus.tanh()
324 }
325 })
326 .collect();
327
328 Ok(Vector::from_vec(data))
329 }
330}