#![allow(dead_code)]
use rayon::prelude::*;
use std::arch::x86_64::*;
use std::mem;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64 as arch;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64 as arch;
pub fn supports_avx2() -> bool {
#[cfg(target_arch = "x86_64")]
unsafe {
arch::_xgetbv(0) & 0x6 == 0x6
}
#[cfg(target_arch = "aarch64")]
true }
pub fn max_pooling_simd(input: &[u8], width: usize, factor: usize) -> (usize, usize, Vec<u8>) {
let output_width = width / factor;
let output_height = input.len() / (width * factor);
let mut output = vec![0; output_width * output_height];
output
.par_chunks_mut(output_width)
.enumerate()
.for_each(|(oy, row)| {
let start_y = oy * factor;
let end_y = start_y + factor;
(0..output_width).for_each(|ox| {
let start_x = ox * factor;
let mut simd_max = unsafe { _mm256_setzero_si256() };
for y in start_y..end_y {
let row_start = y * width + start_x;
let chunk = &input[row_start..row_start + factor];
for chunk32 in chunk.chunks_exact(32) {
let data =
unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
}
let remainder = chunk.chunks_exact(32).remainder();
if !remainder.is_empty() {
let mut buffer = [0u8; 32];
buffer[..remainder.len()].copy_from_slice(remainder);
let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
simd_max = unsafe { _mm256_max_epu8(simd_max, data) };
}
}
let mut max_val = 0;
let max_arr: &[u8; 32] = unsafe { mem::transmute(&simd_max) };
for &val in max_arr {
if val > max_val {
max_val = val;
}
}
row[ox] = max_val; });
});
(output_width, output_height, output)
}
pub fn min_pooling_simd(input: &[u8], width: usize, factor: usize) -> (usize, usize, Vec<u8>) {
let output_width = width / factor;
let output_height = input.len() / (width * factor);
let mut output = vec![255; output_width * output_height];
output
.par_chunks_mut(output_width)
.enumerate()
.for_each(|(oy, row)| {
let start_y = oy * factor;
let end_y = start_y + factor;
(0..output_width).for_each(|ox| {
let start_x = ox * factor;
let mut simd_min = unsafe { _mm256_set1_epi8(255u8 as i8) };
for y in start_y..end_y {
let row_start = y * width + start_x;
let chunk = &input[row_start..row_start + factor];
for chunk32 in chunk.chunks_exact(32) {
let data =
unsafe { _mm256_loadu_si256(chunk32.as_ptr() as *const __m256i) };
simd_min = unsafe { _mm256_min_epu8(simd_min, data) };
}
let remainder = chunk.chunks_exact(32).remainder();
if !remainder.is_empty() {
let mut buffer = [255u8; 32];
buffer[..remainder.len()].copy_from_slice(remainder);
let data = unsafe { _mm256_loadu_si256(buffer.as_ptr() as *const __m256i) };
simd_min = unsafe { _mm256_min_epu8(simd_min, data) };
}
}
let mut min_val = 255;
let min_arr: &[u8; 32] = unsafe { mem::transmute(&simd_min) };
for &val in min_arr {
if val < min_val {
min_val = val;
}
}
row[ox] = min_val; });
});
(output_width, output_height, output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_max_pooling_simple() {
let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let width = 3;
let factor = 3;
let (output_width, output_height, output) = max_pooling_simd(&input, width, factor);
assert_eq!(output_width, 1);
assert_eq!(output_height, 1);
assert_eq!(output.len(), 1);
assert_eq!(output[0], 9); }
#[test]
fn test_min_pooling_simple() {
let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let width = 3;
let factor = 3;
let (output_width, output_height, output) = min_pooling_simd(&input, width, factor);
assert_eq!(output_width, 1);
assert_eq!(output_height, 1);
assert_eq!(output.len(), 1);
assert_eq!(output[0], 1); }
#[test]
fn test_max_pooling_2x2() {
let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let width = 4;
let factor = 2;
let (output_width, output_height, output) = max_pooling_simd(&input, width, factor);
assert_eq!(output_width, 2);
assert_eq!(output_height, 2);
assert_eq!(output.len(), 4);
assert_eq!(output[0], 6); assert_eq!(output[1], 8); assert_eq!(output[2], 14); assert_eq!(output[3], 16); }
#[test]
fn test_min_pooling_2x2() {
let input = vec![16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
let width = 4;
let factor = 2;
let (output_width, output_height, output) = min_pooling_simd(&input, width, factor);
assert_eq!(output_width, 2);
assert_eq!(output_height, 2);
assert_eq!(output.len(), 4);
assert_eq!(output[0], 11); assert_eq!(output[1], 9); assert_eq!(output[2], 3); assert_eq!(output[3], 1); }
#[test]
fn test_pooling_edge_values() {
let input = vec![
0, 255, 0, 255, 255, 0, 255, 0, 0, 255, 0, 255, 255, 0, 255, 0,
];
let width = 4;
let factor = 2;
let (_, _, max_output) = max_pooling_simd(&input, width, factor);
assert!(max_output.iter().all(|&x| x == 255));
let (_, _, min_output) = min_pooling_simd(&input, width, factor);
assert!(min_output.iter().all(|&x| x == 0));
}
#[test]
fn test_pooling_identical_values() {
let input = vec![100; 36]; let width = 6;
let factor = 3;
let (max_w, max_h, max_output) = max_pooling_simd(&input, width, factor);
let (min_w, min_h, min_output) = min_pooling_simd(&input, width, factor);
assert_eq!(max_w, min_w);
assert_eq!(max_h, min_h);
assert_eq!(max_output.len(), min_output.len());
assert!(max_output.iter().all(|&x| x == 100));
assert!(min_output.iter().all(|&x| x == 100));
}
}