use crate::{
cilassembly::{ChangeRefRc, CilAssembly},
metadata::{
tables::{ClassLayoutRaw, TableDataOwned, TableId},
token::Token,
},
Error, Result,
};
pub struct ClassLayoutBuilder {
packing_size: Option<u16>,
class_size: Option<u32>,
parent: Option<u32>,
}
impl Default for ClassLayoutBuilder {
fn default() -> Self {
Self::new()
}
}
impl ClassLayoutBuilder {
#[must_use]
pub fn new() -> Self {
Self {
packing_size: None,
class_size: None,
parent: None,
}
}
#[must_use]
pub fn packing_size(mut self, packing: u16) -> Self {
self.packing_size = Some(packing);
self
}
#[must_use]
pub fn class_size(mut self, size: u32) -> Self {
self.class_size = Some(size);
self
}
#[must_use]
pub fn parent(mut self, parent: u32) -> Self {
self.parent = Some(parent);
self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<ChangeRefRc> {
const MAX_CLASS_SIZE: u32 = 0x1000_0000;
let packing_size = self
.packing_size
.ok_or_else(|| Error::ModificationInvalid("Packing size is required".to_string()))?;
let class_size = self
.class_size
.ok_or_else(|| Error::ModificationInvalid("Class size is required".to_string()))?;
let parent = self
.parent
.ok_or_else(|| Error::ModificationInvalid("Parent type is required".to_string()))?;
if packing_size != 0 && (packing_size & (packing_size - 1)) != 0 {
return Err(Error::ModificationInvalid(format!(
"Packing size must be 0 or a power of 2, got {packing_size}"
)));
}
if packing_size > 128 {
return Err(Error::ModificationInvalid(format!(
"Packing size cannot exceed 128 bytes, got {packing_size}"
)));
}
if class_size > MAX_CLASS_SIZE {
return Err(Error::ModificationInvalid(format!(
"Class size cannot exceed 256MB (0x{MAX_CLASS_SIZE:X}), got {class_size}"
)));
}
let rid = assembly.next_rid(TableId::ClassLayout)?;
let token = Token::from_parts(TableId::ClassLayout, rid);
let class_layout_raw = ClassLayoutRaw {
rid,
token,
offset: 0, packing_size,
class_size,
parent,
};
assembly.table_row_add(
TableId::ClassLayout,
TableDataOwned::ClassLayout(class_layout_raw),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::{ChangeRefKind, CilAssembly},
metadata::cilassemblyview::CilAssemblyView,
};
use std::path::PathBuf;
#[test]
fn test_class_layout_builder_basic() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let layout_ref = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(4)
.class_size(0)
.build(&mut assembly)
.unwrap();
assert_eq!(
layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
}
}
#[test]
fn test_class_layout_builder_different_packings() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type1_rid = 1; let type2_rid = 2; let type3_rid = 3; let type4_rid = 4;
let layout1_ref = ClassLayoutBuilder::new()
.parent(type1_rid)
.packing_size(1)
.class_size(0)
.build(&mut assembly)
.unwrap();
let layout2_ref = ClassLayoutBuilder::new()
.parent(type2_rid)
.packing_size(8)
.class_size(0)
.build(&mut assembly)
.unwrap();
let layout3_ref = ClassLayoutBuilder::new()
.parent(type3_rid)
.packing_size(16)
.class_size(0)
.build(&mut assembly)
.unwrap();
let layout4_ref = ClassLayoutBuilder::new()
.parent(type4_rid)
.packing_size(64)
.class_size(0)
.build(&mut assembly)
.unwrap();
assert_eq!(
layout1_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
layout2_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
layout3_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
layout4_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert!(!std::sync::Arc::ptr_eq(&layout1_ref, &layout2_ref));
assert!(!std::sync::Arc::ptr_eq(&layout1_ref, &layout3_ref));
assert!(!std::sync::Arc::ptr_eq(&layout1_ref, &layout4_ref));
}
}
#[test]
fn test_class_layout_builder_default_packing() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let layout_ref = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(0) .class_size(0) .build(&mut assembly)
.unwrap();
assert_eq!(
layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
}
}
#[test]
fn test_class_layout_builder_explicit_sizes() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type1_rid = 1; let type2_rid = 2; let type3_rid = 3;
let layout1_ref = ClassLayoutBuilder::new()
.parent(type1_rid)
.packing_size(4)
.class_size(16)
.build(&mut assembly)
.unwrap();
let layout2_ref = ClassLayoutBuilder::new()
.parent(type2_rid)
.packing_size(8)
.class_size(256)
.build(&mut assembly)
.unwrap();
let layout3_ref = ClassLayoutBuilder::new()
.parent(type3_rid)
.packing_size(16)
.class_size(65536)
.build(&mut assembly)
.unwrap();
assert_eq!(
layout1_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
layout2_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
layout3_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
}
}
#[test]
fn test_class_layout_builder_missing_packing_size() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let result = ClassLayoutBuilder::new()
.parent(type_rid)
.class_size(16)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_class_layout_builder_missing_class_size() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let result = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(4)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_class_layout_builder_missing_parent() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let result = ClassLayoutBuilder::new()
.packing_size(4)
.class_size(16)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_class_layout_builder_invalid_packing_size() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let result = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(3) .class_size(16)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_class_layout_builder_excessive_packing_size() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let result = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(256) .class_size(16)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_class_layout_builder_excessive_class_size() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let result = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(4)
.class_size(0x20000000) .build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_class_layout_builder_maximum_valid_values() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let type_rid = 1;
let layout_ref = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(128) .class_size(0x10000000 - 1) .build(&mut assembly)
.unwrap();
assert_eq!(
layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
}
}
#[test]
fn test_class_layout_builder_all_valid_packing_sizes() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let valid_packings = [0, 1, 2, 4, 8, 16, 32, 64, 128];
for (i, &packing) in valid_packings.iter().enumerate() {
let type_rid = 1 + i as u32;
let layout_ref = ClassLayoutBuilder::new()
.parent(type_rid)
.packing_size(packing)
.class_size(16)
.build(&mut assembly)
.unwrap();
assert_eq!(
layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
}
}
}
#[test]
fn test_class_layout_builder_realistic_scenarios() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let pinvoke_type_rid = 1; let pinvoke_layout_ref = ClassLayoutBuilder::new()
.parent(pinvoke_type_rid)
.packing_size(1) .class_size(32) .build(&mut assembly)
.unwrap();
let perf_type_rid = 2; let perf_layout_ref = ClassLayoutBuilder::new()
.parent(perf_type_rid)
.packing_size(64) .class_size(128) .build(&mut assembly)
.unwrap();
let simd_type_rid = 3; let simd_layout_ref = ClassLayoutBuilder::new()
.parent(simd_type_rid)
.packing_size(16) .class_size(64) .build(&mut assembly)
.unwrap();
let managed_type_rid = 4; let managed_layout_ref = ClassLayoutBuilder::new()
.parent(managed_type_rid)
.packing_size(0) .class_size(0) .build(&mut assembly)
.unwrap();
assert_eq!(
pinvoke_layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
perf_layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
simd_layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert_eq!(
managed_layout_ref.kind(),
ChangeRefKind::TableRow(TableId::ClassLayout)
);
assert!(!std::sync::Arc::ptr_eq(
&pinvoke_layout_ref,
&perf_layout_ref
));
assert!(!std::sync::Arc::ptr_eq(
&pinvoke_layout_ref,
&simd_layout_ref
));
assert!(!std::sync::Arc::ptr_eq(
&pinvoke_layout_ref,
&managed_layout_ref
));
assert!(!std::sync::Arc::ptr_eq(&perf_layout_ref, &simd_layout_ref));
assert!(!std::sync::Arc::ptr_eq(
&perf_layout_ref,
&managed_layout_ref
));
assert!(!std::sync::Arc::ptr_eq(
&simd_layout_ref,
&managed_layout_ref
));
}
}
}