use crate::CompactBitset;
use num_traits::{Num, NumCast, PrimInt, ToPrimitive, Unsigned};
use std::{fmt::Debug, mem::size_of, ops::Deref};
#[inline(always)]
fn numcast<T: NumCast>(num: impl Num + ToPrimitive) -> T {
T::from(num).unwrap()
}
#[derive(Clone)]
pub struct BitRef<'parent, T: PrimInt + Unsigned> {
pub(crate) parent: &'parent CompactBitset<T>,
pub(crate) idx: usize,
}
impl<'parent, T: PrimInt + Unsigned> BitRef<'parent, T> {
pub const BITS: usize = size_of::<T>() * 8;
}
impl<'parent, T: PrimInt + Unsigned> Debug for BitRef<'parent, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", **self)
}
}
impl<'parent, T: PrimInt + Unsigned> Deref for BitRef<'parent, T> {
type Target = bool;
fn deref(&self) -> &Self::Target {
if (self.parent.data >> self.idx & numcast(1)) != numcast(0) {
&true
} else {
&false
}
}
}
impl<'parent, T: PrimInt + Unsigned> Into<bool> for BitRef<'parent, T> {
fn into(self) -> bool {
*self
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq for BitRef<'parent, T> {
fn eq(&self, other: &Self) -> bool {
self.deref() == other.deref()
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq<bool> for BitRef<'parent, T> {
fn eq(&self, other: &bool) -> bool {
self.deref() == other
}
}
pub struct BitRefMut<'parent, T: PrimInt + Unsigned> {
pub(crate) parent: &'parent mut CompactBitset<T>,
pub(crate) idx: usize,
}
impl<'parent, T: PrimInt + Unsigned> BitRefMut<'parent, T> {
pub const BITS: usize = size_of::<T>() * 8;
pub fn set(&mut self, value: impl Into<bool>) {
let value = value.into();
let mask: T = numcast(1 << self.idx);
self.parent.data = (self.parent.data & !mask) | (numcast((value as usize) << self.idx));
}
fn as_const(&'parent self) -> BitRef<'parent, T> {
BitRef::<'parent, T> {
parent: self.parent as &_,
idx: self.idx,
}
}
}
impl<'parent, T: PrimInt + Unsigned> Debug for BitRefMut<'parent, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", **self)
}
}
impl<'parent, T: PrimInt + Unsigned> Deref for BitRefMut<'parent, T> {
type Target = bool;
fn deref(&self) -> &Self::Target {
if *self.as_const().deref() {
&true
} else {
&false
}
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq for BitRefMut<'parent, T> {
fn eq(&self, other: &Self) -> bool {
(**self).eq(other)
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq<bool> for BitRefMut<'parent, T> {
fn eq(&self, other: &bool) -> bool {
(**self).eq(other)
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq<BitRef<'parent, T>> for BitRefMut<'parent, T> {
fn eq(&self, other: &BitRef<'parent, T>) -> bool {
self.as_const() == *other
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq<BitRef<'parent, T>> for bool {
fn eq(&self, other: &BitRef<'parent, T>) -> bool {
other == self
}
}
impl<'parent, T: PrimInt + Unsigned> PartialEq<BitRefMut<'parent, T>> for bool {
fn eq(&self, other: &BitRefMut<'parent, T>) -> bool {
other == self
}
}