Skip to main content

trueno/
activations.rs

1//! Canonical scalar activation functions.
2//!
3//! # One Path Rule (UCBD §4)
4//!
5//! These are THE canonical implementations for scalar activation functions.
6//! All downstream crates (aprender, realizar, entrenar, whisper-apr) MUST
7//! import from here instead of re-implementing.
8//!
9//! For SIMD-vectorized slice operations, see `backends::*/ops/activations`.
10//! For `Vector`-level operations, see `vector::ops::activations`.
11
12/// SiLU (Sigmoid Linear Unit) / Swish activation: x * σ(x).
13///
14/// # Equation
15/// ```text
16/// SiLU(x) = x * σ(x) = x / (1 + exp(-x))
17/// ```
18///
19/// # Contract
20/// - Domain: x ∈ ℝ
21/// - Codomain: SiLU(x) ∈ (-0.278..., ∞)
22/// - SiLU(0) = 0
23/// - limₓ→∞ SiLU(x) = x
24/// - limₓ→-∞ SiLU(x) = 0
25#[inline]
26#[must_use]
27pub fn silu_scalar(x: f32) -> f32 {
28    x / (1.0 + (-x).exp())
29}
30
31/// GELU (Gaussian Error Linear Unit) activation.
32///
33/// Uses the fast tanh approximation (same as PyTorch `gelu('tanh')`).
34///
35/// # Equation
36/// ```text
37/// GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
38/// ```
39///
40/// # Contract
41/// - Domain: x ∈ ℝ
42/// - Codomain: GELU(x) ∈ (-0.170..., ∞)
43/// - GELU(0) = 0
44/// - limₓ→∞ GELU(x) = x
45/// - limₓ→-∞ GELU(x) = 0
46#[inline]
47#[must_use]
48pub fn gelu_scalar(x: f32) -> f32 {
49    let c = (2.0_f32 / std::f32::consts::PI).sqrt();
50    0.5 * x * (1.0 + (c * (x + 0.044_715 * x * x * x)).tanh())
51}
52
53/// Sigmoid activation: σ(x) = 1 / (1 + exp(-x)).
54///
55/// # Equation
56/// ```text
57/// σ(x) = 1 / (1 + exp(-x))
58/// ```
59///
60/// # Contract
61/// - Domain: x ∈ ℝ
62/// - Codomain: σ(x) ∈ (0, 1)
63/// - σ(0) = 0.5
64/// - σ(-x) = 1 - σ(x) (symmetry)
65#[inline]
66#[must_use]
67pub fn sigmoid_scalar(x: f32) -> f32 {
68    1.0 / (1.0 + (-x).exp())
69}
70
71/// ReLU (Rectified Linear Unit) activation.
72///
73/// # Equation
74/// ```text
75/// ReLU(x) = max(0, x)
76/// ```
77///
78/// # Contract
79/// - Domain: x ∈ ℝ
80/// - Codomain: ReLU(x) ∈ [0, ∞)
81/// - ReLU(x) = 0 for x ≤ 0
82/// - ReLU(x) = x for x > 0
83#[inline]
84#[must_use]
85pub fn relu_scalar(x: f32) -> f32 {
86    x.max(0.0)
87}
88
89/// Tanh activation.
90///
91/// # Equation
92/// ```text
93/// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
94/// ```
95///
96/// # Contract
97/// - Domain: x ∈ ℝ
98/// - Codomain: tanh(x) ∈ (-1, 1)
99/// - tanh(0) = 0
100/// - tanh(-x) = -tanh(x) (odd function)
101#[inline]
102#[must_use]
103pub fn tanh_scalar(x: f32) -> f32 {
104    x.tanh()
105}
106
107/// f16 → f32 conversion (IEEE 754 half-precision).
108///
109/// Manual bit-manipulation implementation (no `half` crate dependency).
110/// Delegates to `tiling::q4k_matvec::f16_bits_to_f32` which is the
111/// existing canonical implementation in trueno.
112///
113/// # Contract
114/// - Domain: any u16 (interpreted as IEEE 754 binary16)
115/// - Codomain: f32 (exact representation, no precision loss for normal f16)
116/// - Subnormals, ±inf, NaN handled correctly
117#[inline]
118#[must_use]
119pub fn f16_to_f32(bits: u16) -> f32 {
120    let sign = (bits >> 15) & 0x1;
121    let exponent = (bits >> 10) & 0x1F;
122    let mantissa = bits & 0x3FF;
123
124    // Fast path: normal numbers
125    if exponent != 0 && exponent != 31 {
126        let f32_exp = (exponent as u32 + 112) as u32; // bias adjustment: 127 - 15 = 112
127        let f32_mant = (mantissa as u32) << 13; // 10 bits → 23 bits
128        let f32_bits = ((sign as u32) << 31) | (f32_exp << 23) | f32_mant;
129        return f32::from_bits(f32_bits);
130    }
131
132    // Special cases
133    if exponent == 0 {
134        if mantissa == 0 {
135            return if sign == 1 { -0.0 } else { 0.0 };
136        }
137        // Subnormal
138        const TWO_POW_NEG_14: f32 = 6.103_515_625e-5; // 2^-14
139        let m = mantissa as f32 * (1.0 / 1024.0);
140        let result = m * TWO_POW_NEG_14;
141        return if sign == 1 { -result } else { result };
142    }
143
144    // exponent == 31: Inf or NaN
145    if mantissa == 0 {
146        if sign == 1 {
147            f32::NEG_INFINITY
148        } else {
149            f32::INFINITY
150        }
151    } else {
152        f32::NAN
153    }
154}
155
156/// f32 → f16 conversion (IEEE 754 half-precision).
157///
158/// Manual bit-manipulation implementation. Rounds to nearest even.
159///
160/// # Contract
161/// - Domain: f32
162/// - Codomain: u16 (IEEE 754 binary16 bits)
163/// - Rounds to nearest even
164#[inline]
165#[must_use]
166pub fn f32_to_f16(x: f32) -> u16 {
167    let bits = x.to_bits();
168    let sign = ((bits >> 16) & 0x8000) as u16;
169    let exponent = ((bits >> 23) & 0xFF) as i32;
170    let mantissa = bits & 0x007F_FFFF;
171
172    // Special cases
173    if exponent == 255 {
174        // Inf or NaN
175        if mantissa == 0 {
176            return sign | 0x7C00; // ±Inf
177        }
178        return sign | 0x7C00 | ((mantissa >> 13) as u16).max(1); // NaN (preserve payload)
179    }
180
181    // Rebias exponent: f32 bias=127, f16 bias=15
182    let new_exp = exponent - 112; // 127 - 15
183
184    if new_exp >= 31 {
185        return sign | 0x7C00; // Overflow → ±Inf
186    }
187    if new_exp <= 0 {
188        // Subnormal or zero
189        if new_exp < -10 {
190            return sign; // Too small → ±0
191        }
192        let mant = (mantissa | 0x0080_0000) >> (1 - new_exp + 13);
193        return sign | mant as u16;
194    }
195
196    // Normal number: round to nearest even
197    let round_bit = (mantissa >> 12) & 1;
198    let mant16 = ((mantissa >> 13) as u16) + round_bit as u16;
199    sign | ((new_exp as u16) << 10) | (mant16 & 0x03FF)
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_silu_zero() {
208        assert!((silu_scalar(0.0)).abs() < 1e-7);
209    }
210
211    #[test]
212    fn test_silu_positive() {
213        // SiLU(x) → x for large positive x
214        let x = 10.0;
215        assert!((silu_scalar(x) - x).abs() < 0.01);
216    }
217
218    #[test]
219    fn test_silu_negative() {
220        // SiLU(x) → 0 for large negative x
221        assert!(silu_scalar(-10.0).abs() < 0.01);
222    }
223
224    #[test]
225    fn test_gelu_zero() {
226        assert!((gelu_scalar(0.0)).abs() < 1e-7);
227    }
228
229    #[test]
230    fn test_gelu_positive() {
231        let x = 10.0;
232        assert!((gelu_scalar(x) - x).abs() < 0.01);
233    }
234
235    #[test]
236    fn test_sigmoid_zero() {
237        assert!((sigmoid_scalar(0.0) - 0.5).abs() < 1e-7);
238    }
239
240    #[test]
241    fn test_sigmoid_symmetry() {
242        let x = 2.5;
243        assert!((sigmoid_scalar(x) + sigmoid_scalar(-x) - 1.0).abs() < 1e-6);
244    }
245
246    #[test]
247    fn test_relu_positive() {
248        assert!((relu_scalar(3.0) - 3.0).abs() < 1e-7);
249    }
250
251    #[test]
252    fn test_relu_negative() {
253        assert!((relu_scalar(-3.0)).abs() < 1e-7);
254    }
255
256    #[test]
257    fn test_tanh_zero() {
258        assert!((tanh_scalar(0.0)).abs() < 1e-7);
259    }
260
261    #[test]
262    fn test_tanh_odd() {
263        let x = 1.5;
264        assert!((tanh_scalar(x) + tanh_scalar(-x)).abs() < 1e-6);
265    }
266
267    #[test]
268    fn test_f16_roundtrip() {
269        let val = 1.5_f32;
270        let bits = f32_to_f16(val);
271        let back = f16_to_f32(bits);
272        assert!((val - back).abs() < 1e-3);
273    }
274
275    #[test]
276    fn test_f16_zero() {
277        assert_eq!(f16_to_f32(0), 0.0);
278    }
279
280    // =========================================================================
281    // FALSIFY-GE: gelu-kernel-v1.yaml contract (trueno gelu_scalar)
282    //
283    // Five-Whys (PMAT-354):
284    //   Why 1: trueno had basic gelu tests but zero FALSIFY-GE-* tests
285    //   Why 2: tests checked 2 values (zero, large), not mathematical invariants
286    //   Why 3: no mapping from gelu-kernel-v1.yaml to trueno test names
287    //   Why 4: trueno predates the provable-contracts YAML convention
288    //   Why 5: GELU was "obviously correct" (tanh approximation is textbook)
289    //
290    // References:
291    //   - provable-contracts/contracts/gelu-kernel-v1.yaml
292    //   - Hendrycks & Gimpel (2016) "Gaussian Error Linear Units (GELUs)"
293    // =========================================================================
294
295    /// FALSIFY-GE-001: Non-negativity — GELU(x) >= 0 for all x > 0
296    #[test]
297    fn falsify_ge_001_non_negativity() {
298        let test_values = [0.001, 0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0, 1e6];
299        for &x in &test_values {
300            let y = gelu_scalar(x);
301            assert!(y >= 0.0, "FALSIFIED GE-001: GELU({x}) = {y} < 0 for positive input");
302        }
303    }
304
305    /// FALSIFY-GE-002: Monotonicity — GELU(x) > GELU(y) when x > y > 0
306    #[test]
307    fn falsify_ge_002_positive_monotonicity() {
308        let values: Vec<f32> = vec![0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0];
309        for window in values.windows(2) {
310            let (y_lo, y_hi) = (gelu_scalar(window[0]), gelu_scalar(window[1]));
311            assert!(
312                y_hi > y_lo,
313                "FALSIFIED GE-002: GELU({}) = {} not > GELU({}) = {}",
314                window[1],
315                y_hi,
316                window[0],
317                y_lo
318            );
319        }
320    }
321
322    /// FALSIFY-GE-003: Zero preservation — GELU(0) = 0
323    #[test]
324    fn falsify_ge_003_zero_preservation() {
325        let y = gelu_scalar(0.0);
326        assert!(y.abs() < 1e-7, "FALSIFIED GE-003: GELU(0) = {y}, expected 0");
327    }
328
329    /// FALSIFY-GE-005: Tanh approximation vs exact CDF — |diff| < 0.005
330    ///
331    /// Exact GELU: x * Phi(x) where Phi is the standard normal CDF.
332    /// We approximate Phi via Abramowitz & Stegun erf formula (max error 1.5e-7).
333    #[test]
334    fn falsify_ge_005_tanh_approx_accuracy() {
335        // Abramowitz & Stegun erf approximation (7.1.26), max |error| < 1.5e-7
336        fn erf_approx(x: f32) -> f32 {
337            let sign = x.signum();
338            let x = x.abs();
339            let t = 1.0 / (1.0 + 0.327_591_1 * x);
340            let t2 = t * t;
341            let t3 = t2 * t;
342            let t4 = t3 * t;
343            let t5 = t4 * t;
344            let poly = 0.254_829_592 * t - 0.284_496_736 * t2 + 1.421_413_741 * t3
345                - 1.453_152_027 * t4
346                + 1.061_405_429 * t5;
347            sign * (1.0 - poly * (-x * x).exp())
348        }
349
350        fn gelu_exact(x: f32) -> f32 {
351            let phi = 0.5 * (1.0 + erf_approx(x / std::f32::consts::SQRT_2));
352            x * phi
353        }
354
355        let test_values: Vec<f32> = (-100..=100).map(|i| i as f32 * 0.1).collect();
356        for &x in &test_values {
357            let approx = gelu_scalar(x);
358            let exact = gelu_exact(x);
359            let diff = (approx - exact).abs();
360            assert!(
361                diff < 0.005,
362                "FALSIFIED GE-005: |GELU_approx({x}) - GELU_exact({x})| = {diff} >= 0.005"
363            );
364        }
365    }
366
367    /// FALSIFY-GE-006: Large input stability — GELU(x) ≈ x for large x, ≈ 0 for large -x
368    #[test]
369    fn falsify_ge_006_large_input_stability() {
370        for &x in &[10.0_f32, 50.0, 100.0, 1000.0] {
371            let y = gelu_scalar(x);
372            assert!((y - x).abs() < 0.01, "FALSIFIED GE-006: GELU({x}) = {y}, expected ≈ {x}");
373        }
374        for &x in &[-10.0_f32, -50.0, -100.0, -1000.0] {
375            let y = gelu_scalar(x);
376            assert!(y.abs() < 0.01, "FALSIFIED GE-006: GELU({x}) = {y}, expected ≈ 0");
377        }
378    }
379
380    mod ge_proptest_falsify {
381        use super::*;
382        use proptest::prelude::*;
383
384        // GE-001-prop: non-negativity for positive x
385        proptest! {
386            #![proptest_config(ProptestConfig::with_cases(500))]
387            #[test]
388            fn falsify_ge_001_prop_non_negativity(x in 0.0_f32..1000.0) {
389                let y = gelu_scalar(x);
390                prop_assert!(y >= 0.0, "FALSIFIED GE-001-prop: gelu({x}) = {y} < 0");
391            }
392        }
393
394        // GE-002-prop: monotonicity for positive pairs
395        proptest! {
396            #![proptest_config(ProptestConfig::with_cases(300))]
397            #[test]
398            fn falsify_ge_002_prop_monotonic_positive(
399                a in 0.001_f32..100.0,
400                b in 0.001_f32..100.0,
401            ) {
402                if a != b {
403                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
404                    let y_lo = gelu_scalar(lo);
405                    let y_hi = gelu_scalar(hi);
406                    prop_assert!(
407                        y_hi > y_lo,
408                        "FALSIFIED GE-002-prop: gelu({hi})={y_hi} not > gelu({lo})={y_lo}"
409                    );
410                }
411            }
412        }
413
414        // GE-006-prop: large input stability
415        proptest! {
416            #![proptest_config(ProptestConfig::with_cases(200))]
417            #[test]
418            fn falsify_ge_006_prop_large_positive(x in 10.0_f32..500.0) {
419                let y = gelu_scalar(x);
420                prop_assert!(
421                    (y - x).abs() < 0.01,
422                    "FALSIFIED GE-006-prop: |gelu({x}) - {x}| = {}",
423                    (y - x).abs()
424                );
425            }
426        }
427    }
428}
429
430// =========================================================================
431// FALSIFY-SI: silu-kernel-v1.yaml contract (trueno silu_scalar)
432//
433// Five-Whys (PMAT-354, Phase 11):
434//   Why 1: trueno had basic silu unit tests but zero FALSIFY-SI-* tests
435//   Why 2: unit tests verify point values, not mathematical invariants
436//   Why 3: no mapping from silu-kernel-v1.yaml to trueno test names
437//   Why 4: trueno predates the provable-contracts YAML convention
438//   Why 5: SiLU was "obviously correct" (x * sigmoid(x))
439//
440// References:
441//   - provable-contracts/contracts/silu-kernel-v1.yaml
442//   - Ramachandran et al. (2017) "Searching for Activation Functions"
443// =========================================================================
444
445#[cfg(test)]
446mod silu_contract_tests {
447    use super::*;
448
449    /// FALSIFY-SI-001: Zero preservation — SiLU(0) = 0
450    #[test]
451    fn falsify_si_001_zero_preservation() {
452        let y = silu_scalar(0.0);
453        assert!(y.abs() < 1e-7, "FALSIFIED SI-001: SiLU(0) = {y}, expected 0");
454    }
455
456    /// FALSIFY-SI-002: Global lower bound — SiLU(x) > -0.279 for all x
457    #[test]
458    fn falsify_si_002_global_lower_bound() {
459        let test_values: Vec<f32> =
460            vec![-100.0, -50.0, -10.0, -5.0, -2.0, -1.278, -1.0, -0.5, 0.0, 0.5, 1.0, 5.0, 100.0];
461        for &x in &test_values {
462            let y = silu_scalar(x);
463            assert!(y > -0.28, "FALSIFIED SI-002: SiLU({x}) = {y}, expected > -0.279");
464        }
465    }
466
467    /// FALSIFY-SI-003: Monotonic for positive inputs — x > y > 0 ⟹ SiLU(x) > SiLU(y)
468    #[test]
469    fn falsify_si_003_monotonic_positive() {
470        let values: Vec<f32> = vec![0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0];
471        for i in 1..values.len() {
472            let y_prev = silu_scalar(values[i - 1]);
473            let y_curr = silu_scalar(values[i]);
474            assert!(
475                y_curr > y_prev,
476                "FALSIFIED SI-003: SiLU({}) = {y_curr} not > SiLU({}) = {y_prev}",
477                values[i],
478                values[i - 1]
479            );
480        }
481    }
482
483    /// FALSIFY-SI-005: Asymptotic linearity — |SiLU(x) - x| < 0.01 for x > 10
484    #[test]
485    fn falsify_si_005_asymptotic_linearity() {
486        for &x in &[10.0f32, 20.0, 50.0, 100.0, 500.0] {
487            let y = silu_scalar(x);
488            assert!(
489                (y - x).abs() < 0.01,
490                "FALSIFIED SI-005: |SiLU({x}) - {x}| = {} >= 0.01",
491                (y - x).abs()
492            );
493        }
494    }
495
496    /// FALSIFY-SI-006: Large negative → 0 — |SiLU(x)| < 0.01 for x < -10
497    #[test]
498    fn falsify_si_006_large_negative_vanishes() {
499        for &x in &[-10.0f32, -20.0, -50.0, -100.0, -500.0] {
500            let y = silu_scalar(x);
501            assert!(y.abs() < 0.01, "FALSIFIED SI-006: SiLU({x}) = {y}, expected ≈ 0");
502        }
503    }
504
505    mod si_proptest_falsify {
506        use super::*;
507        use proptest::prelude::*;
508
509        // SI-002-prop: global lower bound
510        proptest! {
511            #![proptest_config(ProptestConfig::with_cases(500))]
512            #[test]
513            fn falsify_si_002_prop_lower_bound(x in -1000.0_f32..1000.0) {
514                let y = silu_scalar(x);
515                prop_assert!(
516                    y > -0.28,
517                    "FALSIFIED SI-002-prop: SiLU({x}) = {y} <= -0.279"
518                );
519            }
520        }
521
522        // SI-003-prop: monotonic for positive pairs
523        proptest! {
524            #![proptest_config(ProptestConfig::with_cases(300))]
525            #[test]
526            fn falsify_si_003_prop_monotonic_positive(
527                a in 0.001_f32..100.0,
528                b in 0.001_f32..100.0,
529            ) {
530                if a != b {
531                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
532                    let y_lo = silu_scalar(lo);
533                    let y_hi = silu_scalar(hi);
534                    prop_assert!(
535                        y_hi > y_lo,
536                        "FALSIFIED SI-003-prop: SiLU({hi})={y_hi} not > SiLU({lo})={y_lo}"
537                    );
538                }
539            }
540        }
541
542        // SI-005-prop: asymptotic linearity for large positive x
543        proptest! {
544            #![proptest_config(ProptestConfig::with_cases(200))]
545            #[test]
546            fn falsify_si_005_prop_asymptotic(x in 10.0_f32..500.0) {
547                let y = silu_scalar(x);
548                prop_assert!(
549                    (y - x).abs() < 0.01,
550                    "FALSIFIED SI-005-prop: |SiLU({x}) - {x}| = {}",
551                    (y - x).abs()
552                );
553            }
554        }
555    }
556}