use crate::{arg_error_noloc, result::Result};
use alloc::string::String;
use awint::{Awi, Bits, SerdeError};
use core::num::NonZero;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct APInt {
value: Awi,
}
impl From<SerdeError> for crate::result::Error {
fn from(value: SerdeError) -> Self {
arg_error_noloc!("APInt error: {}", value)
}
}
pub use awint::bw;
impl APInt {
pub fn bw(&self) -> usize {
self.value.bw()
}
pub fn zero(width: NonZero<usize>) -> APInt {
APInt {
value: Awi::zero(width),
}
}
pub fn is_zero(&self) -> bool {
self.value.is_zero()
}
pub fn add(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::add: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let mut value = self.value.clone();
value
.add_(&rhs.value)
.expect("APInt::add: bitwidth mismatch");
APInt { value }
}
pub fn add_overflow(&self, rhs: &APInt) -> (APInt, bool, bool) {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::add_overflow: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let mut value = Awi::zero(NonZero::new(self.bw()).expect("self has zero bitwidth"));
let (unsigned_overflow, signed_overflow) = value
.cin_sum_(false, &self.value, &rhs.value)
.expect("APInt::add_overflow: bitwidth mismatch");
(APInt { value }, unsigned_overflow, signed_overflow)
}
pub fn sub(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::sub: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let mut value = self.value.clone();
value
.sub_(&rhs.value)
.expect("APInt::sub: bitwidth mismatch");
APInt { value }
}
pub fn sub_overflow(&self, rhs: &APInt) -> (APInt, bool, bool) {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::sub_overflow: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let mut not_rhs = rhs.value.clone();
not_rhs.not_();
let mut value = Awi::zero(NonZero::new(self.bw()).expect("self has zero bitwidth"));
let (carry_out, signed_overflow) = value
.cin_sum_(true, &self.value, ¬_rhs)
.expect("APInt::sub_overflow: bitwidth mismatch");
(APInt { value }, !carry_out, signed_overflow)
}
pub fn mul(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::mul: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let mut value = Awi::zero(NonZero::new(self.bw()).expect("self has zero bitwidth"));
value
.mul_add_(&self.value, &rhs.value)
.expect("APInt::mul: bitwidth mismatch");
APInt { value }
}
pub fn mul_overflow(&self, rhs: &APInt) -> (APInt, bool, bool) {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::mul_overflow: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let bw = NonZero::new(self.bw()).expect("self has zero bitwidth");
let dbw = NonZero::new(self.bw() * 2).expect("self has zero bitwidth");
let mut ulhs = Awi::zero(dbw);
ulhs.zero_resize_(&self.value);
let mut urhs = Awi::zero(dbw);
urhs.zero_resize_(&rhs.value);
let mut uprod = Awi::zero(dbw);
uprod.arb_umul_add_(&ulhs, &urhs);
let mut utrunc = Awi::zero(bw);
let unsigned_overflow = utrunc.zero_resize_(&uprod);
let mut slhs = Awi::zero(dbw);
slhs.sign_resize_(&self.value);
let mut srhs = Awi::zero(dbw);
srhs.sign_resize_(&rhs.value);
let mut sprod = Awi::zero(dbw);
sprod.arb_imul_add_(&mut slhs, &mut srhs);
let mut strunc = Awi::zero(bw);
let signed_overflow = strunc.sign_resize_(&sprod);
(self.mul(rhs), unsigned_overflow, signed_overflow)
}
pub fn shl(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::shl: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let shamt = rhs.to_usize();
let mut value = self.value.clone();
if value.shl_(shamt).is_none() {
value.zero_();
}
APInt { value }
}
pub fn shl_overflow(&self, rhs: &APInt) -> (APInt, bool, bool) {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::shl_overflow: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
let shamt = rhs.to_usize();
assert!(
shamt < self.bw(),
"APInt::shl_overflow: shift amount {} >= bitwidth {}",
shamt,
self.bw()
);
let result = self.shl(rhs);
let mut ushifted_back = result.value.clone();
ushifted_back
.lshr_(shamt)
.expect("shift amount checked against bitwidth above");
let unsigned_overflow = ushifted_back != self.value;
let mut sshifted_back = result.value.clone();
sshifted_back
.ashr_(shamt)
.expect("shift amount checked against bitwidth above");
let signed_overflow = sshifted_back != self.value;
(result, unsigned_overflow, signed_overflow)
}
pub fn udiv(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::udiv: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
assert!(!rhs.is_zero(), "APInt::udiv: division by zero");
let width = NonZero::new(self.bw()).expect("self has zero bitwidth");
let mut quo = Awi::zero(width);
let mut rem = Awi::zero(width);
Bits::udivide(&mut quo, &mut rem, &self.value, &rhs.value).unwrap();
APInt { value: quo }
}
pub fn sdiv(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::sdiv: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
assert!(!rhs.is_zero(), "APInt::sdiv: division by zero");
let width = NonZero::new(self.bw()).expect("self has zero bitwidth");
let mut quo = Awi::zero(width);
let mut rem = Awi::zero(width);
let mut duo = self.value.clone();
let mut div = rhs.value.clone();
Bits::idivide(&mut quo, &mut rem, &mut duo, &mut div).unwrap();
APInt { value: quo }
}
pub fn urem(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::urem: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
assert!(!rhs.is_zero(), "APInt::urem: division by zero");
let width = NonZero::new(self.bw()).expect("self has zero bitwidth");
let mut quo = Awi::zero(width);
let mut rem = Awi::zero(width);
Bits::udivide(&mut quo, &mut rem, &self.value, &rhs.value).unwrap();
APInt { value: rem }
}
pub fn srem(&self, rhs: &APInt) -> APInt {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::srem: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
assert!(!rhs.is_zero(), "APInt::srem: division by zero");
let width = NonZero::new(self.bw()).expect("self has zero bitwidth");
let mut quo = Awi::zero(width);
let mut rem = Awi::zero(width);
let mut duo = self.value.clone();
let mut div = rhs.value.clone();
Bits::idivide(&mut quo, &mut rem, &mut duo, &mut div).unwrap();
APInt { value: rem }
}
pub fn ult(&self, rhs: &APInt) -> bool {
assert_eq!(
self.bw(),
rhs.bw(),
"APInt::ult: bitwidth mismatch ({} vs {})",
self.bw(),
rhs.bw()
);
self.value
.ult(&rhs.value)
.expect("APInt::ult: bitwidth mismatch")
}
pub fn umax(width: NonZero<usize>) -> APInt {
APInt {
value: Awi::umax(width),
}
}
pub fn imax(width: NonZero<usize>) -> APInt {
APInt {
value: Awi::imax(width),
}
}
pub fn imin(width: NonZero<usize>) -> APInt {
APInt {
value: Awi::imin(width),
}
}
pub fn uone(width: NonZero<usize>) -> APInt {
APInt {
value: Awi::uone(width),
}
}
pub fn from_str(value: &str, width: usize, radix: u8) -> Result<APInt> {
let sign_opt = value.chars().next().ok_or(SerdeError::Empty)?;
let neg = sign_opt == '-';
let value = if neg || sign_opt == '+' {
&value[1..]
} else {
value
};
let sign = if neg { Some(true) } else { None };
let value = Awi::from_str_radix(
sign,
value,
radix,
NonZero::new(width).ok_or(SerdeError::ZeroBitwidth)?,
)?;
Ok(APInt { value })
}
pub fn to_string(&self, radix: u8, signed: bool) -> String {
match Awi::bits_to_string_radix(&self.value, signed, radix, false, 1) {
Ok(mut s) => {
if signed && self.value.msb() {
s.insert(0, '-');
}
s
}
Err(e) => {
panic!("APInt error: {e}");
}
}
}
pub fn to_string_decimal(&self, signed: bool) -> String {
self.to_string(10, signed)
}
pub fn to_string_signed(&self, radix: u8) -> String {
self.to_string(radix, true)
}
pub fn to_string_unsigned(&self, radix: u8) -> String {
self.to_string(radix, false)
}
pub fn to_string_signed_decimal(&self) -> String {
self.to_string_signed(10)
}
pub fn to_string_unsigned_decimal(&self) -> String {
self.to_string_unsigned(10)
}
pub fn from_u8(value: u8, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.u8_(value);
APInt { value: awi_value }
}
pub fn to_u8(&self) -> u8 {
self.value.to_u8()
}
pub fn from_u16(value: u16, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.u16_(value);
APInt { value: awi_value }
}
pub fn to_u16(&self) -> u16 {
self.value.to_u16()
}
pub fn from_u32(value: u32, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.u32_(value);
APInt { value: awi_value }
}
pub fn to_u32(&self) -> u32 {
self.value.to_u32()
}
pub fn from_u64(value: u64, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.u64_(value);
APInt { value: awi_value }
}
pub fn to_u64(&self) -> u64 {
self.value.to_u64()
}
pub fn from_usize(value: usize, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.usize_(value);
APInt { value: awi_value }
}
pub fn to_usize(&self) -> usize {
self.value.to_usize()
}
pub fn from_u128(value: u128, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.u128_(value);
APInt { value: awi_value }
}
pub fn to_u128(&self) -> u128 {
self.value.to_u128()
}
pub fn from_i8(value: i8, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.i8_(value);
APInt { value: awi_value }
}
pub fn to_i8(&self) -> i8 {
self.value.to_i8()
}
pub fn from_i16(value: i16, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.i16_(value);
APInt { value: awi_value }
}
pub fn to_i16(&self) -> i16 {
self.value.to_i16()
}
pub fn from_i32(value: i32, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.i32_(value);
APInt { value: awi_value }
}
pub fn to_i32(&self) -> i32 {
self.value.to_i32()
}
pub fn from_i64(value: i64, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.i64_(value);
APInt { value: awi_value }
}
pub fn to_i64(&self) -> i64 {
self.value.to_i64()
}
pub fn from_i128(value: i128, width: NonZero<usize>) -> APInt {
let mut awi_value = Awi::zero_with_capacity(width, width);
awi_value.i128_(value);
APInt { value: awi_value }
}
pub fn to_i128(&self) -> i128 {
self.value.to_i128()
}
}
#[cfg(test)]
mod tests {
use alloc::string::ToString;
use expect_test::expect;
use super::*;
#[test]
fn test_zero() {
let width = bw(4);
let apint = APInt::zero(width);
assert!(apint.is_zero());
}
#[test]
fn test_limits() {
let width = bw(4);
let umax = APInt::umax(width);
assert_eq!(umax.to_u8(), 15);
assert_eq!(umax.to_i8(), -1);
let imax = APInt::imax(width);
assert_eq!(imax.to_i8(), 7);
assert_eq!(imax.to_u8(), 7);
let imin = APInt::imin(width);
assert_eq!(imin.to_i8(), -8);
assert_eq!(imin.to_u8(), 8);
}
#[test]
fn test_from_str() {
let width = 4;
let apint = APInt::from_str("7", width, 10).unwrap();
assert_eq!(apint.to_u8(), 7);
let apint = APInt::from_str("-8", width, 10).unwrap();
assert_eq!(apint.to_i8(), -8);
assert_eq!(apint.to_string(10, true), "-8");
let apint = APInt::from_str("+15", width, 10).unwrap();
assert_eq!(apint.to_i8(), -1);
assert_eq!(apint.to_u8(), 15);
assert_eq!(apint.to_string(10, true), "-1");
assert_eq!(apint.to_string(10, false), "15");
let apint = APInt::from_str("-2", width, 10).unwrap();
assert_eq!(apint.to_i8(), -2);
assert_eq!(apint.to_u8(), 14);
assert_eq!(apint.to_string(10, true), "-2");
assert_eq!(apint.to_string(10, false), "14");
}
#[test]
fn test_from_str_failure() {
let width = 4;
let result = APInt::from_str("invalid", width, 10);
expect![[r#"
Compilation error: invalid argument.
APInt error: InvalidChar"#]]
.assert_eq(&result.unwrap_err().to_string());
let result = APInt::from_str("", width, 10);
expect![[r#"
Compilation error: invalid argument.
APInt error: Empty"#]]
.assert_eq(&result.unwrap_err().to_string());
let result = APInt::from_str("16", width, 10);
expect![[r#"
Compilation error: invalid argument.
APInt error: Overflow"#]]
.assert_eq(&result.unwrap_err().to_string());
}
#[test]
fn test_from_u8() {
let width = bw(4);
for i in 0..16 {
let apint = APInt::from_u8(i, width);
assert_eq!(apint.to_u8(), i);
}
}
#[test]
fn test_from_u16() {
let width = bw(4);
for i in 0..16 {
let apint = APInt::from_u16(i, width);
assert_eq!(apint.to_u16(), i);
}
}
#[test]
fn test_from_u32() {
let width = bw(4);
for i in 0..16 {
let apint = APInt::from_u32(i, width);
assert_eq!(apint.to_u32(), i);
}
}
#[test]
fn test_from_u64() {
let width = bw(4);
for i in 0..16 {
let apint = APInt::from_u64(i, width);
assert_eq!(apint.to_u64(), i);
}
}
#[test]
fn test_from_u128() {
let width = bw(4);
for i in 0..16 {
let apint = APInt::from_u128(i, width);
assert_eq!(apint.to_u128(), i);
}
}
#[test]
fn test_from_i8() {
let width = bw(4);
for i in -8..8 {
let apint = APInt::from_i8(i, width);
assert_eq!(apint.to_i8(), i);
}
}
#[test]
fn test_from_i16() {
let width = bw(4);
for i in -8..8 {
let apint = APInt::from_i16(i, width);
assert_eq!(apint.to_i16(), i);
}
}
#[test]
fn test_from_i32() {
let width = bw(4);
for i in -8..8 {
let apint = APInt::from_i32(i, width);
assert_eq!(apint.to_i32(), i);
}
}
#[test]
fn test_from_i64() {
let width = bw(4);
for i in -8..8 {
let apint = APInt::from_i64(i, width);
assert_eq!(apint.to_i64(), i);
}
}
#[test]
fn test_from_i128() {
let width = bw(4);
for i in -8..8 {
let apint = APInt::from_i128(i, width);
assert_eq!(apint.to_i128(), i);
}
}
#[test]
fn test_add() {
let width = bw(4);
let sum = APInt::from_u8(3, width).add(&APInt::from_u8(4, width));
assert_eq!(sum.to_u8(), 7);
let sum = APInt::from_u8(15, width).add(&APInt::from_u8(1, width));
assert_eq!(sum.to_u8(), 0);
let sum = APInt::from_i8(-1, width).add(&APInt::from_i8(1, width));
assert_eq!(sum.to_i8(), 0);
let sum = APInt::from_u8(9, width).add(&APInt::zero(width));
assert_eq!(sum.to_u8(), 9);
}
#[test]
fn test_sub() {
let width = bw(4);
let diff = APInt::from_u8(7, width).sub(&APInt::from_u8(4, width));
assert_eq!(diff.to_u8(), 3);
let diff = APInt::from_u8(0, width).sub(&APInt::from_u8(1, width));
assert_eq!(diff.to_u8(), 15);
assert_eq!(diff.to_i8(), -1);
let diff = APInt::from_u8(5, width).sub(&APInt::from_u8(5, width));
assert!(diff.is_zero());
}
#[test]
fn test_mul() {
let width = bw(4);
let prod = APInt::from_u8(3, width).mul(&APInt::from_u8(4, width));
assert_eq!(prod.to_u8(), 12);
let prod = APInt::from_u8(3, width).mul(&APInt::from_u8(6, width));
assert_eq!(prod.to_u8(), 2);
let prod = APInt::from_i8(-2, width).mul(&APInt::from_i8(3, width));
assert_eq!(prod.to_i8(), -6);
assert_eq!(prod.to_u8(), 10);
let prod = APInt::from_u8(7, width).mul(&APInt::zero(width));
assert!(prod.is_zero());
}
#[test]
fn test_shl() {
let width = bw(4);
let res = APInt::from_u8(1, width).shl(&APInt::from_u8(2, width));
assert_eq!(res.to_u8(), 4);
let res = APInt::from_u8(5, width).shl(&APInt::zero(width));
assert_eq!(res.to_u8(), 5);
let res = APInt::from_u8(3, width).shl(&APInt::from_u8(3, width));
assert_eq!(res.to_u8(), 8);
let res = APInt::from_u8(0xf, width).shl(&APInt::from_u8(4, width));
assert!(res.is_zero());
let res = APInt::from_u8(0xf, width).shl(&APInt::from_u8(7, width));
assert!(res.is_zero());
}
#[test]
fn test_udiv() {
let width = bw(4);
let res = APInt::from_u8(12, width).udiv(&APInt::from_u8(4, width));
assert_eq!(res.to_u8(), 3);
let res = APInt::from_u8(13, width).udiv(&APInt::from_u8(4, width));
assert_eq!(res.to_u8(), 3);
let res = APInt::from_u8(0xf, width).udiv(&APInt::from_u8(2, width));
assert_eq!(res.to_u8(), 7);
let res = APInt::from_u8(9, width).udiv(&APInt::uone(width));
assert_eq!(res.to_u8(), 9);
let res = APInt::from_u8(3, width).udiv(&APInt::from_u8(5, width));
assert!(res.is_zero());
}
#[test]
#[should_panic(expected = "division by zero")]
fn test_udiv_by_zero_panics() {
let width = bw(4);
let _ = APInt::from_u8(7, width).udiv(&APInt::zero(width));
}
#[test]
fn test_ult() {
let width = bw(4);
assert!(APInt::from_u8(3, width).ult(&APInt::from_u8(5, width)));
assert!(!APInt::from_u8(5, width).ult(&APInt::from_u8(3, width)));
assert!(!APInt::from_u8(4, width).ult(&APInt::from_u8(4, width)));
assert!(APInt::from_u8(1, width).ult(&APInt::from_u8(0xf, width)));
assert!(!APInt::from_u8(0xf, width).ult(&APInt::from_u8(1, width)));
}
#[test]
fn test_sdiv() {
let width = bw(4);
assert_eq!(
APInt::from_i8(6, width)
.sdiv(&APInt::from_i8(2, width))
.to_i8(),
3
);
assert_eq!(
APInt::from_i8(-7, width)
.sdiv(&APInt::from_i8(2, width))
.to_i8(),
-3
);
assert_eq!(
APInt::from_i8(7, width)
.sdiv(&APInt::from_i8(-2, width))
.to_i8(),
-3
);
assert_eq!(
APInt::from_i8(-6, width)
.sdiv(&APInt::from_i8(-3, width))
.to_i8(),
2
);
assert_eq!(
APInt::from_i8(-5, width).sdiv(&APInt::uone(width)).to_i8(),
-5
);
}
#[test]
#[should_panic(expected = "division by zero")]
fn test_sdiv_by_zero_panics() {
let width = bw(4);
let _ = APInt::from_i8(7, width).sdiv(&APInt::zero(width));
}
#[test]
fn test_urem() {
let width = bw(4);
assert_eq!(
APInt::from_u8(13, width)
.urem(&APInt::from_u8(4, width))
.to_u8(),
1
);
assert_eq!(
APInt::from_u8(12, width)
.urem(&APInt::from_u8(4, width))
.to_u8(),
0
);
assert_eq!(
APInt::from_u8(3, width)
.urem(&APInt::from_u8(5, width))
.to_u8(),
3
);
assert_eq!(
APInt::from_u8(0xf, width)
.urem(&APInt::from_u8(4, width))
.to_u8(),
3
);
}
#[test]
#[should_panic(expected = "division by zero")]
fn test_urem_by_zero_panics() {
let width = bw(4);
let _ = APInt::from_u8(7, width).urem(&APInt::zero(width));
}
#[test]
fn test_srem() {
let width = bw(4);
assert_eq!(
APInt::from_i8(7, width)
.srem(&APInt::from_i8(3, width))
.to_i8(),
1
);
assert_eq!(
APInt::from_i8(-7, width)
.srem(&APInt::from_i8(3, width))
.to_i8(),
-1
);
assert_eq!(
APInt::from_i8(7, width)
.srem(&APInt::from_i8(-3, width))
.to_i8(),
1
);
assert_eq!(
APInt::from_i8(-7, width)
.srem(&APInt::from_i8(-3, width))
.to_i8(),
-1
);
assert_eq!(
APInt::from_i8(-6, width)
.srem(&APInt::from_i8(3, width))
.to_i8(),
0
);
}
#[test]
#[should_panic(expected = "division by zero")]
fn test_srem_by_zero_panics() {
let width = bw(4);
let _ = APInt::from_i8(7, width).srem(&APInt::zero(width));
}
}