torsh-functional 0.1.2

Functional programming utilities for ToRSh tensors
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
//! # Activation Functions for Neural Networks
//!
//! This module provides a comprehensive collection of activation functions organized into
//! focused sub-modules for better maintainability and discoverability.
//!
//! ## Mathematical Foundation
//!
//! Activation functions introduce **non-linearity** into neural networks, enabling them to
//! learn complex patterns beyond linear transformations. Without activation functions, a
//! multi-layer network would be equivalent to a single-layer linear transformation.
//!
//! ### Role in Neural Networks
//! ```text
//! Layer output: y = σ(Wx + b)
//! ```
//! where:
//! - `W` is the weight matrix
//! - `x` is the input
//! - `b` is the bias
//! - `σ` is the activation function
//!
//! ### Key Properties
//!
//! #### Non-linearity
//! - **Essential**: Linear compositions remain linear: f(g(x)) = linear
//! - **Non-linear**: Enables learning of complex decision boundaries
//!
//! #### Differentiability
//! - Required for backpropagation and gradient descent
//! - Sub-differentiable at isolated points (e.g., ReLU at 0) is acceptable
//!
//! #### Range and Saturation
//! - **Unbounded** (ReLU): Can cause exploding activations
//! - **Bounded** (Sigmoid, Tanh): May cause vanishing gradients in deep networks
//!
//! #### Zero-Centered
//! - **Zero-centered** (Tanh): Gradients can be positive or negative, faster convergence
//! - **Not zero-centered** (ReLU, Sigmoid): Can cause zig-zagging dynamics
//!
//! ## Activation Function Families
//!
//! ### ReLU Family (Piecewise Linear)
//! ```text
//! ReLU(x) = max(0, x)
//! LeakyReLU(x) = max(αx, x)  where α ∈ (0, 1)
//! ELU(x) = { x if x > 0, α(exp(x) - 1) if x ≤ 0 }
//! ```
//! **Best for**: Hidden layers in deep networks, CNNs
//! **Advantages**: Computationally efficient, sparse activation, no vanishing gradient for x > 0
//! **Disadvantages**: "Dying ReLU" problem, not zero-centered
//!
//! ### Sigmoid Family (Smooth Bounded)
//! ```text
//! Sigmoid(x) = 1 / (1 + exp(-x))  ∈ (0, 1)
//! SiLU(x) = x · Sigmoid(x)  (Swish)
//! ```
//! **Best for**: Binary classification (output layer), gates in LSTMs
//! **Advantages**: Smooth, interpretable as probability
//! **Disadvantages**: Vanishing gradient problem, not zero-centered, expensive exp()
//!
//! ### Tanh Family (Zero-Centered Bounded)
//! ```text
//! Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))  ∈ (-1, 1)
//! ```
//! **Best for**: Hidden layers when zero-centered output desired, RNNs
//! **Advantages**: Zero-centered, stronger gradients than sigmoid
//! **Disadvantages**: Still suffers from vanishing gradient, expensive computation
//!
//! ### Softmax Family (Normalization)
//! ```text
//! Softmax(x_i) = exp(x_i) / Σ_j exp(x_j)
//! ```
//! **Best for**: Multi-class classification (output layer)
//! **Advantages**: Probabilistic interpretation, differentiable
//! **Disadvantages**: Only for output layer, sensitive to outliers
//!
//! ### Advanced Functions (Modern)
//! ```text
//! GELU(x) = x · Φ(x)  where Φ is Gaussian CDF
//! Mish(x) = x · tanh(softplus(x))
//! ```
//! **Best for**: Transformers (GELU), general deep learning (Mish)
//! **Advantages**: Smooth, non-monotonic, state-of-the-art performance
//! **Disadvantages**: More expensive computation
//!
//! ## Performance Characteristics
//!
//! ### Computational Complexity (per element)
//! - **ReLU family**: O(1) - simple comparison/multiplication
//! - **Sigmoid/Tanh**: O(1) but requires exp() - ~10-100x slower than ReLU
//! - **Softmax**: O(n) where n is the number of classes - reduction operation
//! - **GELU/Mish**: O(1) but more complex than ReLU - ~2-5x slower
//!
//! ### Memory Usage
//! - **Standard operations**: O(n) - same as input
//! - **In-place operations**: O(1) - modify input directly
//! - **Softmax**: O(n) - requires temporary storage for normalization
//!
//! ### Gradient Computation
//! - **ReLU**: Fastest - binary gradient (0 or 1)
//! - **Sigmoid/Tanh**: Moderate - requires output value
//! - **Softmax**: Expensive - requires full Jacobian for multi-dimensional
//!
//! ## Choosing the Right Activation
//!
//! ### Decision Tree
//! 1. **Output Layer?**
//!    - Binary classification → **Sigmoid**
//!    - Multi-class classification → **Softmax**
//!    - Regression → **None** or **ReLU** (for non-negative)
//!
//! 2. **Hidden Layer in CNN?**
//!    - Default → **ReLU**
//!    - Want smoothness → **GELU**
//!    - Concerned about dying ReLU → **Leaky ReLU** or **ELU**
//!
//! 3. **Hidden Layer in Transformer?**
//!    - Modern standard → **GELU**
//!    - Alternative → **SiLU/Swish**
//!
//! 4. **Hidden Layer in RNN/LSTM?**
//!    - Gates → **Sigmoid** (by design)
//!    - Hidden state → **Tanh** (by design)
//!
//! 5. **Memory Constrained?**
//!    - Use **in-place variants** (relu_, sigmoid_, etc.)
//!
//! ## Common Use Cases
//!
//! ### Convolutional Neural Network
//! ```rust,no_run
//! # use torsh_functional::activations::{relu, softmax};
//! # use torsh_functional::random_ops::randn;
//! # fn example() -> torsh_core::Result<()> {
//! // Conv → ReLU → Pool → Conv → ReLU → Pool → FC → Softmax
//! let conv1_out = randn(&[32, 64, 28, 28], None, None, None)?;
//! let relu1_out = relu(&conv1_out, false)?;
//!
//! // ... more layers ...
//!
//! let logits = randn(&[32, 10], None, None, None)?;
//! let predictions = softmax(&logits, 1, None)?;
//! # Ok(())
//! # }
//! ```
//!
//! ### Transformer Architecture
//! ```rust,no_run
//! # use torsh_functional::activations::gelu;
//! # use torsh_functional::random_ops::randn;
//! # fn example() -> torsh_core::Result<()> {
//! // Attention → LayerNorm → FFN(GELU) → LayerNorm
//! let ffn_input = randn(&[32, 512, 768], None, None, None)?;
//! let weights1 = randn(&[768, 3072], None, None, None)?;
//! let hidden = ffn_input.matmul(&weights1)?;
//! let activated = gelu(&hidden)?;
//! let weights2 = randn(&[3072, 768], None, None, None)?;
//! let output = activated.matmul(&weights2)?;
//! # Ok(())
//! # }
//! ```
//!
//! ## Organization
//!
//! The activation functions are organized into the following sub-modules:
//!
//! - [`relu_family`]: ReLU and related variants (ReLU, Leaky ReLU, ELU, SELU, etc.)
//! - [`sigmoid_family`]: Sigmoid and related variants (Sigmoid, Hard Sigmoid, SiLU, etc.)
//! - [`tanh_family`]: Tanh and related variants (Tanh, Softsign, Hardtanh, etc.)
//! - [`softmax_family`]: Softmax and related variants (Softmax, Log Softmax, Softmin, etc.)
//! - [`advanced`]: Advanced functions (GELU, GLU, Scaled Dot-Product Attention, etc.)
//! - [`inplace`]: In-place variants for memory-efficient operations
//!
//! ## Quick Reference
//!
//! | Function | Range | Zero-Centered | Computation | Best For |
//! |----------|-------|---------------|-------------|----------|
//! | ReLU | [0, ∞) | No | Fast | CNNs, hidden layers |
//! | Leaky ReLU | (-∞, ∞) | No | Fast | Avoid dying ReLU |
//! | ELU | (-α, ∞) | Nearly | Moderate | Smoother gradients |
//! | Sigmoid | (0, 1) | No | Slow | Binary output |
//! | Tanh | (-1, 1) | Yes | Slow | RNN hidden layers |
//! | Softmax | (0, 1), sum=1 | No | Moderate | Multi-class output |
//! | GELU | (-∞, ∞) | Nearly | Moderate | Transformers |
//! | SiLU/Swish | (-∞, ∞) | No | Moderate | General deep learning |

