#[allow(clippy::assertions_on_constants)]
pub(crate) fn find_bit<F: FnOnce(usize) -> usize>(value: usize, rank_fn: F) -> usize {
const P: usize = usize::BITS.trailing_zeros() as usize; const M: [usize; P] = sum_masks();
const _: () = assert!(usize::BITS.is_power_of_two());
const _: () = assert!(P >= 2);
let mut sum = [0; P + 1];
sum[0] = value;
sum[1] = value - ((value >> 1) & M[0]);
sum[2] = (sum[1] & M[1]) + ((sum[1] >> 2) & M[1]);
for p in 2..P {
sum[p + 1] = (sum[p] + (sum[p] >> (1 << p))) & M[p];
}
let mut rank = rank_fn(sum[P]);
let mut shift = 0usize;
for p in (0..P).rev() {
let sub_mask = (1 << (1 << p)) - 1;
let lower_sum = (sum[p] >> shift) & sub_mask;
let cmp_mask = ((lower_sum as isize - rank as isize) >> (isize::BITS - 1)) as usize;
rank -= lower_sum & cmp_mask;
shift += (1 << p) & cmp_mask;
}
shift
}
#[allow(clippy::assertions_on_constants)]
const fn sum_masks() -> [usize; usize::BITS.trailing_zeros() as usize] {
const P: usize = usize::BITS.trailing_zeros() as usize; const _: () = assert!(
usize::BITS == 1 << P,
"sum masks are only supported for `usize` with a power-of-two bit width"
);
let mut m = [0usize; P];
let mut p = 0;
while p != P {
m[p] = !0 / (1 + (1 << (1 << p)));
p += 1;
}
m
}
#[cfg(all(test, not(nexosim_loom), not(miri)))]
mod tests {
use super::*;
use crate::util::rng;
#[test]
fn find_bit_fuzz() {
const SAMPLES: usize = 10_000;
#[inline(always)]
fn check(value: usize) {
let bitsum = value.count_ones() as usize;
for rank in 1..=bitsum {
let pos = find_bit(value, |s| {
assert_eq!(s, bitsum);
rank
});
assert!(
value & (1 << pos) != 0,
"input value: {value:064b}\nrequested rank: {rank}\nreturned position: {pos}"
);
assert_eq!(
rank,
(value & ((1 << pos) - 1)).count_ones() as usize + 1,
"input value: {value:064b}\nrequested rank: {rank}\nreturned position: {pos}"
);
}
}
let pos = find_bit(0, |s| {
assert_eq!(s, 0);
0
});
assert_eq!(pos, 0);
check(1);
check(1 << (usize::BITS - 1));
check(usize::MAX);
let rng = rng::Rng::new(12345);
for _ in 0..SAMPLES {
let mut r = rng.rand() as usize;
let mut shift = 64;
while shift < usize::BITS {
r |= (rng.rand() as usize) << shift;
shift += 64;
}
check(r);
}
}
}