use rustc_hash::FxHashSet;
use crate::{
analysis::{
x86_decode_all, x86_detect_prologue, X86Function, X86PrologueKind, X86ToSsaTranslator,
},
cilassembly::{CilAssembly, MethodBodyBuilder},
compiler::SsaCodeGenerator,
file::File,
metadata::{
method::MethodImplCodeType,
tables::{MethodDefRaw, TableDataOwned, TableId},
token::Token,
},
Error, Result,
};
#[derive(Debug, Clone, Default)]
pub struct ConversionStats {
pub converted: usize,
pub failed: usize,
pub failed_tokens: Vec<Token>,
pub errors: Vec<String>,
}
pub struct NativeMethodConversionPass {
targets: FxHashSet<Token>,
skip_prologue: bool,
bitness: Option<u32>,
}
impl Default for NativeMethodConversionPass {
fn default() -> Self {
Self::new()
}
}
impl NativeMethodConversionPass {
#[must_use]
pub fn new() -> Self {
Self {
targets: FxHashSet::default(),
skip_prologue: true,
bitness: None,
}
}
pub fn register_target(&mut self, token: Token) {
self.targets.insert(token);
}
pub fn register_targets(&mut self, tokens: impl IntoIterator<Item = Token>) {
self.targets.extend(tokens);
}
#[must_use]
pub fn with_skip_prologue(mut self, skip: bool) -> Self {
self.skip_prologue = skip;
self
}
#[must_use]
pub fn with_bitness(mut self, bitness: u32) -> Self {
self.bitness = Some(bitness);
self
}
#[must_use]
pub fn target_count(&self) -> usize {
self.targets.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.targets.is_empty()
}
pub fn run(&self, assembly: &mut CilAssembly, file: &File) -> Result<ConversionStats> {
let mut stats = ConversionStats::default();
if self.targets.is_empty() {
return Ok(stats);
}
let bitness = self
.bitness
.unwrap_or_else(|| if file.pe().is_64bit { 64 } else { 32 });
for &token in &self.targets {
match self.convert_method(assembly, file, token, bitness) {
Ok(()) => {
stats.converted += 1;
}
Err(e) => {
stats.failed += 1;
stats.failed_tokens.push(token);
stats.errors.push(format!("0x{:08x}: {}", token.value(), e));
}
}
}
Ok(stats)
}
fn convert_method(
&self,
assembly: &mut CilAssembly,
file: &File,
token: Token,
bitness: u32,
) -> Result<()> {
let rid = token.row();
#[allow(clippy::redundant_closure_for_method_calls)]
let method_row = assembly
.view()
.tables()
.and_then(|t| t.table::<MethodDefRaw>())
.and_then(|table| table.get(rid))
.ok_or_else(|| Error::X86Error(format!("MethodDef row {rid} not found for token")))?;
let impl_code_type = MethodImplCodeType::from_impl_flags(method_row.impl_flags);
if !impl_code_type.contains(MethodImplCodeType::NATIVE) {
return Err(Error::X86Error(format!(
"Method 0x{:08x} is not a native method",
token.value()
)));
}
if method_row.rva == 0 {
return Err(Error::X86Error(format!(
"Method 0x{:08x} has no RVA",
token.value()
)));
}
let offset = file.rva_to_offset(method_row.rva as usize)?;
let x86_bytes = &file.data()[offset..];
let (decode_bytes, base_offset) = if self.skip_prologue {
let prologue = x86_detect_prologue(x86_bytes, bitness);
if prologue.kind == X86PrologueKind::DynCipher {
(&x86_bytes[prologue.size..], prologue.size as u64)
} else {
(x86_bytes, 0u64)
}
} else {
(x86_bytes, 0u64)
};
let instructions = x86_decode_all(decode_bytes, bitness, base_offset)?;
if instructions.is_empty() {
return Err(Error::X86Error("No instructions decoded".to_string()));
}
let cfg = X86Function::new(&instructions, bitness, base_offset);
let translator = X86ToSsaTranslator::new(&cfg);
let ssa_function = translator.translate()?;
let mut codegen = SsaCodeGenerator::new();
let result = codegen.compile(&ssa_function, assembly)?;
let (method_body, _) = MethodBodyBuilder::from_compilation(
result.bytecode,
result.max_stack,
result.locals,
result.exception_handlers,
)
.init_locals(false)
.build(assembly)?;
let new_rva = assembly.store_method_body(method_body);
let updated_row = MethodDefRaw {
rid: method_row.rid,
token: method_row.token,
offset: method_row.offset,
rva: new_rva,
impl_flags: (method_row.impl_flags & !0x0087) | MethodImplCodeType::IL.bits(),
flags: method_row.flags & !0x2000,
name: method_row.name,
signature: method_row.signature,
param_list: method_row.param_list,
};
assembly.table_row_update(
TableId::MethodDef,
rid,
TableDataOwned::MethodDef(updated_row),
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conversion_pass_new() {
let pass = NativeMethodConversionPass::new();
assert!(pass.is_empty());
assert_eq!(pass.target_count(), 0);
}
#[test]
fn test_register_targets() {
let mut pass = NativeMethodConversionPass::new();
let token1 = Token::new(0x06000001);
let token2 = Token::new(0x06000002);
pass.register_target(token1);
assert_eq!(pass.target_count(), 1);
pass.register_target(token2);
assert_eq!(pass.target_count(), 2);
pass.register_target(token1);
assert_eq!(pass.target_count(), 2);
}
#[test]
fn test_register_multiple_targets() {
let mut pass = NativeMethodConversionPass::new();
let tokens = vec![
Token::new(0x06000001),
Token::new(0x06000002),
Token::new(0x06000003),
];
pass.register_targets(tokens);
assert_eq!(pass.target_count(), 3);
}
#[test]
fn test_builder_pattern() {
let pass = NativeMethodConversionPass::new()
.with_skip_prologue(false)
.with_bitness(64);
assert!(!pass.skip_prologue);
assert_eq!(pass.bitness, Some(64));
}
}