use crate::ScalarF4E4;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[allow(dead_code)]
#[inline]
unsafe fn pack_15_i16_to_17bit(values: &[i16; 15]) -> __m256i {
let mut packed = [0u32; 8];
packed[0] |= values[0] as u32 & 0x1FFFF;
packed[0] |= (values[1] as u32 & 0x7FFF) << 17;
packed[1] |= (values[1] as u32 & 0x1FFFF) >> 15;
packed[1] |= (values[2] as u32 & 0x3FFF) << 2;
packed[1] |= ((values[2] as u32 & 0x1FFFF) >> 14) << 16;
packed[2] |= (values[2] as u32 & 0x1FFFF) >> 30;
_mm256_loadu_si256(packed.as_ptr() as *const __m256i)
}
#[allow(dead_code)]
#[inline]
unsafe fn unpack_17bit_to_15_i16(packed: __m256i) -> [i16; 15] {
let mut unpacked = [0i16; 15];
let bytes: [u32; 8] = core::mem::transmute(packed);
unpacked[0] = (bytes[0] & 0x1FFFF) as i16;
unpacked[1] = (((bytes[0] >> 17) | (bytes[1] << 15)) & 0x1FFFF) as i16;
unpacked[2] = (((bytes[1] >> 2) | (bytes[2] << 30)) & 0x1FFFF) as i16;
unpacked
}
#[target_feature(enable = "avx2")]
pub unsafe fn scalar_subtract_batch_avx2_8op(
a: &[ScalarF4E4; 8],
b: &[ScalarF4E4; 8],
result: &mut [ScalarF4E4; 8],
) {
let a_vec = _mm256_loadu_si256(a.as_ptr() as *const __m256i);
let b_vec = _mm256_loadu_si256(b.as_ptr() as *const __m256i);
let a_frac = _mm256_srai_epi32(_mm256_slli_epi32(a_vec, 16), 16);
let b_frac = _mm256_srai_epi32(_mm256_slli_epi32(b_vec, 16), 16);
let a_exp = _mm256_srai_epi32(a_vec, 16);
let b_exp = _mm256_srai_epi32(b_vec, 16);
let exp_diff = _mm256_sub_epi32(a_exp, b_exp);
let b_gt_a = _mm256_cmpgt_epi32(b_exp, a_exp);
let shift_amount = _mm256_abs_epi32(exp_diff);
let a_frac_shifted = _mm256_srav_epi32(a_frac, shift_amount);
let b_frac_shifted = _mm256_srav_epi32(b_frac, shift_amount);
let frac_a_final = _mm256_blendv_epi8(a_frac, a_frac_shifted, b_gt_a);
let frac_b_final = _mm256_blendv_epi8(b_frac_shifted, b_frac, b_gt_a);
let result_frac = _mm256_sub_epi32(frac_a_final, frac_b_final);
let result_exp = _mm256_blendv_epi8(a_exp, b_exp, b_gt_a);
let result_frac_i16 = _mm256_packs_epi32(result_frac, result_frac);
let result_exp_i16 = _mm256_packs_epi32(result_exp, result_exp);
let result_lo = _mm256_unpacklo_epi16(result_frac_i16, result_exp_i16);
let result_hi = _mm256_unpackhi_epi16(result_frac_i16, result_exp_i16);
let result_vec =
_mm256_permute4x64_epi64(_mm256_unpacklo_epi64(result_lo, result_hi), 0b11011000);
_mm256_storeu_si256(result.as_mut_ptr() as *mut __m256i, result_vec);
}
#[target_feature(enable = "avx2")]
pub unsafe fn scalar_subtract_batch_avx2_15op_shift(
a: &[ScalarF4E4],
b: &[ScalarF4E4],
result: &mut [ScalarF4E4],
) {
let chunks = a.len() / 15;
for i in 0..chunks {
let idx = i * 15;
let mut a_fracs = [0i16; 15];
let mut b_fracs = [0i16; 15];
for j in 0..15 {
a_fracs[j] = a[idx + j].fraction;
b_fracs[j] = b[idx + j].fraction;
}
for j in 0..15 {
result[idx + j] = a[idx + j] - b[idx + j];
}
}
for i in (chunks * 15)..a.len() {
result[i] = a[i] - b[i];
}
}
#[target_feature(enable = "avx2,bmi2")]
pub unsafe fn scalar_subtract_batch_avx2_15op_bmi2(
a: &[ScalarF4E4],
b: &[ScalarF4E4],
result: &mut [ScalarF4E4],
) {
scalar_subtract_batch_avx2_15op_shift(a, b, result);
}
#[target_feature(enable = "avx2")]
pub unsafe fn scalar_subtract_batch_avx2(
a: &[ScalarF4E4],
b: &[ScalarF4E4],
result: &mut [ScalarF4E4],
) {
let chunks = a.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let a_chunk: &[ScalarF4E4; 8] = &a[idx..idx + 8].try_into().unwrap();
let b_chunk: &[ScalarF4E4; 8] = &b[idx..idx + 8].try_into().unwrap();
let result_chunk: &mut [ScalarF4E4; 8] = &mut result[idx..idx + 8].try_into().unwrap();
scalar_subtract_batch_avx2_8op(a_chunk, b_chunk, result_chunk);
}
for i in (chunks * 8)..a.len() {
result[i] = a[i] - b[i];
}
}
#[target_feature(enable = "sse4.2")]
pub unsafe fn scalar_subtract_batch_sse42(
a: &[ScalarF4E4],
b: &[ScalarF4E4],
result: &mut [ScalarF4E4],
) {
scalar_subtract_fallback(a, b, result);
}
fn scalar_subtract_fallback(a: &[ScalarF4E4], b: &[ScalarF4E4], result: &mut [ScalarF4E4]) {
for i in 0..a.len() {
result[i] = a[i] - b[i];
}
}