use crate::{
mem::BitMemory,
order::BitOrder,
store::BitStore,
};
use core::{
any::type_name,
fmt::{
self,
Binary,
Debug,
Display,
Formatter,
},
iter::{
FusedIterator,
Sum,
},
marker::PhantomData,
ops::{
BitAnd,
BitOr,
Not,
},
};
use radium::marker::BitOps;
#[cfg(feature = "serde")]
use core::convert::TryFrom;
macro_rules! make {
(idx $e:expr) => {
BitIdx {
idx: $e,
_ty: PhantomData,
}
};
(tail $e:expr) => {
BitTail {
end: $e,
_ty: PhantomData,
}
};
(pos $e:expr) => {
BitPos {
pos: $e,
_ty: PhantomData,
}
};
(sel $e:expr) => {
BitSel { sel: $e }
};
(mask $e:expr) => {
BitMask { mask: $e }
};
}
pub trait BitRegister: BitMemory + BitOps + BitStore {}
macro_rules! register {
($($t:ty),+ $(,)?) => { $(
impl BitRegister for $t {
}
)* };
}
register!(u8, u16, u32);
#[cfg(target_pointer_width = "64")]
impl BitRegister for u64 {
}
register!(usize);
#[repr(transparent)]
#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitIdx<R>
where R: BitRegister
{
idx: u8,
_ty: PhantomData<R>,
}
impl<R> BitIdx<R>
where R: BitRegister
{
pub(crate) const LAST: Self = make!(idx R::MASK);
pub(crate) const ZERO: Self = make!(idx 0);
#[inline]
#[doc(hidden)]
pub(crate) fn new(idx: u8) -> Option<Self> {
if idx >= R::BITS {
return None;
}
Some(make!(idx idx))
}
#[inline]
#[doc(hidden)]
pub unsafe fn new_unchecked(idx: u8) -> Self {
debug_assert!(
idx < R::BITS,
"Bit index {} cannot exceed type width {}",
idx,
R::BITS
);
make!(idx idx)
}
#[inline]
pub(crate) fn incr(self) -> (Self, bool) {
let next = self.idx + 1;
(make!(idx next & R::MASK), next == R::BITS)
}
#[inline]
pub(crate) fn decr(self) -> (Self, bool) {
let next = self.idx.wrapping_sub(1);
(make!(idx next & R::MASK), self.idx == 0)
}
#[inline(always)]
pub fn position<O>(self) -> BitPos<R>
where O: BitOrder {
O::at::<R>(self)
}
#[inline(always)]
pub fn select<O>(self) -> BitSel<R>
where O: BitOrder {
O::select::<R>(self)
}
#[inline]
pub fn mask<O>(self) -> BitMask<R>
where O: BitOrder {
self.select::<O>().mask()
}
#[inline(always)]
#[cfg(not(tarpaulin_include))]
pub fn value(self) -> u8 {
self.idx
}
#[inline]
pub(crate) fn range_all() -> impl Iterator<Item = Self>
+ DoubleEndedIterator
+ ExactSizeIterator
+ FusedIterator {
(Self::ZERO.idx ..= Self::LAST.idx).map(|val| make!(idx val))
}
#[inline]
pub fn range(
from: Self,
upto: BitTail<R>,
) -> impl Iterator<Item = Self>
+ DoubleEndedIterator
+ ExactSizeIterator
+ FusedIterator
{
debug_assert!(
from.value() <= upto.value(),
"Ranges must run from low to high"
);
(from.value() .. upto.value()).map(|val| make!(idx val))
}
#[inline]
pub(crate) fn offset(self, by: isize) -> (isize, Self) {
let val = self.value();
let (far, ovf) = by.overflowing_add(val as isize);
if !ovf {
if (0 .. R::BITS as isize).contains(&far) {
(0, make!(idx far as u8))
}
else {
(far >> R::INDX, make!(idx far as u8 & R::MASK))
}
}
else {
let far = far as usize;
((far >> R::INDX) as isize, make!(idx far as u8 & R::MASK))
}
}
#[inline]
pub(crate) fn span(self, len: usize) -> (usize, BitTail<R>) {
make!(tail self.value()).span(len)
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Binary for BitIdx<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "{:0>1$b}", self.idx, R::INDX as usize)
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Debug for BitIdx<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "BitIdx<{}>(", type_name::<R>())?;
Display::fmt(&self.idx, fmt)?;
fmt.write_str(")")
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Display for BitIdx<R>
where R: BitRegister
{
#[inline(always)]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
Display::fmt(&self.idx, fmt)
}
}
#[repr(transparent)]
#[cfg(feature = "serde")]
pub struct BitIdxErr<R>
where R: BitRegister
{
err: u8,
_ty: PhantomData<R>,
}
#[cfg(feature = "serde")]
impl<R> TryFrom<u8> for BitIdx<R>
where R: BitRegister
{
type Error = BitIdxErr<R>;
#[inline]
fn try_from(idx: u8) -> Result<Self, Self::Error> {
Self::new(idx).ok_or(BitIdxErr {
err: idx,
_ty: PhantomData,
})
}
}
#[cfg(feature = "serde")]
#[cfg(not(tarpaulin_include))]
impl<R> Debug for BitIdxErr<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "BitIdxErr<{}>(", type_name::<R>())?;
Display::fmt(&self.err, fmt)?;
fmt.write_str(")")
}
}
#[cfg(feature = "serde")]
#[cfg(not(tarpaulin_include))]
impl<R> Display for BitIdxErr<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(
fmt,
"The value {} is too large to index into {}",
self.err,
core::any::type_name::<R>()
)
}
}
#[cfg(all(feature = "serde", feature = "std"))]
#[cfg(not(tarpaulin_include))]
impl<R> std::error::Error for BitIdxErr<R> where R: BitRegister
{
}
#[repr(transparent)]
#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitTail<R>
where R: BitRegister
{
end: u8,
_ty: PhantomData<R>,
}
impl<R> BitTail<R>
where R: BitRegister
{
pub(crate) const END: Self = make!(tail R::BITS);
pub(crate) const ZERO: Self = make!(tail 0);
#[inline]
pub(crate) unsafe fn new_unchecked(end: u8) -> Self {
debug_assert!(
end <= R::BITS,
"Bit tail {} cannot exceed type width {}",
end,
R::BITS
);
make!(tail end)
}
#[inline]
#[cfg(not(tarpaulin_include))]
pub fn value(self) -> u8 {
self.end
}
#[inline]
#[cfg(test)]
pub(crate) fn range_from(
start: BitIdx<R>,
) -> impl Iterator<Item = Self>
+ DoubleEndedIterator
+ ExactSizeIterator
+ FusedIterator {
(start.idx ..= Self::END.end).map(|val| make!(tail val))
}
pub(crate) fn span(self, len: usize) -> (usize, Self) {
if len == 0 {
return (0, self);
}
let val = self.end;
let head = val & R::MASK;
let bits_in_head = (R::BITS - head) as usize;
if len <= bits_in_head {
return (1, make!(tail head + len as u8));
}
let bits_after_head = len - bits_in_head;
let elts = bits_after_head >> R::INDX;
let tail = bits_after_head as u8 & R::MASK;
let is_zero = (tail == 0) as u8;
let edges = 2 - is_zero as usize;
(elts + edges, make!(tail(is_zero << R::INDX) | tail))
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Debug for BitTail<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "BitTail<{}>(", type_name::<R>())?;
Display::fmt(&self.end, fmt)?;
fmt.write_str(")")
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Display for BitTail<R>
where R: BitRegister
{
#[inline(always)]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
Display::fmt(&self.end, fmt)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitPos<R>
where R: BitRegister
{
pos: u8,
_ty: PhantomData<R>,
}
impl<R> BitPos<R>
where R: BitRegister
{
#[inline]
pub unsafe fn new(pos: u8) -> Option<Self> {
if pos >= R::BITS {
return None;
}
Some(make!(pos pos))
}
#[inline]
pub unsafe fn new_unchecked(pos: u8) -> Self {
debug_assert!(
pos < R::BITS,
"Bit position {} cannot exceed type width {}",
pos,
R::BITS
);
make!(pos pos)
}
#[inline]
pub fn select(self) -> BitSel<R> {
make!(sel R::ONE << self.pos)
}
#[inline]
pub fn mask(self) -> BitMask<R> {
make!(mask self.select().sel)
}
#[inline]
pub fn value(self) -> u8 {
self.pos
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Debug for BitPos<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "BitPos<{}>(", type_name::<R>())?;
Display::fmt(&self.pos, fmt)?;
fmt.write_str(")")
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Display for BitPos<R>
where R: BitRegister
{
#[inline(always)]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
Display::fmt(&self.pos, fmt)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitSel<R>
where R: BitRegister
{
sel: R,
}
impl<R> BitSel<R>
where R: BitRegister
{
#[inline]
pub unsafe fn new(sel: R) -> Option<Self> {
if sel.count_ones() != 1 {
return None;
}
Some(make!(sel sel))
}
#[inline]
pub unsafe fn new_unchecked(sel: R) -> Self {
debug_assert!(
sel.count_ones() == 1,
"Selections are required to have exactly one set bit: {:0>1$b}",
sel,
R::BITS as usize
);
make!(sel sel)
}
#[inline]
pub fn mask(self) -> BitMask<R>
where R: BitRegister {
make!(mask self.sel)
}
#[inline]
pub fn value(self) -> R {
self.sel
}
pub fn range_all() -> impl Iterator<Item = Self>
+ DoubleEndedIterator
+ ExactSizeIterator
+ FusedIterator {
BitIdx::<R>::range_all().map(|i| make!(pos i.idx).select())
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Binary for BitSel<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "{:0>1$b}", self.sel, R::BITS as usize)
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Debug for BitSel<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "BitSel<{}>(", type_name::<R>())?;
Binary::fmt(&self, fmt)?;
fmt.write_str(")")
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Display for BitSel<R>
where R: BitRegister
{
#[inline(always)]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
Display::fmt(&self.sel, fmt)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BitMask<R>
where R: BitRegister
{
mask: R,
}
impl<R> BitMask<R>
where R: BitRegister
{
pub const ALL: Self = make!(mask R::ALL);
pub const ZERO: Self = make!(mask R::ZERO);
#[inline]
pub unsafe fn new(mask: R) -> Self {
make!(mask mask)
}
#[inline]
pub fn combine(mut self, sel: BitSel<R>) -> Self {
self.insert(sel);
self
}
#[inline]
pub fn insert(&mut self, sel: BitSel<R>) {
self.mask |= sel.sel;
}
#[inline]
pub fn test(self, sel: BitSel<R>) -> bool {
self.mask & sel.sel != R::ZERO
}
#[inline]
pub fn value(self) -> R {
self.mask
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Binary for BitMask<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "{:0>1$b}", self.mask, R::BITS as usize)
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Debug for BitMask<R>
where R: BitRegister
{
#[inline]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "BitMask<{}>(", type_name::<R>())?;
Binary::fmt(&self, fmt)?;
fmt.write_str(")")
}
}
#[cfg(not(tarpaulin_include))]
impl<R> Display for BitMask<R>
where R: BitRegister
{
#[inline(always)]
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
Display::fmt(&self.mask, fmt)
}
}
impl<R> Sum<BitSel<R>> for BitMask<R>
where R: BitRegister
{
fn sum<I>(iter: I) -> Self
where I: Iterator<Item = BitSel<R>> {
iter.fold(Self::ZERO, Self::combine)
}
}
impl<R> BitAnd<R> for BitMask<R>
where R: BitRegister
{
type Output = Self;
fn bitand(self, rhs: R) -> Self {
make!(mask self.mask & rhs)
}
}
impl<R> BitOr<R> for BitMask<R>
where R: BitRegister
{
type Output = Self;
fn bitor(self, rhs: R) -> Self {
make!(mask self.mask | rhs)
}
}
impl<R> Not for BitMask<R>
where R: BitRegister
{
type Output = Self;
fn not(self) -> Self::Output {
make!(mask !self.mask)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::order::{
Lsb0,
Msb0,
};
#[test]
fn index_fns() {
assert!(BitIdx::<u8>::new(8).is_none());
for n in 0 .. 8 {
assert_eq!(
BitIdx::<u8>::new(n).unwrap().position::<Lsb0>().value(),
n
);
}
for n in 0 .. 8 {
assert_eq!(
BitIdx::<u8>::new(n).unwrap().position::<Msb0>().value(),
7 - n
);
}
for n in 0 .. 8 {
assert_eq!(
BitIdx::<u8>::new(n).unwrap().mask::<Lsb0>().value(),
1 << n
);
}
for n in 0 .. 8 {
assert_eq!(
BitIdx::<u8>::new(n).unwrap().mask::<Msb0>().value(),
128 >> n
);
}
for n in 0 .. 8 {
assert_eq!(BitIdx::<u8>::new(n).unwrap().value(), n);
}
}
#[test]
fn tail_fns() {
for n in 0 .. 8 {
let tail: BitTail<u8> = make!(tail n);
assert_eq!(tail.value(), n);
}
}
#[test]
fn position_fns() {
assert!(unsafe { BitPos::<u8>::new(8) }.is_none());
for n in 0 .. 8 {
let pos: BitPos<u8> = make!(pos n);
let mask: BitMask<u8> = make!(mask 1 << n);
assert_eq!(pos.mask(), mask);
}
}
#[test]
fn select_fns() {
assert!(unsafe { BitSel::<u8>::new(1) }.is_some());
assert!(unsafe { BitSel::<u8>::new(3) }.is_none());
for (n, sel) in BitSel::<u8>::range_all().enumerate() {
assert_eq!(sel, make!(sel(1 << n) as u8));
}
}
#[test]
fn fold_masks() {
assert_eq!(
BitSel::<u8>::range_all()
.map(BitSel::mask)
.fold(BitMask::<u8>::ZERO, |accum, mask| accum | mask.value()),
BitMask::<u8>::ALL
);
assert_eq!(!BitMask::<u8>::ALL, BitMask::ZERO);
}
#[test]
fn offset() {
let (elts, idx) =
BitIdx::<u32>::new(31).unwrap().offset(isize::max_value());
assert_eq!(elts, (isize::max_value() >> 5) + 1);
assert_eq!(idx, BitIdx::new(30).unwrap());
}
#[test]
fn span() {
let start: BitTail<u8> = make!(tail 4);
assert_eq!(start.span(0), (0, start));
assert_eq!(start.span(4), (1, make!(tail 8)));
assert_eq!(start.span(8), (2, start));
}
#[test]
fn walk() {
let end: BitIdx<u8> = make!(idx 7);
assert_eq!(end.incr(), (make!(idx 0), true));
assert_eq!(end.decr(), (make!(idx 6), false));
}
}