use rayon::prelude::*;
use rustc_hash::FxHashSet as HashSet;
use crate::{get_reg_family, semantics, Gadget};
pub fn filter_stack_pivot<'a, P>(gadgets: P) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let regs_overwritten = g.analysis().regs_overwritten(true);
if regs_overwritten.contains(&iced_x86::Register::RSP)
|| regs_overwritten.contains(&iced_x86::Register::ESP)
|| regs_overwritten.contains(&iced_x86::Register::SP)
{
return true;
}
false
})
.collect()
}
pub fn filter_dispatcher<'a, P>(gadgets: P) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
if let Some((tail_instr, preceding_instrs)) = g.instrs.split_last() {
if semantics::is_jop_gadget_tail(tail_instr) {
debug_assert_eq!(g.analysis().ty(), crate::GadgetType::Jop);
let dispatch_reg = tail_instr.op0_register();
for i in preceding_instrs {
if semantics::is_reg_rw(i, &dispatch_reg) {
return true;
}
}
}
}
false
})
.collect()
}
pub fn filter_reg_pop_only<'a, P>(gadgets: P) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
if let Some((tail_instr, mut preceding_instrs)) = g.instrs.split_last() {
if semantics::is_gadget_tail(tail_instr) && (!preceding_instrs.is_empty()) {
if let Some((second_to_last, remaining)) = preceding_instrs.split_last() {
if second_to_last.mnemonic() == iced_x86::Mnemonic::Leave {
preceding_instrs = remaining;
}
}
let pop_chain = preceding_instrs.iter().all(|instr| {
matches!(
instr.mnemonic(),
iced_x86::Mnemonic::Pop | iced_x86::Mnemonic::Popa
)
});
if pop_chain && !preceding_instrs.is_empty() {
return true;
}
}
}
false
})
.collect()
}
pub fn filter_reg_only<'a, P>(gadgets: P) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
g.analysis().used_mem().all(|um| um.displacement() == 0x0)
&& g.instrs
.iter()
.all(|instr| instr.op_count() == 0 || semantics::is_reg_ops_only(instr))
})
.collect()
}
pub fn filter_set_params<'a, P>(gadgets: P, param_regs: &[iced_x86::Register]) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
for instr in &g.instrs {
if matches!(
instr.mnemonic(),
iced_x86::Mnemonic::Pusha | iced_x86::Mnemonic::Pushad
) {
return true;
}
if instr.mnemonic() == iced_x86::Mnemonic::Push {
if let Ok(op_kind) = instr.try_op_kind(0) {
if op_kind == iced_x86::OpKind::Register {
return true;
}
}
}
for reg in param_regs {
if semantics::is_reg_set(instr, reg) {
return true;
}
}
}
false
})
.collect()
}
pub fn filter_regs_overwritten<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let regs_overwritten = g.analysis().regs_overwritten(false);
match opt_regs {
Some(regs) => regs.iter().all(|r| regs_overwritten.contains(r)),
None => !regs_overwritten.is_empty(),
}
})
.collect()
}
pub fn filter_regs_written<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let analysis = g.analysis();
let regs_written = analysis
.regs_overwritten(true)
.into_iter()
.chain(analysis.regs_updated())
.collect::<HashSet<iced_x86::Register>>();
match opt_regs {
Some(regs) => regs.iter().all(|r| regs_written.contains(r)),
None => !regs_written.is_empty(),
}
})
.collect()
}
pub fn filter_regs_not_written<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let analysis = g.analysis();
let regs_written = analysis
.regs_overwritten(true)
.into_iter()
.chain(analysis.regs_updated())
.collect::<HashSet<iced_x86::Register>>();
match opt_regs {
Some(regs) => {
let regs = regs.iter().flat_map(get_reg_family).collect::<Vec<_>>();
regs.iter().all(|r| !regs_written.contains(r))
}
None => regs_written.is_empty(),
}
})
.collect()
}
pub fn filter_regs_deref_write<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let mut regs_derefed_write = g.analysis().regs_dereferenced_mem_write();
match opt_regs {
Some(regs) => regs.iter().all(|r| regs_derefed_write.contains(r)),
None => {
regs_derefed_write.retain(|r| r != &iced_x86::Register::RSP);
regs_derefed_write.retain(|r| r != &iced_x86::Register::ESP);
regs_derefed_write.retain(|r| r != &iced_x86::Register::SP);
!regs_derefed_write.is_empty()
}
}
})
.collect()
}
pub fn filter_regs_read<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let regs_read = g.analysis().regs_read();
match opt_regs {
Some(regs) => regs.iter().all(|r| regs_read.contains(r)),
None => !regs_read.is_empty(),
}
})
.collect()
}
pub fn filter_regs_not_read<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let regs_read = g.analysis().regs_read();
match opt_regs {
Some(regs) => {
let regs = regs.iter().flat_map(get_reg_family).collect::<Vec<_>>();
regs.iter().all(|r| !regs_read.contains(r))
}
None => regs_read.is_empty(),
}
})
.collect()
}
pub fn filter_regs_deref_read<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let mut regs_derefed_read = g.analysis().regs_dereferenced_mem_read();
match opt_regs {
Some(regs) => regs.iter().all(|r| regs_derefed_read.contains(r)),
None => {
regs_derefed_read.retain(|r| r != &iced_x86::Register::RSP);
regs_derefed_read.retain(|r| r != &iced_x86::Register::ESP);
regs_derefed_read.retain(|r| r != &iced_x86::Register::SP);
!regs_derefed_read.is_empty()
}
}
})
.collect()
}
pub fn filter_reg_no_deref<'a, P>(gadgets: P, opt_regs: Option<&[iced_x86::Register]>) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.filter(|g| {
let mut regs_derefed = g.analysis().regs_dereferenced();
match opt_regs {
Some(regs) => regs.iter().all(|r| !regs_derefed.contains(r)),
None => {
regs_derefed.retain(|r| r != &iced_x86::Register::RSP);
regs_derefed.retain(|r| r != &iced_x86::Register::ESP);
regs_derefed.retain(|r| r != &iced_x86::Register::SP);
regs_derefed.is_empty()
}
}
})
.collect()
}
pub fn filter_bad_addr_bytes<'a, P>(gadgets: P, bad_bytes: &[u8]) -> P
where
P: IntoParallelIterator<Item = Gadget<'a>> + FromParallelIterator<Gadget<'a>>,
{
gadgets
.into_par_iter()
.map(|mut g| {
g.full_matches
.retain(|addr| addr.to_le_bytes().iter().all(|b| !bad_bytes.contains(b)));
g.partial_matches = g
.partial_matches
.iter()
.filter(|(addr, _)| addr.to_le_bytes().iter().all(|b| !bad_bytes.contains(b)))
.map(|(addr, bins)| (*addr, bins.clone()))
.collect();
g
})
.filter(|g| !g.full_matches.is_empty() || !g.partial_matches.is_empty())
.collect()
}