use crate::{
cilassembly::writer::{context::WriteContext, output::Output},
file::pe::{constants::COR20_HEADER_SIZE, SectionTable},
utils::align_to,
Result,
};
const METADATA_STREAM_NAMES: [&str; 5] = ["#~", "#Strings", "#US", "#GUID", "#Blob"];
pub fn apply_all_fixups(ctx: &mut WriteContext) -> Result<()> {
let pe_sig_offset = u32::try_from(ctx.pe_signature_offset).map_err(|_| {
crate::Error::LayoutFailed("PE signature offset exceeds u32 range".to_string())
})?;
ctx.write_u32_at(ctx.dos_header_offset + 0x3C, pe_sig_offset)?;
fixup_optional_header(ctx)?;
fixup_section_table(ctx)?;
fixup_cor20_header(ctx)?;
fixup_data_directories(ctx)?;
zero_stripped_data_regions(ctx)?;
fixup_coff_characteristics(ctx)?;
fixup_checksum(ctx)?;
Ok(())
}
pub fn fixup_optional_header(ctx: &mut WriteContext) -> Result<()> {
let text_file_size = u32::try_from(align_to(
ctx.text_section_size,
u64::from(ctx.file_alignment),
))
.map_err(|_| crate::Error::LayoutFailed("Text file size exceeds u32 range".to_string()))?;
let mut end_rva: u32 = 0;
for section in &ctx.sections {
if section.removed {
continue;
}
if let (Some(rva), Some(size)) = (section.rva, section.data_size) {
let section_end = rva.saturating_add(size);
if section_end > end_rva {
end_rva = section_end;
}
}
}
let image_size = u32::try_from(align_to(
u64::from(end_rva),
u64::from(ctx.section_alignment),
))
.map_err(|_| crate::Error::LayoutFailed("Image size exceeds u32 range".to_string()))?;
ctx.write_u32_at(ctx.optional_header_offset + 4, text_file_size)?;
if let Some(entry_rva) = ctx.native_entry_rva {
ctx.write_u32_at(ctx.optional_header_offset + 16, entry_rva)?;
}
ctx.write_u32_at(ctx.optional_header_offset + 56, image_size)?;
let headers_size = u32::try_from(ctx.text_section_offset)
.map_err(|_| crate::Error::LayoutFailed("Headers size exceeds u32 range".to_string()))?;
ctx.write_u32_at(ctx.optional_header_offset + 60, headers_size)?;
Ok(())
}
pub fn fixup_section_table(ctx: &mut WriteContext) -> Result<()> {
let mut section_headers: Vec<[u8; SectionTable::SIZE]> = Vec::new();
for section in &ctx.sections {
if section.removed {
continue;
}
let (Some(data_offset), Some(rva), Some(data_size)) =
(section.data_offset, section.rva, section.data_size)
else {
continue; };
let file_size = u32::try_from(align_to(
u64::from(data_size),
u64::from(ctx.file_alignment),
))
.map_err(|_| {
crate::Error::LayoutFailed(format!(
"Section {} file size exceeds u32 range",
section.name
))
})?;
let offset_u32 = u32::try_from(data_offset).map_err(|_| {
crate::Error::LayoutFailed(format!("Section {} offset exceeds u32 range", section.name))
})?;
let mut header = [0u8; SectionTable::SIZE];
let name_bytes = section.name.as_bytes();
let copy_len = std::cmp::min(name_bytes.len(), 8);
header[..copy_len].copy_from_slice(&name_bytes[..copy_len]);
header[8..12].copy_from_slice(&data_size.to_le_bytes());
header[12..16].copy_from_slice(&rva.to_le_bytes());
header[16..20].copy_from_slice(&file_size.to_le_bytes());
header[20..24].copy_from_slice(&offset_u32.to_le_bytes());
header[24..28].copy_from_slice(&0u32.to_le_bytes());
header[28..32].copy_from_slice(&0u32.to_le_bytes());
header[32..34].copy_from_slice(&0u16.to_le_bytes());
header[34..36].copy_from_slice(&0u16.to_le_bytes());
header[36..40].copy_from_slice(§ion.characteristics.to_le_bytes());
section_headers.push(header);
}
let mut offset = ctx.section_table_offset;
for header in §ion_headers {
ctx.write_at(offset, header)?;
offset += SectionTable::SIZE as u64;
}
let original_table_size = ctx.sections.len() * SectionTable::SIZE;
let new_table_size = section_headers.len() * SectionTable::SIZE;
if new_table_size < original_table_size {
let zeros = vec![0u8; original_table_size - new_table_size];
ctx.write_at(offset, &zeros)?;
}
let new_count = u16::try_from(section_headers.len()).unwrap_or(0);
ctx.write_u16_at(ctx.coff_header_offset + 2, new_count)?;
Ok(())
}
pub fn fixup_cor20_header(ctx: &mut WriteContext) -> Result<()> {
let metadata_rva = ctx.offset_to_rva(ctx.metadata_offset);
let metadata_size = u32::try_from(ctx.metadata_size)
.map_err(|_| crate::Error::LayoutFailed("Metadata size exceeds u32 range".to_string()))?;
ctx.write_u32_at(ctx.cor20_header_offset + 8, metadata_rva)?; ctx.write_u32_at(ctx.cor20_header_offset + 12, metadata_size)?;
if ctx.entry_point_token != 0 && !ctx.token_remapping.is_empty() {
if let Some(&new_token) = ctx.token_remapping.get(&ctx.entry_point_token) {
ctx.write_u32_at(ctx.cor20_header_offset + 20, new_token)?;
}
}
if ctx.resource_data_size > 0 {
let resource_rva = ctx.offset_to_rva(ctx.resource_data_offset);
let resource_size = u32::try_from(ctx.resource_data_size).map_err(|_| {
crate::Error::LayoutFailed("Resource size exceeds u32 range".to_string())
})?;
ctx.write_u32_at(ctx.cor20_header_offset + 24, resource_rva)?; ctx.write_u32_at(ctx.cor20_header_offset + 28, resource_size)?; }
Ok(())
}
pub fn fixup_data_directories(ctx: &mut WriteContext) -> Result<()> {
let dd_offset = if ctx.is_pe32_plus { 112 } else { 96 };
let dd_base = ctx.optional_header_offset + dd_offset;
let has_iat = ctx.iat_size > 0;
if has_iat {
let iat_rva = ctx.text_section_rva;
let iat_size = u32::try_from(ctx.iat_size).unwrap_or(8);
ctx.write_u32_at(dd_base + 12 * 8, iat_rva)?;
ctx.write_u32_at(dd_base + 12 * 8 + 4, iat_size)?;
let clr_rva = ctx.text_section_rva + iat_size;
ctx.write_u32_at(dd_base + 14 * 8, clr_rva)?;
ctx.write_u32_at(dd_base + 14 * 8 + 4, COR20_HEADER_SIZE)?;
} else {
ctx.write_u32_at(dd_base + 12 * 8, 0)?;
ctx.write_u32_at(dd_base + 12 * 8 + 4, 0)?;
let clr_rva = ctx.text_section_rva;
ctx.write_u32_at(dd_base + 14 * 8, clr_rva)?;
ctx.write_u32_at(dd_base + 14 * 8 + 4, COR20_HEADER_SIZE)?;
}
if let (Some(rva), Some(size)) = (ctx.import_data_rva, ctx.import_data_size) {
ctx.write_u32_at(dd_base + 8, rva)?;
ctx.write_u32_at(dd_base + 8 + 4, size)?;
} else {
ctx.write_u32_at(dd_base + 8, 0)?;
ctx.write_u32_at(dd_base + 8 + 4, 0)?;
}
if let (Some(rva), Some(size)) = (ctx.export_data_rva, ctx.export_data_size) {
ctx.write_u32_at(dd_base, rva)?;
ctx.write_u32_at(dd_base + 4, size)?;
}
let rsrc_section = ctx
.sections
.iter()
.find(|s| s.name.starts_with(".rsrc") && !s.removed);
if let Some(section) = rsrc_section {
if let (Some(rva), Some(size)) = (section.rva, section.data_size) {
ctx.write_u32_at(dd_base + 2 * 8, rva)?;
ctx.write_u32_at(dd_base + 2 * 8 + 4, size)?;
}
} else if ctx.pe_resource_size > 0 {
let rva = ctx.offset_to_rva(ctx.pe_resource_offset);
ctx.write_u32_at(dd_base + 2 * 8, rva)?;
ctx.write_u32_at(dd_base + 2 * 8 + 4, ctx.pe_resource_size)?;
} else {
ctx.write_u32_at(dd_base + 2 * 8, 0)?;
ctx.write_u32_at(dd_base + 2 * 8 + 4, 0)?;
}
let reloc_section = ctx
.sections
.iter()
.find(|s| s.name.starts_with(".reloc") && !s.removed);
if let Some(section) = reloc_section {
if let (Some(rva), Some(size)) = (section.rva, section.data_size) {
ctx.write_u32_at(dd_base + 5 * 8, rva)?;
ctx.write_u32_at(dd_base + 5 * 8 + 4, size)?;
} else {
ctx.write_u32_at(dd_base + 5 * 8, 0)?;
ctx.write_u32_at(dd_base + 5 * 8 + 4, 0)?;
}
} else {
ctx.write_u32_at(dd_base + 5 * 8, 0)?;
ctx.write_u32_at(dd_base + 5 * 8 + 4, 0)?;
}
Ok(())
}
pub fn fixup_metadata_stream_headers(
ctx: &mut WriteContext,
metadata_root_offset: u64,
stream_headers_offset: u64,
) -> Result<()> {
let mut offset = stream_headers_offset;
let streams = [
(
ctx.tables_stream_offset,
ctx.tables_stream_size,
METADATA_STREAM_NAMES[0],
),
(
ctx.strings_heap_offset,
ctx.strings_heap_size,
METADATA_STREAM_NAMES[1],
),
(
ctx.us_heap_offset,
ctx.us_heap_size,
METADATA_STREAM_NAMES[2],
),
(
ctx.guid_heap_offset,
ctx.guid_heap_size,
METADATA_STREAM_NAMES[3],
),
(
ctx.blob_heap_offset,
ctx.blob_heap_size,
METADATA_STREAM_NAMES[4],
),
];
for (stream_offset, stream_size, name) in &streams {
let relative_offset =
u32::try_from(*stream_offset - metadata_root_offset).map_err(|_| {
crate::Error::LayoutFailed("Stream relative offset exceeds u32 range".to_string())
})?;
let aligned_size = u32::try_from(align_to(*stream_size, 4)).map_err(|_| {
crate::Error::LayoutFailed("Stream aligned size exceeds u32 range".to_string())
})?;
ctx.write_u32_at(offset, relative_offset)?;
ctx.write_u32_at(offset + 4, aligned_size)?;
let name_with_null = name.len() + 1;
let aligned_name = align_to(name_with_null as u64, 4);
offset += 8 + aligned_name;
}
Ok(())
}
pub fn zero_stripped_data_regions(ctx: &mut WriteContext) -> Result<()> {
let _ = ctx.original_debug_dir;
if let Some((cert_offset, cert_size)) = ctx.original_certificate_dir {
let cert_offset_u64 = u64::from(cert_offset);
if cert_offset_u64 + u64::from(cert_size) <= ctx.bytes_written {
let zeros = vec![0u8; cert_size as usize];
ctx.write_at(cert_offset_u64, &zeros)?;
}
}
Ok(())
}
pub fn fixup_coff_characteristics(ctx: &mut WriteContext) -> Result<()> {
if ctx.relocs_stripped {
const IMAGE_FILE_RELOCS_STRIPPED: u16 = 0x0001;
let chars_offset = ctx.coff_header_offset + 18;
let current_bytes = ctx.output.as_slice();
#[allow(clippy::cast_possible_truncation)]
let chars_offset_usize = chars_offset as usize;
if chars_offset_usize + 2 <= current_bytes.len() {
let current = u16::from_le_bytes([
current_bytes[chars_offset_usize],
current_bytes[chars_offset_usize + 1],
]);
ctx.write_u16_at(chars_offset, current | IMAGE_FILE_RELOCS_STRIPPED)?;
}
}
Ok(())
}
pub fn fixup_checksum(ctx: &mut WriteContext) -> Result<()> {
let checksum_offset = ctx.optional_header_offset + 64;
let actual_size = usize::try_from(ctx.bytes_written)
.map_err(|_| crate::Error::LayoutFailed("File size exceeds usize range".to_string()))?;
let checksum = calculate_pe_checksum(&ctx.output, checksum_offset, actual_size);
ctx.write_u32_at(checksum_offset, checksum)?;
Ok(())
}
fn calculate_pe_checksum(output: &Output, checksum_offset: u64, actual_size: usize) -> u32 {
let data = output.as_slice();
let file_size = actual_size.min(data.len()); let checksum_offset_usize = usize::try_from(checksum_offset).unwrap_or(usize::MAX);
let mut sum: u64 = 0;
let mut i = 0;
while i + 1 < file_size {
if i >= checksum_offset_usize && i < checksum_offset_usize + 4 {
i += 2;
continue;
}
let word = u16::from_le_bytes([data[i], data[i + 1]]);
sum += u64::from(word);
i += 2;
}
if i < file_size {
if i < checksum_offset_usize || i >= checksum_offset_usize + 4 {
sum += u64::from(data[i]);
}
}
while sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16);
}
#[allow(clippy::cast_possible_truncation)]
let checksum = (sum as u32) + (file_size as u32);
checksum
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cilassembly::writer::generator::PeGenerator;
use crate::cilassembly::CilAssembly;
use crate::CilAssemblyView;
use std::path::Path;
use tempfile::NamedTempFile;
#[test]
fn test_checksum_excludes_checksum_field() {
let temp_file = NamedTempFile::new().expect("Failed to create temp file");
let mut output = Output::create(temp_file.path(), 256).expect("Failed to create output");
let pattern: Vec<u8> = (0..256).map(|i| i as u8).collect();
output.write_at(0, &pattern).expect("Failed to write data");
let checksum1 = calculate_pe_checksum(&output, 64, 256);
output
.write_at(64, &[0xFF, 0xFF, 0xFF, 0xFF])
.expect("Failed to modify checksum area");
let checksum2 = calculate_pe_checksum(&output, 64, 256);
assert_eq!(
checksum1, checksum2,
"Checksum should be the same regardless of checksum field content"
);
}
#[test]
fn test_apply_fixups_integration() {
let view = CilAssemblyView::from_path(Path::new("tests/samples/crafted_2.exe"))
.expect("Failed to load test assembly");
let assembly: CilAssembly = view.to_owned();
let temp_file = NamedTempFile::new().expect("Failed to create temp file");
let generator = PeGenerator::new(&assembly);
generator
.to_file(temp_file.path())
.expect("PE generation should succeed");
let reloaded = CilAssemblyView::from_path(temp_file.path())
.expect("Should be able to reload generated PE");
assert!(
reloaded.tables().is_some(),
"Reloaded PE should have tables"
);
}
}