#![cfg(target_arch = "x86_64")]
use crate::int8::QuantParams;
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn matmul_int8_simd(
a: &[i8],
b: &[i8],
params: QuantParams,
m: usize,
n: usize,
k: usize,
) -> Vec<i32> {
super::scalar::matmul_int8_scalar(a, b, params, m, n, k)
}
#[target_feature(enable = "avx2")]
pub unsafe fn conv2d_int8_simd(
input: &[i8],
kernel: &[i8],
params: QuantParams,
h: usize,
w: usize,
c: usize,
k: usize,
stride: usize,
) -> Vec<i32> {
super::scalar::conv2d_int8_scalar(input, kernel, params, h, w, c, k, stride)
}
#[target_feature(enable = "avx2")]
#[allow(dead_code)]
unsafe fn dot_product_int8_avx2(a: &[i8], b: &[i8]) -> i32 {
assert_eq!(a.len(), b.len());
let len = a.len();
let mut sum = 0i32;
let mut i = 0;
while i + 32 <= len {
let a_vec = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let b_vec = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
let a_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(a_vec));
let b_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(b_vec));
let prod = _mm256_mullo_epi16(a_lo, b_lo);
let prod_i32 = _mm256_madd_epi16(prod, _mm256_set1_epi16(1));
let mut temp = [0i32; 8];
_mm256_storeu_si256(temp.as_mut_ptr() as *mut __m256i, prod_i32);
sum += temp.iter().sum::<i32>();
i += 16; }
while i < len {
sum += a[i] as i32 * b[i] as i32;
i += 1;
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_simd_available() {
if is_x86_feature_detected!("avx2") {
let params = QuantParams {
scale: 1.0,
zero_point: 0,
};
let a = vec![1i8; 16];
let b = vec![1i8; 16];
let c = unsafe { matmul_int8_simd(&a, &b, params, 4, 4, 4) };
assert!(c.iter().all(|&x| x == 4));
} else {
println!("AVX2 not available, skipping SIMD test");
}
}
#[test]
fn test_conv2d_simd_available() {
if is_x86_feature_detected!("avx2") {
let params = QuantParams {
scale: 1.0,
zero_point: 0,
};
let input = vec![1i8; 5 * 5 * 1];
let kernel = vec![1i8; 3 * 3 * 1];
let output = unsafe { conv2d_int8_simd(&input, &kernel, params, 5, 5, 1, 3, 1) };
assert_eq!(output.len(), 3 * 3);
assert!(output.iter().all(|&x| x == 9));
}
}
}