use std::{
borrow::Borrow,
cmp::Ordering,
collections::HashMap,
fmt::{Debug, Display, Formatter, LowerHex},
hash::{Hash, Hasher},
};
use jingle_sleigh::PcodeOperation;
use crate::{
analysis::{
cpa::{
IntoState,
lattice::JoinSemiLattice,
state::{AbstractState, LocationState, MergeOutcome, Successor},
},
location::{basic::state::BasicLocationState, unwind::UnwindingAnalysis},
},
modeling::machine::cpu::concrete::ConcretePcodeAddress,
register_strengthen,
};
type BackEdge = (ConcretePcodeAddress, ConcretePcodeAddress);
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct UnwindingState {
location: ConcretePcodeAddress,
dominators: Vec<ConcretePcodeAddress>,
back_edge_counts: HashMap<BackEdge, usize>,
max_count: usize,
}
impl Hash for UnwindingState {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut v: Vec<_> = self.back_edge_counts.iter().collect();
v.sort_by(|a, b| a.0.partial_cmp(b.0).unwrap_or(a.1.cmp(b.1)));
v.hash(state);
}
}
impl Display for UnwindingState {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut edges: Vec<_> = self.back_edge_counts.iter().collect();
edges.sort_by_key(|(edge, _)| *edge);
for (i, (edge, count)) in edges.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "({:x} -> {:x}):{}", edge.0, edge.1, count)?;
}
Ok(())
}
}
impl LowerHex for UnwindingState {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "BackEdgeCount(loc: ")?;
write!(f, "{:#x}", self.location)?;
write!(f, ", edges: {{")?;
let mut edges: Vec<_> = self.back_edge_counts.iter().collect();
edges.sort_by_key(|(edge, _)| *edge);
for (i, (edge, count)) in edges.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "({:#x} -> {:#x}): {}", edge.0, edge.1, count)?;
}
write!(f, "}})")
}
}
impl UnwindingState {
fn with_location(location: ConcretePcodeAddress, max_count: usize) -> Self {
let dominators = vec![location];
Self {
location,
dominators,
back_edge_counts: HashMap::new(),
max_count,
}
}
fn terminated(&self) -> bool {
self.back_edge_counts
.values()
.any(|&count| count >= self.max_count)
}
fn move_to<L: LocationState>(&mut self, other: &L) {
if let Some(new_location) = other.get_location() {
if let Some(idx) = self.dominators.iter().position(|p| p == &new_location) {
let edge = (self.location, new_location);
*self.back_edge_counts.entry(edge).or_insert(0) += 1;
self.dominators.truncate(idx);
} else {
self.dominators.push(new_location);
}
self.location = new_location;
}
}
}
impl PartialOrd for UnwindingState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.location == other.location && self.back_edge_counts == other.back_edge_counts {
Some(Ordering::Equal)
} else {
None
}
}
}
impl JoinSemiLattice for UnwindingState {
fn join(&mut self, other: &Self) {
self.max_count = self.max_count.max(other.max_count);
let uneq_idx = self
.dominators
.iter()
.zip(&other.dominators)
.position(|(a, b)| a != b);
if let Some(uneq_idx) = uneq_idx {
self.dominators.truncate(uneq_idx);
}
if self.dominators.last() != Some(&self.location) {
self.dominators.push(self.location);
}
for (edge, &other_count) in other.back_edge_counts.iter() {
let key = *edge;
let entry = self.back_edge_counts.entry(key).or_insert(0);
if *entry < other_count {
*entry = other_count;
}
}
}
}
impl AbstractState for UnwindingState {
fn merge(&mut self, other: &Self) -> MergeOutcome {
self.merge_sep(other)
}
fn stop<'a, T: Iterator<Item = &'a Self>>(&'a self, states: T) -> bool {
self.stop_sep(states)
}
fn transfer<'a, B: Borrow<PcodeOperation>>(&'a self, _opcode: B) -> Successor<'a, Self> {
if self.terminated() {
return std::iter::empty().into();
}
std::iter::once(self.clone()).into()
}
}
impl IntoState<UnwindingAnalysis> for ConcretePcodeAddress {
fn into_state(self, c: &UnwindingAnalysis) -> UnwindingState {
UnwindingState::with_location(self, c.max_count)
}
}
register_strengthen!(UnwindingState, BasicLocationState, UnwindingState::move_to);