use crate::abstract_interp::IntervalFact;
use crate::state::lattice::{AbstractDomain, Lattice};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct BitFact {
pub known_zero: u64,
pub known_one: u64,
}
impl BitFact {
pub fn top() -> Self {
Self {
known_zero: 0,
known_one: 0,
}
}
pub fn bottom() -> Self {
Self {
known_zero: u64::MAX,
known_one: u64::MAX,
}
}
pub fn from_const(n: i64) -> Self {
let bits = n as u64;
Self {
known_zero: !bits,
known_one: bits,
}
}
pub fn is_top(&self) -> bool {
self.known_zero == 0 && self.known_one == 0
}
pub fn is_bottom(&self) -> bool {
self.known_zero & self.known_one != 0
}
pub fn is_non_negative(&self) -> bool {
self.known_zero & (1u64 << 63) != 0
}
pub fn bit_and(&self, other: &Self) -> Self {
if self.is_bottom() || other.is_bottom() {
return Self::bottom();
}
Self {
known_zero: self.known_zero | other.known_zero,
known_one: self.known_one & other.known_one,
}
}
pub fn bit_or(&self, other: &Self) -> Self {
if self.is_bottom() || other.is_bottom() {
return Self::bottom();
}
Self {
known_zero: self.known_zero & other.known_zero,
known_one: self.known_one | other.known_one,
}
}
pub fn bit_xor(&self, other: &Self) -> Self {
if self.is_bottom() || other.is_bottom() {
return Self::bottom();
}
Self {
known_zero: (self.known_zero & other.known_zero) | (self.known_one & other.known_one),
known_one: (self.known_one & other.known_zero) | (self.known_zero & other.known_one),
}
}
pub fn left_shift(&self, shift: &IntervalFact) -> Self {
if self.is_bottom() || shift.is_bottom() {
return Self::bottom();
}
match (shift.lo, shift.hi) {
(Some(lo), Some(hi)) if lo == hi && (0..=63).contains(&lo) => {
let k = lo as u32;
Self {
known_zero: (self.known_zero << k) | ((1u64 << k) - 1),
known_one: self.known_one << k,
}
}
_ => Self::top(),
}
}
pub fn right_shift(&self, shift: &IntervalFact) -> Self {
if self.is_bottom() || shift.is_bottom() {
return Self::bottom();
}
match (shift.lo, shift.hi) {
(Some(lo), Some(hi)) if lo == hi && (0..=63).contains(&lo) => {
let k = lo as u32;
let high_mask = if k == 0 { 0u64 } else { u64::MAX << (64 - k) };
if self.is_non_negative() {
Self {
known_zero: (self.known_zero >> k) | high_mask,
known_one: self.known_one >> k,
}
} else if self.known_one & (1u64 << 63) != 0 {
Self {
known_zero: self.known_zero >> k,
known_one: (self.known_one >> k) | high_mask,
}
} else {
Self {
known_zero: self.known_zero >> k,
known_one: self.known_one >> k,
}
}
}
_ => Self::top(),
}
}
pub fn upper_bound_hint(&self) -> Option<i64> {
if !self.is_non_negative() || self.is_bottom() {
return None;
}
let max_val = !self.known_zero & 0x7FFF_FFFF_FFFF_FFFFu64;
Some(max_val as i64)
}
}
impl Lattice for BitFact {
fn bot() -> Self {
Self::bottom()
}
fn join(&self, other: &Self) -> Self {
if self.is_bottom() {
return other.clone();
}
if other.is_bottom() {
return self.clone();
}
Self {
known_zero: self.known_zero & other.known_zero,
known_one: self.known_one & other.known_one,
}
}
fn leq(&self, other: &Self) -> bool {
if self.is_bottom() {
return true;
}
if other.is_bottom() {
return false;
}
(other.known_zero & !self.known_zero) == 0 && (other.known_one & !self.known_one) == 0
}
}
impl AbstractDomain for BitFact {
fn top() -> Self {
Self::top()
}
fn meet(&self, other: &Self) -> Self {
if self.is_bottom() || other.is_bottom() {
return Self::bottom();
}
let kz = self.known_zero | other.known_zero;
let ko = self.known_one | other.known_one;
if kz & ko != 0 {
return Self::bottom();
}
Self {
known_zero: kz,
known_one: ko,
}
}
fn widen(&self, other: &Self) -> Self {
self.join(other)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_const_positive() {
let f = BitFact::from_const(0x0F);
assert_eq!(f.known_one, 0x0F);
assert_eq!(f.known_zero, !0x0Fu64);
assert!(f.is_non_negative());
}
#[test]
fn from_const_negative() {
let f = BitFact::from_const(-1);
assert_eq!(f.known_one, u64::MAX);
assert_eq!(f.known_zero, 0);
assert!(!f.is_non_negative());
}
#[test]
fn from_const_zero() {
let f = BitFact::from_const(0);
assert_eq!(f.known_one, 0);
assert_eq!(f.known_zero, u64::MAX);
assert!(f.is_non_negative());
}
#[test]
fn top_and_bottom() {
assert!(BitFact::top().is_top());
assert!(!BitFact::top().is_bottom());
assert!(BitFact::bottom().is_bottom());
assert!(!BitFact::bottom().is_top());
}
#[test]
fn join_commutative() {
let a = BitFact::from_const(0xFF);
let b = BitFact::from_const(0x0F);
assert_eq!(a.join(&b), b.join(&a));
}
#[test]
fn join_idempotent() {
let a = BitFact::from_const(42);
assert_eq!(a.join(&a), a);
}
#[test]
fn join_relaxes_bits() {
let a = BitFact::from_const(0xFF);
let b = BitFact::from_const(0x0F);
let j = a.join(&b);
assert_eq!(j.known_one & 0xFF, 0x0F);
assert_eq!(j.known_zero & 0xF0, 0);
assert_eq!(j.known_one & 0xF0, 0);
}
#[test]
fn meet_commutative() {
let a = BitFact {
known_zero: 0xF0,
known_one: 0x0F,
};
let b = BitFact {
known_zero: 0x0F00,
known_one: 0,
};
assert_eq!(
<BitFact as AbstractDomain>::meet(&a, &b),
<BitFact as AbstractDomain>::meet(&b, &a)
);
}
#[test]
fn meet_contradiction_is_bottom() {
let a = BitFact {
known_zero: 0,
known_one: 0x01,
};
let b = BitFact {
known_zero: 0x01,
known_one: 0,
};
assert!(<BitFact as AbstractDomain>::meet(&a, &b).is_bottom());
}
#[test]
fn leq_reflexive() {
let a = BitFact::from_const(42);
assert!(a.leq(&a));
}
#[test]
fn leq_bottom_is_least() {
assert!(BitFact::bottom().leq(&BitFact::top()));
assert!(BitFact::bottom().leq(&BitFact::from_const(0)));
}
#[test]
fn leq_more_precise_is_lower() {
let precise = BitFact::from_const(0xFF);
let vague = BitFact::top();
assert!(precise.leq(&vague));
assert!(!vague.leq(&precise));
}
#[test]
fn bit_and_transfer() {
let a = BitFact::from_const(0xFF);
let b = BitFact::from_const(0x0F);
let result = a.bit_and(&b);
assert_eq!(result.known_one, 0x0F);
assert_eq!(result.known_zero, !0x0Fu64);
}
#[test]
fn bit_and_with_mask_bounds() {
let unknown = BitFact::top();
let mask = BitFact::from_const(0x07);
let result = unknown.bit_and(&mask);
assert_eq!(result.known_zero & !0x07u64, !0x07u64);
assert_eq!(result.known_one & 0x07, 0);
}
#[test]
fn bit_or_transfer() {
let a = BitFact::from_const(0xF0);
let b = BitFact::from_const(0x0F);
let result = a.bit_or(&b);
assert_eq!(result.known_one, 0xFF);
assert_eq!(result.known_zero, !0xFFu64);
}
#[test]
fn bit_or_with_unknown() {
let unknown = BitFact::top();
let bits = BitFact::from_const(0x01);
let result = unknown.bit_or(&bits);
assert_ne!(result.known_one & 0x01, 0);
assert_eq!(result.known_zero & 0x01, 0);
}
#[test]
fn bit_xor_transfer() {
let a = BitFact::from_const(0xFF);
let b = BitFact::from_const(0x0F);
let result = a.bit_xor(&b);
assert_eq!(result.known_one, 0xF0);
assert_eq!(result.known_zero, !0xF0u64);
}
#[test]
fn bit_xor_self_is_zero() {
let a = BitFact::from_const(42);
let result = a.bit_xor(&a);
assert_eq!(result.known_one, 0);
assert_eq!(result.known_zero, u64::MAX);
}
#[test]
fn bit_xor_with_zero_is_identity() {
let a = BitFact::from_const(0xFF);
let zero = BitFact::from_const(0);
let result = a.bit_xor(&zero);
assert_eq!(result, a);
}
#[test]
fn left_shift_known_bits() {
let a = BitFact::from_const(0x0F);
let shift = IntervalFact::exact(4);
let result = a.left_shift(&shift);
assert_eq!(result.known_one, 0xF0);
assert_ne!(result.known_zero & 0x0F, 0);
}
#[test]
fn left_shift_range_is_top() {
let a = BitFact::from_const(0x0F);
let shift = IntervalFact {
lo: Some(1),
hi: Some(3),
};
let result = a.left_shift(&shift);
assert!(result.is_top());
}
#[test]
fn left_shift_invalid_is_top() {
let a = BitFact::from_const(0x0F);
let shift = IntervalFact::exact(64);
assert!(a.left_shift(&shift).is_top());
let neg_shift = IntervalFact::exact(-1);
assert!(a.left_shift(&neg_shift).is_top());
}
#[test]
fn right_shift_known_bits_non_negative() {
let a = BitFact::from_const(0xF0);
let shift = IntervalFact::exact(4);
let result = a.right_shift(&shift);
assert_eq!(result.known_one, 0x0F);
assert_ne!(result.known_zero & (0xFu64 << 60), 0);
}
#[test]
fn right_shift_negative_fills_ones() {
let a = BitFact::from_const(-16);
let shift = IntervalFact::exact(4);
let result = a.right_shift(&shift);
assert_eq!(result.known_one, u64::MAX);
assert_eq!(result.known_zero, 0);
}
#[test]
fn right_shift_unknown_sign() {
let a = BitFact {
known_zero: 0x0F,
known_one: 0,
};
let shift = IntervalFact::exact(4);
let result = a.right_shift(&shift);
let high_mask = 0xFu64 << 60;
assert_eq!(result.known_zero & high_mask, 0);
assert_eq!(result.known_one & high_mask, 0);
}
#[test]
fn upper_bound_hint_constant() {
let f = BitFact::from_const(7);
assert_eq!(f.upper_bound_hint(), Some(7));
}
#[test]
fn upper_bound_hint_masked() {
let unknown = BitFact::top();
let mask = BitFact::from_const(0x07);
let result = unknown.bit_and(&mask);
assert_eq!(result.upper_bound_hint(), Some(7));
}
#[test]
fn upper_bound_hint_negative_is_none() {
let f = BitFact::from_const(-1);
assert_eq!(f.upper_bound_hint(), None);
}
#[test]
fn upper_bound_hint_top_is_none() {
assert_eq!(BitFact::top().upper_bound_hint(), None);
}
#[test]
fn is_non_negative_positive() {
assert!(BitFact::from_const(42).is_non_negative());
assert!(BitFact::from_const(0).is_non_negative());
}
#[test]
fn is_non_negative_negative() {
assert!(!BitFact::from_const(-1).is_non_negative());
assert!(!BitFact::from_const(i64::MIN).is_non_negative());
}
#[test]
fn is_non_negative_unknown() {
assert!(!BitFact::top().is_non_negative());
}
fn sample_bits() -> Vec<BitFact> {
vec![
BitFact::bottom(),
BitFact::top(),
BitFact::from_const(0),
BitFact::from_const(1),
BitFact::from_const(-1),
BitFact::from_const(0xFF),
BitFact::from_const(i64::MIN),
BitFact::from_const(i64::MAX),
]
}
#[test]
fn join_associative_bit() {
let xs = sample_bits();
for a in &xs {
for b in &xs {
for c in &xs {
let lhs = a.join(b).join(c);
let rhs = a.join(&b.join(c));
assert_eq!(
lhs, rhs,
"join not associative for {:?}, {:?}, {:?}",
a, b, c
);
}
}
}
}
#[test]
fn meet_idempotent_bit() {
for a in sample_bits() {
assert_eq!(a.meet(&a), a, "meet not idempotent for {:?}", a);
}
}
#[test]
fn meet_associative_bit() {
let xs = sample_bits();
for a in &xs {
for b in &xs {
for c in &xs {
let lhs = a.meet(b).meet(c);
let rhs = a.meet(&b.meet(c));
assert_eq!(
lhs, rhs,
"meet not associative for {:?}, {:?}, {:?}",
a, b, c
);
}
}
}
}
#[test]
fn meet_top_identity_bit() {
for a in sample_bits() {
assert_eq!(a.meet(&BitFact::top()), a, "x ⊓ ⊤ failed for {:?}", a);
}
}
#[test]
fn meet_bottom_absorbing_bit() {
for a in sample_bits() {
assert_eq!(
a.meet(&BitFact::bottom()),
BitFact::bottom(),
"x ⊓ ⊥ failed for {:?}",
a
);
}
}
#[test]
fn join_top_absorbing_bit() {
for a in sample_bits() {
assert_eq!(
a.join(&BitFact::top()),
BitFact::top(),
"x ⊔ ⊤ failed for {:?}",
a
);
}
}
#[test]
fn widen_idempotent_bit() {
for a in sample_bits() {
assert_eq!(a.widen(&a), a, "widen(x, x) failed for {:?}", a);
}
}
#[test]
fn widen_over_approximates_join_bit() {
let xs = sample_bits();
for a in &xs {
for b in &xs {
let j = a.join(b);
let w = a.widen(b);
assert!(
j.leq(&w),
"widen({:?}, {:?}) = {:?} does not over-approx join = {:?}",
a,
b,
w,
j
);
}
}
}
#[test]
fn meet_is_lower_bound_bit() {
let xs = sample_bits();
for a in &xs {
for b in &xs {
let m = a.meet(b);
assert!(m.leq(a), "a ⊓ b ⊑ a failed for {:?}, {:?}", a, b);
assert!(m.leq(b), "a ⊓ b ⊑ b failed for {:?}, {:?}", a, b);
}
}
}
#[test]
fn join_is_upper_bound_bit() {
let xs = sample_bits();
for a in &xs {
for b in &xs {
let j = a.join(b);
assert!(a.leq(&j), "a ⊑ a ⊔ b failed for {:?}, {:?}", a, b);
assert!(b.leq(&j), "b ⊑ a ⊔ b failed for {:?}, {:?}", a, b);
}
}
}
#[test]
fn join_min_max_signbit_safe() {
let a = BitFact::from_const(i64::MIN);
let b = BitFact::from_const(i64::MAX);
let _ = a.join(&b); let _ = a.meet(&b);
let _ = a.widen(&b);
}
}