1#![allow(dead_code)]
30use rayon::prelude::*;
31use std::arch::x86_64::*;
32use std::mem;
33
34#[cfg(target_arch = "x86_64")]
36use std::arch::x86_64 as arch;
37
38#[cfg(target_arch = "aarch64")]
39use std::arch::aarch64 as arch;
40
41pub fn supports_avx2() -> bool {
43 #[cfg(target_arch = "x86_64")]
44 unsafe {
45 arch::_xgetbv(0) & 0x6 == 0x6
46 }
47
48 #[cfg(target_arch = "aarch64")]
49 true }
51pub fn max_pooling_simd(input: &[u8], width: usize, factor: usize) -> (usize, usize, Vec<u8>) {
75 let output_width = width / factor;
77 let output_height = input.len() / (width * factor);
78 let mut output = vec![0; output_width * output_height];
79
80 output
82 .par_chunks_mut(output_width)
83 .enumerate()
84 .for_each(|(oy, row)| {
85 let start_y = oy * factor;
87 let end_y = start_y + factor;
88
89 (0..output_width).for_each(|ox| {
91 let start_x = ox * factor;
92
93 let mut simd_max = unsafe { _mm256_setzero_si256() };
95
96 for y in start_y..end_y {
98 let row_start = y * width + start_x;
99 let chunk = &input[row_start..row_start + factor];
100
101 for chunk32 in chunk.chunks_exact(32) {
103 let data =
104 unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
105 simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
106 }
107
108 let remainder = chunk.chunks_exact(32).remainder();
110 if !remainder.is_empty() {
111 let mut buffer = [0u8; 32];
113 buffer[..remainder.len()].copy_from_slice(remainder);
114 let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
115 simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
116 }
117 }
118
119 let mut max_val = 0;
121 let max_arr: &[u8; 32] = unsafe { mem::transmute(&simd_max) };
122 for &val in max_arr {
123 if val > max_val {
124 max_val = val;
125 }
126 }
127 row[ox] = max_val; });
129 });
130 (output_width, output_height, output)
131}
132
133pub fn min_pooling_simd(input: &[u8], width: usize, factor: usize) -> (usize, usize, Vec<u8>) {
157 let output_width = width / factor;
159 let output_height = input.len() / (width * factor);
160 let mut output = vec![255; output_width * output_height]; output
164 .par_chunks_mut(output_width)
165 .enumerate()
166 .for_each(|(oy, row)| {
167 let start_y = oy * factor;
169 let end_y = start_y + factor;
170
171 (0..output_width).for_each(|ox| {
173 let start_x = ox * factor;
174
175 let mut simd_min = unsafe { _mm256_set1_epi8(255u8 as i8) };
177
178 for y in start_y..end_y {
180 let row_start = y * width + start_x;
181 let chunk = &input[row_start..row_start + factor];
182
183 for chunk32 in chunk.chunks_exact(32) {
185 let data =
186 unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
187 simd_min = unsafe { _mm256_min_epu8(simd_min, data) };
188 }
189
190 let remainder = chunk.chunks_exact(32).remainder();
192 if !remainder.is_empty() {
193 let mut buffer = [255u8; 32];
195 buffer[..remainder.len()].copy_from_slice(remainder);
196 let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
197 simd_min = unsafe { _mm256_min_epu8(simd_min, data) };
198 }
199 }
200
201 let mut min_val = 255;
203 let min_arr: &[u8; 32] = unsafe { mem::transmute(&simd_min) };
204 for &val in min_arr {
205 if val < min_val {
206 min_val = val;
207 }
208 }
209 row[ox] = min_val; });
211 });
212 (output_width, output_height, output)
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_max_pooling_simple() {
221 let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
223 let width = 3;
224 let factor = 3;
225
226 let (output_width, output_height, output) = max_pooling_simd(&input, width, factor);
227
228 assert_eq!(output_width, 1);
229 assert_eq!(output_height, 1);
230 assert_eq!(output.len(), 1);
231 assert_eq!(output[0], 9); }
233
234 #[test]
235 fn test_min_pooling_simple() {
236 let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
238 let width = 3;
239 let factor = 3;
240
241 let (output_width, output_height, output) = min_pooling_simd(&input, width, factor);
242
243 assert_eq!(output_width, 1);
244 assert_eq!(output_height, 1);
245 assert_eq!(output.len(), 1);
246 assert_eq!(output[0], 1); }
248
249 #[test]
250 fn test_max_pooling_2x2() {
251 let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
253 let width = 4;
254 let factor = 2;
255
256 let (output_width, output_height, output) = max_pooling_simd(&input, width, factor);
257
258 assert_eq!(output_width, 2);
259 assert_eq!(output_height, 2);
260 assert_eq!(output.len(), 4);
261
262 assert_eq!(output[0], 6); assert_eq!(output[1], 8); assert_eq!(output[2], 14); assert_eq!(output[3], 16); }
268
269 #[test]
270 fn test_min_pooling_2x2() {
271 let input = vec![16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
273 let width = 4;
274 let factor = 2;
275
276 let (output_width, output_height, output) = min_pooling_simd(&input, width, factor);
277
278 assert_eq!(output_width, 2);
279 assert_eq!(output_height, 2);
280 assert_eq!(output.len(), 4);
281
282 assert_eq!(output[0], 11); assert_eq!(output[1], 9); assert_eq!(output[2], 3); assert_eq!(output[3], 1); }
288
289 #[test]
290 fn test_pooling_edge_values() {
291 let input = vec![
293 0, 255, 0, 255, 255, 0, 255, 0, 0, 255, 0, 255, 255, 0, 255, 0,
294 ];
295 let width = 4;
296 let factor = 2;
297
298 let (_, _, max_output) = max_pooling_simd(&input, width, factor);
300 assert!(max_output.iter().all(|&x| x == 255));
301
302 let (_, _, min_output) = min_pooling_simd(&input, width, factor);
304 assert!(min_output.iter().all(|&x| x == 0));
305 }
306
307 #[test]
308 fn test_pooling_identical_values() {
309 let input = vec![100; 36]; let width = 6;
312 let factor = 3;
313
314 let (max_w, max_h, max_output) = max_pooling_simd(&input, width, factor);
315 let (min_w, min_h, min_output) = min_pooling_simd(&input, width, factor);
316
317 assert_eq!(max_w, min_w);
318 assert_eq!(max_h, min_h);
319 assert_eq!(max_output.len(), min_output.len());
320
321 assert!(max_output.iter().all(|&x| x == 100));
323 assert!(min_output.iter().all(|&x| x == 100));
324 }
325}