use crate::{
cilassembly::{CilAssembly, Operation, TableModifications},
utils::calculate_table_row_size,
Error, Result,
};
pub fn calculate_table_stream_expansion(assembly: &CilAssembly) -> Result<u64> {
let changes = assembly.changes();
let view = assembly.view();
let tables = view.tables().ok_or_else(|| {
Error::LayoutFailed("No tables found in assembly for expansion calculation".to_string())
})?;
let mut total_expansion = 0u64;
let mut header_expansion = 0u64;
for table_id in changes.modified_tables() {
if let Some(table_mod) = changes.get_table_modifications(table_id) {
let row_size = calculate_table_row_size(table_id, &tables.info);
let original_count = tables.table_row_count(table_id);
let (new_count, additional_rows) = match table_mod {
TableModifications::Replaced(new_rows) => {
let new_count = u32::try_from(new_rows.len()).map_err(|_| {
Error::LayoutFailed(format!(
"Table {:?} row count {} exceeds u32::MAX",
table_id,
new_rows.len()
))
})?;
let additional = new_count.saturating_sub(original_count);
(new_count, additional)
}
TableModifications::Sparse { operations, .. } => {
let insert_count_raw = operations
.iter()
.filter(|op| matches!(op.operation, Operation::Insert(_, _)))
.count();
let insert_count = u32::try_from(insert_count_raw).map_err(|_| {
Error::LayoutFailed(format!(
"Table {table_id:?} insert count {insert_count_raw} exceeds u32::MAX"
))
})?;
let new_count = original_count + insert_count;
(new_count, insert_count)
}
};
let expansion_bytes = u64::from(additional_rows) * u64::from(row_size);
total_expansion += expansion_bytes;
if original_count == 0 && new_count > 0 {
header_expansion += 4;
}
}
}
Ok(total_expansion + header_expansion)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use std::path::PathBuf;
fn get_test_assembly() -> Result<CilAssembly> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path)?;
Ok(CilAssembly::new(view))
}
#[test]
fn test_no_modifications_returns_zero() -> Result<()> {
let assembly = get_test_assembly()?;
let expansion = calculate_table_stream_expansion(&assembly)?;
assert_eq!(expansion, 0);
Ok(())
}
#[test]
fn test_adding_string_to_heap_no_table_expansion() -> Result<()> {
let mut assembly = get_test_assembly()?;
assembly.userstring_add("TestString")?;
assembly.userstring_add("AnotherString")?;
let expansion = calculate_table_stream_expansion(&assembly)?;
assert_eq!(expansion, 0);
Ok(())
}
#[test]
fn test_adding_typedef_causes_expansion() -> Result<()> {
let mut assembly = get_test_assembly()?;
let _typedef = ClassBuilder::new("TestClass")
.public()
.namespace("Test")
.build(&mut assembly)?;
let expansion = calculate_table_stream_expansion(&assembly)?;
assert!(expansion > 0);
Ok(())
}
#[test]
fn test_adding_method_causes_expansion() -> Result<()> {
let mut assembly = get_test_assembly()?;
let _class = ClassBuilder::new("TestClass")
.public()
.namespace("Test")
.method(|m| {
m.public().implementation(|body| {
body.implementation(|asm| {
asm.ret()?;
Ok(())
})
})
})
.build(&mut assembly)?;
let expansion = calculate_table_stream_expansion(&assembly)?;
assert!(expansion > 0);
Ok(())
}
#[test]
fn test_multiple_classes_increases_expansion() -> Result<()> {
let mut assembly = get_test_assembly()?;
let _class1 = ClassBuilder::new("TestClass1")
.public()
.namespace("Test")
.build(&mut assembly)?;
let expansion1 = calculate_table_stream_expansion(&assembly)?;
let _class2 = ClassBuilder::new("TestClass2")
.public()
.namespace("Test")
.build(&mut assembly)?;
let expansion2 = calculate_table_stream_expansion(&assembly)?;
assert!(expansion2 > expansion1);
Ok(())
}
#[test]
fn test_class_with_properties_causes_larger_expansion() -> Result<()> {
let mut assembly = get_test_assembly()?;
let _class = ClassBuilder::new("TestClass")
.public()
.namespace("Test")
.auto_property("Name", TypeSignature::String)
.auto_property("Value", TypeSignature::I4)
.build(&mut assembly)?;
let expansion = calculate_table_stream_expansion(&assembly)?;
assert!(expansion > 0);
Ok(())
}
}