1#![allow(dead_code)]
2use rayon::prelude::*;
3use std::arch::x86_64::*;
4use std::mem;
5
6#[cfg(target_arch = "x86_64")]
8use std::arch::x86_64 as arch;
9
10#[cfg(target_arch = "aarch64")]
11use std::arch::aarch64 as arch;
12
13pub fn supports_avx2() -> bool {
15 #[cfg(target_arch = "x86_64")]
16 unsafe {
17 arch::_xgetbv(0) & 0x6 == 0x6
18 }
19
20 #[cfg(target_arch = "aarch64")]
21 true }
23
24pub fn max_pooling_simd(input: &[u8], width: usize, factor: usize) -> Vec<u8> {
26 let output_width = width / factor;
27 let output_height = input.len() / (width * factor);
28 let mut output = vec![0; output_width * output_height];
29
30 output
32 .par_chunks_mut(output_width)
33 .enumerate()
34 .for_each(|(oy, row)| {
35 let start_y = oy * factor;
36 let end_y = start_y + factor;
37
38 (0..output_width).for_each(|ox| {
39 let start_x = ox * factor;
40 let mut simd_max = unsafe { _mm256_setzero_si256() };
44 for y in start_y..end_y {
45 let row_start = y * width + start_x;
46 let chunk = &input[row_start..row_start + factor];
47
48 for chunk32 in chunk.chunks_exact(32) {
50 let data =
51 unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
52 simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
53 }
54
55 let remainder = chunk.chunks_exact(32).remainder();
57 if !remainder.is_empty() {
58 let mut buffer = [0u8; 32];
59 buffer[..remainder.len()].copy_from_slice(remainder);
60 let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
61 simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
62 }
63 }
64
65 let mut max_val = 0;
67 let max_arr: &[u8; 32] = unsafe { mem::transmute(&simd_max) };
68 for &val in max_arr {
69 if val > max_val {
70 max_val = val;
71 }
72 }
73 row[ox] = max_val;
74 });
75 });
76 output
77}