// Sub-modules
pub mod advanced;
pub mod inplace;
pub mod relu_family;
pub mod sigmoid_family;
pub mod softmax_family;
pub mod tanh_family;

// Helper functions for reducing code duplication
use torsh_core::dtype::FloatElement;
use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;

/// Generic element-wise activation function helper
///
/// This function eliminates code duplication across activation functions by providing
/// a common pattern for element-wise transformations.
///
/// # Parameters
/// - `input`: Input tensor
/// - `operation`: Closure that transforms each element
///
/// # Returns
/// New tensor with the same shape and device as input, with transformed elements
pub fn apply_elementwise<T, F>(input: &Tensor<T>, operation: F) -> TorshResult<Tensor<T>>
where
    T: FloatElement + Copy,
    F: Fn(T) -> T,
{
    let data = input.data()?;
    let result_data: Vec<T> = data.iter().map(|&x| operation(x)).collect();
    Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}

/// Generic element-wise activation function with inplace support
///
/// This function provides a common pattern for activation functions that support
/// both inplace and non-inplace operations.
///
/// # Parameters
/// - `input`: Input tensor
/// - `inplace`: Whether to perform operation in-place (currently creates new tensor regardless)
/// - `operation`: Closure that transforms each element
///
/// # Returns
/// New tensor with transformed elements
pub fn apply_elementwise_inplace<T, F>(
    input: &Tensor<T>,
    _inplace: bool,
    operation: F,
) -> TorshResult<Tensor<T>>
where
    T: FloatElement + Copy,
    F: Fn(T) -> T,
{
    // Note: Currently always creates new tensor for simplicity
    // True inplace operations would require mutable tensor interface
    apply_elementwise(input, operation)
}

