1use crate::error::TruenoError;
16
17pub fn relu(input: &[f32], output: &mut [f32]) -> Result<(), TruenoError> {
29 contract_pre_relu!(input);
30 let n = input.len();
31 if n != output.len() {
32 return Err(TruenoError::InvalidInput(format!(
33 "relu size mismatch: input[{}], output[{}]",
34 n,
35 output.len()
36 )));
37 }
38
39 #[cfg(target_arch = "x86_64")]
40 {
41 if n > 4096 {
48 relu_autovec(input, output);
49 contract_post_elementwise_parity!(output);
50 return Ok(());
51 }
52 if is_x86_feature_detected!("avx512f") {
53 unsafe {
54 relu_avx512(input, output);
55 }
56 contract_post_elementwise_parity!(output);
57 return Ok(());
58 }
59 if is_x86_feature_detected!("avx2") {
60 unsafe {
61 relu_avx2(input, output);
62 }
63 contract_post_elementwise_parity!(output);
64 return Ok(());
65 }
66 }
67
68 relu_autovec(input, output);
69 contract_post_elementwise_parity!(output);
70 Ok(())
71}
72
73#[inline]
78fn relu_autovec(input: &[f32], output: &mut [f32]) {
79 for i in 0..input.len() {
80 output[i] = input[i].max(0.0);
81 }
82}
83
84#[cfg(target_arch = "x86_64")]
86#[target_feature(enable = "avx512f")]
87unsafe fn relu_avx512(input: &[f32], output: &mut [f32]) {
88 use std::arch::x86_64::*;
89 unsafe {
90 let n = input.len();
91 let ip = input.as_ptr();
92 let op = output.as_mut_ptr();
93 let zero = _mm512_setzero_ps();
94 let mut i = 0;
95
96 let data_bytes = n * 4;
97 let op_aligned = (op as usize) % 64 == 0;
98 if data_bytes > NT_STORE_THRESHOLD_BYTES && op_aligned {
99 while i + 64 <= n {
101 _mm_prefetch(ip.add(i + 128).cast::<i8>(), _MM_HINT_T0);
102
103 _mm512_stream_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
104 _mm512_stream_ps(
105 op.add(i + 16),
106 _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 16)), zero),
107 );
108 _mm512_stream_ps(
109 op.add(i + 32),
110 _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 32)), zero),
111 );
112 _mm512_stream_ps(
113 op.add(i + 48),
114 _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 48)), zero),
115 );
116 i += 64;
117 }
118 while i + 16 <= n {
119 _mm512_stream_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
120 i += 16;
121 }
122 _mm_sfence();
123 } else {
124 while i + 64 <= n {
125 _mm512_storeu_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
126 _mm512_storeu_ps(
127 op.add(i + 16),
128 _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 16)), zero),
129 );
130 _mm512_storeu_ps(
131 op.add(i + 32),
132 _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 32)), zero),
133 );
134 _mm512_storeu_ps(
135 op.add(i + 48),
136 _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 48)), zero),
137 );
138 i += 64;
139 }
140 while i + 16 <= n {
141 _mm512_storeu_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
142 i += 16;
143 }
144 }
145 for j in i..n {
146 output[j] = input[j].max(0.0);
147 }
148 } }
150
151const PREFETCH_DISTANCE: usize = 512;
155
156const NT_STORE_THRESHOLD_BYTES: usize = 512 * 1024; #[cfg(target_arch = "x86_64")]
164#[target_feature(enable = "avx2")]
165unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) {
166 use std::arch::x86_64::*;
167
168 let n = input.len();
169 let data_bytes = n * 4;
170
171 let out_aligned = (output.as_ptr() as usize) % 32 == 0;
175 if data_bytes > NT_STORE_THRESHOLD_BYTES && out_aligned {
176 unsafe { relu_avx2_nt(input, output) }
177 return;
178 }
179
180 let chunks = n / 64;
186 let remainder_64 = chunks * 64;
187
188 unsafe {
189 let zero = _mm256_setzero_ps();
190 let inp = input.as_ptr();
191 let out = output.as_mut_ptr();
192
193 for i in 0..chunks {
194 let base = i * 64;
195 let v0 = _mm256_loadu_ps(inp.add(base));
196 let v1 = _mm256_loadu_ps(inp.add(base + 8));
197 let v2 = _mm256_loadu_ps(inp.add(base + 16));
198 let v3 = _mm256_loadu_ps(inp.add(base + 24));
199 let v4 = _mm256_loadu_ps(inp.add(base + 32));
200 let v5 = _mm256_loadu_ps(inp.add(base + 40));
201 let v6 = _mm256_loadu_ps(inp.add(base + 48));
202 let v7 = _mm256_loadu_ps(inp.add(base + 56));
203 _mm256_storeu_ps(out.add(base), _mm256_max_ps(v0, zero));
204 _mm256_storeu_ps(out.add(base + 8), _mm256_max_ps(v1, zero));
205 _mm256_storeu_ps(out.add(base + 16), _mm256_max_ps(v2, zero));
206 _mm256_storeu_ps(out.add(base + 24), _mm256_max_ps(v3, zero));
207 _mm256_storeu_ps(out.add(base + 32), _mm256_max_ps(v4, zero));
208 _mm256_storeu_ps(out.add(base + 40), _mm256_max_ps(v5, zero));
209 _mm256_storeu_ps(out.add(base + 48), _mm256_max_ps(v6, zero));
210 _mm256_storeu_ps(out.add(base + 56), _mm256_max_ps(v7, zero));
211 }
212
213 let mut i = remainder_64;
214 while i + 8 <= n {
215 let v = _mm256_loadu_ps(inp.add(i));
216 _mm256_storeu_ps(out.add(i), _mm256_max_ps(v, zero));
217 i += 8;
218 }
219
220 while i < n {
221 *out.add(i) = (*inp.add(i)).max(0.0);
222 i += 1;
223 }
224 }
225}
226
227#[cfg(target_arch = "x86_64")]
232#[target_feature(enable = "avx2")]
233unsafe fn relu_avx2_nt(input: &[f32], output: &mut [f32]) {
234 use std::arch::x86_64::*;
235
236 let n = input.len();
237 let chunks = n / 32;
238 let remainder_32 = chunks * 32;
239
240 unsafe {
241 let zero = _mm256_setzero_ps();
242
243 for i in 0..chunks {
244 let base = i * 32;
245 _mm_prefetch(
247 input.as_ptr().add(base + PREFETCH_DISTANCE / 4) as *const i8,
248 _MM_HINT_T0,
249 );
250 let v0 = _mm256_loadu_ps(input.as_ptr().add(base));
251 let v1 = _mm256_loadu_ps(input.as_ptr().add(base + 8));
252 let v2 = _mm256_loadu_ps(input.as_ptr().add(base + 16));
253 let v3 = _mm256_loadu_ps(input.as_ptr().add(base + 24));
254 _mm256_stream_ps(output.as_mut_ptr().add(base), _mm256_max_ps(v0, zero));
256 _mm256_stream_ps(output.as_mut_ptr().add(base + 8), _mm256_max_ps(v1, zero));
257 _mm256_stream_ps(output.as_mut_ptr().add(base + 16), _mm256_max_ps(v2, zero));
258 _mm256_stream_ps(output.as_mut_ptr().add(base + 24), _mm256_max_ps(v3, zero));
259 }
260
261 _mm_sfence();
263
264 let mut i = remainder_32;
266 while i + 8 <= n {
267 let v = _mm256_loadu_ps(input.as_ptr().add(i));
268 _mm256_storeu_ps(output.as_mut_ptr().add(i), _mm256_max_ps(v, zero));
269 i += 8;
270 }
271 while i < n {
272 output[i] = input[i].max(0.0);
273 i += 1;
274 }
275 }
276}
277
278pub fn add(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<(), TruenoError> {
290 let n = a.len();
291 if n != b.len() || n != output.len() {
292 return Err(TruenoError::InvalidInput(format!(
293 "add size mismatch: a[{}], b[{}], output[{}]",
294 n,
295 b.len(),
296 output.len()
297 )));
298 }
299 contract_pre_add!(a, b);
300
301 #[cfg(target_arch = "x86_64")]
302 {
303 if n > 4096 {
307 add_autovec(a, b, output);
308 return Ok(());
309 }
310 if is_x86_feature_detected!("avx512f") {
311 unsafe {
312 add_avx512(a, b, output);
313 }
314 return Ok(());
315 }
316 if is_x86_feature_detected!("avx2") {
317 unsafe {
318 add_avx2(a, b, output);
319 }
320 return Ok(());
321 }
322 }
323
324 add_autovec(a, b, output);
325 contract_post_elementwise_parity!(output);
326 Ok(())
327}
328
329#[inline]
331fn add_autovec(a: &[f32], b: &[f32], output: &mut [f32]) {
332 for i in 0..a.len() {
333 output[i] = a[i] + b[i];
334 }
335}
336
337#[cfg(target_arch = "x86_64")]
339#[target_feature(enable = "avx512f")]
340unsafe fn add_avx512(a: &[f32], b: &[f32], output: &mut [f32]) {
341 use std::arch::x86_64::*;
342 unsafe {
343 let n = a.len();
344 let ap = a.as_ptr();
345 let bp = b.as_ptr();
346 let rp = output.as_mut_ptr();
347 let mut i = 0;
348
349 let data_bytes = n * 4;
350 let rp_aligned = (rp as usize) % 64 == 0;
351 if data_bytes > NT_STORE_THRESHOLD_BYTES && rp_aligned {
352 while i + 64 <= n {
354 if i + 128 <= n {
356 _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
357 _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
358 }
359
360 _mm512_stream_ps(
361 rp.add(i),
362 _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
363 );
364 _mm512_stream_ps(
365 rp.add(i + 16),
366 _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 16)), _mm512_loadu_ps(bp.add(i + 16))),
367 );
368 _mm512_stream_ps(
369 rp.add(i + 32),
370 _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 32)), _mm512_loadu_ps(bp.add(i + 32))),
371 );
372 _mm512_stream_ps(
373 rp.add(i + 48),
374 _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 48)), _mm512_loadu_ps(bp.add(i + 48))),
375 );
376 i += 64;
377 }
378 while i + 16 <= n {
379 _mm512_stream_ps(
380 rp.add(i),
381 _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
382 );
383 i += 16;
384 }
385 _mm_sfence();
386 } else {
387 while i + 64 <= n {
388 _mm512_storeu_ps(
389 rp.add(i),
390 _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
391 );
392 _mm512_storeu_ps(
393 rp.add(i + 16),
394 _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 16)), _mm512_loadu_ps(bp.add(i + 16))),
395 );
396 _mm512_storeu_ps(
397 rp.add(i + 32),
398 _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 32)), _mm512_loadu_ps(bp.add(i + 32))),
399 );
400 _mm512_storeu_ps(
401 rp.add(i + 48),
402 _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 48)), _mm512_loadu_ps(bp.add(i + 48))),
403 );
404 i += 64;
405 }
406 while i + 16 <= n {
407 _mm512_storeu_ps(
408 rp.add(i),
409 _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
410 );
411 i += 16;
412 }
413 }
414 for j in i..n {
415 output[j] = a[j] + b[j];
416 }
417 } }
419
420#[cfg(target_arch = "x86_64")]
421#[target_feature(enable = "avx2")]
422unsafe fn add_avx2(a: &[f32], b: &[f32], output: &mut [f32]) {
423 use std::arch::x86_64::*;
424
425 let n = a.len();
426 let data_bytes = n * 4;
427
428 let out_aligned = (output.as_ptr() as usize) % 32 == 0;
431 if data_bytes > NT_STORE_THRESHOLD_BYTES && out_aligned {
432 unsafe { add_avx2_nt(a, b, output) }
433 return;
434 }
435
436 let chunks = n / 64;
439 let remainder_64 = chunks * 64;
440
441 unsafe {
442 let ap = a.as_ptr();
443 let bp = b.as_ptr();
444 let op = output.as_mut_ptr();
445
446 for i in 0..chunks {
447 let base = i * 64;
448 let a0 = _mm256_loadu_ps(ap.add(base));
450 let b0 = _mm256_loadu_ps(bp.add(base));
451 let a1 = _mm256_loadu_ps(ap.add(base + 8));
452 let b1 = _mm256_loadu_ps(bp.add(base + 8));
453 let a2 = _mm256_loadu_ps(ap.add(base + 16));
454 let b2 = _mm256_loadu_ps(bp.add(base + 16));
455 let a3 = _mm256_loadu_ps(ap.add(base + 24));
456 let b3 = _mm256_loadu_ps(bp.add(base + 24));
457 let a4 = _mm256_loadu_ps(ap.add(base + 32));
458 let b4 = _mm256_loadu_ps(bp.add(base + 32));
459 let a5 = _mm256_loadu_ps(ap.add(base + 40));
460 let b5 = _mm256_loadu_ps(bp.add(base + 40));
461 let a6 = _mm256_loadu_ps(ap.add(base + 48));
462 let b6 = _mm256_loadu_ps(bp.add(base + 48));
463 let a7 = _mm256_loadu_ps(ap.add(base + 56));
464 let b7 = _mm256_loadu_ps(bp.add(base + 56));
465 _mm256_storeu_ps(op.add(base), _mm256_add_ps(a0, b0));
466 _mm256_storeu_ps(op.add(base + 8), _mm256_add_ps(a1, b1));
467 _mm256_storeu_ps(op.add(base + 16), _mm256_add_ps(a2, b2));
468 _mm256_storeu_ps(op.add(base + 24), _mm256_add_ps(a3, b3));
469 _mm256_storeu_ps(op.add(base + 32), _mm256_add_ps(a4, b4));
470 _mm256_storeu_ps(op.add(base + 40), _mm256_add_ps(a5, b5));
471 _mm256_storeu_ps(op.add(base + 48), _mm256_add_ps(a6, b6));
472 _mm256_storeu_ps(op.add(base + 56), _mm256_add_ps(a7, b7));
473 }
474
475 let mut i = remainder_64;
476 while i + 8 <= n {
477 let av = _mm256_loadu_ps(ap.add(i));
478 let bv = _mm256_loadu_ps(bp.add(i));
479 _mm256_storeu_ps(op.add(i), _mm256_add_ps(av, bv));
480 i += 8;
481 }
482
483 while i < n {
484 *op.add(i) = *ap.add(i) + *bp.add(i);
485 i += 1;
486 }
487 }
488}
489
490#[cfg(target_arch = "x86_64")]
492#[target_feature(enable = "avx2")]
493unsafe fn add_avx2_nt(a: &[f32], b: &[f32], output: &mut [f32]) {
494 use std::arch::x86_64::*;
495
496 let n = a.len();
497 let chunks = n / 32;
498 let remainder_32 = chunks * 32;
499
500 unsafe {
501 for i in 0..chunks {
502 let base = i * 32;
503 _mm_prefetch(a.as_ptr().add(base + PREFETCH_DISTANCE / 4) as *const i8, _MM_HINT_T0);
504 _mm_prefetch(b.as_ptr().add(base + PREFETCH_DISTANCE / 4) as *const i8, _MM_HINT_T0);
505 let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
506 let a1 = _mm256_loadu_ps(a.as_ptr().add(base + 8));
507 let a2 = _mm256_loadu_ps(a.as_ptr().add(base + 16));
508 let a3 = _mm256_loadu_ps(a.as_ptr().add(base + 24));
509 let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
510 let b1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
511 let b2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
512 let b3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
513 _mm256_stream_ps(output.as_mut_ptr().add(base), _mm256_add_ps(a0, b0));
514 _mm256_stream_ps(output.as_mut_ptr().add(base + 8), _mm256_add_ps(a1, b1));
515 _mm256_stream_ps(output.as_mut_ptr().add(base + 16), _mm256_add_ps(a2, b2));
516 _mm256_stream_ps(output.as_mut_ptr().add(base + 24), _mm256_add_ps(a3, b3));
517 }
518
519 _mm_sfence();
520
521 let mut i = remainder_32;
522 while i + 8 <= n {
523 let av = _mm256_loadu_ps(a.as_ptr().add(i));
524 let bv = _mm256_loadu_ps(b.as_ptr().add(i));
525 _mm256_storeu_ps(output.as_mut_ptr().add(i), _mm256_add_ps(av, bv));
526 i += 8;
527 }
528 while i < n {
529 output[i] = a[i] + b[i];
530 i += 1;
531 }
532 }
533}
534
535pub fn mul_scalar(input: &[f32], scalar: f32, output: &mut [f32]) -> Result<(), TruenoError> {
547 debug_assert!(!input.is_empty(), "Contract mul_scalar: input is empty");
549 debug_assert!(scalar.is_finite(), "Contract mul_scalar: scalar is not finite");
550 let n = input.len();
551 if n != output.len() {
552 return Err(TruenoError::InvalidInput(format!(
553 "mul_scalar size mismatch: input[{}], output[{}]",
554 n,
555 output.len()
556 )));
557 }
558
559 #[cfg(target_arch = "x86_64")]
560 {
561 if is_x86_feature_detected!("avx2") {
562 unsafe {
563 mul_scalar_avx2(input, scalar, output);
564 }
565 return Ok(());
566 }
567 }
568
569 for i in 0..n {
570 output[i] = input[i] * scalar;
571 }
572 Ok(())
573}
574
575#[cfg(target_arch = "x86_64")]
576#[target_feature(enable = "avx2")]
577unsafe fn mul_scalar_avx2(input: &[f32], scalar: f32, output: &mut [f32]) {
578 use std::arch::x86_64::*;
579
580 let n = input.len();
581 let chunks = n / 32;
582 let remainder_32 = chunks * 32;
583
584 unsafe {
585 let s = _mm256_set1_ps(scalar);
586
587 for i in 0..chunks {
588 let base = i * 32;
589 let v0 = _mm256_loadu_ps(input.as_ptr().add(base));
590 let v1 = _mm256_loadu_ps(input.as_ptr().add(base + 8));
591 let v2 = _mm256_loadu_ps(input.as_ptr().add(base + 16));
592 let v3 = _mm256_loadu_ps(input.as_ptr().add(base + 24));
593 _mm256_storeu_ps(output.as_mut_ptr().add(base), _mm256_mul_ps(v0, s));
594 _mm256_storeu_ps(output.as_mut_ptr().add(base + 8), _mm256_mul_ps(v1, s));
595 _mm256_storeu_ps(output.as_mut_ptr().add(base + 16), _mm256_mul_ps(v2, s));
596 _mm256_storeu_ps(output.as_mut_ptr().add(base + 24), _mm256_mul_ps(v3, s));
597 }
598
599 let mut i = remainder_32;
600 while i + 8 <= n {
601 let v = _mm256_loadu_ps(input.as_ptr().add(i));
602 _mm256_storeu_ps(output.as_mut_ptr().add(i), _mm256_mul_ps(v, s));
603 i += 8;
604 }
605
606 while i < n {
607 output[i] = input[i] * scalar;
608 i += 1;
609 }
610 }
611}
612
613#[must_use]
623pub fn relu_alloc(input: &[f32]) -> Vec<f32> {
624 let n = input.len();
625 let mut output = vec![0.0f32; n];
626 let _ = relu(input, &mut output);
627 output
628}
629
630#[must_use]
636pub fn add_alloc(a: &[f32], b: &[f32]) -> Vec<f32> {
637 assert_eq!(a.len(), b.len(), "add_alloc: length mismatch");
638 let n = a.len();
639 let mut output = vec![0.0f32; n];
640 let _ = add(a, b, &mut output);
641 output
642}
643
644#[must_use]
646pub fn mul_scalar_alloc(input: &[f32], scalar: f32) -> Vec<f32> {
647 let n = input.len();
648 let mut output = vec![0.0f32; n];
649 let _ = mul_scalar(input, scalar, &mut output);
650 output
651}
652
653pub fn fused_add_relu(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<(), TruenoError> {
671 let n = a.len();
672 if n != b.len() || n != output.len() {
673 return Err(TruenoError::InvalidInput(format!(
674 "fused_add_relu size mismatch: a[{}], b[{}], output[{}]",
675 n,
676 b.len(),
677 output.len()
678 )));
679 }
680 for i in 0..n {
682 output[i] = (a[i] + b[i]).max(0.0);
683 }
684 Ok(())
685}
686
687pub fn fused_mul_add(
697 a: &[f32],
698 b: &[f32],
699 c: &[f32],
700 output: &mut [f32],
701) -> Result<(), TruenoError> {
702 let n = a.len();
703 if n != b.len() || n != c.len() || n != output.len() {
704 return Err(TruenoError::InvalidInput(format!(
705 "fused_mul_add size mismatch: a[{}], b[{}], c[{}], output[{}]",
706 n,
707 b.len(),
708 c.len(),
709 output.len()
710 )));
711 }
712 for i in 0..n {
713 output[i] = a[i].mul_add(b[i], c[i]);
714 }
715 Ok(())
716}
717
718pub fn fused_scale_bias_relu(
729 input: &[f32],
730 scale: f32,
731 bias: f32,
732 output: &mut [f32],
733) -> Result<(), TruenoError> {
734 let n = input.len();
735 if n != output.len() {
736 return Err(TruenoError::InvalidInput(format!(
737 "fused_scale_bias_relu size mismatch: input[{}], output[{}]",
738 n,
739 output.len()
740 )));
741 }
742 for i in 0..n {
743 output[i] = input[i].mul_add(scale, bias).max(0.0);
744 }
745 Ok(())
746}
747
748#[inline]
758pub fn relu_inplace(data: &mut [f32]) {
759 for x in data.iter_mut() {
760 *x = x.max(0.0);
761 }
762}
763
764pub fn add_inplace(a: &mut [f32], b: &[f32]) -> Result<(), TruenoError> {
768 if a.len() != b.len() {
769 return Err(TruenoError::InvalidInput(format!(
770 "add_inplace size mismatch: a[{}], b[{}]",
771 a.len(),
772 b.len()
773 )));
774 }
775 for i in 0..a.len() {
776 a[i] += b[i];
777 }
778 Ok(())
779}
780
781#[inline]
785pub fn scale_inplace(data: &mut [f32], scalar: f32) {
786 for x in data.iter_mut() {
787 *x *= scalar;
788 }
789}
790
791pub fn fused_add_relu_inplace(a: &mut [f32], b: &[f32]) -> Result<(), TruenoError> {
796 if a.len() != b.len() {
797 return Err(TruenoError::InvalidInput(format!(
798 "fused_add_relu_inplace size mismatch: a[{}], b[{}]",
799 a.len(),
800 b.len()
801 )));
802 }
803 for i in 0..a.len() {
804 a[i] = (a[i] + b[i]).max(0.0);
805 }
806 Ok(())
807}
808
809#[cfg(test)]
814mod tests {
815 use super::*;
816
817 #[test]
820 fn test_relu_basic() {
821 let input = [-1.0, 0.0, 1.0, -0.5, 2.0, -3.0, 0.1, -0.1];
822 let expected = [0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.1, 0.0];
823 let mut output = vec![0.0f32; 8];
824 relu(&input, &mut output).unwrap();
825 assert_eq!(output, expected);
826 }
827
828 #[test]
829 fn test_relu_large() {
830 let n = 11008; let input: Vec<f32> =
832 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
833 let mut output = vec![0.0f32; n];
834 relu(&input, &mut output).unwrap();
835 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
836 assert_eq!(out, inp.max(0.0), "ReLU mismatch at {i}");
837 }
838 }
839
840 #[test]
841 fn test_relu_avx2_scalar_parity() {
842 for n in [1, 7, 8, 15, 16, 31, 32, 63, 64, 128, 4096] {
843 let input: Vec<f32> =
844 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 500.0 - 1.0).collect();
845 let mut output = vec![0.0f32; n];
846 relu(&input, &mut output).unwrap();
847 for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
848 assert_eq!(out, inp.max(0.0), "ReLU parity at [{i}] n={n}");
849 }
850 }
851 }
852
853 #[test]
854 fn test_relu_error_mismatch() {
855 let input = vec![1.0f32; 4];
856 let mut output = vec![0.0f32; 3];
857 assert!(relu(&input, &mut output).is_err());
858 }
859
860 #[test]
863 fn test_add_basic() {
864 let a = [1.0, 2.0, 3.0, 4.0];
865 let b = [10.0, 20.0, 30.0, 40.0];
866 let mut output = vec![0.0f32; 4];
867 add(&a, &b, &mut output).unwrap();
868 assert_eq!(output, vec![11.0, 22.0, 33.0, 44.0]);
869 }
870
871 #[test]
872 fn test_add_large() {
873 let n = 4096;
874 let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
875 let b: Vec<f32> = (0..n).map(|i| (i * 2) as f32).collect();
876 let mut output = vec![0.0f32; n];
877 add(&a, &b, &mut output).unwrap();
878 for i in 0..n {
879 assert_eq!(output[i], a[i] + b[i], "Add mismatch at {i}");
880 }
881 }
882
883 #[test]
884 fn test_add_avx2_scalar_parity() {
885 for n in [1, 7, 8, 15, 16, 31, 32, 63, 64, 128, 4096] {
886 let a: Vec<f32> = (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 500.0 - 1.0).collect();
887 let b: Vec<f32> = (0..n).map(|i| ((i * 13 + 7) % 1000) as f32 / 500.0 - 1.0).collect();
888 let mut output = vec![0.0f32; n];
889 add(&a, &b, &mut output).unwrap();
890 for i in 0..n {
891 assert_eq!(output[i], a[i] + b[i], "Add parity at [{i}] n={n}");
892 }
893 }
894 }
895
896 #[test]
897 fn test_add_error_mismatch() {
898 let a = vec![1.0f32; 4];
899 let b = vec![1.0f32; 3];
900 let mut output = vec![0.0f32; 4];
901 assert!(add(&a, &b, &mut output).is_err());
902 }
903
904 #[test]
907 fn test_mul_scalar_basic() {
908 let input = [1.0, 2.0, 3.0, 4.0];
909 let mut output = vec![0.0f32; 4];
910 mul_scalar(&input, 2.5, &mut output).unwrap();
911 assert_eq!(output, vec![2.5, 5.0, 7.5, 10.0]);
912 }
913
914 #[test]
915 fn test_mul_scalar_large() {
916 let n = 4096;
917 let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
918 let mut output = vec![0.0f32; n];
919 mul_scalar(&input, std::f32::consts::PI, &mut output).unwrap();
920 for i in 0..n {
921 assert!(
922 (output[i] - input[i] * std::f32::consts::PI).abs() < 1e-5,
923 "Mul scalar mismatch at {i}"
924 );
925 }
926 }
927
928 #[test]
929 fn test_mul_scalar_avx2_scalar_parity() {
930 for n in [1, 7, 8, 15, 16, 31, 32, 63, 64, 128, 4096] {
931 let input: Vec<f32> =
932 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 500.0 - 1.0).collect();
933 let mut output = vec![0.0f32; n];
934 mul_scalar(&input, std::f32::consts::E, &mut output).unwrap();
935 for i in 0..n {
936 assert!(
937 (output[i] - input[i] * std::f32::consts::E).abs() < 1e-4,
938 "Mul scalar parity at [{i}] n={n}",
939 );
940 }
941 }
942 }
943
944 #[test]
945 fn test_mul_scalar_error_mismatch() {
946 let input = vec![1.0f32; 4];
947 let mut output = vec![0.0f32; 3];
948 assert!(mul_scalar(&input, 1.0, &mut output).is_err());
949 }
950
951 #[test]
954 fn test_fused_add_relu_basic() {
955 let a = vec![-2.0, -1.0, 0.0, 1.0, 2.0, -0.5, 0.5, 3.0];
956 let b = vec![1.0, 0.5, -1.0, -2.0, 0.0, 1.0, -1.0, -4.0];
957 let mut out = vec![0.0f32; 8];
958 fused_add_relu(&a, &b, &mut out).unwrap();
959 let expected: Vec<f32> = a.iter().zip(&b).map(|(a, b)| (a + b).max(0.0)).collect();
960 assert_eq!(out, expected);
961 }
962
963 #[test]
964 fn test_fused_add_relu_large() {
965 let n = 10_000;
966 let a: Vec<f32> = (0..n).map(|i| (i as f32 - 5000.0) / 100.0).collect();
967 let b: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3) - 1500.0).collect();
968 let mut out = vec![0.0f32; n];
969 fused_add_relu(&a, &b, &mut out).unwrap();
970 for i in 0..n {
971 assert_eq!(out[i], (a[i] + b[i]).max(0.0), "mismatch at {i}");
972 }
973 }
974
975 #[test]
976 fn test_fused_mul_add_basic() {
977 let a = vec![1.0, 2.0, 3.0, 4.0];
978 let b = vec![2.0, 3.0, 4.0, 5.0];
979 let c = vec![0.5, 0.5, 0.5, 0.5];
980 let mut out = vec![0.0f32; 4];
981 fused_mul_add(&a, &b, &c, &mut out).unwrap();
982 let expected: Vec<f32> = (0..4).map(|i| a[i].mul_add(b[i], c[i])).collect();
983 assert_eq!(out, expected);
984 }
985
986 #[test]
987 fn test_fused_scale_bias_relu_basic() {
988 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
989 let mut out = vec![0.0f32; 5];
990 fused_scale_bias_relu(&input, 2.0, 1.0, &mut out).unwrap();
991 assert_eq!(out, vec![0.0, 0.0, 1.0, 3.0, 5.0]);
993 }
994}