provable_contracts/kernels/
activation.rs1use std::f32::consts::PI;
11
12pub fn relu_scalar(input: &[f32], output: &mut [f32]) {
21 assert_eq!(input.len(), output.len());
22 for (x, y) in input.iter().zip(output.iter_mut()) {
23 *y = x.max(0.0);
24 }
25}
26
27pub fn gelu_scalar(input: &[f32], output: &mut [f32]) {
32 assert_eq!(input.len(), output.len());
33 let sqrt_2_over_pi = (2.0f32 / PI).sqrt();
34 for (x, y) in input.iter().zip(output.iter_mut()) {
35 let inner = sqrt_2_over_pi * (x + 0.044_715 * x * x * x);
36 *y = 0.5 * x * (1.0 + inner.tanh());
37 }
38}
39
40pub fn silu_scalar(input: &[f32], output: &mut [f32]) {
45 assert_eq!(input.len(), output.len());
46 for (x, y) in input.iter().zip(output.iter_mut()) {
47 *y = x / (1.0 + (-x).exp());
48 }
49}
50
51#[cfg(target_arch = "x86_64")]
56use std::arch::x86_64::{_mm256_loadu_ps, _mm256_max_ps, _mm256_setzero_ps, _mm256_storeu_ps};
57
58#[cfg(target_arch = "x86_64")]
66#[target_feature(enable = "avx2")]
67pub unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) {
68 assert_eq!(input.len(), output.len());
69 let n = input.len();
70 unsafe {
72 let zero = _mm256_setzero_ps();
73 let mut i = 0;
74 while i + 8 <= n {
75 let v = _mm256_loadu_ps(input.as_ptr().add(i));
76 let r = _mm256_max_ps(v, zero);
77 _mm256_storeu_ps(output.as_mut_ptr().add(i), r);
78 i += 8;
79 }
80 for j in i..n {
82 output[j] = input[j].max(0.0);
83 }
84 }
85}
86
87#[cfg(target_arch = "x86_64")]
95#[target_feature(enable = "avx2")]
96pub unsafe fn gelu_avx2(input: &[f32], output: &mut [f32]) {
97 gelu_scalar(input, output);
98}
99
100#[cfg(target_arch = "x86_64")]
108#[target_feature(enable = "avx2")]
109pub unsafe fn silu_avx2(input: &[f32], output: &mut [f32]) {
110 silu_scalar(input, output);
111}
112
113include!("activation_ptx.rs");
114
115#[cfg(test)]
120mod tests {
121 use super::super::ulp::assert_ulp_eq;
122 use super::*;
123 use proptest::prelude::*;
124
125 #[test]
128 fn test_relu_negative_to_zero() {
129 let input = [-3.0f32, -1.0, -0.5, -1e-6];
130 let mut output = vec![0.0f32; input.len()];
131 relu_scalar(&input, &mut output);
132 for &y in &output {
133 assert_eq!(y, 0.0);
134 }
135 }
136
137 #[test]
138 fn test_relu_positive_identity() {
139 let input = [0.5f32, 1.0, 3.0, 100.0];
140 let mut output = vec![0.0f32; input.len()];
141 relu_scalar(&input, &mut output);
142 for (x, y) in input.iter().zip(output.iter()) {
143 assert_eq!(x, y);
144 }
145 }
146
147 #[test]
148 fn test_relu_zero() {
149 let input = [0.0f32];
150 let mut output = vec![0.0f32; 1];
151 relu_scalar(&input, &mut output);
152 assert_eq!(output[0], 0.0);
153 }
154
155 #[test]
158 fn test_gelu_zero() {
159 let input = [0.0f32];
160 let mut output = vec![0.0f32; 1];
161 gelu_scalar(&input, &mut output);
162 assert!(
163 (output[0]).abs() < 1e-7,
164 "GELU(0) should be 0, got {}",
165 output[0]
166 );
167 }
168
169 #[test]
170 fn test_gelu_large_positive() {
171 let input = [10.0f32];
172 let mut output = vec![0.0f32; 1];
173 gelu_scalar(&input, &mut output);
174 assert!(
176 (output[0] - 10.0).abs() < 1e-4,
177 "GELU(10) should be ~10, got {}",
178 output[0]
179 );
180 }
181
182 #[test]
183 fn test_gelu_large_negative() {
184 let input = [-10.0f32];
185 let mut output = vec![0.0f32; 1];
186 gelu_scalar(&input, &mut output);
187 assert!(
189 output[0].abs() < 1e-4,
190 "GELU(-10) should be ~0, got {}",
191 output[0]
192 );
193 }
194
195 #[test]
198 fn test_silu_zero() {
199 let input = [0.0f32];
200 let mut output = vec![0.0f32; 1];
201 silu_scalar(&input, &mut output);
202 assert!(
203 (output[0]).abs() < 1e-7,
204 "SiLU(0) should be 0, got {}",
205 output[0]
206 );
207 }
208
209 #[test]
210 fn test_silu_positive() {
211 let input = [1.0f32];
212 let mut output = vec![0.0f32; 1];
213 silu_scalar(&input, &mut output);
214 let expected = 1.0 / (1.0 + (-1.0f32).exp());
216 assert!(
217 (output[0] - expected).abs() < 1e-6,
218 "SiLU(1) should be ~{expected}, got {}",
219 output[0]
220 );
221 }
222
223 #[test]
224 fn test_silu_negative() {
225 let input = [-1.0f32];
226 let mut output = vec![0.0f32; 1];
227 silu_scalar(&input, &mut output);
228 let expected = -1.0 / (1.0 + 1.0f32.exp());
230 assert!(
231 (output[0] - expected).abs() < 1e-6,
232 "SiLU(-1) should be ~{expected}, got {}",
233 output[0]
234 );
235 }
236
237 proptest! {
240 #[test]
241 fn prop_relu_nonnegative(x in proptest::num::f32::NORMAL) {
242 let input = [x];
243 let mut output = [0.0f32];
244 relu_scalar(&input, &mut output);
245 prop_assert!(output[0] >= 0.0, "ReLU output must be >= 0, got {}", output[0]);
246 }
247
248 #[test]
249 fn prop_gelu_zero_at_zero(scale in -1e-10f32..1e-10f32) {
250 let input = [scale];
252 let mut output = [0.0f32];
253 gelu_scalar(&input, &mut output);
254 prop_assert!(
255 output[0].abs() < 1e-6,
256 "GELU({scale}) should be ~0, got {}",
257 output[0]
258 );
259 }
260
261 #[test]
262 fn prop_silu_sign_preserving(x in proptest::num::f32::NORMAL) {
263 let input = [x];
265 let mut output = [0.0f32];
266 silu_scalar(&input, &mut output);
267 if x > 0.0 {
268 prop_assert!(output[0] >= 0.0, "SiLU({x}) should be >= 0, got {}", output[0]);
269 } else if x < 0.0 {
270 prop_assert!(output[0] <= 0.0, "SiLU({x}) should be <= 0, got {}", output[0]);
271 }
272 }
273 }
274
275 #[cfg(target_arch = "x86_64")]
278 #[test]
279 fn test_relu_avx2_parity() {
280 if !is_x86_feature_detected!("avx2") {
281 return;
282 }
283 let input: Vec<f32> = (-20..20).map(|i| i as f32 * 0.5).collect();
284 let mut scalar_out = vec![0.0f32; input.len()];
285 let mut avx2_out = vec![0.0f32; input.len()];
286
287 relu_scalar(&input, &mut scalar_out);
288 unsafe { relu_avx2(&input, &mut avx2_out) };
289
290 assert_ulp_eq(&scalar_out, &avx2_out, 2);
291 }
292
293 #[cfg(target_arch = "x86_64")]
294 #[test]
295 fn test_gelu_avx2_parity() {
296 if !is_x86_feature_detected!("avx2") {
297 return;
298 }
299 let input: Vec<f32> = (-20..20).map(|i| i as f32 * 0.25).collect();
300 let mut scalar_out = vec![0.0f32; input.len()];
301 let mut avx2_out = vec![0.0f32; input.len()];
302
303 gelu_scalar(&input, &mut scalar_out);
304 unsafe { gelu_avx2(&input, &mut avx2_out) };
305
306 assert_ulp_eq(&scalar_out, &avx2_out, 2);
307 }
308
309 #[cfg(target_arch = "x86_64")]
310 #[test]
311 fn test_silu_avx2_parity() {
312 if !is_x86_feature_detected!("avx2") {
313 return;
314 }
315 let input: Vec<f32> = (-20..20).map(|i| i as f32 * 0.3).collect();
316 let mut scalar_out = vec![0.0f32; input.len()];
317 let mut avx2_out = vec![0.0f32; input.len()];
318
319 silu_scalar(&input, &mut scalar_out);
320 unsafe { silu_avx2(&input, &mut avx2_out) };
321
322 assert_ulp_eq(&scalar_out, &avx2_out, 2);
323 }
324
325 #[cfg(target_arch = "x86_64")]
326 #[test]
327 fn test_relu_avx2_non_aligned_length() {
328 if !is_x86_feature_detected!("avx2") {
330 return;
331 }
332 let input: Vec<f32> = (-5..6).map(|i| i as f32).collect(); let mut scalar_out = vec![0.0f32; input.len()];
334 let mut avx2_out = vec![0.0f32; input.len()];
335
336 relu_scalar(&input, &mut scalar_out);
337 unsafe { relu_avx2(&input, &mut avx2_out) };
338
339 assert_ulp_eq(&scalar_out, &avx2_out, 0);
340 }
341
342 #[test]
345 fn test_relu_ptx_structure() {
346 let ptx = relu_ptx();
347 assert!(ptx.contains(".version 8.5"), "missing PTX version");
348 assert!(ptx.contains(".target sm_90"), "missing PTX target");
349 assert!(ptx.contains(".entry relu_kernel"), "missing entry point");
350 assert!(ptx.contains("ret;"), "missing ret instruction");
351 let open = ptx.matches('{').count();
352 let close = ptx.matches('}').count();
353 assert_eq!(
354 open, close,
355 "unbalanced braces: {open} open vs {close} close"
356 );
357 }
358
359 #[test]
360 fn test_gelu_ptx_structure() {
361 let ptx = gelu_ptx();
362 assert!(ptx.contains(".version 8.5"), "missing PTX version");
363 assert!(ptx.contains(".target sm_90"), "missing PTX target");
364 assert!(ptx.contains(".entry gelu_kernel"), "missing entry point");
365 assert!(ptx.contains("ret;"), "missing ret instruction");
366 assert!(ptx.contains("ex2.approx.f32"), "missing ex2.approx for exp");
367 let open = ptx.matches('{').count();
368 let close = ptx.matches('}').count();
369 assert_eq!(
370 open, close,
371 "unbalanced braces: {open} open vs {close} close"
372 );
373 }
374
375 #[test]
376 fn test_silu_ptx_structure() {
377 let ptx = silu_ptx();
378 assert!(ptx.contains(".version 8.5"), "missing PTX version");
379 assert!(ptx.contains(".target sm_90"), "missing PTX target");
380 assert!(ptx.contains(".entry silu_kernel"), "missing entry point");
381 assert!(ptx.contains("ret;"), "missing ret instruction");
382 assert!(ptx.contains("ex2.approx.f32"), "missing ex2.approx for exp");
383 let open = ptx.matches('{').count();
384 let close = ptx.matches('}').count();
385 assert_eq!(
386 open, close,
387 "unbalanced braces: {open} open vs {close} close"
388 );
389 }
390
391 #[test]
392 fn test_ptx_kernels_are_nonempty() {
393 assert!(!relu_ptx().is_empty());
394 assert!(!gelu_ptx().is_empty());
395 assert!(!silu_ptx().is_empty());
396 }
397}