use std::fmt::{Binary, Debug, Display, LowerHex, UpperHex};
use std::mem::size_of;
use std::ops::{BitAnd, BitAndAssign, BitOrAssign, Not, Shl, ShlAssign, Shr, ShrAssign};
use bit_vec::BitVec;
use thiserror::Error;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum BitsError {
#[error("index out of range")]
IndexOutOfRange,
}
pub trait Bits:
Binary
+ BitAnd<Self, Output=Self>
+ BitAndAssign<Self>
+ BitOrAssign<Self>
+ Copy
+ Debug
+ Display
+ Eq
+ From<u8>
+ LowerHex
+ Not<Output=Self>
+ Shl<u8, Output=Self>
+ ShlAssign<u8>
+ Shr<u8, Output=Self>
+ ShrAssign<u8>
+ Sized
+ UpperHex
{
const WIDTH: u8 = size_of::<Self>() as u8 * 8;
const BITS: u8;
const MASK: u8 = Self::WIDTH - 1;
const MAX_ELT: usize = core::usize::MAX >> Self::BITS;
fn set(&mut self, place: u8, value: bool) -> Result<(), BitsError> {
if place > Self::MASK {
return Err(BitsError::IndexOutOfRange);
}
*self &= !(Self::from(1u8) << place);
*self |= Self::from(value as u8) << place;
Ok(())
}
fn get(&self, place: u8) -> Result<bool, BitsError> {
if place > Self::MASK {
return Err(BitsError::IndexOutOfRange);
}
Ok((*self >> place) & Self::from(1) == Self::from(1))
}
fn take(&mut self, place: u8) -> Result<bool, BitsError> {
if place > Self::MASK {
return Err(BitsError::IndexOutOfRange);
}
let mask: Self = Self::from(1).shl(place);
let neg_mask = !mask;
let mask = mask & *self;
*self &= neg_mask;
Ok(mask != 0.into())
}
#[doc(hidden)]
const TY: &'static str;
}
impl Bits for u8 {
const BITS: u8 = 3;
const TY: &'static str = "u8";
}
impl Bits for u16 {
const BITS: u8 = 4;
const TY: &'static str = "u16";
}
impl Bits for u32 {
const BITS: u8 = 5;
const TY: &'static str = "u32";
}
pub trait BitReverse {
fn reverse(&self) -> Self;
}
impl BitReverse for BitVec {
#[inline]
fn reverse(&self) -> BitVec {
let mut reversed = BitVec::new();
for bit in self.iter().rev() {
reversed.push(bit)
}
reversed
}
}
pub trait BitTrim {
fn trim_left(&self) -> Self;
}
impl BitTrim for BitVec {
fn trim_left(&self) -> BitVec {
let mut trimmed: BitVec = BitVec::new();
let mut notrim = false;
for bit in self.iter() {
if bit {
trimmed.push(bit);
notrim = true;
} else if notrim {
trimmed.push(bit);
}
}
trimmed
}
}
pub trait ToBytes {
fn to_byte_vec(&self) -> Vec<u8>;
}
impl ToBytes for BitVec {
fn to_byte_vec(&self) -> Vec<u8> {
let mut bytes = vec![];
let mut byte = 0;
let mut offset = 0;
for (idx_bit, bit) in self.iter().rev().enumerate() {
let idx_byte = (idx_bit % 8) as u8;
byte.set(idx_byte, bit)
.unwrap_or_else(|_| unreachable!("Byte should have 8-bit width"));
if idx_byte == 7 {
bytes.push(byte);
byte = 0;
}
offset = idx_byte;
}
if offset != 7 {
bytes.push(byte);
}
bytes.reverse();
bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reverse() {
let mut bits = BitVec::new();
bits.push(false); bits.push(false); bits.push(false); bits.push(true); bits.push(true); bits.push(true); bits.push(false); bits.push(true); bits.push(false); assert_eq!(vec![0, 184], bits.reverse().to_byte_vec());
}
#[test]
fn trim_left() {
let mut bits = BitVec::new();
bits.push(false); bits.push(false); bits.push(false); bits.push(true); bits.push(true); bits.push(true); bits.push(false); bits.push(true); assert_eq!(vec![29], bits.trim_left().to_byte_vec());
}
#[test]
fn to_byte_vec() {
let mut bits = BitVec::new();
bits.push(true); bits.push(false); bits.push(false); bits.push(true); bits.push(true); bits.push(true); bits.push(false); bits.push(true); assert_eq!(vec![157], bits.to_byte_vec());
bits.push(false); assert_eq!(vec![1, 58], bits.to_byte_vec());
}
#[test]
fn set_out_of_range() {
let (mut b8, mut b16, mut b32) = (0_u8, 0_u16, 0_u32);
assert!(matches!(b8.set(8, true), Err(BitsError::IndexOutOfRange)));
assert!(matches!(b16.set(16, true), Err(BitsError::IndexOutOfRange)));
assert!(matches!(b32.set(32, true), Err(BitsError::IndexOutOfRange)));
}
#[test]
fn get_out_of_range() {
let (b8, b16, b32) = (0_u8, 0_u16, 0_u32);
assert!(matches!(b8.get(8), Err(BitsError::IndexOutOfRange)));
assert!(matches!(b16.get(16), Err(BitsError::IndexOutOfRange)));
assert!(matches!(b32.get(32), Err(BitsError::IndexOutOfRange)));
}
#[test]
fn test_take() {
let mut b8 = 0u8;
b8.set(4, true).unwrap();
assert!(matches!(b8.take(4), Ok(true)));
assert!(matches!(b8.get(4), Ok(false)));
let mut b8 = 0xffu8;
b8.set(4, false).unwrap();
assert!(matches!(b8.take(4), Ok(false)));
assert!(matches!(b8.get(4), Ok(false)));
assert_eq!(b8, 0xefu8);
}
}