image_max_polling/
lib.rs

1#![allow(dead_code)]
2use rayon::prelude::*;
3use std::arch::x86_64::*;
4use std::mem;
5
6// 条件编译选择指令集
7#[cfg(target_arch = "x86_64")]
8use std::arch::x86_64 as arch;
9
10#[cfg(target_arch = "aarch64")]
11use std::arch::aarch64 as arch;
12
13// 动态检测 CPU 特性
14pub fn supports_avx2() -> bool {
15    #[cfg(target_arch = "x86_64")]
16    unsafe {
17        arch::_xgetbv(0) & 0x6 == 0x6
18    }
19
20    #[cfg(target_arch = "aarch64")]
21    true // ARM 默认启用 NEON
22}
23
24/// SIMD 加速的最大值池化
25pub fn max_pooling_simd(input: &[u8], width: usize, factor: usize) -> Vec<u8> {
26    let output_width = width / factor;
27    let output_height = input.len() / (width * factor);
28    let mut output = vec![0; output_width * output_height];
29
30    // 分块并行处理
31    output
32        .par_chunks_mut(output_width)
33        .enumerate()
34        .for_each(|(oy, row)| {
35            let start_y = oy * factor;
36            let end_y = start_y + factor;
37
38            (0..output_width).for_each(|ox| {
39                let start_x = ox * factor;
40                // let end_x = start_x + factor;
41
42                // 使用 SIMD 计算局部最大值
43                let mut simd_max = unsafe { _mm256_setzero_si256() };
44                for y in start_y..end_y {
45                    let row_start = y * width + start_x;
46                    let chunk = &input[row_start..row_start + factor];
47
48                    // 每 32 字节一组处理
49                    for chunk32 in chunk.chunks_exact(32) {
50                        let data =
51                            unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
52                        simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
53                    }
54
55                    // 处理剩余不足 32 字节的部分
56                    let remainder = chunk.chunks_exact(32).remainder();
57                    if !remainder.is_empty() {
58                        let mut buffer = [0u8; 32];
59                        buffer[..remainder.len()].copy_from_slice(remainder);
60                        let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
61                        simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
62                    }
63                }
64
65                // 提取 SIMD 结果中的最大值
66                let mut max_val = 0;
67                let max_arr: &[u8; 32] = unsafe { mem::transmute(&simd_max) };
68                for &val in max_arr {
69                    if val > max_val {
70                        max_val = val;
71                    }
72                }
73                row[ox] = max_val;
74            });
75        });
76    output
77}