use crate::bitvec::BitVec;
use crate::short_bit_vec::{ShortBitVec, ShortType, SHORT_BITS};
use crate::synth::VCDValue;
use num_bigint::BigUint;
use num_traits::ToPrimitive;
use std::cmp::Ordering;
use std::fmt::{Binary, Debug, Formatter, LowerHex, UpperHex};
use std::hash::Hasher;
use std::num::Wrapping;
pub type LiteralType = u64;
pub const LITERAL_BITS: usize = 64;
pub const fn clog2(t: usize) -> usize {
let mut p = 0;
let mut b = 1;
while b < t {
p += 1;
b *= 2;
}
p
}
#[test]
fn test_clog2_is_correct() {
assert_eq!(clog2(1024), 10);
}
#[derive(Clone, Debug, Copy)]
pub enum Bits<const N: usize> {
#[doc(hidden)]
Short(ShortBitVec<N>),
#[doc(hidden)]
Long(BitVec<N>),
}
impl<const N: usize> From<BigUint> for Bits<N> {
fn from(x: BigUint) -> Self {
assert!(
x.bits() <= N as u64,
"cannot fit value from BigUInt with {} bits into Bits<{}>",
x.bits(),
N
);
if N <= SHORT_BITS {
x.to_u64().unwrap().into()
} else {
let mut ret = [false; N];
(0..N).for_each(|i| ret[i] = x.bit(i as u64));
Bits::Long(ret.into())
}
}
}
impl<const N: usize> From<Bits<N>> for BigUint {
fn from(y: Bits<N>) -> Self {
let mut x = BigUint::default();
for i in 0..N {
x.set_bit(i as u64, y.get_bit(i));
}
x
}
}
#[cfg(test)]
fn random_bits<const N: usize>() -> Bits<N> {
use rand::random;
let mut x = Bits::default();
for bit in 0..N {
if random::<bool>() {
x = x.replace_bit(bit, true);
}
}
x
}
#[test]
fn test_biguint_roundtrip() {
use rand::random;
use seq_macro::seq;
seq!(N in 5..150 {
for _iters in 0..10 {
let y: Bits<N> = random_bits();
let z: BigUint = y.into();
let h: Bits<N> = z.into();
assert_eq!(h, y);
}
});
seq!(N in 5..150 {
for _iters in 0..10 {
let bits = (0..N).map(|_| if random::<bool>() {
b"1"[0]
} else {
b"0"[0]
}).collect::<Vec<u8>>();
let y = BigUint::parse_bytes(&bits, 2).unwrap();
let z : Bits<N> = y.clone().into();
let h : BigUint = z.into();
assert_eq!(h, y);
}
});
}
#[test]
fn test_cast_from_biguint() {
let x = BigUint::parse_bytes(b"1011000101", 2).unwrap();
let y: Bits<16> = x.into();
let p = format!("y = {:x}", y);
assert_eq!(p, "y = 02c5");
println!("y = {:x}", y);
}
impl<const N: usize> Binary for Bits<N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
for i in 0..N {
if self.get_bit(N - 1 - i) {
write!(f, "1")?;
} else {
write!(f, "0")?;
}
}
Ok(())
}
}
#[test]
fn test_print_as_binary() {
let x = Bits::<16>::from(0b_1011_0100_1000_0000);
let p = format!("x = {:b}", x);
assert_eq!(p, "x = 1011010010000000")
}
impl<const N: usize> LowerHex for Bits<N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let m: usize = N + (4 - (N % 4)) % 4; let digits: usize = m / 4;
for digit in 0..digits {
let nibble: Bits<4> = self.get_bits(4 * (digits - 1 - digit));
let nibble_u8: LiteralType = nibble.into();
std::fmt::LowerHex::fmt(&nibble_u8, f)?;
}
Ok(())
}
}
#[test]
fn test_print_as_lowercase_hex() {
let x = Bits::<16>::from(0xcafe);
let p = format!("x = {:x}", x);
assert_eq!(p, "x = cafe");
}
impl<const N: usize> UpperHex for Bits<N> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let m: usize = N + (4 - (N % 4)) % 4; let digits: usize = m / 4;
for digit in 0..digits {
let nibble: Bits<4> = self.get_bits(4 * (digits - 1 - digit));
let nibble_u8: LiteralType = nibble.into();
std::fmt::UpperHex::fmt(&nibble_u8, f)?;
}
Ok(())
}
}
#[test]
fn test_print_as_uppercase_hex() {
let x = Bits::<16>::from(0xcafe);
let p = format!("x = {:X}", x);
assert_eq!(p, "x = CAFE");
}
pub fn bits<const N: usize>(x: LiteralType) -> Bits<N> {
let t: Bits<N> = x.into();
t
}
pub trait ToBits {
fn to_bits<const N: usize>(self) -> Bits<N>;
}
impl ToBits for u8 {
fn to_bits<const N: usize>(self) -> Bits<N> {
(self as LiteralType).into()
}
}
impl ToBits for u16 {
fn to_bits<const N: usize>(self) -> Bits<N> {
(self as LiteralType).into()
}
}
impl ToBits for u32 {
fn to_bits<const N: usize>(self) -> Bits<N> {
(self as LiteralType).into()
}
}
impl ToBits for u64 {
fn to_bits<const N: usize>(self) -> Bits<N> {
(self as LiteralType).into()
}
}
impl ToBits for usize {
fn to_bits<const N: usize>(self) -> Bits<N> {
(self as LiteralType).into()
}
}
impl ToBits for u128 {
fn to_bits<const N: usize>(self) -> Bits<N> {
Bits::<N>::from(BigUint::from(self))
}
}
pub fn bit_cast<const M: usize, const N: usize>(x: Bits<N>) -> Bits<M> {
match x {
Bits::Short(t) => {
let t: ShortType = t.into();
let t = if M < N {
t & ShortBitVec::<M>::mask().short()
} else {
t
};
let k: Bits<M> = (t as LiteralType).into();
k
}
Bits::Long(t) => {
if M > SHORT_BITS {
Bits::Long(t.resize())
} else {
let k: ShortType = t.into();
Bits::Short(k.into())
}
}
}
}
#[doc(hidden)]
impl<const N: usize> From<Bits<N>> for VCDValue {
fn from(val: Bits<N>) -> Self {
if N == 1 {
if val.get_bit(0) {
VCDValue::Single(vcd::Value::V1)
} else {
VCDValue::Single(vcd::Value::V0)
}
} else {
let mut x = vec![];
for i in 0..N {
if val.get_bit(N - 1 - i) {
x.push(vcd::Value::V1)
} else {
x.push(vcd::Value::V0)
}
}
VCDValue::Vector(x)
}
}
}
#[test]
fn test_bits_from_int_via_bits() {
let x: Bits<23> = bits(23);
let u: LiteralType = x.into();
assert_eq!(u, 23);
}
impl<const N: usize> Bits<N> {
#[inline(always)]
pub fn any(&self) -> bool {
match self {
Bits::Short(x) => x.any(),
Bits::Long(x) => x.any(),
}
}
#[inline(always)]
pub fn all(&self) -> bool {
match self {
Bits::Short(x) => x.all(),
Bits::Long(x) => x.all(),
}
}
#[inline(always)]
pub fn xor(&self) -> bool {
match self {
Bits::Short(x) => x.xor(),
Bits::Long(x) => x.xor(),
}
}
pub fn index(&self) -> usize {
match self {
Bits::Short(x) => x.short() as usize,
Bits::Long(_x) => panic!("Cannot map long bit vector to index type"),
}
}
#[inline(always)]
pub fn len(&self) -> usize {
N
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
N == 0
}
pub fn count() -> u128 {
1 << N
}
#[inline(always)]
pub fn get_bit(&self, index: usize) -> bool {
assert!(index < N);
match self {
Bits::Short(x) => x.get_bit(index),
Bits::Long(x) => x.get_bit(index),
}
}
pub fn replace_bit(&self, index: usize, val: bool) -> Self {
assert!(index < N);
match self {
Bits::Short(x) => Bits::Short(x.replace_bit(index, val)),
Bits::Long(x) => Bits::Long(x.replace_bit(index, val)),
}
}
#[inline(always)]
pub fn get_bits<const M: usize>(&self, index: usize) -> Bits<M> {
assert!(index <= N);
bit_cast::<M, N>(*self >> index as LiteralType)
}
#[inline(always)]
pub fn set_bits<const M: usize>(&mut self, index: usize, rhs: Bits<M>) {
assert!(index <= N);
assert!(index + M <= N);
let mask = !(bit_cast::<N, M>(Bits::<M>::mask()) << index as LiteralType);
let masked = *self & mask;
let replace = bit_cast::<N, M>(rhs) << index as LiteralType;
*self = masked | replace
}
#[inline(always)]
pub fn mask() -> Bits<N> {
if N <= SHORT_BITS {
Bits::Short(ShortBitVec::<N>::mask())
} else {
Bits::Long([true; N].into())
}
}
pub const fn width() -> usize {
N
}
pub fn to_u8(self) -> u8 {
assert!(N <= 8, "Cannot convert Bits::<{}> to u8 - too many bits", N);
let x: LiteralType = self.into();
x as u8
}
pub fn to_u16(self) -> u16 {
assert!(
N <= 16,
"Cannot convert Bits::<{}> to u16 - too many bits",
N
);
let x: LiteralType = self.into();
x as u16
}
pub fn to_u32(self) -> u32 {
assert!(
N <= 32,
"Cannot convert Bits::<{}> to u32 - too many bits",
N
);
let x: LiteralType = self.into();
x as u32
}
pub fn to_u64(self) -> u64 {
assert!(
N <= 64,
"Cannot convert Bits::<{}> to u64 - too many bits",
N
);
let x: LiteralType = self.into();
x
}
pub fn to_u128(self) -> u128 {
match self {
Bits::Short(x) => x.to_u128(),
Bits::Long(x) => x.to_u128(),
}
}
}
impl From<bool> for Bits<1> {
#[inline(always)]
fn from(x: bool) -> Self {
if x {
1.into()
} else {
0.into()
}
}
}
impl From<Bits<1>> for bool {
#[inline(always)]
fn from(val: Bits<1>) -> Self {
val.get_bit(0)
}
}
impl<const N: usize> From<LiteralType> for Bits<N> {
fn from(x: LiteralType) -> Self {
if N > SHORT_BITS {
let y: BitVec<N> = x.into();
Bits::Long(y)
} else {
assert!(
x <= ShortBitVec::<N>::max_legal(),
"Value 0x{:x} does not fit into bitvector of length {}",
x,
N
);
Bits::Short((x as ShortType).into())
}
}
}
impl<const N: usize> From<Wrapping<LiteralType>> for Bits<N> {
fn from(x: Wrapping<LiteralType>) -> Self {
x.0.into()
}
}
impl<const N: usize> From<Bits<N>> for LiteralType {
fn from(x: Bits<N>) -> Self {
assert!(N <= LITERAL_BITS);
match x {
Bits::Short(t) => {
let p: ShortType = t.into();
p as LiteralType
}
Bits::Long(t) => t.into(),
}
}
}
#[inline(always)]
#[doc(hidden)]
fn binop<Tshort, TLong, const N: usize>(
a: Bits<N>,
b: Bits<N>,
short_op: Tshort,
long_op: TLong,
) -> Bits<N>
where
Tshort: Fn(ShortBitVec<N>, ShortBitVec<N>) -> ShortBitVec<N>,
TLong: Fn(BitVec<N>, BitVec<N>) -> BitVec<N>,
{
match a {
Bits::Short(x) => match b {
Bits::Short(y) => Bits::Short(short_op(x, y)),
_ => {
unreachable!()
}
},
Bits::Long(x) => match b {
Bits::Long(y) => Bits::Long(long_op(x, y)),
_ => {
unreachable!()
}
},
}
}
macro_rules! op {
($func: ident, $method: ident, $op: tt) => {
#[doc(hidden)]
impl<const N: usize> std::ops::$method<Bits<N>> for Bits<N> {
type Output = Bits<N>;
#[inline(always)]
fn $func(self, rhs: Bits<N>) -> Self::Output {
binop(self, rhs, |a, b| a $op b, |a, b| a $op b)
}
}
impl<const N: usize> std::ops::$method<LiteralType> for Bits<N> {
type Output = Bits<N>;
fn $func(self, rhs: LiteralType) -> Self::Output {
binop(self, rhs.into(), |a, b| a $op b, |a, b| a $op b)
}
}
impl<const N: usize> std::ops::$method<Bits<N>> for LiteralType {
type Output = Bits<N>;
#[inline(always)]
fn $func(self, rhs: Bits<N>) -> Self::Output {
binop(self.into(), rhs.into(), |a, b| a $op b, |a, b| a $op b)
}
}
}
}
op!(add, Add, +);
op!(sub, Sub, -);
op!(bitor, BitOr, |);
op!(bitand, BitAnd, &);
op!(bitxor, BitXor, ^);
macro_rules! op_shift {
($func: ident, $method: ident, $op: tt) => {
impl<const M: usize, const N: usize> std::ops::$method<Bits<M>> for Bits<N> {
type Output = Bits<N>;
#[inline(always)]
fn $func(self, rhs: Bits<M>) -> Self::Output {
self $op rhs.to_u64()
}
}
impl<const N: usize> std::ops::$method<LiteralType> for Bits<N> {
type Output = Bits<N>;
fn $func(self, rhs: LiteralType) -> Self::Output {
match self {
Bits::Short(x) => Bits::Short(x $op rhs),
Bits::Long(x) =>
Bits::Long(x $op rhs),
}
}
}
impl<const N: usize> std::ops::$method<Bits<N>> for LiteralType {
type Output = Bits<LITERAL_BITS>;
fn $func(self, rhs: Bits<N>) -> Self::Output {
binop(self.into(), bit_cast(rhs), |a, b| a $op b, |a, b| a $op b)
}
}
}
}
op_shift!(shr, Shr, >>);
op_shift!(shl, Shl, <<);
impl<const N: usize> Default for Bits<N> {
fn default() -> Bits<N> {
bits::<N>(0)
}
}
impl<const N: usize> std::ops::Not for Bits<N> {
type Output = Bits<N>;
fn not(self) -> Self::Output {
match self {
Bits::Short(x) => Bits::Short(!x),
Bits::Long(x) => Bits::Long(!x),
}
}
}
#[doc(hidden)]
impl<const N: usize> Ord for Bits<N> {
fn cmp(&self, other: &Bits<N>) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
#[doc(hidden)]
impl<const N: usize> PartialOrd<Bits<N>> for LiteralType {
fn partial_cmp(&self, other: &Bits<N>) -> Option<Ordering> {
let self_as_bits: Bits<N> = (*self).into();
self_as_bits.partial_cmp(other)
}
}
#[doc(hidden)]
impl<const N: usize> PartialOrd<LiteralType> for Bits<N> {
fn partial_cmp(&self, other: &LiteralType) -> Option<Ordering> {
let other_as_bits: Bits<N> = (*other).into();
self.partial_cmp(&other_as_bits)
}
}
#[doc(hidden)]
impl<const N: usize> PartialOrd<Bits<N>> for Bits<N> {
#[inline(always)]
fn partial_cmp(&self, other: &Bits<N>) -> Option<Ordering> {
match self {
Bits::Short(x) => match other {
Bits::Short(y) => x.partial_cmp(y),
_ => panic!("Short Long case"),
},
Bits::Long(x) => match other {
Bits::Long(y) => x.partial_cmp(y),
_ => panic!("Long short case"),
},
}
}
}
#[doc(hidden)]
impl<const N: usize> PartialEq<Bits<N>> for Bits<N> {
#[inline(always)]
fn eq(&self, other: &Bits<N>) -> bool {
match self {
Bits::Short(x) => match other {
Bits::Short(y) => x == y,
_ => panic!("Short Long case"),
},
Bits::Long(x) => match other {
Bits::Long(y) => x == y,
_ => panic!("Long Short case"),
},
}
}
}
#[doc(hidden)]
impl<const N: usize> PartialEq<LiteralType> for Bits<N> {
fn eq(&self, other: &LiteralType) -> bool {
let other_as_bits: Bits<N> = (*other).into();
self.eq(&other_as_bits)
}
}
#[doc(hidden)]
impl<const N: usize> PartialEq<Bits<N>> for LiteralType {
fn eq(&self, other: &Bits<N>) -> bool {
let self_as_bits: Bits<N> = (*self).into();
self_as_bits.eq(other)
}
}
#[doc(hidden)]
impl PartialEq<bool> for Bits<1> {
#[inline(always)]
fn eq(&self, other: &bool) -> bool {
self.get_bit(0) == *other
}
}
#[doc(hidden)]
impl PartialEq<Bits<1>> for bool {
fn eq(&self, other: &Bits<1>) -> bool {
*self == other.get_bit(0)
}
}
#[doc(hidden)]
impl<const N: usize> Eq for Bits<N> {}
#[doc(hidden)]
impl<const N: usize> std::hash::Hash for Bits<N> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Bits::Short(t) => t.hash(state),
Bits::Long(t) => t.hash(state),
}
}
}
#[doc(hidden)]
impl<const N: usize> std::ops::Add<bool> for Bits<N> {
type Output = Bits<N>;
fn add(self, rhs: bool) -> Self::Output {
if rhs {
self + Bits::<N>::from(1)
} else {
self
}
}
}
#[cfg(test)]
mod tests {
use super::{bit_cast, clog2, Bits};
use crate::bits::random_bits;
use crate::bits::{LiteralType, ToBits};
use num_bigint::BigUint;
use num_traits::One;
use seq_macro::seq;
use std::num::Wrapping;
#[test]
fn test_get_bits_section() {
let x: Bits<40> = 0xD_ADBE_EFCA.into();
let y = x.get_bits::<32>(8).to_u32();
let answer = 0xDAD_BEEF;
assert_eq!(y, answer);
}
#[test]
fn test_short_from_u8() {
let x: Bits<4> = 15.into();
let y: LiteralType = x.into();
assert_eq!(y, 15 & (0x0F));
}
#[test]
fn test_short_from_u16() {
let x: Bits<12> = 1432.into();
let y: LiteralType = x.into();
assert_eq!(y, 1432 & (0x0FFF));
}
#[test]
fn test_short_from_u32() {
let x: Bits<64> = 12434234.into();
let y: LiteralType = x.into();
assert_eq!(y, 12434234);
}
#[test]
fn test_from_u32() {
let x: Bits<64> = 0xFFFF_FFFF.into();
let y: LiteralType = x.into();
assert_eq!(y, 0xFFFF_FFFF);
}
#[test]
fn or_test() {
let a: Bits<32> = 45.into();
let b: Bits<32> = 10395.into();
let c = a | b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 45 | 10395)
}
#[test]
fn and_test() {
let a: Bits<32> = 45.into();
let b: Bits<32> = 10395.into();
let c = a & b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 45 & 10395)
}
#[test]
fn xor_test() {
let a: Bits<32> = 45.into();
let b: Bits<32> = 10395.into();
let c = a ^ b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 45 ^ 10395)
}
#[test]
fn not_test() {
let a: Bits<32> = 45.into();
let c = !a;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, (!45_u32) as LiteralType);
}
#[test]
fn shr_test() {
let a: Bits<32> = 10395.into();
let c: Bits<32> = a >> 4;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 10395 >> 4);
}
#[test]
fn shr_test_pair() {
let a: Bits<32> = 10395.into();
let b: Bits<32> = 4.into();
let c = a >> b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 10395 >> 4);
}
#[test]
fn shl_test() {
let a: Bits<32> = 10395.into();
let c = a << 24;
let c_u32 = c.to_u32();
assert_eq!(c_u32, 10395 << 24);
}
#[test]
fn shl_test_pair() {
let a: Bits<32> = 10395.into();
let b: Bits<32> = 4.into();
let c = a << b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 10395 << 4);
}
#[test]
fn add_works() {
let a: Bits<32> = 10234.into();
let b: Bits<32> = 19423.into();
let c = a + b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 10234 + 19423);
}
#[test]
fn add_int_works() {
let a: Bits<32> = 10234.into();
let b = 19423;
let c: Bits<32> = a + b;
let c_u32: LiteralType = c.into();
assert_eq!(c_u32, 10234 + 19423);
}
#[test]
fn add_works_with_overflow() {
let x = 2_042_102_334_u32;
let y = 2_942_142_512_u32;
let a: Bits<32> = x.to_bits();
let b: Bits<32> = y.to_bits();
let c = a + b;
let c_u32 = c.to_u32();
assert_eq!(Wrapping(c_u32), Wrapping(x) + Wrapping(y));
}
#[test]
fn sub_works() {
let x = 2_042_102_334_u32;
let y = 2_942_142_512_u32;
let a: Bits<32> = x.to_bits();
let b: Bits<32> = y.to_bits();
let c = a - b;
let c_u32 = c.to_u32();
assert_eq!(Wrapping(c_u32), Wrapping(x) - Wrapping(y));
}
#[test]
fn sub_int_works() {
let x = 2_042_102_334_u32;
let y = 2_942_142_512;
let a: Bits<32> = x.to_bits();
let c = a - y;
let c_u32 = c.to_u32();
assert_eq!(Wrapping(c_u32), Wrapping(x) - Wrapping(y as u32));
}
#[test]
fn eq_works() {
let x = 2_032_142_351;
let y = 2_942_142_512;
let a: Bits<32> = x.into();
let b: Bits<32> = x.into();
let c: Bits<32> = y.into();
assert_eq!(a, b);
assert_ne!(a, c)
}
#[test]
fn mask_works() {
let a: Bits<48> = 0xFFFF_FFFF_FFFF.into();
let b = Bits::<48>::mask();
assert_eq!(a, b);
let a: Bits<16> = 0xFFFF.into();
let b = Bits::<16>::mask();
assert_eq!(a, b)
}
#[test]
fn get_bit_works() {
let a: Bits<48> = 0xFFFF_FFFF_FFF5.into();
assert!(a.get_bit(0));
assert!(!a.get_bit(1));
assert!(a.get_bit(2));
assert!(!a.get_bit(3));
let c: Bits<5> = 3.into();
assert!(!a.get_bit(c.index()));
}
#[test]
fn test_bit_cast_short() {
let a: Bits<8> = 0xFF.into();
let b: Bits<16> = bit_cast(a);
assert_eq!(b, 0xFF);
let c: Bits<4> = bit_cast(a);
assert_eq!(c, 0xF);
}
#[test]
fn test_bit_cast_long() {
let a: Bits<48> = 0xdead_cafe_babe.into();
let b: Bits<44> = bit_cast(a);
assert_eq!(b, 0xead_cafe_babe);
let b: Bits<32> = bit_cast(a);
assert_eq!(b, 0xcafe_babe);
}
#[test]
fn test_bit_extract_long() {
let a: Bits<48> = 0xdead_cafe_babe.into();
let b: Bits<44> = a.get_bits(4);
assert_eq!(b, 0x0dea_dcaf_ebab);
let b: Bits<32> = a.get_bits(16);
assert_eq!(b, 0xdead_cafe);
}
#[test]
fn test_set_bit() {
let a: Bits<48> = 0xdead_cafe_babe.into();
let mut b = a;
for i in 4..8 {
b = b.replace_bit(i, false)
}
assert_eq!(b, 0xdead_cafe_ba0e);
}
#[test]
fn test_set_bits() {
let a: Bits<16> = 0xdead.into();
let b: Bits<4> = 0xf.into();
let mut c = a;
c.set_bits(4, b);
assert_eq!(c, 0xdefd);
let a: Bits<48> = 0xdead_cafe_babe.into();
let b: Bits<8> = 0xde.into();
let mut c = a;
c.set_bits(16, b);
assert_eq!(c, 0xdead_cade_babe);
}
#[test]
fn test_constants_and_bits() {
let a: Bits<16> = 0xdead.into();
let b = a + 1;
let c = 1 + a;
println!("{:x}", b);
assert_eq!(b, 0xdeae);
assert_eq!(b, c);
}
#[test]
fn test_clog2() {
const A_WIDTH: usize = clog2(250);
let a: Bits<{ A_WIDTH }> = 153.into();
println!("{:x}", a);
assert_eq!(a.len(), 8);
assert_eq!(clog2(1024), 10);
}
#[test]
fn test_clog2_inline() {
const A_WIDTH: usize = clog2(1000);
let a: Bits<A_WIDTH> = 1023.into();
assert_eq!(a.len(), 10);
}
#[test]
fn test_default() {
const N: usize = 128;
let a = Bits::<N>::default();
assert_eq!(a, 0);
}
#[test]
fn test_get_bits() {
fn get_bits_test<const N: usize, const M: usize>() {
for offset in 0_usize..N {
let y: Bits<N> = random_bits();
let z = y.get_bits::<M>(offset);
let yb: BigUint = y.into();
let yb = (yb >> offset) & ((BigUint::one() << M) - BigUint::one());
let zb: BigUint = z.into();
assert_eq!(zb, yb);
}
}
seq!(N in 0..16 {
get_bits_test::<8, N>();
});
seq!(N in 0..64 {
get_bits_test::<32, N>();
});
seq!(N in 0..65 {
get_bits_test::<64, N>();
});
seq!(N in 0..300 {
get_bits_test::<256, N>();
});
seq!(N in 0..150 {
get_bits_test::<125, N>();
});
}
#[test]
fn test_bitcast() {
fn bitcast_test<const N: usize, const M: usize>() {
let y: Bits<N> = random_bits();
let z = bit_cast::<M, N>(y);
let yb: BigUint = y.into();
let zb = yb & ((BigUint::one() << M) - BigUint::one());
let zc: BigUint = z.into();
assert_eq!(zb, zc);
}
fn bitcast_test_set<const M: usize>() {
bitcast_test::<M, 1>();
bitcast_test::<M, 8>();
bitcast_test::<M, 16>();
bitcast_test::<M, 32>();
bitcast_test::<M, 64>();
bitcast_test::<M, 128>();
bitcast_test::<M, 256>();
}
bitcast_test_set::<1>();
bitcast_test_set::<8>();
bitcast_test_set::<16>();
bitcast_test_set::<32>();
bitcast_test_set::<64>();
bitcast_test_set::<128>();
bitcast_test_set::<256>();
}
#[test]
fn test_any() {
seq!(N in 1..150 {
for _rep in 0..10 {
let y: Bits<N> = random_bits();
let z : BigUint = y.into();
let y_any = y.any();
let z_any = z.count_ones() != 0;
assert_eq!(y_any, z_any)
}
let y = Bits::<N>::default();
assert!(!y.any());
});
}
#[test]
fn test_all() {
seq!(N in 1..150 {
for _rep in 0..10 {
let y: Bits<N> = random_bits();
let z : BigUint = y.into();
let y_all = y.all();
let z_all = z.count_ones() == N;
assert_eq!(y_all, z_all)
}
let y = Bits::<N>::mask();
assert!(y.all());
});
}
#[test]
fn test_shl_var_bitwidths_driver() {
fn test_shl_var_bitwidths<const N: usize, const M: usize>() {
for _iter in 0..100 {
let y: Bits<N> = random_bits();
let z: Bits<M> = random_bits();
let r = y << z;
let y1: BigUint = y.into();
let mask: BigUint = (BigUint::one() << N) - BigUint::one();
let z1: u128 = z.to_u128();
let r1 = (y1 << z1) & mask;
let convert: BigUint = r.into();
assert_eq!(convert, r1);
}
}
fn test_shl_var_bitwidths_set<const N: usize>() {
test_shl_var_bitwidths::<N, 1>();
test_shl_var_bitwidths::<N, 2>();
test_shl_var_bitwidths::<N, 4>();
test_shl_var_bitwidths::<N, 8>();
test_shl_var_bitwidths::<N, 16>();
}
test_shl_var_bitwidths_set::<1>();
test_shl_var_bitwidths_set::<2>();
test_shl_var_bitwidths_set::<4>();
test_shl_var_bitwidths_set::<8>();
test_shl_var_bitwidths_set::<16>();
test_shl_var_bitwidths_set::<32>();
test_shl_var_bitwidths_set::<64>();
test_shl_var_bitwidths_set::<128>();
test_shl_var_bitwidths_set::<256>();
}
#[test]
fn test_shr_var_bitwidths_driver() {
fn test_shr_var_bitwidths<const N: usize, const M: usize>() {
for _iter in 0..100 {
let y: Bits<N> = random_bits();
let z: Bits<M> = random_bits();
let r = y >> z;
let y1: BigUint = y.into();
let mask: BigUint = (BigUint::one() << N) - BigUint::one();
let z1: u128 = z.to_u128();
let r1 = (y1 >> z1) & mask;
let convert: BigUint = r.into();
assert_eq!(convert, r1);
}
}
fn test_shr_var_bitwidths_set<const N: usize>() {
test_shr_var_bitwidths::<N, 1>();
test_shr_var_bitwidths::<N, 2>();
test_shr_var_bitwidths::<N, 4>();
test_shr_var_bitwidths::<N, 8>();
test_shr_var_bitwidths::<N, 16>();
}
test_shr_var_bitwidths_set::<1>();
test_shr_var_bitwidths_set::<2>();
test_shr_var_bitwidths_set::<4>();
test_shr_var_bitwidths_set::<8>();
test_shr_var_bitwidths_set::<16>();
test_shr_var_bitwidths_set::<32>();
test_shr_var_bitwidths_set::<64>();
test_shr_var_bitwidths_set::<128>();
test_shr_var_bitwidths_set::<256>();
}
#[test]
fn test_shl_to_zero() {
let y: Bits<1> = 1.into();
let z: Bits<8> = 128.into();
let r = y << z;
assert_eq!(r, 0);
}
#[test]
fn test_shl() {
seq!(N in 1..150 {
for l in 0..N {
let y: Bits<N> = random_bits();
let z: Bits<N> = y << l;
let y1 : BigUint = y.into();
let mask : BigUint = (BigUint::one() << N) - BigUint::one();
let z1 = (y1 << l) & mask;
let convert : BigUint = z.into();
assert_eq!(z1, convert)
}
});
}
#[test]
fn test_shr() {
seq!(N in 1..150 {
for l in 0..N {
let y: Bits<N> = random_bits();
let z: Bits<N> = y >> l;
let y1 : BigUint = y.into();
let mask : BigUint = (BigUint::one() << N) - BigUint::one();
let z1 = (y1 >> l) & mask;
let convert : BigUint = z.into();
assert_eq!(z1, convert)
}
});
}
macro_rules! test_op_with_values {
($func: ident) => {
seq!(N in 1..150 {
for _iters in 0..10 {
let y: Bits<N> = random_bits();
let z: Bits<N> = random_bits();
let v1_as_bint : BigUint = y.into();
let v2_as_bint : BigUint = z.into();
let mask : BigUint = (BigUint::one() << N) - BigUint::one();
let (lib_answer, biguint_answer) = $func(y, z, v1_as_bint, v2_as_bint, mask);
let convert : BigUint = lib_answer.into();
assert_eq!(biguint_answer, convert)
}
});
}
}
#[test]
fn test_add() {
fn add<const N: usize>(
y: Bits<N>,
z: Bits<N>,
y1: BigUint,
z1: BigUint,
mask: BigUint,
) -> (Bits<N>, BigUint) {
(y + z, (y1 + z1) & mask)
}
test_op_with_values!(add);
}
#[test]
fn test_sub() {
fn sub<const N: usize>(
y: Bits<N>,
z: Bits<N>,
y1: BigUint,
z1: BigUint,
mask: BigUint,
) -> (Bits<N>, BigUint) {
if z1 <= y1 {
(y - z, (y1 - z1))
} else {
(y - z, mask + BigUint::one() + y1 - z1)
}
}
test_op_with_values!(sub);
}
#[test]
fn test_bitor() {
fn bor<const N: usize>(
y: Bits<N>,
z: Bits<N>,
y1: BigUint,
z1: BigUint,
mask: BigUint,
) -> (Bits<N>, BigUint) {
(y | z, (y1 | z1) & mask)
}
test_op_with_values!(bor);
}
#[test]
fn test_bitand() {
fn band<const N: usize>(
y: Bits<N>,
z: Bits<N>,
y1: BigUint,
z1: BigUint,
mask: BigUint,
) -> (Bits<N>, BigUint) {
(y & z, (y1 & z1) & mask)
}
test_op_with_values!(band);
}
#[test]
fn test_bitxor() {
fn bxor<const N: usize>(
y: Bits<N>,
z: Bits<N>,
y1: BigUint,
z1: BigUint,
mask: BigUint,
) -> (Bits<N>, BigUint) {
(y ^ z, (y1 ^ z1) & mask)
}
test_op_with_values!(bxor);
}
#[test]
fn test_not() {
fn not<const N: usize>(
y: Bits<N>,
_z: Bits<N>,
y1: BigUint,
_z1: BigUint,
mask: BigUint,
) -> (Bits<N>, BigUint) {
(!y, (y1 ^ mask))
}
test_op_with_values!(not);
}
macro_rules! test_cmp_with_values {
($func: ident) => {
seq!(N in 1..256 {
for _iters in 0..10 {
let y: Bits<N> = random_bits();
let z: Bits<N> = random_bits();
let v1_as_bint : BigUint = y.into();
let v2_as_bint : BigUint = z.into();
let (lib_answer, biguint_answer) = $func(y, z, v1_as_bint, v2_as_bint);
assert_eq!(lib_answer, biguint_answer)
}
});
}
}
#[test]
fn test_lt() {
fn lt<const N: usize>(y: Bits<N>, z: Bits<N>, y1: BigUint, z1: BigUint) -> (bool, bool) {
(y < z, y1 < z1)
}
test_cmp_with_values!(lt);
}
#[test]
fn test_le() {
fn le<const N: usize>(y: Bits<N>, z: Bits<N>, y1: BigUint, z1: BigUint) -> (bool, bool) {
(y <= z, y1 <= z1)
}
test_cmp_with_values!(le);
}
#[test]
fn test_eq() {
fn eq<const N: usize>(y: Bits<N>, z: Bits<N>, y1: BigUint, z1: BigUint) -> (bool, bool) {
(y == z, y1 == z1)
}
test_cmp_with_values!(eq);
}
#[test]
fn test_neq() {
fn neq<const N: usize>(y: Bits<N>, z: Bits<N>, y1: BigUint, z1: BigUint) -> (bool, bool) {
(y != z, y1 != z1)
}
test_cmp_with_values!(neq);
}
#[test]
fn test_ge() {
fn ge<const N: usize>(y: Bits<N>, z: Bits<N>, y1: BigUint, z1: BigUint) -> (bool, bool) {
(y >= z, y1 >= z1)
}
test_cmp_with_values!(ge);
}
#[test]
fn test_gt() {
fn gt<const N: usize>(y: Bits<N>, z: Bits<N>, y1: BigUint, z1: BigUint) -> (bool, bool) {
(y > z, y1 > z1)
}
test_cmp_with_values!(gt);
}
}
pub type Bit = bool;
impl std::ops::Mul<Bits<16>> for Bits<16> {
type Output = Bits<32>;
fn mul(self, rhs: Bits<16>) -> Self::Output {
let x = match self {
Bits::Short(x) => x.short(),
Bits::Long(_) => {
panic!("unreachable!")
}
};
let y = match rhs {
Bits::Short(x) => x.short(),
Bits::Long(_) => {
panic!("unreachable!")
}
};
Bits::Short(ShortBitVec::from(x * y))
}
}