use baa::*;
use proptest::prelude::*;
fn do_test_is_negative(a: &str) {
let expected = a.starts_with('1');
let value = BitVecValue::from_bit_str(a).unwrap();
let actual = value.is_negative();
assert_eq!(actual, expected, "{a}")
}
fn do_test_concat(a: &str, b: &str) {
let a_value = BitVecValue::from_bit_str(a).unwrap();
let b_value = BitVecValue::from_bit_str(b).unwrap();
let c_value = a_value.concat(&b_value);
let expected = format!("{a}{b}");
assert_eq!(c_value.to_bit_str(), expected);
let mut c_value = BitVecValue::zero(c_value.width());
c_value.concat_in_place(&a_value, &b_value);
assert_eq!(c_value.to_bit_str(), expected);
}
fn do_test_slice(src: &str, hi: WidthInt, lo: WidthInt) {
assert!(
!src.is_empty(),
"slice only works with vectors that are at least 1-bit!"
);
let src_value = BitVecValue::from_bit_str(src).unwrap();
assert!(hi >= lo);
assert!(hi < src_value.width());
let res = src_value.slice(hi, lo);
assert_eq!(res.width(), hi - lo + 1);
let expected: String = src
.chars()
.skip((src_value.width() - 1 - hi) as usize)
.take(res.width() as usize)
.collect();
assert_eq!(res.to_bit_str(), expected);
let mut res = BitVecValue::zero(hi - lo + 1);
res.slice_in_place(&src_value, hi, lo);
assert_eq!(res.to_bit_str(), expected);
}
fn do_test_shift(src: &str, by: WidthInt, right: bool, signed: bool) {
assert!(!(!right && signed), "left shift is always unsigned!");
let a = BitVecValue::from_bit_str(src).unwrap();
let b = BitVecValue::from_u64(by as u64, a.width());
assert_eq!(a.width(), b.width());
let res = if right {
if signed {
a.arithmetic_shift_right(&b)
} else {
a.shift_right(&b)
}
} else {
a.shift_left(&b)
};
let padding_len = std::cmp::min(by, a.width()) as usize;
let pad_char = if signed {
src.chars().next().unwrap()
} else {
'0'
};
let mut expected: String = pad_char.to_string().repeat(padding_len);
if right {
let msb: String = src.chars().take(a.width() as usize - padding_len).collect();
expected.push_str(&msb);
} else {
let mut msb: String = src.chars().skip(padding_len).collect();
msb.push_str(&expected);
expected = msb;
}
let expected = BitVecValue::from_bit_str(&expected).unwrap();
assert_eq!(res, expected, "{src:?} {by} {res:?} {expected:?}");
let mut res = BitVecValue::zero(a.width());
if right {
if signed {
res.arithmetic_shift_right_in_place(&a, &b);
} else {
res.shift_right_in_place(&a, &b);
}
} else {
res.shift_left_in_place(&a, &b)
};
assert_eq!(res, expected, "{src:?} {by} {res:?} {expected:?}");
}
fn do_test_shift_right(src: &str, by: WidthInt) {
do_test_shift(src, by, true, false);
}
fn do_test_shift_left(src: &str, by: WidthInt) {
do_test_shift(src, by, false, false);
}
fn do_test_arithmetic_shift_right(src: &str, by: WidthInt) {
do_test_shift(src, by, true, true);
}
fn do_test_zero_ext(src: &str, by: WidthInt) {
let value = BitVecValue::from_bit_str(src).unwrap();
let expected_res_width = value.width() + by;
let actual = value.zero_extend(by);
assert_eq!(expected_res_width, actual.width());
let expected =
BitVecValue::from_bit_str(&format!("{}{}", "0".repeat(by as usize), src)).unwrap();
assert_eq!(
actual,
expected,
"{} vs. {}",
actual.to_bit_str(),
expected.to_bit_str()
);
let mut actual = BitVecValue::zero(value.width() + by);
actual.zero_extend_in_place(&value, by);
assert_eq!(actual, expected);
}
fn do_test_sign_ext(src: &str, by: WidthInt) {
assert!(!src.is_empty(), "sign extend only works for non zero bits");
let value = BitVecValue::from_bit_str(src).unwrap();
let expected_res_width = value.width() + by;
let actual = value.sign_extend(by);
assert_eq!(expected_res_width, actual.width());
let sign_bit = src.chars().next().unwrap().to_string();
let expected =
BitVecValue::from_bit_str(&format!("{}{}", sign_bit.repeat(by as usize), src)).unwrap();
assert_eq!(
actual,
expected,
"{} vs. {}",
actual.to_bit_str(),
expected.to_bit_str()
);
let mut actual = BitVecValue::zero(value.width() + by);
actual.sign_extend_in_place(&value, by);
assert_eq!(actual, expected);
}
fn bit_str_arg() -> impl Strategy<Value = String> {
"[01]+"
}
fn slice_args() -> impl Strategy<Value = (String, WidthInt, WidthInt)> {
bit_str_arg()
.prop_flat_map(|bits: String| {
let len = std::cmp::max(bits.len(), 1);
(Just(bits), 0..(len as WidthInt))
})
.prop_flat_map(|(bits, msb)| (Just(bits), Just(msb), 0..(msb + 1)))
}
fn shift_args() -> impl Strategy<Value = (String, WidthInt)> {
bit_str_arg().prop_flat_map(|bits: String| {
let len = std::cmp::max(bits.len(), 1);
let max_shift =
std::cmp::min(mask(bits.len() as WidthInt) + 1, WidthInt::MAX as Word) as WidthInt;
let by = prop_oneof![0..(len as WidthInt), 0..(max_shift)];
(Just(bits), by)
})
}
#[test]
fn do_test_is_negative_regressions() {
let a = "0101101001111010000111000001011010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
do_test_is_negative(a);
}
#[test]
fn test_arithmetic_shift_right_regression() {
do_test_arithmetic_shift_right("1", 0);
do_test_arithmetic_shift_right("10", 1);
do_test_arithmetic_shift_right(&format!("1{}", "0".repeat(Word::BITS as usize)), 1);
do_test_arithmetic_shift_right(&format!("1{}", "0".repeat((Word::BITS * 2) as usize)), 1);
do_test_arithmetic_shift_right(
&format!("1{}", "0".repeat((Word::BITS * 2) as usize)),
Word::BITS,
);
do_test_arithmetic_shift_right(
&format!("1{}", "0".repeat((Word::BITS * 2) as usize)),
Word::BITS * 2,
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(5000))]
#[test]
fn test_is_negative(a in bit_str_arg()) {
do_test_is_negative(&a);
}
#[test]
fn test_concat(a in bit_str_arg(), b in bit_str_arg()) {
do_test_concat(&a, &b);
}
#[test]
fn test_slice((s, msb, lsb) in slice_args()) {
do_test_slice(&s, msb, lsb);
}
#[test]
fn test_shift_right((s, by) in shift_args()) {
do_test_shift_right(&s, by);
}
#[test]
fn test_shift_left((s, by) in shift_args()) {
do_test_shift_left(&s, by);
}
#[test]
fn test_arithmetic_shift_right((s, by) in shift_args()) {
do_test_arithmetic_shift_right(&s, by);
}
#[test]
fn test_zero_extend(s in bit_str_arg(), by in 0..(1000 as WidthInt)) {
do_test_zero_ext(&s, by);
}
#[test]
fn test_sign_extend(s in bit_str_arg(), by in 0..(1000 as WidthInt)) {
do_test_sign_ext(&s, by);
}
}