use core::arch::asm;
#[inline]
#[target_feature(enable = "sve2-bitperm")]
pub unsafe fn bdep_u64(data: u64, mask: u64) -> u64 {
let result: u64;
asm!(
"fmov d0, {data}", "fmov d1, {mask}", "bdep z0.d, z0.d, z1.d", "fmov {result}, d0", data = in(reg) data,
mask = in(reg) mask,
result = out(reg) result,
options(pure, nomem, nostack)
);
result
}
#[inline]
#[allow(dead_code)]
#[target_feature(enable = "sve2-bitperm")]
pub unsafe fn bext_u64(data: u64, mask: u64) -> u64 {
let result: u64;
asm!(
"fmov d0, {data}", "fmov d1, {mask}", "bext z0.d, z0.d, z1.d", "fmov {result}, d0", data = in(reg) data,
mask = in(reg) mask,
result = out(reg) result,
options(pure, nomem, nostack)
);
result
}
#[inline]
#[target_feature(enable = "sve2-bitperm")]
pub unsafe fn toggle64_sve2(carry: u64, quote_mask: u64) -> (u64, u64) {
const ODDS_MASK: u64 = 0x5555_5555_5555_5555;
let c = carry & 0x1;
let addend = bdep_u64(ODDS_MASK << c, quote_mask);
let comp_w = !quote_mask;
let shifted = (addend << 1) | c;
let (result, overflow) = shifted.overflowing_add(comp_w);
let new_carry = if overflow { 1 } else { 0 };
(result, new_carry)
}
#[inline]
#[target_feature(enable = "sve2-bitperm")]
pub unsafe fn select_in_word_bdep(x: u64, k: u32) -> u32 {
if x == 0 {
return 64;
}
let pop = x.count_ones();
if k >= pop {
return 64;
}
let mask = if k >= 63 {
u64::MAX
} else {
(1u64 << (k + 1)) - 1
};
let scattered = bdep_u64(mask, x);
if scattered == 0 {
return 64;
}
63 - scattered.leading_zeros()
}
#[cfg(feature = "std")]
#[inline]
#[allow(dead_code)]
pub fn has_sve2_bitperm() -> bool {
std::arch::is_aarch64_feature_detected!("sve2-bitperm")
}
#[cfg(not(feature = "std"))]
#[inline]
#[allow(dead_code)]
pub const fn has_sve2_bitperm() -> bool {
false
}
#[cfg(test)]
mod tests {
use super::*;
fn has_sve2() -> bool {
#[cfg(feature = "std")]
{
std::arch::is_aarch64_feature_detected!("sve2-bitperm")
}
#[cfg(not(feature = "std"))]
{
false
}
}
#[test]
fn test_bdep_basic() {
if !has_sve2() {
eprintln!("Skipping SVE2 test: CPU doesn't support sve2-bitperm");
return;
}
unsafe {
assert_eq!(bdep_u64(0xFFFF, 0), 0);
assert_eq!(bdep_u64(0xF, 0xF), 0xF);
assert_eq!(bdep_u64(0b1010, 0b01010101), 0b01000100);
}
}
#[test]
fn test_bext_basic() {
if !has_sve2() {
eprintln!("Skipping SVE2 test: CPU doesn't support sve2-bitperm");
return;
}
unsafe {
assert_eq!(bext_u64(0xFFFF, 0), 0);
assert_eq!(bext_u64(0b01010101, 0b01010101), 0b1111);
assert_eq!(bext_u64(0b11110000, 0b11110000), 0b1111);
}
}
#[test]
fn test_bdep_bext_roundtrip() {
if !has_sve2() {
return;
}
unsafe {
let mask = 0b10101010_10101010u64;
let data = 0b11110000_11110000u64;
let extracted = bext_u64(data, mask);
let deposited = bdep_u64(extracted, mask);
assert_eq!(deposited, data & mask);
}
}
#[test]
fn test_toggle64_no_quotes() {
if !has_sve2() {
return;
}
unsafe {
let (mask, carry) = toggle64_sve2(0, 0);
assert_eq!(carry, 0, "No quotes should not change carry");
assert_eq!(mask, !0u64, "No quotes means all outside");
}
}
#[test]
fn test_toggle64_single_quote() {
if !has_sve2() {
return;
}
unsafe {
let (_mask, carry) = toggle64_sve2(0, 1);
assert_eq!(carry, 1, "Odd quotes should set carry");
}
}
#[test]
fn test_toggle64_matches_prefix_xor() {
if !has_sve2() {
return;
}
fn toggle64_reference(carry: u64, quote_mask: u64) -> (u64, u64) {
const ODDS_MASK: u64 = 0x5555_5555_5555_5555;
let c = carry & 0x1;
let mut addend = 0u64;
let mut src_bit = 0;
for i in 0..64 {
if (quote_mask >> i) & 1 == 1 {
let bit = (ODDS_MASK << c >> src_bit) & 1;
addend |= bit << i;
src_bit += 1;
}
}
let comp_w = !quote_mask;
let shifted = (addend << 1) | c;
let (result, overflow) = shifted.overflowing_add(comp_w);
let new_carry = if overflow { 1 } else { 0 };
(result, new_carry)
}
let patterns = [
0u64,
1,
0b11,
0b101,
0b1001,
0x8000_0000_0000_0000,
0xAAAA_AAAA_AAAA_AAAA,
0x5555_5555_5555_5555,
0xFF00_FF00_FF00_FF00,
];
for "e_mask in &patterns {
for carry in [0u64, 1] {
unsafe {
let (sve2_mask, sve2_carry) = toggle64_sve2(carry, quote_mask);
let (ref_mask, ref_carry) = toggle64_reference(carry, quote_mask);
assert_eq!(
sve2_mask, ref_mask,
"Mask mismatch for quote_mask={:#x}, carry={}",
quote_mask, carry
);
assert_eq!(
sve2_carry, ref_carry,
"Carry mismatch for quote_mask={:#x}, carry={}",
quote_mask, carry
);
}
}
}
}
#[test]
fn test_select_in_word_bdep_basic() {
if !has_sve2() {
eprintln!("Skipping SVE2 test: CPU doesn't support sve2-bitperm");
return;
}
unsafe {
assert_eq!(select_in_word_bdep(0, 0), 64);
assert_eq!(select_in_word_bdep(1, 0), 0);
assert_eq!(select_in_word_bdep(1, 1), 64);
assert_eq!(select_in_word_bdep(1 << 63, 0), 63);
let word = 0b1010_1010u64;
assert_eq!(select_in_word_bdep(word, 0), 1);
assert_eq!(select_in_word_bdep(word, 1), 3);
assert_eq!(select_in_word_bdep(word, 2), 5);
assert_eq!(select_in_word_bdep(word, 3), 7);
assert_eq!(select_in_word_bdep(word, 4), 64);
}
}
#[test]
fn test_select_in_word_bdep_all_ones() {
if !has_sve2() {
return;
}
unsafe {
let word = u64::MAX;
for k in 0..64 {
assert_eq!(select_in_word_bdep(word, k), k, "k={}", k);
}
assert_eq!(select_in_word_bdep(word, 64), 64);
}
}
#[test]
fn test_select_in_word_bdep_matches_ctz() {
if !has_sve2() {
return;
}
fn select_ctz(x: u64, k: u32) -> u32 {
let mut val = x;
let mut remaining = k;
loop {
if val == 0 {
return 64;
}
let t = val.trailing_zeros();
if remaining == 0 {
return t;
}
remaining -= 1;
val &= val - 1;
}
}
let patterns = [
0u64,
1,
0xFF,
0x8000_0000_0000_0000,
u64::MAX,
0xAAAA_AAAA_AAAA_AAAA,
0x5555_5555_5555_5555,
0x1234_5678_9ABC_DEF0,
];
for &word in &patterns {
let pop = word.count_ones();
for k in 0..=pop {
unsafe {
let bdep_result = select_in_word_bdep(word, k);
let ctz_result = select_ctz(word, k);
assert_eq!(
bdep_result, ctz_result,
"Mismatch for word={:#x}, k={}",
word, k
);
}
}
}
}
}