use crate::store::BitStore;
use core::{
marker::PhantomData,
ops::Deref,
};
#[cfg(feature = "serde")]
use core::convert::TryFrom;
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitIdx<T>
where T: BitStore {
idx: u8,
_ty: PhantomData<T>,
}
impl<T> BitIdx<T>
where T: BitStore {
pub fn new(idx: u8) -> Option<Self> {
if idx >= T::BITS {
return None;
}
Some(unsafe { Self::new_unchecked(idx) })
}
#[doc(hidden)]
#[inline]
pub unsafe fn new_unchecked(idx: u8) -> Self {
debug_assert!(
idx < T::BITS,
"Bit index {} cannot exceed type width {}",
idx,
T::BITS,
);
Self { idx, _ty: PhantomData }
}
pub(crate) fn offset(self, by: isize) -> (isize, Self) {
let val = *self;
let (far, ovf) = by.overflowing_add(val as isize);
if !ovf {
if (0 .. T::BITS as isize).contains(&far) {
(0, (far as u8).idx())
}
else {
(far >> T::INDX, (far as u8 & T::MASK).idx())
}
}
else {
let far = far as usize;
((far >> T::INDX) as isize, (far as u8 & T::MASK).idx())
}
}
#[inline]
pub(crate) fn span(self, len: usize) -> (usize, BitTail<T>) {
unsafe { BitTail::new_unchecked(*self) }.span(len)
}
}
impl<T> Deref for BitIdx<T>
where T: BitStore {
type Target = u8;
fn deref(&self) -> &Self::Target {
&self.idx
}
}
#[cfg(feature = "serde")]
impl<T> TryFrom<u8> for BitIdx<T>
where T: BitStore {
type Error = &'static str;
fn try_from(idx: u8) -> Result<Self, Self::Error> {
if idx < T::BITS {
Ok(Self { idx, _ty: PhantomData })
}
else {
Err("Attempted to construct a `BitIdx` with an index out of range")
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub(crate) struct BitTail<T>
where T: BitStore {
end: u8,
_ty: PhantomData<T>,
}
impl<T> BitTail<T>
where T: BitStore {
pub(crate) unsafe fn new_unchecked(end: u8) -> Self {
debug_assert!(
end <= T::BITS,
"Bit tail {} cannot surpass type width {}",
end,
T::BITS,
);
Self { end, _ty: PhantomData }
}
pub(crate) fn span(self, len: usize) -> (usize, Self) {
let val = *self;
debug_assert!(
val <= T::BITS,
"Tail out of range: {} overflows type width {}",
val,
T::BITS,
);
if len == 0 {
return (0, self);
}
let head = val & T::MASK;
let bits_in_head = (T::BITS - head) as usize;
if len <= bits_in_head {
return (1, (head + len as u8).tail());
}
let bits_after_head = len - bits_in_head;
let elts = bits_after_head >> T::INDX;
let tail = bits_after_head as u8 & T::MASK;
let is_zero = (tail == 0) as u8;
let edges = 2 - is_zero as usize;
(elts + edges, ((is_zero << T::INDX) | tail).tail())
}
}
impl<T> Deref for BitTail<T>
where T: BitStore {
type Target = u8;
fn deref(&self) -> &Self::Target {
&self.end
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitPos<T>
where T: BitStore {
pos: u8,
_ty: PhantomData<T>,
}
impl<T> BitPos<T>
where T: BitStore {
#[inline]
pub fn new(pos: u8) -> Self {
assert!(
pos < T::BITS,
"Bit position {} cannot exceed type width {}",
pos,
T::BITS,
);
Self { pos, _ty: PhantomData }
}
#[cfg_attr(debug_assertions, inline)]
#[cfg_attr(not(debug_assertions), inline(always))]
pub unsafe fn new_unchecked(pos: u8) -> Self {
debug_assert!(
pos < T::BITS,
"Bit position {} cannot exceed type width {}",
pos,
T::BITS,
);
Self { pos, _ty: PhantomData }
}
}
impl<T> Deref for BitPos<T>
where T: BitStore {
type Target = u8;
fn deref(&self) -> &Self::Target {
&self.pos
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitMask<T>
where T: BitStore {
mask: T,
}
impl<T> BitMask<T>
where T: BitStore {
#[inline]
pub fn new(mask: T) -> Self {
assert!(
mask.count_ones() == 1,
"Masks are required to have exactly one set bit: {:0>1$b}",
mask,
T::BITS as usize,
);
Self { mask }
}
#[cfg_attr(debug_assertions, inline)]
#[cfg_attr(not(debug_assertions), inline(always))]
pub unsafe fn new_unchecked(mask: T) -> Self {
debug_assert!(
mask.count_ones() == 1,
"Masks are required to have exactly one set bit: {:0>1$b}",
mask,
T::BITS as usize,
);
Self { mask }
}
}
impl<T> Deref for BitMask<T>
where T: BitStore {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.mask
}
}
pub(crate) trait Indexable {
fn idx<T>(self) -> BitIdx<T>
where T: BitStore;
fn tail<T>(self) -> BitTail<T>
where T: BitStore;
fn pos<T>(self) -> BitPos<T>
where T: BitStore;
}
impl Indexable for u8 {
fn idx<T>(self) -> BitIdx<T>
where T: BitStore {
unsafe { BitIdx::<T>::new_unchecked(self) }
}
fn tail<T>(self) -> BitTail<T>
where T: BitStore {
unsafe { BitTail::<T>::new_unchecked(self) }
}
fn pos<T>(self) -> BitPos<T>
where T: BitStore {
unsafe { BitPos::<T>::new_unchecked(self) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jump_far_up() {
for n in 1 .. 8 {
let (elt, bit) = n.idx::<u8>().offset(isize::max_value());
assert_eq!(elt, (isize::max_value() >> u8::INDX) + 1);
assert_eq!(*bit, n - 1);
}
let (elt, bit) = 0u8.idx::<u8>().offset(isize::max_value());
assert_eq!(elt, isize::max_value() >> u8::INDX);
assert_eq!(*bit, 7);
}
#[test]
fn jump_far_down() {
for n in 0 .. 8 {
let (elt, bit) = n.idx::<u8>().offset(isize::min_value());
assert_eq!(elt, isize::min_value() >> u8::INDX);
assert_eq!(*bit, n);
}
}
}