// Re-export all functions from sub-modules for backward compatibility and convenience

// ReLU family functions
pub use relu_family::{
    celu, elu, hardshrink, leaky_relu, prelu, relu, relu6, rrelu, selu, softshrink, threshold,
};

// Sigmoid family functions
pub use sigmoid_family::{
    hardsigmoid, hardsigmoid_v2, hardswish, log_sigmoid, mish, sigmoid, silu, softplus, swish,
};

// Tanh family functions
pub use tanh_family::{hardtanh, softsign, tanh, tanhshrink};

// Softmax family functions
pub use softmax_family::{gumbel_softmax, log_softmax, softmax, softmin};

// Advanced functions
pub use advanced::{gelu, glu, local_response_norm, scaled_dot_product_attention};

// In-place functions
pub use inplace::{gelu_, leaky_relu_, relu_, sigmoid_, silu_, tanh_};

// All functions are already re-exported above, no need for additional aliases

#[cfg(test)]
mod integration_tests {
    use super::*;
    use torsh_core::device::DeviceType;
    use torsh_tensor::creation::from_vec;

    /// Integration test to ensure all activation functions work together
    #[test]
    fn test_activation_functions_integration() -> torsh_core::Result<()> {
        let device = DeviceType::Cpu;

        // Test input
        let input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], device)?;

        // Test ReLU family
        let _relu_out = relu(&input, false)?;
        let _leaky_relu_out = leaky_relu(&input, 0.1, false)?;
        let _elu_out = elu(&input, 1.0, false)?;
        let _selu_out = selu(&input, false)?;

        // Test sigmoid family
        let _sigmoid_out = sigmoid(&input)?;
        let _silu_out = silu(&input, false)?;
        let _mish_out = mish(&input, false)?;

        // Test tanh family
        let _tanh_out = tanh(&input)?;
        let _softsign_out = softsign(&input)?;
        let _hardtanh_out = hardtanh(&input, -1.0, 1.0)?;

        // Test advanced functions
        let _gelu_out = gelu(&input)?;

        // Test softmax with different input
        let logits = from_vec(vec![1.0, 2.0, 3.0], &[3], device)?;
        let _softmax_out = softmax(&logits, 0, None)?;
        let _log_softmax_out = log_softmax(&logits, 0, None)?;

        Ok(())
    }

    /// Test that all activation functions produce finite, non-NaN values
    #[test]
    fn test_activation_functions_numerical_stability() -> torsh_core::Result<()> {
        let device = DeviceType::Cpu;

        // Test with edge cases
        let extreme_input = from_vec(vec![-100.0, -1e-8, 0.0, 1e-8, 100.0], &[5], device)?;

        // Test functions that should handle extreme values
        let sigmoid_out = sigmoid(&extreme_input)?;
        let sigmoid_data = sigmoid_out.data()?;
        for &val in sigmoid_data.iter() {
            let val: f32 = val;
            assert!(
                val.is_finite() && !val.is_nan(),
                "Sigmoid produced invalid value: {}",
                val
            );
            assert!(
                val >= 0.0 && val <= 1.0,
                "Sigmoid value {} not in [0,1]",
                val
            );
        }

        let tanh_out = tanh(&extreme_input)?;
        let tanh_data = tanh_out.data()?;
        for &val in tanh_data.iter() {
            let val: f32 = val;
            assert!(
                val.is_finite() && !val.is_nan(),
                "Tanh produced invalid value: {}",
                val
            );
            assert!(
                val >= -1.0 && val <= 1.0,
                "Tanh value {} not in [-1,1]",
                val
            );
        }

        let gelu_out = gelu(&extreme_input)?;
        let gelu_data = gelu_out.data()?;
        for &val in gelu_data.iter() {
            let val: f32 = val;
            assert!(
                val.is_finite() && !val.is_nan(),
                "GELU produced invalid value: {}",
                val
            );
        }

        Ok(())
    }

    /// Test in-place operations
    #[test]
    fn test_inplace_operations() -> torsh_core::Result<()> {
        let device = DeviceType::Cpu;

        // Test in-place ReLU
        let mut input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], device)?;
        relu_(&mut input)?;
        let data = input.data()?;

        // Verify ReLU behavior
        assert_eq!(data[0], 0.0); // -2.0 -> 0.0
        assert_eq!(data[1], 0.0); // -1.0 -> 0.0
        assert_eq!(data[2], 0.0); // 0.0 -> 0.0
        assert_eq!(data[3], 1.0); // 1.0 -> 1.0
        assert_eq!(data[4], 2.0); // 2.0 -> 2.0

        // Test in-place sigmoid
        let mut input2 = from_vec(vec![0.0], &[1], device)?;
        sigmoid_(&mut input2)?;
        let data2 = input2.data()?;
        assert!((data2[0] - 0.5_f32).abs() < 1e-6); // sigmoid(0) = 0.5

        Ok(())
    }

    /// Test that different activation families produce expected output ranges
    #[test]
    fn test_activation_output_ranges() -> torsh_core::Result<()> {
        let device = DeviceType::Cpu;
        let input = from_vec(vec![-5.0, -1.0, 0.0, 1.0, 5.0], &[5], device)?;

        // ReLU family - non-negative outputs
        let relu_out = relu(&input, false)?;
        let relu_data = relu_out.data()?;
        for &val in relu_data.iter() {
            assert!(val >= 0.0, "ReLU output {} should be non-negative", val);
        }

        // Sigmoid family - (0, 1) range
        let sigmoid_out = sigmoid(&input)?;
        let sigmoid_data = sigmoid_out.data()?;
        for &val in sigmoid_data.iter() {
            assert!(
                val > 0.0 && val < 1.0,
                "Sigmoid output {} not in (0,1)",
                val
            );
        }

        // Tanh family - (-1, 1) range
        let tanh_out = tanh(&input)?;
        let tanh_data = tanh_out.data()?;
        for &val in tanh_data.iter() {
            assert!(val > -1.0 && val < 1.0, "Tanh output {} not in (-1,1)", val);
        }

        // Softmax - probability distribution
        let logits = from_vec(vec![1.0, 2.0, 3.0], &[3], device)?;
        let softmax_out = softmax(&logits, 0, None)?;
        let softmax_data = softmax_out.data()?;

        // Check probability properties
        let sum: f32 = softmax_data.iter().sum();
        assert!(
            (sum - 1.0).abs() < 1e-6,
            "Softmax should sum to 1, got {}",
            sum
        );
        for &val in softmax_data.iter() {
            assert!(
                val >= 0.0 && val <= 1.0,
                "Softmax output {} not in [0,1]",
                val
            );
        }

        Ok(())
    }

    /// Test monotonicity properties where applicable
    #[test]
    fn test_activation_monotonicity() -> torsh_core::Result<()> {
        let device = DeviceType::Cpu;
        let input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], device)?;

        // Test monotonic functions (note: SILU/Swish is NOT monotonic, it has a minimum around x ≈ -1.278)
        let monotonic_functions = vec![
            ("relu", relu(&input, false)?),
            ("sigmoid", sigmoid(&input)?),
            ("tanh", tanh(&input)?),
        ];

        for (name, output) in monotonic_functions {
            let data = output.data()?;
            // Check that outputs are non-decreasing (monotonic)
            for i in 1..data.len() {
                assert!(
                    data[i] >= data[i - 1],
                    "{} should be monotonic: {} < {} at indices {}, {}",
                    name,
                    data[i],
                    data[i - 1],
                    i,
                    i - 1
                );
            }
        }

        Ok(())
    }
}