Skip to main content

provable_contracts/kernels/
activation.rs

1//! Activation kernels: `ReLU`, `GELU`, `SiLU`.
2//!
3//! Matches `activation-kernel-v1.yaml`.
4//!
5//! Each function provides one of three backends:
6//! - `fn {name}_scalar(...)` — Pure Rust scalar reference (ground truth)
7//! - `unsafe fn {name}_avx2(...)` — AVX2 SIMD implementation
8//! - `fn {name}_ptx() -> &'static str` — PTX assembly source string
9
10use std::f32::consts::PI;
11
12// ────────────────────────────────────────────────────────────────────────────
13// Scalar implementations
14// ────────────────────────────────────────────────────────────────────────────
15
16/// `ReLU`: max(0, x)
17///
18/// # Panics
19/// Panics if `input.len() != output.len()`.
20pub fn relu_scalar(input: &[f32], output: &mut [f32]) {
21    assert_eq!(input.len(), output.len());
22    for (x, y) in input.iter().zip(output.iter_mut()) {
23        *y = x.max(0.0);
24    }
25}
26
27/// `GELU`: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
28///
29/// # Panics
30/// Panics if `input.len() != output.len()`.
31pub fn gelu_scalar(input: &[f32], output: &mut [f32]) {
32    assert_eq!(input.len(), output.len());
33    let sqrt_2_over_pi = (2.0f32 / PI).sqrt();
34    for (x, y) in input.iter().zip(output.iter_mut()) {
35        let inner = sqrt_2_over_pi * (x + 0.044_715 * x * x * x);
36        *y = 0.5 * x * (1.0 + inner.tanh());
37    }
38}
39
40/// `SiLU` (Swish): x / (1 + exp(-x))
41///
42/// # Panics
43/// Panics if `input.len() != output.len()`.
44pub fn silu_scalar(input: &[f32], output: &mut [f32]) {
45    assert_eq!(input.len(), output.len());
46    for (x, y) in input.iter().zip(output.iter_mut()) {
47        *y = x / (1.0 + (-x).exp());
48    }
49}
50
51// ────────────────────────────────────────────────────────────────────────────
52// AVX2 implementations
53// ────────────────────────────────────────────────────────────────────────────
54
55#[cfg(target_arch = "x86_64")]
56use std::arch::x86_64::{_mm256_loadu_ps, _mm256_max_ps, _mm256_setzero_ps, _mm256_storeu_ps};
57
58/// AVX2 `ReLU`: `_mm256_max_ps(x, zero)` with scalar tail.
59///
60/// # Safety
61/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
62///
63/// # Panics
64/// Panics if `input.len() != output.len()`.
65#[cfg(target_arch = "x86_64")]
66#[target_feature(enable = "avx2")]
67pub unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) {
68    assert_eq!(input.len(), output.len());
69    let n = input.len();
70    // SAFETY: caller guarantees AVX2 is available; target_feature gate enforces it.
71    unsafe {
72        let zero = _mm256_setzero_ps();
73        let mut i = 0;
74        while i + 8 <= n {
75            let v = _mm256_loadu_ps(input.as_ptr().add(i));
76            let r = _mm256_max_ps(v, zero);
77            _mm256_storeu_ps(output.as_mut_ptr().add(i), r);
78            i += 8;
79        }
80        // Scalar tail for remaining elements
81        for j in i..n {
82            output[j] = input[j].max(0.0);
83        }
84    }
85}
86
87/// AVX2 `GELU` — delegates to scalar (no hardware `tanh` in AVX2).
88///
89/// # Safety
90/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
91///
92/// # Panics
93/// Panics if `input.len() != output.len()`.
94#[cfg(target_arch = "x86_64")]
95#[target_feature(enable = "avx2")]
96pub unsafe fn gelu_avx2(input: &[f32], output: &mut [f32]) {
97    gelu_scalar(input, output);
98}
99
100/// AVX2 `SiLU` — delegates to scalar (no hardware `exp` in AVX2).
101///
102/// # Safety
103/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
104///
105/// # Panics
106/// Panics if `input.len() != output.len()`.
107#[cfg(target_arch = "x86_64")]
108#[target_feature(enable = "avx2")]
109pub unsafe fn silu_avx2(input: &[f32], output: &mut [f32]) {
110    silu_scalar(input, output);
111}
112
113include!("activation_ptx.rs");
114
115// ────────────────────────────────────────────────────────────────────────────
116// Tests
117// ────────────────────────────────────────────────────────────────────────────
118
119#[cfg(test)]
120mod tests {
121    use super::super::ulp::assert_ulp_eq;
122    use super::*;
123    use proptest::prelude::*;
124
125    // ── ReLU known-answer tests ──────────────────────────────────────────
126
127    #[test]
128    fn test_relu_negative_to_zero() {
129        let input = [-3.0f32, -1.0, -0.5, -1e-6];
130        let mut output = vec![0.0f32; input.len()];
131        relu_scalar(&input, &mut output);
132        for &y in &output {
133            assert_eq!(y, 0.0);
134        }
135    }
136
137    #[test]
138    fn test_relu_positive_identity() {
139        let input = [0.5f32, 1.0, 3.0, 100.0];
140        let mut output = vec![0.0f32; input.len()];
141        relu_scalar(&input, &mut output);
142        for (x, y) in input.iter().zip(output.iter()) {
143            assert_eq!(x, y);
144        }
145    }
146
147    #[test]
148    fn test_relu_zero() {
149        let input = [0.0f32];
150        let mut output = vec![0.0f32; 1];
151        relu_scalar(&input, &mut output);
152        assert_eq!(output[0], 0.0);
153    }
154
155    // ── GELU known-answer tests ──────────────────────────────────────────
156
157    #[test]
158    fn test_gelu_zero() {
159        let input = [0.0f32];
160        let mut output = vec![0.0f32; 1];
161        gelu_scalar(&input, &mut output);
162        assert!(
163            (output[0]).abs() < 1e-7,
164            "GELU(0) should be 0, got {}",
165            output[0]
166        );
167    }
168
169    #[test]
170    fn test_gelu_large_positive() {
171        let input = [10.0f32];
172        let mut output = vec![0.0f32; 1];
173        gelu_scalar(&input, &mut output);
174        // For large positive x, GELU(x) ~ x
175        assert!(
176            (output[0] - 10.0).abs() < 1e-4,
177            "GELU(10) should be ~10, got {}",
178            output[0]
179        );
180    }
181
182    #[test]
183    fn test_gelu_large_negative() {
184        let input = [-10.0f32];
185        let mut output = vec![0.0f32; 1];
186        gelu_scalar(&input, &mut output);
187        // For large negative x, GELU(x) ~ 0
188        assert!(
189            output[0].abs() < 1e-4,
190            "GELU(-10) should be ~0, got {}",
191            output[0]
192        );
193    }
194
195    // ── SiLU known-answer tests ──────────────────────────────────────────
196
197    #[test]
198    fn test_silu_zero() {
199        let input = [0.0f32];
200        let mut output = vec![0.0f32; 1];
201        silu_scalar(&input, &mut output);
202        assert!(
203            (output[0]).abs() < 1e-7,
204            "SiLU(0) should be 0, got {}",
205            output[0]
206        );
207    }
208
209    #[test]
210    fn test_silu_positive() {
211        let input = [1.0f32];
212        let mut output = vec![0.0f32; 1];
213        silu_scalar(&input, &mut output);
214        // SiLU(1) = 1 / (1 + exp(-1)) ~ 0.7310586
215        let expected = 1.0 / (1.0 + (-1.0f32).exp());
216        assert!(
217            (output[0] - expected).abs() < 1e-6,
218            "SiLU(1) should be ~{expected}, got {}",
219            output[0]
220        );
221    }
222
223    #[test]
224    fn test_silu_negative() {
225        let input = [-1.0f32];
226        let mut output = vec![0.0f32; 1];
227        silu_scalar(&input, &mut output);
228        // SiLU(-1) = -1 / (1 + exp(1)) ~ -0.2689414
229        let expected = -1.0 / (1.0 + 1.0f32.exp());
230        assert!(
231            (output[0] - expected).abs() < 1e-6,
232            "SiLU(-1) should be ~{expected}, got {}",
233            output[0]
234        );
235    }
236
237    // ── Property-based tests ─────────────────────────────────────────────
238
239    proptest! {
240        #[test]
241        fn prop_relu_nonnegative(x in proptest::num::f32::NORMAL) {
242            let input = [x];
243            let mut output = [0.0f32];
244            relu_scalar(&input, &mut output);
245            prop_assert!(output[0] >= 0.0, "ReLU output must be >= 0, got {}", output[0]);
246        }
247
248        #[test]
249        fn prop_gelu_zero_at_zero(scale in -1e-10f32..1e-10f32) {
250            // GELU near zero should be near zero
251            let input = [scale];
252            let mut output = [0.0f32];
253            gelu_scalar(&input, &mut output);
254            prop_assert!(
255                output[0].abs() < 1e-6,
256                "GELU({scale}) should be ~0, got {}",
257                output[0]
258            );
259        }
260
261        #[test]
262        fn prop_silu_sign_preserving(x in proptest::num::f32::NORMAL) {
263            // SiLU(x) has the same sign as x (or is zero)
264            let input = [x];
265            let mut output = [0.0f32];
266            silu_scalar(&input, &mut output);
267            if x > 0.0 {
268                prop_assert!(output[0] >= 0.0, "SiLU({x}) should be >= 0, got {}", output[0]);
269            } else if x < 0.0 {
270                prop_assert!(output[0] <= 0.0, "SiLU({x}) should be <= 0, got {}", output[0]);
271            }
272        }
273    }
274
275    // ── AVX2 parity tests ────────────────────────────────────────────────
276
277    #[cfg(target_arch = "x86_64")]
278    #[test]
279    fn test_relu_avx2_parity() {
280        if !is_x86_feature_detected!("avx2") {
281            return;
282        }
283        let input: Vec<f32> = (-20..20).map(|i| i as f32 * 0.5).collect();
284        let mut scalar_out = vec![0.0f32; input.len()];
285        let mut avx2_out = vec![0.0f32; input.len()];
286
287        relu_scalar(&input, &mut scalar_out);
288        unsafe { relu_avx2(&input, &mut avx2_out) };
289
290        assert_ulp_eq(&scalar_out, &avx2_out, 2);
291    }
292
293    #[cfg(target_arch = "x86_64")]
294    #[test]
295    fn test_gelu_avx2_parity() {
296        if !is_x86_feature_detected!("avx2") {
297            return;
298        }
299        let input: Vec<f32> = (-20..20).map(|i| i as f32 * 0.25).collect();
300        let mut scalar_out = vec![0.0f32; input.len()];
301        let mut avx2_out = vec![0.0f32; input.len()];
302
303        gelu_scalar(&input, &mut scalar_out);
304        unsafe { gelu_avx2(&input, &mut avx2_out) };
305
306        assert_ulp_eq(&scalar_out, &avx2_out, 2);
307    }
308
309    #[cfg(target_arch = "x86_64")]
310    #[test]
311    fn test_silu_avx2_parity() {
312        if !is_x86_feature_detected!("avx2") {
313            return;
314        }
315        let input: Vec<f32> = (-20..20).map(|i| i as f32 * 0.3).collect();
316        let mut scalar_out = vec![0.0f32; input.len()];
317        let mut avx2_out = vec![0.0f32; input.len()];
318
319        silu_scalar(&input, &mut scalar_out);
320        unsafe { silu_avx2(&input, &mut avx2_out) };
321
322        assert_ulp_eq(&scalar_out, &avx2_out, 2);
323    }
324
325    #[cfg(target_arch = "x86_64")]
326    #[test]
327    fn test_relu_avx2_non_aligned_length() {
328        // Test with length not divisible by 8 to exercise the scalar tail
329        if !is_x86_feature_detected!("avx2") {
330            return;
331        }
332        let input: Vec<f32> = (-5..6).map(|i| i as f32).collect(); // 11 elements
333        let mut scalar_out = vec![0.0f32; input.len()];
334        let mut avx2_out = vec![0.0f32; input.len()];
335
336        relu_scalar(&input, &mut scalar_out);
337        unsafe { relu_avx2(&input, &mut avx2_out) };
338
339        assert_ulp_eq(&scalar_out, &avx2_out, 0);
340    }
341
342    // ── PTX structural tests ─────────────────────────────────────────────
343
344    #[test]
345    fn test_relu_ptx_structure() {
346        let ptx = relu_ptx();
347        assert!(ptx.contains(".version 8.5"), "missing PTX version");
348        assert!(ptx.contains(".target sm_90"), "missing PTX target");
349        assert!(ptx.contains(".entry relu_kernel"), "missing entry point");
350        assert!(ptx.contains("ret;"), "missing ret instruction");
351        let open = ptx.matches('{').count();
352        let close = ptx.matches('}').count();
353        assert_eq!(
354            open, close,
355            "unbalanced braces: {open} open vs {close} close"
356        );
357    }
358
359    #[test]
360    fn test_gelu_ptx_structure() {
361        let ptx = gelu_ptx();
362        assert!(ptx.contains(".version 8.5"), "missing PTX version");
363        assert!(ptx.contains(".target sm_90"), "missing PTX target");
364        assert!(ptx.contains(".entry gelu_kernel"), "missing entry point");
365        assert!(ptx.contains("ret;"), "missing ret instruction");
366        assert!(ptx.contains("ex2.approx.f32"), "missing ex2.approx for exp");
367        let open = ptx.matches('{').count();
368        let close = ptx.matches('}').count();
369        assert_eq!(
370            open, close,
371            "unbalanced braces: {open} open vs {close} close"
372        );
373    }
374
375    #[test]
376    fn test_silu_ptx_structure() {
377        let ptx = silu_ptx();
378        assert!(ptx.contains(".version 8.5"), "missing PTX version");
379        assert!(ptx.contains(".target sm_90"), "missing PTX target");
380        assert!(ptx.contains(".entry silu_kernel"), "missing entry point");
381        assert!(ptx.contains("ret;"), "missing ret instruction");
382        assert!(ptx.contains("ex2.approx.f32"), "missing ex2.approx for exp");
383        let open = ptx.matches('{').count();
384        let close = ptx.matches('}').count();
385        assert_eq!(
386            open, close,
387            "unbalanced braces: {open} open vs {close} close"
388        );
389    }
390
391    #[test]
392    fn test_ptx_kernels_are_nonempty() {
393        assert!(!relu_ptx().is_empty());
394        assert!(!gelu_ptx().is_empty());
395        assert!(!silu_ptx().is_empty());
396    }
397}