use softcore_prelude::*;
pub const fn quot_round_zero(n: i128, m: i128) -> i128 {
n / m
}
pub const fn rem_round_zero(n: i128, m: i128) -> i128 {
n % m
}
pub fn sign_extend<const M: i128, const N: i128>(n: i128, input: BitVector<M>) -> BitVector<N> {
sail_sign_extend(input, n)
}
pub fn sub_vec<const N: i128>(a: BitVector<N>, b: BitVector<N>) -> BitVector<N> {
assert!(N == 64, "`sub_vec` only support 64 bits for now");
bv((a.bits() as i64).wrapping_sub(b.bits() as i64) as u64)
}
pub fn shift_bits_left<const N: i128, const M: i128>(
value: BitVector<N>,
shift: BitVector<M>,
) -> BitVector<N> {
bv::<N>(value.bits() << shift.bits())
}
pub fn shift_bits_right<const N: i128, const M: i128>(
value: BitVector<N>,
shift: BitVector<M>,
) -> BitVector<N> {
bv::<N>(value.bits() >> shift.bits())
}
pub fn shift_right_arith<const N: i128>(value: BitVector<N>, shift: i128) -> BitVector<N> {
assert!(N <= 64, "Maximum supported size is 64 for now");
if shift <= 0 {
return value;
}
let sign_bit = (value.bits() >> (N - 1)) & 1;
if shift >= N {
if sign_bit == 0 {
return bv::<N>(0);
} else {
return bv::<N>(if N == 64 { u64::MAX } else { (1u64 << N) - 1 });
}
}
let shifted = value.bits() >> shift;
if sign_bit == 0 {
bv::<N>(shifted)
} else {
let mask = if N == 64 {
u64::MAX << (64 - shift)
} else {
((1u64 << shift) - 1) << (N - shift)
};
bv::<N>(shifted | mask)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shift() {
assert_eq!(shift_right_arith(bv::<8>(0b10110111), 0).bits(), 0b10110111);
assert_eq!(shift_right_arith(bv::<8>(0b01110111), 0).bits(), 0b01110111);
assert_eq!(shift_right_arith(bv::<8>(0b01110111), 1).bits(), 0b00111011);
assert_eq!(shift_right_arith(bv::<8>(0b01110111), 2).bits(), 0b00011101);
assert_eq!(shift_right_arith(bv::<8>(0b01110111), 3).bits(), 0b00001110);
assert_eq!(shift_right_arith(bv::<8>(0b01111111), 4).bits(), 0b00000111);
assert_eq!(shift_right_arith(bv::<8>(0b10110111), 1).bits(), 0b11011011);
assert_eq!(shift_right_arith(bv::<8>(0b10110111), 2).bits(), 0b11101101);
assert_eq!(shift_right_arith(bv::<8>(0b10110111), 3).bits(), 0b11110110);
assert_eq!(shift_right_arith(bv::<8>(0b11111111), 4).bits(), 0b11111111);
assert_eq!(shift_right_arith(bv::<8>(0b01110111), 8).bits(), 0b00000000);
assert_eq!(
shift_right_arith(bv::<8>(0b01110111), 10).bits(),
0b00000000
);
assert_eq!(shift_right_arith(bv::<8>(0b10110111), 8).bits(), 0b11111111);
assert_eq!(
shift_right_arith(bv::<8>(0b10110111), 10).bits(),
0b11111111
);
assert_eq!(
shift_right_arith(bv::<16>(0b0111111111111111), 4).bits(),
0b0000011111111111
);
assert_eq!(
shift_right_arith(bv::<16>(0b1111111111111111), 4).bits(),
0b1111111111111111
);
assert_eq!(
shift_right_arith(bv::<16>(0b1000000000000000), 4).bits(),
0b1111100000000000
);
assert_eq!(
shift_right_arith(bv::<32>(0x7FFFFFFF), 16).bits(),
0x00007FFF
);
assert_eq!(
shift_right_arith(bv::<32>(0x80000000), 16).bits(),
0xFFFF8000
);
assert_eq!(
shift_right_arith(bv::<32>(0xFFFFFFFF), 16).bits(),
0xFFFFFFFF
);
assert_eq!(
shift_right_arith(bv::<64>(0x7FFFFFFFFFFFFFFF), 32).bits(),
0x000000007FFFFFFF
);
assert_eq!(
shift_right_arith(bv::<64>(0x8000000000000000), 32).bits(),
0xFFFFFFFF80000000
);
assert_eq!(
shift_right_arith(bv::<64>(0xFFFFFFFFFFFFFFFF), 32).bits(),
0xFFFFFFFFFFFFFFFF
);
assert_eq!(shift_right_arith(bv::<4>(0b0111), 1).bits(), 0b0011);
assert_eq!(shift_right_arith(bv::<4>(0b1111), 1).bits(), 0b1111);
assert_eq!(shift_right_arith(bv::<4>(0b1000), 2).bits(), 0b1110);
assert_eq!(
shift_right_arith(bv::<8>(0b10110111), -1).bits(),
0b10110111
);
assert_eq!(
shift_right_arith(bv::<8>(0b01110111), -5).bits(),
0b01110111
);
}
}