use crate::{
bit::{Bit, BitMut, BitRef},
index::Index,
safety_markers::{Combines, SizeMarker, Smaller, Splits},
};
pub(crate) const fn bit_len<T>() -> usize
where
T: Bitset,
{
T::BYTE_SIZE * 8
}
pub trait Bitset: Sized + Clone + PartialEq + Eq {
type Repr: Sized + Clone + PartialEq + Eq;
type Size: SizeMarker;
const BYTE_SIZE: usize;
const NONE: Self;
const ALL: Self;
fn from_repr(repr: Self::Repr) -> Self;
#[inline(always)]
fn build(&mut self) -> Self {
self.clone()
}
fn from_index(index: &Index<Self>) -> Self;
fn expand<Res>(self) -> Res
where
Res: Bitset,
Self::Size: Smaller<Res::Size>,
{
let result = self
.ones()
.map(|Index(i, ..)| Index::<Res>::from_usize(i))
.fold(&mut Res::NONE.clone(), |acc, i| acc.set(i))
.build();
result
}
fn expand_optimized<Res>(self) -> Res
where
Self: LeftAligned,
Res: Bitset + LeftAligned,
Self::Size: Smaller<Res::Size>,
{
let mut result = Res::NONE.clone();
unsafe {
std::ptr::copy_nonoverlapping(
&self as *const _ as *const u8,
&mut result as *mut _ as *mut u8,
Self::BYTE_SIZE,
);
}
result
}
fn from_iterable<'a, I>(iterable: I) -> Self
where
Self: 'a,
I: IntoIterator<Item = &'a Bit>,
{
iterable
.into_iter()
.take(bit_len::<Self>())
.enumerate()
.filter(|(_, &b)| bool::from(b))
.fold(&mut Self::NONE.clone(), |acc, (i, _)| {
acc.set(Index::<_>::from_usize(i))
})
.build()
}
#[inline(always)]
fn count_ones(&self) -> usize {
self.ones().count()
}
#[inline(always)]
fn count_zeros(&self) -> usize {
self.zeros().count()
}
#[inline(always)]
fn replace(&mut self, index: Index<Self>, value: Bit) -> &mut Self {
if bool::from(value) {
self.set(index);
} else {
self.unset(index);
}
self
}
fn set(&mut self, index: Index<Self>) -> &mut Self;
fn unset(&mut self, index: Index<Self>) -> &mut Self;
fn flip(&mut self, index: Index<Self>) -> &mut Self;
fn include(&mut self, other: Self) -> &mut Self;
fn exclude(&mut self, other: Self) -> &mut Self;
fn bit(&self, index: Index<Self>) -> Bit;
fn bit_ref(&self, index: Index<Self>) -> BitRef<'_, Self>;
fn bit_mut(&mut self, index: Index<Self>) -> BitMut<'_, Self>;
fn complement(self) -> Self;
fn union(self, other: Self) -> Self;
fn intersection(self, other: Self) -> Self;
#[inline(always)]
fn difference(self, other: Self) -> Self {
self.intersection(other.complement())
}
fn sym_difference(self, other: Self) -> Self;
#[inline(always)]
fn includes(&self, other: &Self) -> bool {
for i in other.ones() {
if !bool::from(self.bit(i)) {
return false;
}
}
true
}
#[inline(always)]
fn intersects(&self, other: &Self) -> bool {
for i in other.ones() {
if bool::from(self.bit(i)) {
return true;
}
}
false
}
fn combine<Other, Res>(self, other: Other) -> Res
where
Other: Bitset,
Res: Bitset,
Self::Size: Combines<Other::Size, Res::Size> + Smaller<Res::Size>,
Other::Size: Smaller<Res::Size>,
{
let mut result = self
.ones()
.map(|Index(i, ..)| Index::<Res>::from_usize(i))
.fold(&mut Res::NONE.clone(), |acc, i| acc.set(i))
.build();
let result = other
.ones()
.map(|Index(i, ..)| Index::<Res>::from_usize(i + bit_len::<Self>()))
.fold(&mut result, |acc, i| acc.set(i))
.build();
result
}
fn combine_optimized<Other, Res>(self, other: Other) -> Res
where
Self: LeftAligned,
Other: Bitset + LeftAligned,
Res: Bitset + LeftAligned,
Self::Size: Combines<Other::Size, Res::Size> + Smaller<Res::Size>,
Other::Size: Smaller<Res::Size>,
{
let mut result = Res::NONE.clone();
unsafe {
std::ptr::copy_nonoverlapping(
&self as *const _ as *const u8,
&mut result as *mut _ as *mut u8,
Self::BYTE_SIZE,
);
std::ptr::copy_nonoverlapping(
&other as *const _ as *const u8,
(&mut result as *mut _ as *mut u8).add(Self::BYTE_SIZE),
Other::BYTE_SIZE,
);
}
result
}
fn split<Res1, Res2>(self) -> (Res1, Res2)
where
Res1: Bitset,
Res2: Bitset,
Self::Size: Splits<Res1::Size, Res2::Size>,
Res1::Size: Smaller<Self::Size>,
Res2::Size: Smaller<Self::Size>,
{
let result1 = self
.bits_ref()
.take(bit_len::<Res1>())
.enumerate()
.map(|(i, bit)| (Index::<Res1>::from_usize(i), bit))
.fold(&mut Res1::NONE.clone(), |acc, (i, bit)| {
acc.replace(i, *bit)
})
.build();
let result2 = self
.bits_ref()
.skip(bit_len::<Res1>())
.enumerate()
.map(|(i, bit)| (Index::<Res2>::from_usize(i), bit))
.fold(&mut Res2::NONE.clone(), |acc, (i, bit)| {
acc.replace(i, *bit)
})
.build();
(result1, result2)
}
fn split_optimized<Res1, Res2>(self) -> (Res1, Res2)
where
Self: LeftAligned,
Res1: Bitset + LeftAligned,
Res2: Bitset + LeftAligned,
Self::Size: Splits<Res1::Size, Res2::Size>,
Res1::Size: Smaller<Self::Size>,
Res2::Size: Smaller<Self::Size>,
{
let mut result1 = Res1::NONE.clone();
let mut result2 = Res2::NONE.clone();
unsafe {
std::ptr::copy_nonoverlapping(
&self as *const _ as *const u8,
&mut result1 as *mut _ as *mut u8,
Res1::BYTE_SIZE,
);
std::ptr::copy_nonoverlapping(
(&self as *const _ as *const u8).add(Res1::BYTE_SIZE),
&mut result2 as *mut _ as *mut u8,
Res2::BYTE_SIZE,
);
}
(result1, result2)
}
#[inline(always)]
fn bits(self) -> impl Iterator<Item = Bit> + DoubleEndedIterator {
(0..bit_len::<Self>())
.map(|i| Index::<Self>::from_usize(i))
.map(move |i| self.bit(i))
}
#[inline(always)]
fn bits_ref(&self) -> impl Iterator<Item = BitRef<'_, Self>> + DoubleEndedIterator {
(0..bit_len::<Self>())
.map(|i| Index::<Self>::from_usize(i))
.map(|i| self.bit_ref(i))
}
#[inline(always)]
fn bits_mut(&mut self) -> impl Iterator<Item = BitMut<'_, Self>> + DoubleEndedIterator {
let p = self as *mut Self;
(0..bit_len::<Self>())
.map(|i| Index::<Self>::from_usize(i))
.map(move |i| unsafe { p.as_mut().unwrap().bit_mut(i) })
}
#[inline(always)]
fn ones(&self) -> impl Iterator<Item = Index<Self>> + DoubleEndedIterator {
self.bits_ref().filter_map(|bit| {
if bool::from(*bit) {
Some(BitRef::index(&bit))
} else {
None
}
})
}
#[inline(always)]
fn zeros(&self) -> impl Iterator<Item = Index<Self>> + DoubleEndedIterator {
self.bits_ref().filter_map(|bit| {
if bool::from(!*bit) {
Some(BitRef::index(&bit))
} else {
None
}
})
}
}
pub unsafe trait LeftAligned: Bitset + Sized + Clone + PartialEq + Eq {
#[doc(hidden)]
type _Repr: Sized + Clone + PartialEq + Eq;
#[doc(hidden)]
type _Size: SizeMarker;
#[doc(hidden)]
const _BYTE_SIZE: usize;
#[doc(hidden)]
const _NONE: Self;
#[doc(hidden)]
const _ALL: Self;
#[doc(hidden)]
fn _from_repr(value: Self::Repr) -> Self;
fn shift_left(mut self, amount: Index<Self>) -> Self {
let byte_shift = amount.byte_index();
let bit_shift = amount.bit_index();
let ptr = &mut self as *mut _ as *mut u8;
if byte_shift > 0 {
unsafe {
std::ptr::copy(ptr.add(byte_shift), ptr, Self::BYTE_SIZE - byte_shift);
std::ptr::write_bytes(ptr.add(Self::BYTE_SIZE - byte_shift), 0, byte_shift);
}
}
if bit_shift > 0 {
let bytes: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(ptr, Self::BYTE_SIZE) };
let mut carry = 0;
for byte in bytes.iter_mut().rev() {
let shifted = *byte << bit_shift | carry;
carry = *byte >> (8 - bit_shift);
*byte = shifted;
}
}
self
}
fn shift_right(mut self, amount: Index<Self>) -> Self {
let byte_shift = amount.byte_index();
let bit_shift = amount.bit_index();
let ptr = &mut self as *mut _ as *mut u8;
if byte_shift > 0 {
unsafe {
std::ptr::copy(ptr, ptr.add(byte_shift), Self::BYTE_SIZE - byte_shift);
std::ptr::write_bytes(ptr, 0, byte_shift);
}
}
if bit_shift > 0 {
let bytes: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(ptr, Self::BYTE_SIZE) };
let mut carry = 0;
for byte in bytes.iter_mut() {
let shifted = *byte >> bit_shift | carry;
carry = *byte << (8 - bit_shift);
*byte = shifted;
}
}
self
}
}
impl<T> Bitset for T
where
T: LeftAligned + Sized + Clone + PartialEq + Eq,
{
type Repr = <Self as LeftAligned>::_Repr;
type Size = <Self as LeftAligned>::_Size;
const BYTE_SIZE: usize = Self::_BYTE_SIZE;
const NONE: Self = Self::_NONE;
const ALL: Self = Self::_ALL;
#[inline(always)]
fn from_repr(value: Self::Repr) -> Self {
Self::_from_repr(value)
}
#[inline(always)]
fn from_index(index: &Index<Self>) -> Self {
Self::NONE.clone().set(*index).clone()
}
#[inline(always)]
fn count_ones(&self) -> usize {
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(self as *const _ as *const u8, Self::BYTE_SIZE) };
bytes.iter().fold(0, |acc, &byte| acc + byte.count_ones()) as usize
}
#[inline(always)]
fn count_zeros(&self) -> usize {
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(self as *const _ as *const u8, Self::BYTE_SIZE) };
bytes.iter().fold(0, |acc, &byte| acc + byte.count_zeros()) as usize
}
#[inline(always)]
fn set(&mut self, index: Index<Self>) -> &mut Self {
let self_ptr = self as *mut _ as *mut u8;
unsafe {
let byte = self_ptr.add(index.byte_index());
*byte |= index.bitmask();
}
self
}
#[inline(always)]
fn unset(&mut self, index: Index<Self>) -> &mut Self {
let self_ptr = self as *mut _ as *mut u8;
unsafe {
let byte = self_ptr.add(index.byte_index());
*byte &= !index.bitmask();
}
self
}
#[inline(always)]
fn flip(&mut self, index: Index<Self>) -> &mut Self {
let self_ptr = self as *mut _ as *mut u8;
unsafe {
let byte = self_ptr.add(index.byte_index());
*byte ^= index.bitmask();
}
self
}
#[inline(always)]
fn include(&mut self, other: Self) -> &mut Self {
let self_bytes: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(self as *mut _ as *mut u8, Self::BYTE_SIZE) };
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(&other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
self_bytes[i] |= other_bytes[i];
}
self
}
#[inline(always)]
fn exclude(&mut self, other: Self) -> &mut Self {
let self_bytes: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(self as *mut _ as *mut u8, Self::BYTE_SIZE) };
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(&other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
self_bytes[i] &= !other_bytes[i];
}
self
}
#[inline(always)]
fn bit(&self, index: Index<Self>) -> Bit {
let self_ptr = self as *const _ as *const u8;
let byte = unsafe { *self_ptr.add(index.byte_index()) };
Bit::from(byte & index.bitmask() != 0)
}
#[inline(always)]
fn bit_ref(&self, index: Index<Self>) -> BitRef<'_, Self> {
let self_ptr = self as *const _ as *const u8;
let byte = unsafe { *self_ptr.add(index.byte_index()) };
BitRef(Bit::from(byte & index.bitmask() != 0), index, self)
}
#[inline(always)]
fn bit_mut(&mut self, index: Index<Self>) -> BitMut<'_, Self> {
let self_ptr = self as *mut _ as *const u8;
let byte = unsafe { *self_ptr.add(index.byte_index()) };
BitMut(Bit::from(byte & index.bitmask() != 0), index, self)
}
#[inline]
fn complement(mut self) -> Self {
let bytes: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(&mut self as *mut _ as *mut u8, Self::BYTE_SIZE)
};
for byte in bytes.iter_mut() {
*byte = !*byte;
}
self
}
#[inline]
fn union(mut self, other: Self) -> Self {
let self_bytes: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(&mut self as *mut _ as *mut u8, Self::BYTE_SIZE)
};
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(&other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
self_bytes[i] |= other_bytes[i];
}
self
}
#[inline]
fn intersection(mut self, other: Self) -> Self {
let self_bytes: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(&mut self as *mut _ as *mut u8, Self::BYTE_SIZE)
};
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(&other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
self_bytes[i] &= other_bytes[i];
}
self
}
#[inline]
fn difference(mut self, other: Self) -> Self {
let self_bytes: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(&mut self as *mut _ as *mut u8, Self::BYTE_SIZE)
};
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(&other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
self_bytes[i] &= !other_bytes[i];
}
self
}
#[inline]
fn sym_difference(mut self, other: Self) -> Self {
let self_bytes: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(&mut self as *mut _ as *mut u8, Self::BYTE_SIZE)
};
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(&other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
self_bytes[i] ^= other_bytes[i];
}
self
}
#[inline]
fn includes(&self, other: &Self) -> bool {
let self_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(self as *const _ as *const u8, Self::BYTE_SIZE) };
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
if self_bytes[i] & other_bytes[i] != other_bytes[i] {
return false;
}
}
true
}
#[inline]
fn intersects(&self, other: &Self) -> bool {
let self_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(self as *const _ as *const u8, Self::BYTE_SIZE) };
let other_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(other as *const _ as *const u8, Self::BYTE_SIZE) };
for i in 0..Self::BYTE_SIZE {
if self_bytes[i] & other_bytes[i] != 0 {
return true;
}
}
false
}
}