use crate::analysis::{
ssa::{CmpKind, SsaVarId},
x86::types::X86Condition,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArithmeticKind {
Add,
Sub,
LogicalOp,
Neg,
Other,
}
#[derive(Debug, Clone)]
pub enum FlagTestSource {
Direct(SsaVarId),
Subtract {
left: SsaVarId,
right: SsaVarId,
},
BitwiseAnd {
left: SsaVarId,
right: SsaVarId,
},
}
#[derive(Debug, Clone)]
pub enum FlagProducer {
Compare {
left: SsaVarId,
right: SsaVarId,
},
Test {
left: SsaVarId,
right: SsaVarId,
},
Arithmetic {
result: SsaVarId,
left: SsaVarId,
right: SsaVarId,
kind: ArithmeticKind,
},
}
#[derive(Debug, Clone, Default)]
pub struct FlagState {
producer: Option<FlagProducer>,
carry: Option<SsaVarId>,
}
impl FlagState {
#[must_use]
pub fn new() -> Self {
Self {
producer: None,
carry: None,
}
}
pub fn set_compare(&mut self, left: SsaVarId, right: SsaVarId) {
self.producer = Some(FlagProducer::Compare { left, right });
}
pub fn set_test(&mut self, left: SsaVarId, right: SsaVarId) {
self.producer = Some(FlagProducer::Test { left, right });
}
pub fn set_arithmetic(
&mut self,
result: SsaVarId,
left: SsaVarId,
right: SsaVarId,
kind: ArithmeticKind,
) {
self.producer = Some(FlagProducer::Arithmetic {
result,
left,
right,
kind,
});
}
pub fn set_arithmetic_unary(&mut self, result: SsaVarId) {
self.producer = Some(FlagProducer::Arithmetic {
result,
left: result,
right: result,
kind: ArithmeticKind::Neg,
});
}
pub fn clear(&mut self) {
self.producer = None;
self.carry = None;
}
pub fn clear_producer(&mut self) {
self.producer = None;
}
pub fn set_carry(&mut self, carry: SsaVarId) {
self.carry = Some(carry);
}
pub fn clear_carry(&mut self, zero: SsaVarId) {
self.carry = Some(zero);
}
#[must_use]
pub fn carry(&self) -> Option<SsaVarId> {
self.carry
}
#[must_use]
pub fn producer(&self) -> Option<&FlagProducer> {
self.producer.as_ref()
}
#[must_use]
pub fn is_known(&self) -> bool {
self.producer.is_some()
}
#[must_use]
pub fn get_branch_operands(
&self,
condition: X86Condition,
) -> Option<(CmpKind, SsaVarId, SsaVarId, bool)> {
match &self.producer {
Some(FlagProducer::Compare { left, right }) => {
let (cmp, unsigned) = condition_to_cmp(condition)?;
Some((cmp, *left, *right, unsigned))
}
Some(FlagProducer::Test { left, right }) => {
match condition {
X86Condition::E => {
Some((CmpKind::Eq, *left, *right, false))
}
X86Condition::Ne => {
Some((CmpKind::Ne, *left, *right, false))
}
X86Condition::S | X86Condition::Ns => None,
_ => None,
}
}
Some(FlagProducer::Arithmetic {
result,
left,
right,
..
}) => {
match condition {
X86Condition::E => Some((CmpKind::Eq, *result, *result, false)),
X86Condition::Ne => Some((CmpKind::Ne, *result, *result, false)),
X86Condition::L | X86Condition::Ge | X86Condition::Le | X86Condition::G => {
let (cmp, unsigned) = condition_to_cmp(condition)?;
Some((cmp, *left, *right, unsigned))
}
X86Condition::B | X86Condition::Ae | X86Condition::Be | X86Condition::A => {
let (cmp, unsigned) = condition_to_cmp(condition)?;
Some((cmp, *left, *right, unsigned))
}
_ => None,
}
}
None => None,
}
}
#[must_use]
pub fn get_condition_operands(&self, condition: X86Condition) -> Option<ConditionEval> {
match &self.producer {
Some(FlagProducer::Compare { left, right }) => match condition {
X86Condition::S => Some(ConditionEval::SignFlag {
source: FlagTestSource::Subtract {
left: *left,
right: *right,
},
negated: false,
}),
X86Condition::Ns => Some(ConditionEval::SignFlag {
source: FlagTestSource::Subtract {
left: *left,
right: *right,
},
negated: true,
}),
X86Condition::O => Some(ConditionEval::OverflowFlag {
left: *left,
right: *right,
result: None,
kind: ArithmeticKind::Sub,
negated: false,
}),
X86Condition::No => Some(ConditionEval::OverflowFlag {
left: *left,
right: *right,
result: None,
kind: ArithmeticKind::Sub,
negated: true,
}),
X86Condition::P => Some(ConditionEval::ParityFlag {
source: FlagTestSource::Subtract {
left: *left,
right: *right,
},
negated: false,
}),
X86Condition::Np => Some(ConditionEval::ParityFlag {
source: FlagTestSource::Subtract {
left: *left,
right: *right,
},
negated: true,
}),
_ => {
let (cmp, unsigned) = condition_to_cmp(condition)?;
Some(ConditionEval::Compare {
cmp,
left: *left,
right: *right,
unsigned,
})
}
},
Some(FlagProducer::Test { left, right }) => match condition {
X86Condition::E => Some(ConditionEval::Test {
cmp: CmpKind::Eq,
left: *left,
right: *right,
}),
X86Condition::Ne => Some(ConditionEval::Test {
cmp: CmpKind::Ne,
left: *left,
right: *right,
}),
X86Condition::S => Some(ConditionEval::SignFlag {
source: FlagTestSource::BitwiseAnd {
left: *left,
right: *right,
},
negated: false,
}),
X86Condition::Ns => Some(ConditionEval::SignFlag {
source: FlagTestSource::BitwiseAnd {
left: *left,
right: *right,
},
negated: true,
}),
X86Condition::O => Some(ConditionEval::OverflowFlag {
left: *left,
right: *right,
result: None,
kind: ArithmeticKind::LogicalOp,
negated: false,
}),
X86Condition::No => Some(ConditionEval::OverflowFlag {
left: *left,
right: *right,
result: None,
kind: ArithmeticKind::LogicalOp,
negated: true,
}),
X86Condition::P => Some(ConditionEval::ParityFlag {
source: FlagTestSource::BitwiseAnd {
left: *left,
right: *right,
},
negated: false,
}),
X86Condition::Np => Some(ConditionEval::ParityFlag {
source: FlagTestSource::BitwiseAnd {
left: *left,
right: *right,
},
negated: true,
}),
_ => None,
},
Some(FlagProducer::Arithmetic {
result,
left,
right,
kind,
}) => match condition {
X86Condition::E => Some(ConditionEval::ZeroTest {
cmp: CmpKind::Eq,
result: *result,
}),
X86Condition::Ne => Some(ConditionEval::ZeroTest {
cmp: CmpKind::Ne,
result: *result,
}),
X86Condition::S => Some(ConditionEval::SignFlag {
source: FlagTestSource::Direct(*result),
negated: false,
}),
X86Condition::Ns => Some(ConditionEval::SignFlag {
source: FlagTestSource::Direct(*result),
negated: true,
}),
X86Condition::O => Some(ConditionEval::OverflowFlag {
left: *left,
right: *right,
result: Some(*result),
kind: *kind,
negated: false,
}),
X86Condition::No => Some(ConditionEval::OverflowFlag {
left: *left,
right: *right,
result: Some(*result),
kind: *kind,
negated: true,
}),
X86Condition::P => Some(ConditionEval::ParityFlag {
source: FlagTestSource::Direct(*result),
negated: false,
}),
X86Condition::Np => Some(ConditionEval::ParityFlag {
source: FlagTestSource::Direct(*result),
negated: true,
}),
_ => {
let (cmp, unsigned) = condition_to_cmp(condition)?;
Some(ConditionEval::Compare {
cmp,
left: *left,
right: *right,
unsigned,
})
}
},
None => None,
}
}
#[must_use]
pub fn is_zero_test(&self) -> Option<SsaVarId> {
match &self.producer {
Some(FlagProducer::Test { left, right }) if left == right => Some(*left),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum ConditionEval {
Compare {
cmp: CmpKind,
left: SsaVarId,
right: SsaVarId,
unsigned: bool,
},
Test {
cmp: CmpKind,
left: SsaVarId,
right: SsaVarId,
},
ZeroTest { cmp: CmpKind, result: SsaVarId },
SignFlag {
source: FlagTestSource,
negated: bool,
},
OverflowFlag {
left: SsaVarId,
right: SsaVarId,
result: Option<SsaVarId>,
kind: ArithmeticKind,
negated: bool,
},
ParityFlag {
source: FlagTestSource,
negated: bool,
},
}
#[must_use]
pub fn condition_to_cmp(condition: X86Condition) -> Option<(CmpKind, bool)> {
match condition {
X86Condition::E => Some((CmpKind::Eq, false)),
X86Condition::Ne => Some((CmpKind::Ne, false)),
X86Condition::L => Some((CmpKind::Lt, false)),
X86Condition::Ge => Some((CmpKind::Ge, false)),
X86Condition::Le => Some((CmpKind::Le, false)),
X86Condition::G => Some((CmpKind::Gt, false)),
X86Condition::B => Some((CmpKind::Lt, true)),
X86Condition::Ae => Some((CmpKind::Ge, true)),
X86Condition::Be => Some((CmpKind::Le, true)),
X86Condition::A => Some((CmpKind::Gt, true)),
X86Condition::S
| X86Condition::Ns
| X86Condition::O
| X86Condition::No
| X86Condition::P
| X86Condition::Np => None,
}
}
#[must_use]
pub const fn is_zero_flag_only(condition: X86Condition) -> bool {
matches!(condition, X86Condition::E | X86Condition::Ne)
}
#[must_use]
pub const fn is_signed_condition(condition: X86Condition) -> bool {
matches!(
condition,
X86Condition::L | X86Condition::Ge | X86Condition::Le | X86Condition::G
)
}
#[must_use]
pub const fn is_unsigned_condition(condition: X86Condition) -> bool {
matches!(
condition,
X86Condition::B | X86Condition::Ae | X86Condition::Be | X86Condition::A
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flag_state_compare() {
let mut flags = FlagState::new();
assert!(!flags.is_known());
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_compare(v0, v1);
assert!(flags.is_known());
let (cmp, left, right, unsigned) = flags.get_branch_operands(X86Condition::L).unwrap();
assert_eq!(cmp, CmpKind::Lt);
assert_eq!(left, v0);
assert_eq!(right, v1);
assert!(!unsigned);
let (cmp, _, _, unsigned) = flags.get_branch_operands(X86Condition::B).unwrap();
assert_eq!(cmp, CmpKind::Lt);
assert!(unsigned);
}
#[test]
fn test_flag_state_test() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_test(v0, v1);
let result = flags.get_branch_operands(X86Condition::E);
assert!(result.is_some());
let result = flags.get_branch_operands(X86Condition::Ne);
assert!(result.is_some());
let result = flags.get_branch_operands(X86Condition::L);
assert!(result.is_none());
}
#[test]
fn test_zero_test_pattern() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
flags.set_test(v0, v0);
assert_eq!(flags.is_zero_test(), Some(v0));
let v1 = SsaVarId::new();
flags.set_test(v0, v1);
assert!(flags.is_zero_test().is_none());
}
#[test]
fn test_condition_to_cmp() {
assert_eq!(
condition_to_cmp(X86Condition::E),
Some((CmpKind::Eq, false))
);
assert_eq!(
condition_to_cmp(X86Condition::Ne),
Some((CmpKind::Ne, false))
);
assert_eq!(
condition_to_cmp(X86Condition::L),
Some((CmpKind::Lt, false))
);
assert_eq!(
condition_to_cmp(X86Condition::G),
Some((CmpKind::Gt, false))
);
assert_eq!(condition_to_cmp(X86Condition::B), Some((CmpKind::Lt, true)));
assert_eq!(condition_to_cmp(X86Condition::A), Some((CmpKind::Gt, true)));
assert!(condition_to_cmp(X86Condition::P).is_none());
assert!(condition_to_cmp(X86Condition::O).is_none());
}
#[test]
fn test_flag_state_clear() {
let mut flags = FlagState::new();
flags.set_compare(SsaVarId::new(), SsaVarId::new());
assert!(flags.is_known());
flags.clear();
assert!(!flags.is_known());
}
#[test]
fn test_condition_classification() {
assert!(is_zero_flag_only(X86Condition::E));
assert!(is_zero_flag_only(X86Condition::Ne));
assert!(!is_zero_flag_only(X86Condition::L));
assert!(is_signed_condition(X86Condition::L));
assert!(is_signed_condition(X86Condition::G));
assert!(!is_signed_condition(X86Condition::B));
assert!(is_unsigned_condition(X86Condition::B));
assert!(is_unsigned_condition(X86Condition::A));
assert!(!is_unsigned_condition(X86Condition::L));
}
#[test]
fn test_carry_flag_tracking() {
let mut flags = FlagState::new();
assert!(flags.carry().is_none());
let cf = SsaVarId::new();
flags.set_carry(cf);
assert_eq!(flags.carry(), Some(cf));
let zero = SsaVarId::new();
flags.clear_carry(zero);
assert_eq!(flags.carry(), Some(zero));
flags.clear();
assert!(flags.carry().is_none());
}
#[test]
fn test_arithmetic_with_signed_conditions() {
let mut flags = FlagState::new();
let result = SsaVarId::new();
let left = SsaVarId::new();
let right = SsaVarId::new();
flags.set_arithmetic(result, left, right, ArithmeticKind::Add);
let branch = flags.get_branch_operands(X86Condition::E);
assert!(branch.is_some());
let branch = flags.get_branch_operands(X86Condition::L);
assert!(branch.is_some());
let (cmp, l, r, unsigned) = branch.unwrap();
assert_eq!(cmp, CmpKind::Lt);
assert_eq!(l, left);
assert_eq!(r, right);
assert!(!unsigned);
}
#[test]
fn test_condition_eval() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_compare(v0, v1);
let eval = flags.get_condition_operands(X86Condition::L);
assert!(eval.is_some());
assert!(matches!(eval.unwrap(), ConditionEval::Compare { .. }));
flags.set_test(v0, v1);
let eval = flags.get_condition_operands(X86Condition::E);
assert!(eval.is_some());
assert!(matches!(eval.unwrap(), ConditionEval::Test { .. }));
}
#[test]
fn test_sign_flag_from_compare() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_compare(v0, v1);
let eval = flags.get_condition_operands(X86Condition::S);
assert!(eval.is_some());
assert!(matches!(
eval.unwrap(),
ConditionEval::SignFlag {
source: FlagTestSource::Subtract { .. },
negated: false,
}
));
let eval = flags.get_condition_operands(X86Condition::Ns);
assert!(matches!(
eval.unwrap(),
ConditionEval::SignFlag { negated: true, .. }
));
}
#[test]
fn test_sign_flag_from_test() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_test(v0, v1);
let eval = flags.get_condition_operands(X86Condition::S);
assert!(matches!(
eval.unwrap(),
ConditionEval::SignFlag {
source: FlagTestSource::BitwiseAnd { .. },
negated: false,
}
));
}
#[test]
fn test_sign_flag_from_arithmetic() {
let mut flags = FlagState::new();
let result = SsaVarId::new();
let left = SsaVarId::new();
let right = SsaVarId::new();
flags.set_arithmetic(result, left, right, ArithmeticKind::Add);
let eval = flags.get_condition_operands(X86Condition::S);
assert!(matches!(
eval.unwrap(),
ConditionEval::SignFlag {
source: FlagTestSource::Direct(_),
negated: false,
}
));
}
#[test]
fn test_overflow_flag_from_compare() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_compare(v0, v1);
let eval = flags.get_condition_operands(X86Condition::O);
assert!(eval.is_some());
match eval.unwrap() {
ConditionEval::OverflowFlag {
result,
kind,
negated,
..
} => {
assert!(result.is_none()); assert_eq!(kind, ArithmeticKind::Sub);
assert!(!negated);
}
other => panic!("Expected OverflowFlag, got {other:?}"),
}
let eval = flags.get_condition_operands(X86Condition::No);
assert!(matches!(
eval.unwrap(),
ConditionEval::OverflowFlag { negated: true, .. }
));
}
#[test]
fn test_overflow_flag_from_test() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_test(v0, v1);
let eval = flags.get_condition_operands(X86Condition::O);
assert!(matches!(
eval.unwrap(),
ConditionEval::OverflowFlag {
kind: ArithmeticKind::LogicalOp,
negated: false,
..
}
));
}
#[test]
fn test_overflow_flag_from_arithmetic() {
let mut flags = FlagState::new();
let result = SsaVarId::new();
let left = SsaVarId::new();
let right = SsaVarId::new();
flags.set_arithmetic(result, left, right, ArithmeticKind::Add);
let eval = flags.get_condition_operands(X86Condition::O);
match eval.unwrap() {
ConditionEval::OverflowFlag {
result: r,
kind,
negated,
..
} => {
assert!(r.is_some());
assert_eq!(kind, ArithmeticKind::Add);
assert!(!negated);
}
other => panic!("Expected OverflowFlag, got {other:?}"),
}
}
#[test]
fn test_parity_flag_from_compare() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_compare(v0, v1);
let eval = flags.get_condition_operands(X86Condition::P);
assert!(matches!(
eval.unwrap(),
ConditionEval::ParityFlag {
source: FlagTestSource::Subtract { .. },
negated: false,
}
));
let eval = flags.get_condition_operands(X86Condition::Np);
assert!(matches!(
eval.unwrap(),
ConditionEval::ParityFlag { negated: true, .. }
));
}
#[test]
fn test_parity_flag_from_test() {
let mut flags = FlagState::new();
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
flags.set_test(v0, v1);
let eval = flags.get_condition_operands(X86Condition::P);
assert!(matches!(
eval.unwrap(),
ConditionEval::ParityFlag {
source: FlagTestSource::BitwiseAnd { .. },
negated: false,
}
));
}
#[test]
fn test_branch_operands_still_none_for_flag_conditions() {
let mut flags = FlagState::new();
flags.set_compare(SsaVarId::new(), SsaVarId::new());
assert!(flags.get_branch_operands(X86Condition::S).is_none());
assert!(flags.get_branch_operands(X86Condition::Ns).is_none());
assert!(flags.get_branch_operands(X86Condition::O).is_none());
assert!(flags.get_branch_operands(X86Condition::No).is_none());
assert!(flags.get_branch_operands(X86Condition::P).is_none());
assert!(flags.get_branch_operands(X86Condition::Np).is_none());
assert!(flags.get_branch_operands(X86Condition::E).is_some());
assert!(flags.get_branch_operands(X86Condition::L).is_some());
}
#[test]
fn test_arithmetic_kind_neg() {
let mut flags = FlagState::new();
let result = SsaVarId::new();
flags.set_arithmetic_unary(result);
let eval = flags.get_condition_operands(X86Condition::O);
match eval.unwrap() {
ConditionEval::OverflowFlag { kind, .. } => {
assert_eq!(kind, ArithmeticKind::Neg);
}
other => panic!("Expected OverflowFlag, got {other:?}"),
}
}
}