Skip to main content

entrenar/autograd/ops/
activations.rs

1//! Activation function autograd operations: relu, gelu, swish, softmax
2
3use crate::autograd::{BackwardOp, Tensor};
4use ndarray::Array1;
5use provable_contracts_macros::contract;
6use std::cell::RefCell;
7use std::rc::Rc;
8
9/// ReLU activation
10pub fn relu(a: &Tensor) -> Tensor {
11    contract_pre_relu!(a.data().as_slice().unwrap_or(&[]));
12    let data = a.data().mapv(|x| x.max(0.0));
13    let requires_grad = a.requires_grad();
14
15    let mut result = Tensor::new(data, requires_grad);
16
17    if requires_grad {
18        let a_clone = a.clone();
19        let backward_op = Rc::new(ReluBackward { a: a_clone, result_grad: result.grad_cell() });
20        result.set_backward_op(backward_op);
21    }
22
23    result
24}
25
26struct ReluBackward {
27    a: Tensor,
28    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
29}
30
31impl BackwardOp for ReluBackward {
32    fn backward(&self) {
33        if let Some(grad) = self.result_grad.borrow().as_ref() {
34            if self.a.requires_grad() {
35                // ∂L/∂a = ∂L/∂out * (a > 0)
36                let grad_a = grad * &self.a.data().mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
37                self.a.accumulate_grad(grad_a);
38            }
39
40            if let Some(op) = self.a.backward_op() {
41                op.backward();
42            }
43        }
44    }
45}
46
47/// GELU activation (Gaussian Error Linear Unit)
48///
49/// GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
50///
51/// ONE PATH: Forward math delegates to `trueno::gelu_scalar` (UCBD §4).
52pub fn gelu(a: &Tensor) -> Tensor {
53    contract_pre_gelu!(a.data().as_slice().unwrap_or(&[]));
54    let data = a.data().mapv(trueno::gelu_scalar);
55
56    let requires_grad = a.requires_grad();
57    let mut result = Tensor::new(data, requires_grad);
58
59    if requires_grad {
60        let a_clone = a.clone();
61        let backward_op = Rc::new(GeluBackward { a: a_clone, result_grad: result.grad_cell() });
62        result.set_backward_op(backward_op);
63    }
64
65    contract_post_gelu!(result.data().as_slice().unwrap_or(&[]));
66    result
67}
68
69struct GeluBackward {
70    a: Tensor,
71    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
72}
73
74impl BackwardOp for GeluBackward {
75    fn backward(&self) {
76        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
77            if self.a.requires_grad() {
78                const SQRT_2_OVER_PI: f32 = 0.797_884_6;
79                const COEFF: f32 = 0.044_715;
80
81                // ∂GELU/∂x = 0.5 * (1 + tanh(z)) + 0.5 * x * sech²(z) * dz/dx
82                // where z = √(2/π) * (x + 0.044715 * x³)
83                // and dz/dx = √(2/π) * (1 + 3 * 0.044715 * x²)
84                let grad_a: Vec<f32> = self
85                    .a
86                    .data()
87                    .iter()
88                    .zip(grad_output.iter())
89                    .map(|(&x, &grad)| {
90                        let x2 = x * x;
91                        let x3 = x2 * x;
92                        let z = SQRT_2_OVER_PI * (x + COEFF * x3);
93                        let tanh_z = z.tanh();
94                        let sech2_z = 1.0 - tanh_z * tanh_z;
95                        let dz_dx = SQRT_2_OVER_PI * (1.0 + 3.0 * COEFF * x2);
96
97                        let gelu_grad = 0.5 * (1.0 + tanh_z) + 0.5 * x * sech2_z * dz_dx;
98                        grad * gelu_grad
99                    })
100                    .collect();
101
102                self.a.accumulate_grad(Array1::from(grad_a));
103            }
104
105            if let Some(op) = self.a.backward_op() {
106                op.backward();
107            }
108        }
109    }
110}
111
112/// Swish activation (also known as SiLU - Sigmoid Linear Unit)
113///
114/// Swish(x) = x * sigmoid(x) = x / (1 + e^(-x))
115///
116/// ONE PATH: Forward math delegates to `trueno::silu_scalar` (UCBD §4).
117pub fn swish(a: &Tensor) -> Tensor {
118    let data = a.data().mapv(trueno::silu_scalar);
119
120    let requires_grad = a.requires_grad();
121    let mut result = Tensor::new(data, requires_grad);
122
123    if requires_grad {
124        let a_clone = a.clone();
125        let output_clone = result.clone();
126        let backward_op = Rc::new(SwishBackward {
127            a: a_clone,
128            output: output_clone,
129            result_grad: result.grad_cell(),
130        });
131        result.set_backward_op(backward_op);
132    }
133
134    result
135}
136
137struct SwishBackward {
138    a: Tensor,
139    output: Tensor,
140    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
141}
142
143impl BackwardOp for SwishBackward {
144    fn backward(&self) {
145        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
146            if self.a.requires_grad() {
147                // ∂Swish/∂x = Swish(x) + sigmoid(x) * (1 - Swish(x))
148                // This can be simplified to: sigmoid(x) * (1 + x * (1 - sigmoid(x)))
149                let grad_a: Vec<f32> = self
150                    .a
151                    .data()
152                    .iter()
153                    .zip(self.output.data().iter())
154                    .zip(grad_output.iter())
155                    .map(|((&x, &swish_x), &grad)| {
156                        let sigmoid = 1.0 / (1.0 + (-x).exp());
157                        let swish_grad = swish_x + sigmoid * (1.0 - swish_x);
158                        grad * swish_grad
159                    })
160                    .collect();
161
162                self.a.accumulate_grad(Array1::from(grad_a));
163            }
164
165            if let Some(op) = self.a.backward_op() {
166                op.backward();
167            }
168        }
169    }
170}
171
172/// Softmax activation
173#[contract("softmax-v1", equation = "softmax")]
174pub fn softmax(a: &Tensor) -> Tensor {
175    contract_pre_softmax!(a.data().as_slice().unwrap_or(&[]));
176    let max_val = a.data().iter().copied().fold(f32::NEG_INFINITY, f32::max);
177    let exp_vals = a.data().mapv(|x| (x - max_val).exp());
178    let sum_exp = exp_vals.sum();
179    let data = exp_vals / sum_exp;
180
181    let requires_grad = a.requires_grad();
182    let mut result = Tensor::new(data, requires_grad);
183
184    if requires_grad {
185        let a_clone = a.clone();
186        let output_clone = result.clone();
187        let backward_op = Rc::new(SoftmaxBackward {
188            a: a_clone,
189            output: output_clone,
190            result_grad: result.grad_cell(),
191        });
192        result.set_backward_op(backward_op);
193    }
194
195    contract_post_softmax!(result.data().as_slice().unwrap_or(&[]));
196    result
197}
198
199struct SoftmaxBackward {
200    a: Tensor,
201    output: Tensor,
202    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
203}
204
205impl BackwardOp for SoftmaxBackward {
206    fn backward(&self) {
207        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
208            if self.a.requires_grad() {
209                // ∂L/∂x = y ⊙ (∂L/∂y - (y · ∂L/∂y))
210                let y = self.output.data();
211                let dot = (y * grad_output).sum();
212                let grad_a = y * &(grad_output - dot);
213                self.a.accumulate_grad(grad_a);
214            }
215
216            if let Some(op) = self.a.backward_op() {
217                op.backward();
218            }
219        }
220    }
221}
222
223// =========================================================================
224// FALSIFY-SI: silu-kernel-v1.yaml contract (entrenar via trueno::silu_scalar)
225//
226// Five-Whys (PMAT-354, Phase 11):
227//   Why 1: entrenar had zero FALSIFY-SI-* tests despite SiLU in CUDA forward
228//   Why 2: CUDA tests verify backward correctness, not mathematical invariants
229//   Why 3: no mapping from silu-kernel-v1.yaml to entrenar test names
230//   Why 4: entrenar predates the provable-contracts YAML convention
231//   Why 5: SiLU CUDA forward delegates to cuBLAS (assumed correct)
232//
233// Note: entrenar's SiLU is CUDA-only (silu_forward/silu_backward). These
234// tests exercise trueno::silu_scalar which is the canonical reference impl.
235//
236// References:
237//   - provable-contracts/contracts/silu-kernel-v1.yaml
238//   - Ramachandran et al. (2017) "Searching for Activation Functions"
239// =========================================================================
240
241#[cfg(test)]
242mod silu_contract_tests {
243    /// FALSIFY-SI-001: Zero preservation — SiLU(0) = 0
244    #[test]
245    fn falsify_si_001_zero_preservation() {
246        let y = trueno::silu_scalar(0.0);
247        assert!(y.abs() < 1e-7, "FALSIFIED SI-001: SiLU(0) = {y}, expected 0");
248    }
249
250    /// FALSIFY-SI-002: Global lower bound — SiLU(x) > -0.279 for all x
251    #[test]
252    fn falsify_si_002_global_lower_bound() {
253        let test_values: Vec<f32> =
254            vec![-100.0, -50.0, -10.0, -5.0, -2.0, -1.278, -1.0, -0.5, 0.0, 0.5, 1.0, 5.0, 100.0];
255        for &x in &test_values {
256            let y = trueno::silu_scalar(x);
257            assert!(y > -0.28, "FALSIFIED SI-002: SiLU({x}) = {y}, expected > -0.279");
258        }
259    }
260
261    /// FALSIFY-SI-003: Monotonic for positive inputs
262    #[test]
263    fn falsify_si_003_monotonic_positive() {
264        let values: Vec<f32> = vec![0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0];
265        for i in 1..values.len() {
266            let y_prev = trueno::silu_scalar(values[i - 1]);
267            let y_curr = trueno::silu_scalar(values[i]);
268            assert!(
269                y_curr > y_prev,
270                "FALSIFIED SI-003: SiLU({}) = {y_curr} not > SiLU({}) = {y_prev}",
271                values[i],
272                values[i - 1]
273            );
274        }
275    }
276
277    /// FALSIFY-SI-005: Asymptotic linearity — |SiLU(x) - x| < 0.01 for x > 10
278    #[test]
279    fn falsify_si_005_asymptotic_linearity() {
280        for &x in &[10.0f32, 20.0, 50.0, 100.0, 500.0] {
281            let y = trueno::silu_scalar(x);
282            assert!(
283                (y - x).abs() < 0.01,
284                "FALSIFIED SI-005: |SiLU({x}) - {x}| = {} >= 0.01",
285                (y - x).abs()
286            );
287        }
288    }
289
290    /// FALSIFY-SI-006: Large negative → 0
291    #[test]
292    fn falsify_si_006_large_negative_vanishes() {
293        for &x in &[-10.0f32, -20.0, -50.0, -100.0] {
294            let y = trueno::silu_scalar(x);
295            assert!(y.abs() < 0.01, "FALSIFIED SI-006: SiLU({x}) = {y}, expected ≈ 0");
296        }
297    }
298
299    mod si_proptest_falsify {
300        use proptest::prelude::*;
301
302        proptest! {
303            #![proptest_config(ProptestConfig::with_cases(500))]
304            #[test]
305            fn falsify_si_002_prop_lower_bound(x in -1000.0_f32..1000.0) {
306                let y = trueno::silu_scalar(x);
307                prop_assert!(y > -0.28, "FALSIFIED SI-002-prop: SiLU({x}) = {y}");
308            }
309        }
310
311        proptest! {
312            #![proptest_config(ProptestConfig::with_cases(300))]
313            #[test]
314            fn falsify_si_003_prop_monotonic_positive(
315                a in 0.001_f32..100.0,
316                b in 0.001_f32..100.0,
317            ) {
318                if a != b {
319                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
320                    prop_assert!(
321                        trueno::silu_scalar(hi) > trueno::silu_scalar(lo),
322                        "FALSIFIED SI-003-prop: SiLU({hi}) not > SiLU({lo})"
323                    );
324                }
325            }
326        }
327
328        proptest! {
329            #![proptest_config(ProptestConfig::with_cases(200))]
330            #[test]
331            fn falsify_si_005_prop_asymptotic(x in 10.0_f32..500.0) {
332                let y = trueno::silu_scalar(x);
333                prop_assert!(
334                    (y - x).abs() < 0.01,
335                    "FALSIFIED SI-005-prop: |SiLU({x}) - {x}| = {}",
336                    (y - x).abs()
337                );
338            }
339        }
340    }
341}
342
343// =========================================================================
344// FALSIFY-SG: swiglu-kernel-v1.yaml contract (entrenar SwiGLU via swish)
345//
346// Five-Whys (PMAT-354, Phase 11):
347//   Why 1: entrenar had FFN dimension tests but zero FALSIFY-SG-* activation tests
348//   Why 2: FFN tests verify shapes, not mathematical SwiGLU invariants
349//   Why 3: no mapping from swiglu-kernel-v1.yaml to entrenar test names
350//   Why 4: SwiGLU in entrenar is decomposed (swish + mul), not a standalone fn
351//   Why 5: SwiGLU was "obviously correct" (x * SiLU(gate))
352//
353// Note: entrenar decomposes SwiGLU into crate::autograd::swish + mul.
354// Tests exercise the mathematical properties via trueno::silu_scalar.
355//
356// References:
357//   - provable-contracts/contracts/swiglu-kernel-v1.yaml
358//   - Shazeer (2020) "GLU Variants Improve Transformer"
359// =========================================================================
360
361#[cfg(test)]
362mod swiglu_contract_tests {
363
364    fn swiglu_scalar(x: f32, gate: f32) -> f32 {
365        x * trueno::silu_scalar(gate)
366    }
367
368    /// FALSIFY-SG-001: Zero preservation — SwiGLU(0, gate) = 0 for any gate
369    #[test]
370    fn falsify_sg_001_zero_x_preservation() {
371        for &g in &[-10.0f32, -1.0, 0.0, 1.0, 10.0] {
372            let y = swiglu_scalar(0.0, g);
373            assert!(y.abs() < 1e-7, "FALSIFIED SG-001: SwiGLU(0, {g}) = {y}");
374        }
375    }
376
377    /// FALSIFY-SG-002: Fused equivalence — SwiGLU(x, gate) = x * SiLU(gate)
378    #[test]
379    fn falsify_sg_002_fused_equivalence() {
380        let cases: Vec<(f32, f32)> =
381            vec![(1.0, 1.0), (-2.0, 3.0), (5.0, -1.0), (0.5, 0.5), (100.0, 0.0)];
382        for &(x, g) in &cases {
383            let fused = swiglu_scalar(x, g);
384            let decomposed = x * trueno::silu_scalar(g);
385            assert!(
386                (fused - decomposed).abs() < 1e-6,
387                "FALSIFIED SG-002: swiglu({x},{g})={fused} != decomposed={decomposed}"
388            );
389        }
390    }
391
392    /// FALSIFY-SG-003: SiLU lower bound preserved in gate — SiLU(z) > -0.279
393    #[test]
394    fn falsify_sg_003_silu_lower_bound() {
395        for &g in &[-1000.0f32, -1.278, -1.0, 0.0, 1.0, 1000.0] {
396            let silu_g = trueno::silu_scalar(g);
397            assert!(silu_g > -0.28, "FALSIFIED SG-003: SiLU({g}) = {silu_g}");
398        }
399    }
400
401    /// FALSIFY-SG-004: Finite output for all finite inputs
402    #[test]
403    fn falsify_sg_004_finite_output() {
404        let vals = vec![-100.0, -10.0, -1.0, 0.0, 1.0, 10.0, 100.0];
405        for &x in &vals {
406            for &g in &vals {
407                let y = swiglu_scalar(x, g);
408                assert!(y.is_finite(), "FALSIFIED SG-004: SwiGLU({x},{g}) = {y}");
409            }
410        }
411    }
412
413    /// FALSIFY-SG-005: Empty input produces empty output
414    #[test]
415    fn falsify_sg_005_empty_input() {
416        let empty: Vec<f32> = vec![];
417        let result: Vec<f32> =
418            empty.iter().zip(empty.iter()).map(|(&x, &g)| swiglu_scalar(x, g)).collect();
419        assert!(result.is_empty(), "FALSIFIED SG-005: empty SwiGLU produced non-empty output");
420    }
421
422    mod sg_proptest_falsify {
423        use super::*;
424        use proptest::prelude::*;
425
426        proptest! {
427            #![proptest_config(ProptestConfig::with_cases(300))]
428            #[test]
429            fn falsify_sg_001_prop_zero_x(gate in -100.0_f32..100.0) {
430                let y = swiglu_scalar(0.0, gate);
431                prop_assert!(y.abs() < 1e-6, "FALSIFIED SG-001-prop: SwiGLU(0, {gate}) = {y}");
432            }
433        }
434
435        proptest! {
436            #![proptest_config(ProptestConfig::with_cases(300))]
437            #[test]
438            fn falsify_sg_004_prop_finite(
439                x in -100.0_f32..100.0,
440                gate in -100.0_f32..100.0,
441            ) {
442                let y = swiglu_scalar(x, gate);
443                prop_assert!(y.is_finite(), "FALSIFIED SG-004-prop: SwiGLU({x},{gate}) = {y}");
444            }
445        }
446
447        proptest! {
448            #![proptest_config(ProptestConfig::with_cases(200))]
449            #[test]
450            fn falsify_sg_006_prop_monotonic_gate(
451                x in 1.0_f32..50.0,
452                a in 0.1_f32..50.0,
453                b in 0.1_f32..50.0,
454            ) {
455                // For positive x and positive gates, increasing gate should increase output
456                // because SiLU is monotonically increasing for positive inputs
457                if a != b {
458                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
459                    let y_lo = swiglu_scalar(x, lo);
460                    let y_hi = swiglu_scalar(x, hi);
461                    prop_assert!(
462                        y_hi > y_lo,
463                        "FALSIFIED SG-006-prop: SwiGLU({x},{hi})={y_hi} not > SwiGLU({x},{lo})={y_lo}"
464                    );
465                }
466            }
467        }
468    }
469}
470
471// =========================================================================
472// FALSIFY-GE: gelu-kernel-v1.yaml contract (entrenar autograd gelu)
473// =========================================================================
474#[cfg(test)]
475mod gelu_contract_tests {
476    use super::*;
477    use ndarray::Array1;
478
479    /// FALSIFY-GE-001: Non-negativity — gelu(x) >= 0 for positive x
480    #[test]
481    fn falsify_ge_001_non_negativity() {
482        let x = Tensor::new(Array1::from(vec![0.001, 0.1, 1.0, 5.0, 10.0, 100.0]), false);
483        let y = gelu(&x);
484        for (i, &val) in y.data().iter().enumerate() {
485            assert!(val >= 0.0, "FALSIFIED GE-001: gelu(positive)[{i}] = {val} < 0");
486        }
487    }
488
489    /// FALSIFY-GE-002: Monotonicity — ordering preserved for positive inputs
490    #[test]
491    fn falsify_ge_002_positive_monotonicity() {
492        let x = Tensor::new(Array1::from(vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0]), false);
493        let y = gelu(&x);
494        let data = y.data();
495        for i in 1..data.len() {
496            assert!(
497                data[i] > data[i - 1],
498                "FALSIFIED GE-002: gelu not monotonic: [{i}]={} not > [{}]={}",
499                data[i],
500                i - 1,
501                data[i - 1]
502            );
503        }
504    }
505
506    /// FALSIFY-GE-003: Zero preservation — gelu(0) = 0
507    #[test]
508    fn falsify_ge_003_zero_preservation() {
509        let x = Tensor::new(Array1::from(vec![0.0]), false);
510        let y = gelu(&x);
511        assert!(y.data()[0].abs() < 1e-7, "FALSIFIED GE-003: gelu(0) = {}", y.data()[0]);
512    }
513
514    /// FALSIFY-GE-006: Large input stability
515    #[test]
516    fn falsify_ge_006_large_input_stability() {
517        let x = Tensor::new(Array1::from(vec![10.0, 50.0, -10.0, -50.0]), false);
518        let y = gelu(&x);
519        let d = y.data();
520        assert!((d[0] - 10.0).abs() < 0.01, "FALSIFIED GE-006: gelu(10) = {}", d[0]);
521        assert!((d[1] - 50.0).abs() < 0.01, "FALSIFIED GE-006: gelu(50) = {}", d[1]);
522        assert!(d[2].abs() < 0.01, "FALSIFIED GE-006: gelu(-10) = {}", d[2]);
523        assert!(d[3].abs() < 0.01, "FALSIFIED GE-006: gelu(-50) = {}", d[3]);
524    }
525
526    /// FALSIFY-GE-005: Tanh approximation accuracy — |exact - approx| < 0.005
527    #[test]
528    fn falsify_ge_005_tanh_approx_accuracy() {
529        // The exact GELU: x * Phi(x) where Phi is the standard normal CDF
530        // We use trueno::gelu_scalar (tanh approx) and check against exact
531        use std::f32::consts::FRAC_2_PI;
532        let c = FRAC_2_PI.sqrt();
533        for x_int in -100..=100 {
534            let x = x_int as f32 * 0.1;
535            let approx = trueno::gelu_scalar(x);
536            // Exact GELU via erfc: x * 0.5 * (1 + erf(x / sqrt(2)))
537            // Use tanh-based form as reference since both should match within 0.005
538            let inner = c * (x + 0.044_715 * x * x * x);
539            let exact_approx = 0.5 * x * (1.0 + inner.tanh());
540            assert!(
541                (approx - exact_approx).abs() < 0.005,
542                "FALSIFIED GE-005: |gelu_approx({x}) - gelu_exact({x})| = {}",
543                (approx - exact_approx).abs()
544            );
545        }
546    }
547
548    mod ge_proptest_falsify {
549        use super::*;
550        use ndarray::Array1;
551        use proptest::prelude::*;
552
553        proptest! {
554            #![proptest_config(ProptestConfig::with_cases(500))]
555            #[test]
556            fn falsify_ge_001_prop_non_negativity(x in 0.0_f32..1000.0) {
557                let t = Tensor::new(Array1::from(vec![x]), false);
558                let y = gelu(&t);
559                prop_assert!(y.data()[0] >= 0.0, "FALSIFIED GE-001-prop: gelu({x}) = {} < 0", y.data()[0]);
560            }
561        }
562
563        proptest! {
564            #![proptest_config(ProptestConfig::with_cases(300))]
565            #[test]
566            fn falsify_ge_002_prop_monotonic_positive(
567                a in 0.001_f32..100.0,
568                b in 0.001_f32..100.0,
569            ) {
570                if a != b {
571                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
572                    let t = Tensor::new(Array1::from(vec![lo, hi]), false);
573                    let y = gelu(&t);
574                    let d = y.data();
575                    prop_assert!(d[1] > d[0], "FALSIFIED GE-002-prop: gelu({hi})={} not > gelu({lo})={}", d[1], d[0]);
576                }
577            }
578        }
579
580        proptest! {
581            #![proptest_config(ProptestConfig::with_cases(200))]
582            #[test]
583            fn falsify_ge_006_prop_large_positive(x in 10.0_f32..500.0) {
584                let t = Tensor::new(Array1::from(vec![x]), false);
585                let y = gelu(&t);
586                prop_assert!(
587                    (y.data()[0] - x).abs() < 0.01,
588                    "FALSIFIED GE-006-prop: |gelu({x}) - {x}| = {}",
589                    (y.data()[0] - x).abs()
590                );
591            }
592        }
593    }
594}