use std::mem::MaybeUninit;
use std::ptr::NonNull;
use crate::{
BitsIter, WordBits,
heap::{bit_in_word_index, last_word_mask, word_index, words_for},
};
macro_rules! arr {
($cnt:literal, mut $exp:expr) => {
<&mut [_; $cnt] as TryFrom<&mut [_]>>::try_from($exp).unwrap()
};
}
pub fn unary_for_each_carry<A, T>(value: &mut [A], init: T, mut func: impl FnMut(&mut T, &mut A)) -> T {
let dummy = unsafe { std::slice::from_raw_parts(NonNull::<()>::dangling().as_ptr(), value.len()) };
for_each_carry(value, dummy, init, move |c, v, _| func(c, v))
}
pub fn for_each_carry<A, B: Copy, T>(
value: &mut [A],
other: &[B],
init: T,
mut func: impl FnMut(&mut T, &mut A, B),
) -> T {
assert_eq!(value.len(), other.len());
match value.len() {
0 => init,
1 => for_each_carry_const(arr!(1, mut value), other.try_into().unwrap(), init, func),
2 => for_each_carry_const(arr!(2, mut value), other.try_into().unwrap(), init, func),
3 => for_each_carry_const(arr!(3, mut value), other.try_into().unwrap(), init, func),
4 => for_each_carry_const(arr!(4, mut value), other.try_into().unwrap(), init, func),
5 => for_each_carry_const(arr!(5, mut value), other.try_into().unwrap(), init, func),
6 => for_each_carry_const(arr!(6, mut value), other.try_into().unwrap(), init, func),
7 => for_each_carry_const(arr!(7, mut value), other.try_into().unwrap(), init, func),
8 => for_each_carry_const(arr!(8, mut value), other.try_into().unwrap(), init, func),
9 => for_each_carry_const(arr!(9, mut value), other.try_into().unwrap(), init, func),
10 => for_each_carry_const(arr!(10, mut value), other.try_into().unwrap(), init, func),
_ => {
let mut carry = init;
for (a, b) in value.iter_mut().zip(other) {
func(&mut carry, a, *b);
}
carry
}
}
}
fn for_each_carry_const<A, B: Copy, T, const N: usize>(
value: &mut [A; N],
other: &[B; N],
init: T,
mut func: impl FnMut(&mut T, &mut A, B),
) -> T {
let mut carry = init;
for (a, b) in value.iter_mut().zip(other) {
func(&mut carry, a, *b);
}
carry
}
pub const unsafe fn copy_bits_nonoverlapping(
src: *const usize,
src_bit_offset: usize,
dst: *mut usize,
dst_bit_offset: usize,
bit_count: usize,
) {
if bit_count == 0 {
return;
}
unsafe {
let src = src.add(word_index(src_bit_offset));
let src_bit_offset = bit_in_word_index(src_bit_offset);
let dst = dst.add(word_index(dst_bit_offset));
let dst_bit_offset = bit_in_word_index(dst_bit_offset);
if src_bit_offset == 0 && dst_bit_offset == 0 {
let whole_words_len = word_index(bit_count);
std::ptr::copy_nonoverlapping(src, dst, whole_words_len);
if bit_in_word_index(bit_count) != 0 {
let mask = last_word_mask(bit_count);
let last_word_masked = src.add(whole_words_len).read() & mask;
let dst_word_ptr = dst.add(whole_words_len);
let new_dst = (dst_word_ptr.read() & !mask) | last_word_masked;
dst_word_ptr.write(new_dst);
}
return;
} else if bit_in_byte_index(src_bit_offset) == 0 && bit_in_byte_index(dst_bit_offset) == 0 {
let src = src.cast::<u8>().add(byte_index(src_bit_offset));
let dst = dst.cast::<u8>().add(byte_index(dst_bit_offset));
let whole_bytes_len = byte_index(bit_count);
std::ptr::copy_nonoverlapping(src, dst, whole_bytes_len);
if bit_in_byte_index(bit_count) != 0 {
let mask = last_byte_mask(bit_count);
let last_byte_masked = src.add(whole_bytes_len).read() & mask;
let dst_byte_ptr = dst.add(whole_bytes_len);
let new_dst = (dst_byte_ptr.read() & !mask) | last_byte_masked;
dst_byte_ptr.write(new_dst);
}
return;
}
debug_assert!(src_bit_offset < WordBits::BITS as _);
debug_assert!(dst_bit_offset < WordBits::BITS as _);
let mut iter = BitsIter::new_unchecked(src, src_bit_offset, bit_count);
let mut src = src;
let mut dst = dst;
let mut bit_count = bit_count;
if dst_bit_offset != 0 {
let first_word = dst.read();
let mut dst_word = WordBits::new_unchecked(first_word, dst_bit_offset as _);
let remaining = WordBits::BITS as usize - dst_word.len();
let w = iter.next_bits(remaining);
dst_word.push_bits(w);
if !dst_word.is_full() {
debug_assert!(iter.is_empty());
dst.write(dst_word.raw() | (first_word & !last_word_mask(dst_word.len())));
return;
} else {
dst.write(dst_word.raw());
dst = dst.add(1);
}
}
loop {
let next_word = iter.next_unaligned_word();
if !next_word.is_full() {
if !next_word.is_empty() {
let masked_last_word = dst.read() & !last_word_mask(next_word.len());
dst.write(next_word.raw() | masked_last_word);
}
break;
}
dst.write(next_word.raw());
dst = dst.add(1);
}
}
}
#[inline]
const fn byte_index(bit_index: usize) -> usize {
const SHIFT: u32 = (8u32 - 1).count_ones();
bit_index.wrapping_shr(SHIFT)
}
#[inline]
const fn bit_in_byte_index(bit_index: usize) -> usize {
bit_index & (8 - 1)
}
#[inline]
const fn last_byte_mask(len: usize) -> u8 {
let shift = bit_in_byte_index(len);
if shift == 0 && len != 0 {
return u8::MAX;
}
(1u8.wrapping_shl(shift as _)) - 1
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use rand::prelude::*;
use crate::{BitList, WordBits};
use super::*;
#[test]
fn test_copy_bits_nonoverlapping_byte_aligned() {
const MEM: usize = 1024;
let src = &mut [0usize; MEM];
let dst = &mut [0usize; MEM];
let dst_snapshot = &mut [0usize; MEM];
let rng = &mut StdRng::seed_from_u64(654321);
unsafe {
let mut time = Duration::ZERO;
let mut total_bits = 0;
let mut counter = 0;
for i in 0..1000 {
let (src_off, dst_off) = loop {
let src_off = rng.random_range(0..500);
let dst_off = rng.random_range(0..500);
break (src_off, dst_off);
};
let src_ptr_off = rng.random_range(0..500);
let dst_ptr_off = rng.random_range(0..500);
let src_ptr = src.as_ptr().add(src_ptr_off);
let dst_ptr = dst.as_mut_ptr().add(dst_ptr_off);
let bit_len = rng.random_range(8000..1024 * 30);
rng.fill_bytes(src.align_to_mut::<u8>().1);
rng.fill_bytes(dst.align_to_mut::<u8>().1);
dst_snapshot.copy_from_slice(dst);
let start = Instant::now();
copy_bits_nonoverlapping(src_ptr, src_off, dst_ptr, dst_off, bit_len);
time += start.elapsed();
total_bits += bit_len;
counter += 1;
let src_region = BitsIter::new_unchecked(src_ptr, src_off, bit_len);
let dst_region = BitsIter::new_unchecked(dst_ptr, dst_off, bit_len);
let equal = src_region.eq(dst_region);
assert!(
equal,
"{i} Bits are not equal - src_ptr_off: {src_ptr_off}, dst_ptr_off: {dst_ptr_off}, src_off: {src_off}, dst_off: {dst_off}, bit_len: {bit_len}"
);
let dst_ptr_bit_off = dst_ptr_off * WordBits::BITS as usize;
let s1 = BitsIter::new_unchecked(dst.as_ptr(), 0, dst_ptr_bit_off + dst_off);
let d1 = BitsIter::new_unchecked(dst_snapshot.as_ptr(), 0, dst_ptr_bit_off + dst_off);
let soff = dst_ptr_bit_off + dst_off + bit_len;
let s2 = BitsIter::new_unchecked_range(dst.as_ptr(), soff..(MEM * WordBits::BITS as usize));
let d2 = BitsIter::new_unchecked_range(dst_snapshot.as_ptr(), soff..(MEM * WordBits::BITS as usize));
assert!(s1.eq(d1), "{i} s1 != d1");
assert!(s2.eq(d2), "{i} s2 != d2");
}
println!("Total time: {time:.03?}, total bits: {total_bits}, counter: {counter}");
let avg_bits = total_bits as f64 / counter as f64;
let words = avg_bits as f64 / 64.0;
let bytes = avg_bits as f64 / 8.0;
println!(
"Average time: {:.03?}, average bits: {avg_bits:.03} = {words:.03}W = {bytes:.03}B",
time / counter
);
}
}
}