use super::{A64Bytes, A8Bytes, ArrayLength};
use core::arch::asm;
#[inline]
pub fn cmov_u32(condition: bool, src: &u32, dest: &mut u32) {
unsafe {
asm!(
"test {0}, {0}",
"cmovnz {2:e}, {1:e}",
in(reg_byte) condition as u8,
in(reg) *src,
inout(reg) *dest,
);
}
}
#[inline]
pub fn cmov_u64(condition: bool, src: &u64, dest: &mut u64) {
unsafe {
asm!(
"test {0}, {0}",
"cmovnz {2}, {1}",
in(reg_byte) condition as u8,
in(reg) *src,
inout(reg) *dest,
);
}
}
#[inline]
#[cfg(target_pointer_width = "64")]
pub fn cmov_usize(condition: bool, src: &usize, dest: &mut usize) {
let src_transmuted = unsafe { core::mem::transmute::<&usize, &u64>(src) };
let dest_transmuted = unsafe { core::mem::transmute::<&mut usize, &mut u64>(dest) };
cmov_u64(condition, src_transmuted, dest_transmuted);
}
#[inline]
#[cfg(target_pointer_width = "32")]
pub fn cmov_usize(condition: bool, src: &usize, dest: &mut usize) {
let src_transmuted = unsafe { core::mem::transmute::<&usize, &u32>(src) };
let dest_transmuted = unsafe { core::mem::transmute::<&mut usize, &mut u32>(dest) };
cmov_u32(condition, src_transmuted, dest_transmuted);
}
#[inline]
pub fn cmov_i32(condition: bool, src: &i32, dest: &mut i32) {
let src_transmuted = unsafe { core::mem::transmute::<&i32, &u32>(src) };
let dest_transmuted = unsafe { core::mem::transmute::<&mut i32, &mut u32>(dest) };
cmov_u32(condition, src_transmuted, dest_transmuted);
}
#[inline]
pub fn cmov_i64(condition: bool, src: &i64, dest: &mut i64) {
let src_transmuted = unsafe { core::mem::transmute::<&i64, &u64>(src) };
let dest_transmuted = unsafe { core::mem::transmute::<&mut i64, &mut u64>(dest) };
cmov_u64(condition, src_transmuted, dest_transmuted);
}
#[inline]
pub fn cmov_a8_bytes<N: ArrayLength<u8>>(condition: bool, src: &A8Bytes<N>, dest: &mut A8Bytes<N>) {
if N::USIZE != 0 {
let count = (N::USIZE / 8) + (if 0 == N::USIZE % 8 { 0 } else { 1 });
unsafe {
cmov_byte_slice_a8(
condition,
src as *const A8Bytes<N> as *const u64,
dest as *mut A8Bytes<N> as *mut u64,
count,
)
};
}
}
#[cfg(not(target_feature = "avx2"))]
#[inline]
pub fn cmov_a64_bytes<N: ArrayLength<u8>>(
condition: bool,
src: &A64Bytes<N>,
dest: &mut A64Bytes<N>,
) {
if N::USIZE != 0 {
let count = (N::USIZE / 8) + (if 0 == N::USIZE % 8 { 0 } else { 1 });
unsafe {
cmov_byte_slice_a8(
condition,
src as *const A64Bytes<N> as *const u64,
dest as *mut A64Bytes<N> as *mut u64,
count,
)
};
}
}
#[cfg(target_feature = "avx2")]
#[inline]
pub fn cmov_a64_bytes<N: ArrayLength<u8>>(
condition: bool,
src: &A64Bytes<N>,
dest: &mut A64Bytes<N>,
) {
if N::USIZE != 0 {
let count = (N::USIZE / 64) + (if 0 == N::USIZE % 64 { 0 } else { 1 });
unsafe {
cmov_byte_slice_a64(
condition,
src as *const A64Bytes<N> as *const u64,
dest as *mut A64Bytes<N> as *mut u64,
count * 64,
)
};
}
}
#[inline]
unsafe fn cmov_byte_slice_a8(condition: bool, src: *const u64, dest: *mut u64, count: usize) {
debug_assert!(count > 0, "count cannot be 0");
asm!(
"neg {0}",
"42:",
"mov {0}, [{3} + 8*{1} - 8]",
"cmovc {0}, [{2} + 8*{1} - 8]",
"mov [{3} + 8*{1} - 8], {0}",
"dec {1}",
"jnz 42b",
inout(reg) (condition as u64) => _,
inout(reg) count => _,
in(reg) src,
in(reg) dest,
);
}
#[cfg(target_feature = "avx2")]
#[inline]
unsafe fn cmov_byte_slice_a64(condition: bool, src: *const u64, dest: *mut u64, num_bytes: usize) {
debug_assert!(num_bytes > 0, "num_bytes cannot be 0");
debug_assert!(num_bytes % 64 == 0, "num_bytes must be divisible by 64");
asm!(
"neg {0}",
"vmovq xmm2, {0}",
"vbroadcastsd ymm1, xmm2",
"mov {0}, {3}",
"42:",
"vmovdqa ymm2, [{1} + {0} - 64]",
"vpmaskmovq [{2} + {0} - 64], ymm1, ymm2",
"vmovdqa ymm3, [{1} + {0} - 32]",
"vpmaskmovq [{2} + {0} - 32], ymm1, ymm3",
"sub {0}, 64",
"jnz 42b",
inout(reg) condition as u64 => _,
in(reg) src,
in(reg) dest,
in(reg) num_bytes,
out("ymm1") _,
out("ymm2") _,
out("ymm3") _,
);
}