use std::collections::{HashMap, HashSet};
use crate::{
cilassembly::{changes::AssemblyChanges, writer::context::WriteContext},
file::File,
metadata::{
signatures::{parse_field_signature, TypeSignature},
streams::TablesHeader,
tables::{ClassLayoutRaw, FieldRaw, FieldRvaRaw},
typesystem::PointerSize,
},
prelude::TableId,
CilAssemblyView, Error, Result,
};
type FieldDataEntries = Vec<(u32, Vec<u8>)>;
fn collect_field_data(
view: &CilAssemblyView,
file: &File,
changes: &AssemblyChanges,
ptr_size: PointerSize,
) -> Result<FieldDataEntries> {
let mut entries = Vec::new();
collect_original_fieldrva_data(view, file, changes, &mut entries, ptr_size)?;
collect_changes_field_data(changes, &mut entries);
Ok(entries)
}
fn calculate_field_size(
view: &CilAssemblyView,
field_index: u32,
ptr_size: PointerSize,
) -> Result<usize> {
let tables = view.tables().ok_or_else(|| {
Error::ModificationInvalid(format!(
"Cannot access tables for field {field_index} size calculation"
))
})?;
let field_table = tables.table::<FieldRaw>().ok_or_else(|| {
Error::ModificationInvalid(format!(
"Cannot access Field table for field {field_index} size calculation"
))
})?;
let field_row = field_table
.iter()
.find(|r| r.rid == field_index)
.ok_or_else(|| {
Error::ModificationInvalid(format!("Field {field_index} not found in Field table"))
})?;
let blobs = view.blobs().ok_or_else(|| {
Error::ModificationInvalid(format!(
"Cannot access blob heap for field {field_index} size calculation"
))
})?;
let sig_data = blobs.get(field_row.signature as usize).map_err(|_| {
Error::ModificationInvalid(format!(
"Cannot read signature blob for field {field_index}"
))
})?;
let field_sig = parse_field_signature(sig_data).map_err(|e| {
Error::ModificationInvalid(format!("Cannot parse field {field_index} signature: {e}"))
})?;
calculate_type_size(&field_sig.base, tables, field_index, ptr_size)
}
fn calculate_type_size(
type_sig: &TypeSignature,
tables: &TablesHeader,
field_index: u32,
ptr_size: PointerSize,
) -> Result<usize> {
if let Some(size) = type_sig.byte_size(ptr_size) {
return Ok(size);
}
match type_sig {
TypeSignature::ValueType(token) => {
let is_typedef = token.is_table(TableId::TypeDef);
let row = token.row();
if is_typedef {
if let Some(class_layout_table) = tables.table::<ClassLayoutRaw>() {
for layout_row in class_layout_table {
if layout_row.parent == row {
return Ok(layout_row.class_size as usize);
}
}
}
}
Err(Error::ModificationInvalid(format!(
"Field {} has ValueType (token 0x{:08x}) without ClassLayout - cannot determine size",
field_index, token.value()
)))
}
TypeSignature::Array(arr) => {
let element_size = calculate_type_size(&arr.base, tables, field_index, ptr_size)?;
let mut total_elements: usize = 1;
for dim in &arr.dimensions {
let dim_size = dim.size.ok_or_else(|| {
Error::ModificationInvalid(format!(
"Field {field_index} has array with unknown dimension size"
))
})? as usize;
total_elements = total_elements.checked_mul(dim_size).ok_or_else(|| {
Error::ModificationInvalid(format!("Field {field_index} array size overflow"))
})?;
}
element_size.checked_mul(total_elements).ok_or_else(|| {
Error::ModificationInvalid(format!("Field {field_index} array size overflow"))
})
}
TypeSignature::SzArray(_) => Err(Error::ModificationInvalid(format!(
"Field {field_index} is SzArray - size not determinable from type signature. \
Use FixedBufferAttribute or explicit ClassLayout for fixed-size arrays."
))),
_ => Err(Error::ModificationInvalid(format!(
"Field {field_index} has type {type_sig:?} - cannot determine static size"
))),
}
}
fn collect_original_fieldrva_data(
view: &CilAssemblyView,
file: &File,
changes: &AssemblyChanges,
entries: &mut FieldDataEntries,
ptr_size: PointerSize,
) -> Result<()> {
let Some(fieldrva_table) = view
.tables()
.and_then(|t: &TablesHeader<'_>| t.table::<FieldRvaRaw>())
else {
return Ok(());
};
let modified_rids: HashSet<u32> = changes
.get_table_modifications(TableId::FieldRVA)
.map(|mods| mods.change_refs().map(|(rid, _)| *rid).collect())
.unwrap_or_default();
let deleted_rids: HashSet<u32> = changes
.get_table_modifications(TableId::FieldRVA)
.map(|mods| mods.deleted_rids().collect())
.unwrap_or_default();
let deleted_field_rids: HashSet<u32> = changes
.get_table_modifications(TableId::Field)
.map(|mods| mods.deleted_rids().collect())
.unwrap_or_default();
let mut entries_to_process: Vec<(u32, u32, u32)> = Vec::new();
for row in fieldrva_table {
if deleted_rids.contains(&row.rid)
|| deleted_field_rids.contains(&row.field)
|| modified_rids.contains(&row.rid)
{
continue;
}
let rva = row.rva;
if rva == 0 {
continue;
}
entries_to_process.push((rva, row.rid, row.field));
}
if entries_to_process.is_empty() {
return Ok(());
}
entries_to_process.sort_by_key(|(rva, _, _)| *rva);
for (rva, _rid, field_index) in entries_to_process {
#[allow(clippy::cast_possible_truncation)]
let size = calculate_field_size(view, field_index, ptr_size)? as u32;
if size > 1024 * 1024 {
return Err(Error::ModificationInvalid(format!(
"Field {field_index} has unreasonable size {size} bytes"
)));
}
let offset = file.rva_to_offset(rva as usize).map_err(|_| {
Error::ModificationInvalid(format!(
"Cannot convert RVA 0x{rva:08x} to file offset for field {field_index}"
))
})?;
let data = file.data_slice(offset, size as usize).map_err(|_| {
Error::ModificationInvalid(format!(
"Cannot read {size} bytes at offset 0x{offset:08x} for field {field_index}"
))
})?;
entries.push((rva, data.to_vec()));
}
Ok(())
}
fn collect_changes_field_data(changes: &AssemblyChanges, entries: &mut FieldDataEntries) {
if !changes.has_field_data() {
return;
}
let mut field_data_entries: Vec<_> = changes.field_data_entries().collect();
field_data_entries.sort_by_key(|(placeholder, _)| *placeholder);
for (placeholder_rva, data) in field_data_entries {
entries.push((placeholder_rva, data.clone()));
}
}
pub fn write_field_data(ctx: &mut WriteContext) -> Result<()> {
let view = ctx.assembly.view();
let file = view.file();
let changes = ctx.changes;
let ptr_size = PointerSize::from_pe(file.pe().is_64bit);
let entries = collect_field_data(view, file, changes, ptr_size)?;
if entries.is_empty() {
return Ok(());
}
ctx.align_to_4();
for (source_rva, data) in &entries {
let actual_rva = ctx.current_rva();
ctx.field_data_rva_map.insert(*source_rva, actual_rva);
ctx.write(data)?;
ctx.align_to_4();
}
Ok(())
}
pub fn resolve_field_data_rva(buffer: &mut [u8], field_data_rva_map: &HashMap<u32, u32>) {
if buffer.len() < 4 {
return;
}
let rva = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
if rva == 0 {
return;
}
if let Some(&new_rva) = field_data_rva_map.get(&rva) {
let new_bytes = new_rva.to_le_bytes();
buffer[0] = new_bytes[0];
buffer[1] = new_bytes[1];
buffer[2] = new_bytes[2];
buffer[3] = new_bytes[3];
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolve_field_data_rva_placeholder() {
let mut buffer = [0x01, 0x00, 0x00, 0xE0, 0x00, 0x00]; let mut map = HashMap::new();
map.insert(0xE000_0001, 0x3000);
resolve_field_data_rva(&mut buffer, &map);
assert_eq!(
u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]),
0x3000
);
}
#[test]
fn test_resolve_field_data_rva_original() {
let mut buffer = [0x50, 0x20, 0x00, 0x00, 0x00, 0x00]; let mut map = HashMap::new();
map.insert(0x2050, 0x4000);
resolve_field_data_rva(&mut buffer, &map);
assert_eq!(
u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]),
0x4000
);
}
#[test]
fn test_resolve_field_data_rva_not_in_map() {
let mut buffer = [0x00, 0x30, 0x00, 0x00, 0x00, 0x00]; let map = HashMap::new();
resolve_field_data_rva(&mut buffer, &map);
assert_eq!(
u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]),
0x3000
);
}
#[test]
fn test_resolve_field_data_rva_zero() {
let mut buffer = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let mut map = HashMap::new();
map.insert(0, 0x1000);
resolve_field_data_rva(&mut buffer, &map);
assert_eq!(
u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]),
0
);
}
}