use core::{
mem,
ptr,
};
use funty::IsNumber;
use tap::pipe::Pipe;
use crate::{
access::BitAccess,
array::BitArray,
domain::{
Domain,
DomainMut,
},
index::BitMask,
mem::BitMemory,
order::{
BitOrder,
Lsb0,
Msb0,
},
slice::BitSlice,
store::BitStore,
view::BitView,
};
#[cfg(feature = "alloc")]
use crate::{
boxed::BitBox,
vec::BitVec,
};
pub trait BitField {
fn load<M>(&self) -> M
where M: BitMemory {
#[cfg(target_endian = "little")]
return self.load_le::<M>();
#[cfg(target_endian = "big")]
return self.load_be::<M>();
}
fn store<M>(&mut self, value: M)
where M: BitMemory {
#[cfg(target_endian = "little")]
self.store_le(value);
#[cfg(target_endian = "big")]
self.store_be(value);
}
fn load_le<M>(&self) -> M
where M: BitMemory;
fn load_be<M>(&self) -> M
where M: BitMemory;
fn store_le<M>(&mut self, value: M)
where M: BitMemory;
fn store_be<M>(&mut self, value: M)
where M: BitMemory;
}
impl<T> BitField for BitSlice<Lsb0, T>
where T: BitStore
{
fn load_le<M>(&self) -> M
where M: BitMemory {
check::<M>("load", self.len());
match self.domain() {
Domain::Enclave { head, elem, tail } => {
get::<T, M>(elem, Lsb0::mask(head, tail), head.value())
},
Domain::Region { head, body, tail } => {
let mut accum = M::ZERO;
if let Some((elem, tail)) = tail {
accum = get::<T, M>(elem, Lsb0::mask(None, tail), 0);
}
for elem in body.iter().rev().map(BitStore::load_value) {
if M::BITS > T::Mem::BITS {
accum <<= T::Mem::BITS;
}
accum |= resize::<T::Mem, M>(elem);
}
if let Some((head, elem)) = head {
let shamt = head.value();
let rshamt = T::Mem::BITS as u8 - shamt;
if M::BITS as u8 > rshamt {
accum <<= rshamt;
}
else {
accum = M::ZERO;
}
accum |= get::<T, M>(elem, Lsb0::mask(head, None), shamt);
}
accum
},
}
}
fn load_be<M>(&self) -> M
where M: BitMemory {
check::<M>("load", self.len());
match self.domain() {
Domain::Enclave { head, elem, tail } => {
get::<T, M>(elem, Lsb0::mask(head, tail), head.value())
},
Domain::Region { head, body, tail } => {
let mut accum = M::ZERO;
if let Some((head, elem)) = head {
accum =
get::<T, M>(elem, Lsb0::mask(head, None), head.value());
}
for elem in body.iter().map(BitStore::load_value) {
if M::BITS > T::Mem::BITS {
accum <<= T::Mem::BITS;
}
accum |= resize::<T::Mem, M>(elem);
}
if let Some((elem, tail)) = tail {
let shamt = tail.value();
if M::BITS as u8 > shamt {
accum <<= shamt;
}
else {
accum = M::ZERO;
}
accum |= get::<T, M>(elem, Lsb0::mask(None, tail), 0);
}
accum
},
}
}
fn store_le<M>(&mut self, mut value: M)
where M: BitMemory {
check::<M>("store", self.len());
match self.domain_mut() {
DomainMut::Enclave { head, elem, tail } => {
set::<T, M>(elem, value, Lsb0::mask(head, tail), head.value());
},
DomainMut::Region { head, body, tail } => {
if let Some((head, elem)) = head {
let shamt = head.value();
set::<T, M>(elem, value, Lsb0::mask(head, None), shamt);
let lshamt = T::Mem::BITS as u8 - shamt;
if M::BITS as u8 > lshamt {
value >>= lshamt;
}
else {
value = M::ZERO;
}
}
for elem in body.iter_mut() {
elem.store_value(resize(value));
if M::BITS > T::Mem::BITS {
value >>= T::Mem::BITS;
}
}
if let Some((elem, tail)) = tail {
set::<T, M>(elem, value, Lsb0::mask(None, tail), 0);
}
},
}
}
fn store_be<M>(&mut self, mut value: M)
where M: BitMemory {
check::<M>("store", self.len());
match self.domain_mut() {
DomainMut::Enclave { head, elem, tail } => {
set::<T, M>(elem, value, Lsb0::mask(head, tail), head.value());
},
DomainMut::Region { head, body, tail } => {
if let Some((elem, tail)) = tail {
set::<T, M>(elem, value, Lsb0::mask(None, tail), 0);
let shamt = tail.value();
if M::BITS as u8 > shamt {
value >>= shamt;
}
else {
value = M::ZERO;
}
}
for elem in body.iter_mut().rev() {
elem.store_value(resize(value));
if M::BITS > T::Mem::BITS {
value >>= T::Mem::BITS;
}
}
if let Some((head, elem)) = head {
set::<T, M>(
elem,
value,
Lsb0::mask(head, None),
head.value(),
);
}
},
}
}
}
impl<T> BitField for BitSlice<Msb0, T>
where T: BitStore
{
fn load_le<M>(&self) -> M
where M: BitMemory {
check::<M>("load", self.len());
match self.domain() {
Domain::Enclave { head, elem, tail } => get::<T, M>(
elem,
Msb0::mask(head, tail),
T::Mem::BITS as u8 - tail.value(),
),
Domain::Region { head, body, tail } => {
let mut accum = M::ZERO;
if let Some((elem, tail)) = tail {
accum = get::<T, M>(
elem,
Msb0::mask(None, tail),
T::Mem::BITS as u8 - tail.value(),
);
}
for elem in body.iter().rev().map(BitStore::load_value) {
if M::BITS > T::Mem::BITS {
accum <<= T::Mem::BITS;
}
accum |= resize::<T::Mem, M>(elem);
}
if let Some((head, elem)) = head {
let shamt = T::Mem::BITS as u8 - head.value();
if M::BITS as u8 > shamt {
accum <<= shamt;
}
else {
accum = M::ZERO;
}
accum |= get::<T, M>(elem, Msb0::mask(head, None), 0);
}
accum
},
}
}
fn load_be<M>(&self) -> M
where M: BitMemory {
check::<M>("load", self.len());
match self.domain() {
Domain::Enclave { head, elem, tail } => get::<T, M>(
elem,
Msb0::mask(head, tail),
T::Mem::BITS as u8 - tail.value(),
),
Domain::Region { head, body, tail } => {
let mut accum = M::ZERO;
if let Some((head, elem)) = head {
accum = get::<T, M>(elem, Msb0::mask(head, None), 0);
}
for elem in body.iter().map(BitStore::load_value) {
if M::BITS > T::Mem::BITS {
accum <<= T::Mem::BITS;
}
accum |= resize::<T::Mem, M>(elem);
}
if let Some((elem, tail)) = tail {
let shamt = tail.value();
if M::BITS as u8 > shamt {
accum <<= shamt;
}
else {
accum = M::ZERO;
}
accum |= get::<T, M>(
elem,
Msb0::mask(None, tail),
T::Mem::BITS as u8 - shamt,
);
}
accum
},
}
}
fn store_le<M>(&mut self, mut value: M)
where M: BitMemory {
check::<M>("store", self.len());
match self.domain_mut() {
DomainMut::Enclave { head, elem, tail } => set::<T, M>(
elem,
value,
Msb0::mask(head, tail),
T::Mem::BITS as u8 - tail.value(),
),
DomainMut::Region { head, body, tail } => {
if let Some((head, elem)) = head {
set::<T, M>(elem, value, Msb0::mask(head, None), 0);
let shamt = T::Mem::BITS as u8 - head.value();
if M::BITS as u8 > shamt {
value >>= shamt;
}
else {
value = M::ZERO;
}
}
for elem in body.iter_mut() {
elem.store_value(resize(value));
if M::BITS > T::Mem::BITS {
value >>= T::Mem::BITS;
}
}
if let Some((elem, tail)) = tail {
set::<T, M>(
elem,
value,
Msb0::mask(None, tail),
T::Mem::BITS as u8 - tail.value(),
);
}
},
}
}
fn store_be<M>(&mut self, mut value: M)
where M: BitMemory {
check::<M>("store", self.len());
match self.domain_mut() {
DomainMut::Enclave { head, elem, tail } => set::<T, M>(
elem,
value,
Msb0::mask(head, tail),
T::Mem::BITS as u8 - tail.value(),
),
DomainMut::Region { head, body, tail } => {
if let Some((elem, tail)) = tail {
set::<T, M>(
elem,
value,
Msb0::mask(None, tail),
T::Mem::BITS as u8 - tail.value(),
);
if M::BITS as u8 > tail.value() {
value >>= tail.value();
}
else {
value = M::ZERO;
}
}
for elem in body.iter_mut().rev() {
elem.store_value(resize(value));
if M::BITS > T::Mem::BITS {
value >>= T::Mem::BITS;
}
}
if let Some((head, elem)) = head {
set::<T, M>(elem, value, Msb0::mask(head, None), 0);
}
},
}
}
}
impl<O, V> BitField for BitArray<O, V>
where
O: BitOrder,
V: BitView,
BitSlice<O, V::Store>: BitField,
{
fn load_le<M>(&self) -> M
where M: BitMemory {
self.as_bitslice().load_le()
}
fn load_be<M>(&self) -> M
where M: BitMemory {
self.as_bitslice().load_be()
}
fn store_le<M>(&mut self, value: M)
where M: BitMemory {
self.as_mut_bitslice().store_le(value)
}
fn store_be<M>(&mut self, value: M)
where M: BitMemory {
self.as_mut_bitslice().store_be(value)
}
}
#[cfg(feature = "alloc")]
impl<O, T> BitField for BitBox<O, T>
where
O: BitOrder,
T: BitStore,
BitSlice<O, T>: BitField,
{
fn load_le<M>(&self) -> M
where M: BitMemory {
self.as_bitslice().load_le()
}
fn load_be<M>(&self) -> M
where M: BitMemory {
self.as_bitslice().load_be()
}
fn store_le<M>(&mut self, value: M)
where M: BitMemory {
self.as_mut_bitslice().store_le(value)
}
fn store_be<M>(&mut self, value: M)
where M: BitMemory {
self.as_mut_bitslice().store_be(value)
}
}
#[cfg(feature = "alloc")]
impl<O, T> BitField for BitVec<O, T>
where
O: BitOrder,
T: BitStore,
BitSlice<O, T>: BitField,
{
fn load_le<M>(&self) -> M
where M: BitMemory {
self.as_bitslice().load_le()
}
fn load_be<M>(&self) -> M
where M: BitMemory {
self.as_bitslice().load_be()
}
fn store_le<M>(&mut self, value: M)
where M: BitMemory {
self.as_mut_bitslice().store_le(value)
}
fn store_be<M>(&mut self, value: M)
where M: BitMemory {
self.as_mut_bitslice().store_be(value)
}
}
fn check<M>(action: &'static str, len: usize)
where M: BitMemory {
if !(1 ..= M::BITS as usize).contains(&len) {
panic!(
"Cannot {} {} bits from a {}-bit region",
action,
M::BITS,
len,
);
}
}
#[allow(clippy::op_ref)]
fn get<T, M>(elem: &T, mask: BitMask<T::Mem>, shamt: u8) -> M
where
T: BitStore,
M: BitMemory,
{
elem.load_value()
.pipe(|val| val & &mask.value())
.pipe(|val| val >> &(shamt as usize))
.pipe(resize::<T::Mem, M>)
}
#[allow(clippy::op_ref)]
fn set<T, M>(elem: &T::Access, value: M, mask: BitMask<T::Mem>, shamt: u8)
where
T: BitStore,
M: BitMemory,
{
let mask = BitMask::new(mask.value());
let value = value
.pipe(resize::<M, T::Mem>)
.pipe(|val| val << &(shamt as usize))
.pipe(|val| mask & val);
elem.clear_bits(mask);
elem.set_bits(value);
}
fn resize<T, U>(value: T) -> U
where
T: BitMemory,
U: BitMemory,
{
let mut out = U::ZERO;
let size_t = mem::size_of::<T>();
let size_u = mem::size_of::<U>();
unsafe {
resize_inner::<T, U>(&value, &mut out, size_t, size_u);
}
out
}
#[cfg(target_endian = "little")]
unsafe fn resize_inner<T, U>(
src: &T,
dst: &mut U,
size_t: usize,
size_u: usize,
) {
ptr::copy_nonoverlapping(
src as *const T as *const u8,
dst as *mut U as *mut u8,
core::cmp::min(size_t, size_u),
);
}
#[cfg(target_endian = "big")]
unsafe fn resize_inner<T, U>(
src: &T,
dst: &mut U,
size_t: usize,
size_u: usize,
) {
let src = src as *const T as *const u8;
let dst = dst as *mut U as *mut u8;
if size_t > size_u {
ptr::copy_nonoverlapping(src.add(size_t - size_u), dst, size_u);
}
else {
ptr::copy_nonoverlapping(src, dst.add(size_u - size_t), size_t);
}
}
#[cfg(not(any(target_endian = "big", target_endian = "little")))]
compile_fail!(concat!(
"This architecture is currently not supported. File an issue at ",
env!(CARGO_PKG_REPOSITORY)
));
#[cfg(feature = "std")]
mod io;
#[cfg(test)]
mod tests;
#[cfg(all(test, feature = "std", not(miri), not(tarpaulin)))]
mod permutation_tests;