Skip to main content

image_max_polling/
lib.rs

1//! # Image Max & Min Pooling with SIMD Acceleration
2//!
3//! A high-performance Rust library for maximum and minimum pooling operations on images,
4//! leveraging SIMD instructions (AVX2/NEON) and parallel processing for accelerated performance.
5//!
6//! ## Features
7//!
8//! - **SIMD Optimization**: Utilizes AVX2 (x86-64) or NEON (ARM) intrinsics
9//! - **Dual Pooling Operations**: Supports both maximum and minimum pooling
10//! - **Parallel Execution**: Multi-threaded processing via Rayon
11//! - **Dynamic CPU Detection**: Runtime checks for AVX2 support
12//!
13//! ## Quick Start
14//!
15//! ```rust
16//! use image_max_polling::{max_pooling_simd, min_pooling_simd};
17//!
18//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
19//! let width = 3;
20//! let factor = 3;
21//!
22//! // Maximum pooling - extracts brightest features
23//! let (_, _, max_result) = max_pooling_simd(&data, width, factor);
24//!
25//! // Minimum pooling - extracts darkest features
26//! let (_, _, min_result) = min_pooling_simd(&data, width, factor);
27//! ```
28
29#![allow(dead_code)]
30use rayon::prelude::*;
31use std::arch::x86_64::*;
32use std::mem;
33
34// 条件编译选择指令集
35#[cfg(target_arch = "x86_64")]
36use std::arch::x86_64 as arch;
37
38#[cfg(target_arch = "aarch64")]
39use std::arch::aarch64 as arch;
40
41// 动态检测 CPU 特性
42pub fn supports_avx2() -> bool {
43    #[cfg(target_arch = "x86_64")]
44    unsafe {
45        arch::_xgetbv(0) & 0x6 == 0x6
46    }
47
48    #[cfg(target_arch = "aarch64")]
49    true // ARM 默认启用 NEON
50}
51/// 使用 SIMD 指令优化的最大池化函数
52///
53/// # 详细说明
54/// 对输入的二维图像数据执行最大池化操作,使用 AVX2 SIMD 指令集并行处理以提升性能。
55/// 池化操作将输入图像按指定因子缩小,每个输出像素对应输入区域的最大值。
56///
57/// # 参数
58/// - `input`: 输入图像数据的一维数组表示(按行优先存储)
59/// - `width`: 输入图像的宽度(像素数)
60/// - `factor`: 池化因子,决定缩放比例(如 2 表示 2x2 池化)
61///
62/// # 返回值
63/// 返回元组 (output_width, output_height, output_data):
64/// - `output_width`: 输出图像宽度
65/// - `output_height`: 输出图像高度
66/// - `output_data`: 池化后的图像数据
67///
68/// # 示例
69/// ```rust
70/// use image_max_polling::max_pooling_simd;
71/// let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
72/// let (w, h, result) = max_pooling_simd(&input, 3, 1);
73/// ```
74pub fn max_pooling_simd(input: &[u8], width: usize, factor: usize) -> (usize, usize, Vec<u8>) {
75    // 计算输出图像尺寸
76    let output_width = width / factor;
77    let output_height = input.len() / (width * factor);
78    let mut output = vec![0; output_width * output_height];
79
80    // 使用 Rayon 进行分块并行处理,每个线程处理一行输出
81    output
82        .par_chunks_mut(output_width)
83        .enumerate()
84        .for_each(|(oy, row)| {
85            // 计算当前输出行对应的输入行范围
86            let start_y = oy * factor;
87            let end_y = start_y + factor;
88
89            // 遍历当前行的每个输出像素
90            (0..output_width).for_each(|ox| {
91                let start_x = ox * factor;
92
93                // 初始化 AVX2 SIMD 寄存器用于存储最大值
94                let mut simd_max = unsafe { _mm256_setzero_si256() };
95
96                // 遍历池化窗口内的所有行
97                for y in start_y..end_y {
98                    let row_start = y * width + start_x;
99                    let chunk = &input[row_start..row_start + factor];
100
101                    // 使用 SIMD 处理:每次加载 32 字节(256位)进行并行比较
102                    for chunk32 in chunk.chunks_exact(32) {
103                        let data =
104                            unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
105                        simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
106                    }
107
108                    // 处理不足 32 字节的剩余数据
109                    let remainder = chunk.chunks_exact(32).remainder();
110                    if !remainder.is_empty() {
111                        // 将剩余数据复制到对齐的缓冲区
112                        let mut buffer = [0u8; 32];
113                        buffer[..remainder.len()].copy_from_slice(remainder);
114                        let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
115                        simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
116                    }
117                }
118
119                // 从 SIMD 寄存器中提取最终的最大值
120                let mut max_val = 0;
121                let max_arr: &[u8; 32] = unsafe { mem::transmute(&simd_max) };
122                for &val in max_arr {
123                    if val > max_val {
124                        max_val = val;
125                    }
126                }
127                row[ox] = max_val; // 存储到输出数组
128            });
129        });
130    (output_width, output_height, output)
131}
132
133/// 使用 SIMD 指令优化的最小池化函数
134///
135/// # 详细说明
136/// 对输入的二维图像数据执行最小池化操作,使用 AVX2 SIMD 指令集并行处理以提升性能。
137/// 池化操作将输入图像按指定因子缩小,每个输出像素对应输入区域的最小值。
138///
139/// # 参数
140/// - `input`: 输入图像数据的一维数组表示(按行优先存储)
141/// - `width`: 输入图像的宽度(像素数)
142/// - `factor`: 池化因子,决定缩放比例(如 2 表示 2x2 池化)
143///
144/// # 返回值
145/// 返回元组 (output_width, output_height, output_data):
146/// - `output_width`: 输出图像宽度
147/// - `output_height`: 输出图像高度
148/// - `output_data`: 池化后的图像数据
149///
150/// # 示例
151/// ```rust
152/// use image_max_polling::min_pooling_simd;
153/// let input = vec![9, 8, 7, 6, 5, 4, 3, 2, 1];
154/// let (w, h, result) = min_pooling_simd(&input, 3, 1);
155/// ```
156pub fn min_pooling_simd(input: &[u8], width: usize, factor: usize) -> (usize, usize, Vec<u8>) {
157    // 计算输出图像尺寸
158    let output_width = width / factor;
159    let output_height = input.len() / (width * factor);
160    let mut output = vec![255; output_width * output_height]; // 初始化为最大值 255
161
162    // 使用 Rayon 进行分块并行处理,每个线程处理一行输出
163    output
164        .par_chunks_mut(output_width)
165        .enumerate()
166        .for_each(|(oy, row)| {
167            // 计算当前输出行对应的输入行范围
168            let start_y = oy * factor;
169            let end_y = start_y + factor;
170
171            // 遍历当前行的每个输出像素
172            (0..output_width).for_each(|ox| {
173                let start_x = ox * factor;
174
175                // 初始化 AVX2 SIMD 寄存器用于存储最小值(设为 255,即最大值)
176                let mut simd_min = unsafe { _mm256_set1_epi8(255u8 as i8) };
177
178                // 遍历池化窗口内的所有行
179                for y in start_y..end_y {
180                    let row_start = y * width + start_x;
181                    let chunk = &input[row_start..row_start + factor];
182
183                    // 使用 SIMD 处理:每次加载 32 字节(256位)进行并行比较
184                    for chunk32 in chunk.chunks_exact(32) {
185                        let data =
186                            unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
187                        simd_min = unsafe { _mm256_min_epu8(simd_min, data) };
188                    }
189
190                    // 处理不足 32 字节的剩余数据
191                    let remainder = chunk.chunks_exact(32).remainder();
192                    if !remainder.is_empty() {
193                        // 将剩余数据复制到对齐的缓冲区,填充 255(最大值)
194                        let mut buffer = [255u8; 32];
195                        buffer[..remainder.len()].copy_from_slice(remainder);
196                        let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
197                        simd_min = unsafe { _mm256_min_epu8(simd_min, data) };
198                    }
199                }
200
201                // 从 SIMD 寄存器中提取最终的最小值
202                let mut min_val = 255;
203                let min_arr: &[u8; 32] = unsafe { mem::transmute(&simd_min) };
204                for &val in min_arr {
205                    if val < min_val {
206                        min_val = val;
207                    }
208                }
209                row[ox] = min_val; // 存储到输出数组
210            });
211        });
212    (output_width, output_height, output)
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_max_pooling_simple() {
221        // 创建一个简单的 3x3 图像
222        let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
223        let width = 3;
224        let factor = 3;
225
226        let (output_width, output_height, output) = max_pooling_simd(&input, width, factor);
227
228        assert_eq!(output_width, 1);
229        assert_eq!(output_height, 1);
230        assert_eq!(output.len(), 1);
231        assert_eq!(output[0], 9); // 最大值应该是 9
232    }
233
234    #[test]
235    fn test_min_pooling_simple() {
236        // 创建一个简单的 3x3 图像
237        let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
238        let width = 3;
239        let factor = 3;
240
241        let (output_width, output_height, output) = min_pooling_simd(&input, width, factor);
242
243        assert_eq!(output_width, 1);
244        assert_eq!(output_height, 1);
245        assert_eq!(output.len(), 1);
246        assert_eq!(output[0], 1); // 最小值应该是 1
247    }
248
249    #[test]
250    fn test_max_pooling_2x2() {
251        // 创建一个 4x4 图像,进行 2x2 池化
252        let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
253        let width = 4;
254        let factor = 2;
255
256        let (output_width, output_height, output) = max_pooling_simd(&input, width, factor);
257
258        assert_eq!(output_width, 2);
259        assert_eq!(output_height, 2);
260        assert_eq!(output.len(), 4);
261
262        // 每个 2x2 区域的最大值
263        assert_eq!(output[0], 6); // max(1,2,5,6)
264        assert_eq!(output[1], 8); // max(3,4,7,8)
265        assert_eq!(output[2], 14); // max(9,10,13,14)
266        assert_eq!(output[3], 16); // max(11,12,15,16)
267    }
268
269    #[test]
270    fn test_min_pooling_2x2() {
271        // 创建一个 4x4 图像,进行 2x2 池化
272        let input = vec![16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
273        let width = 4;
274        let factor = 2;
275
276        let (output_width, output_height, output) = min_pooling_simd(&input, width, factor);
277
278        assert_eq!(output_width, 2);
279        assert_eq!(output_height, 2);
280        assert_eq!(output.len(), 4);
281
282        // 每个 2x2 区域的最小值
283        assert_eq!(output[0], 11); // min(16,15,12,11)
284        assert_eq!(output[1], 9); // min(14,13,10,9)
285        assert_eq!(output[2], 3); // min(8,7,4,3)
286        assert_eq!(output[3], 1); // min(6,5,2,1)
287    }
288
289    #[test]
290    fn test_pooling_edge_values() {
291        // 测试边界值:0 和 255
292        let input = vec![
293            0, 255, 0, 255, 255, 0, 255, 0, 0, 255, 0, 255, 255, 0, 255, 0,
294        ];
295        let width = 4;
296        let factor = 2;
297
298        // 测试最大池化
299        let (_, _, max_output) = max_pooling_simd(&input, width, factor);
300        assert!(max_output.iter().all(|&x| x == 255));
301
302        // 测试最小池化
303        let (_, _, min_output) = min_pooling_simd(&input, width, factor);
304        assert!(min_output.iter().all(|&x| x == 0));
305    }
306
307    #[test]
308    fn test_pooling_identical_values() {
309        // 测试所有值相同的情况
310        let input = vec![100; 36]; // 6x6 全是 100
311        let width = 6;
312        let factor = 3;
313
314        let (max_w, max_h, max_output) = max_pooling_simd(&input, width, factor);
315        let (min_w, min_h, min_output) = min_pooling_simd(&input, width, factor);
316
317        assert_eq!(max_w, min_w);
318        assert_eq!(max_h, min_h);
319        assert_eq!(max_output.len(), min_output.len());
320
321        // 所有值都应该是 100
322        assert!(max_output.iter().all(|&x| x == 100));
323        assert!(min_output.iter().all(|&x| x == 100));
324    }
325}