use crate::{
cilassembly::{ChangeRefRc, CilAssembly},
metadata::{
tables::{CodedIndex, CodedIndexType, ConstantRaw, TableDataOwned, TableId},
token::Token,
typesystem::ELEMENT_TYPE,
},
Error, Result,
};
pub struct ConstantBuilder {
element_type: Option<u8>,
parent: Option<CodedIndex>,
value: Option<Vec<u8>>,
}
impl Default for ConstantBuilder {
fn default() -> Self {
Self::new()
}
}
impl ConstantBuilder {
#[must_use]
pub fn new() -> Self {
Self {
element_type: None,
parent: None,
value: None,
}
}
#[must_use]
pub fn element_type(mut self, element_type: u8) -> Self {
self.element_type = Some(element_type);
self
}
#[must_use]
pub fn parent(mut self, parent: CodedIndex) -> Self {
self.parent = Some(parent);
self
}
#[must_use]
pub fn value(mut self, value: &[u8]) -> Self {
self.value = Some(value.to_vec());
self
}
#[must_use]
pub fn string_value(mut self, string_value: &str) -> Self {
let utf16_bytes: Vec<u8> = string_value
.encode_utf16()
.flat_map(u16::to_le_bytes)
.collect();
self.element_type = Some(ELEMENT_TYPE::STRING);
self.value = Some(utf16_bytes);
self
}
#[must_use]
pub fn i4_value(mut self, int_value: i32) -> Self {
self.element_type = Some(ELEMENT_TYPE::I4);
self.value = Some(int_value.to_le_bytes().to_vec());
self
}
#[must_use]
pub fn boolean_value(mut self, bool_value: bool) -> Self {
self.element_type = Some(ELEMENT_TYPE::BOOLEAN);
self.value = Some(vec![u8::from(bool_value)]);
self
}
#[must_use]
pub fn null_reference_value(mut self) -> Self {
self.element_type = Some(ELEMENT_TYPE::CLASS);
self.value = Some(vec![0, 0, 0, 0]); self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<ChangeRefRc> {
let element_type = self.element_type.ok_or_else(|| {
Error::ModificationInvalid("Constant element type is required".to_string())
})?;
let parent = self
.parent
.ok_or_else(|| Error::ModificationInvalid("Constant parent is required".to_string()))?;
let value = self
.value
.ok_or_else(|| Error::ModificationInvalid("Constant value is required".to_string()))?;
if value.is_empty() && element_type != ELEMENT_TYPE::CLASS {
return Err(Error::ModificationInvalid(
"Constant value cannot be empty (except for null references)".to_string(),
));
}
let valid_parent_tables = CodedIndexType::HasConstant.tables();
if !valid_parent_tables.contains(&parent.tag) {
return Err(Error::ModificationInvalid(format!(
"Parent must be a HasConstant coded index (Field/Param/Property), got {:?}",
parent.tag
)));
}
match element_type {
ELEMENT_TYPE::BOOLEAN
| ELEMENT_TYPE::CHAR
| ELEMENT_TYPE::I1
| ELEMENT_TYPE::U1
| ELEMENT_TYPE::I2
| ELEMENT_TYPE::U2
| ELEMENT_TYPE::I4
| ELEMENT_TYPE::U4
| ELEMENT_TYPE::I8
| ELEMENT_TYPE::U8
| ELEMENT_TYPE::R4
| ELEMENT_TYPE::R8
| ELEMENT_TYPE::STRING
| ELEMENT_TYPE::CLASS => {
}
_ => {
return Err(Error::ModificationInvalid(format!(
"Invalid element type for constant: 0x{element_type:02X}. Only primitive types, strings, and null references are allowed"
)));
}
}
let value_index = if value.is_empty() {
0 } else {
assembly.blob_add(&value)?.placeholder()
};
let rid = assembly.next_rid(TableId::Constant)?;
let token = Token::from_parts(TableId::Constant, rid);
let constant_raw = ConstantRaw {
rid,
token,
offset: 0, base: element_type,
parent,
value: value_index,
};
assembly.table_row_add(TableId::Constant, TableDataOwned::Constant(constant_raw))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::{ChangeRefKind, CilAssembly},
metadata::cilassemblyview::CilAssemblyView,
};
use std::path::PathBuf;
#[test]
fn test_constant_builder_basic_integer() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_ref = CodedIndex::new(TableId::Field, 1, CodedIndexType::HasConstant);
let int_value = 42i32.to_le_bytes();
let const_ref = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I4)
.parent(field_ref)
.value(&int_value)
.build(&mut assembly)
.unwrap();
assert_eq!(const_ref.kind(), ChangeRefKind::TableRow(TableId::Constant));
}
}
#[test]
fn test_constant_builder_i4_convenience() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_ref = CodedIndex::new(TableId::Field, 1, CodedIndexType::HasConstant);
let const_ref = ConstantBuilder::new()
.parent(field_ref)
.i4_value(42)
.build(&mut assembly)
.unwrap();
assert_eq!(const_ref.kind(), ChangeRefKind::TableRow(TableId::Constant));
}
}
#[test]
fn test_constant_builder_boolean() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let param_ref = CodedIndex::new(TableId::Param, 1, CodedIndexType::HasConstant);
let const_ref = ConstantBuilder::new()
.parent(param_ref)
.boolean_value(true)
.build(&mut assembly)
.unwrap();
assert_eq!(const_ref.kind(), ChangeRefKind::TableRow(TableId::Constant));
}
}
#[test]
fn test_constant_builder_string() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let property_ref = CodedIndex::new(TableId::Property, 1, CodedIndexType::HasConstant);
let const_ref = ConstantBuilder::new()
.parent(property_ref)
.string_value("Hello, World!")
.build(&mut assembly)
.unwrap();
assert_eq!(const_ref.kind(), ChangeRefKind::TableRow(TableId::Constant));
}
}
#[test]
fn test_constant_builder_null_reference() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_ref = CodedIndex::new(TableId::Field, 2, CodedIndexType::HasConstant);
let null_value = [0u8, 0u8, 0u8, 0u8];
let const_ref = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::CLASS)
.parent(field_ref)
.value(&null_value)
.build(&mut assembly)
.unwrap();
assert_eq!(const_ref.kind(), ChangeRefKind::TableRow(TableId::Constant));
}
}
#[test]
fn test_constant_builder_missing_element_type() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_ref = CodedIndex::new(TableId::Field, 1, CodedIndexType::HasConstant);
let int_value = 42i32.to_le_bytes();
let result = ConstantBuilder::new()
.parent(field_ref)
.value(&int_value)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_constant_builder_missing_parent() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let int_value = 42i32.to_le_bytes();
let result = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I4)
.value(&int_value)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_constant_builder_missing_value() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_ref = CodedIndex::new(TableId::Field, 1, CodedIndexType::HasConstant);
let result = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I4)
.parent(field_ref)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_constant_builder_invalid_parent_type() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let invalid_parent = CodedIndex::new(TableId::TypeDef, 1, CodedIndexType::HasConstant); let int_value = 42i32.to_le_bytes();
let result = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I4)
.parent(invalid_parent)
.value(&int_value)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_constant_builder_invalid_element_type() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_ref = CodedIndex::new(TableId::Field, 1, CodedIndexType::HasConstant);
let int_value = 42i32.to_le_bytes();
let result = ConstantBuilder::new()
.element_type(0xFF) .parent(field_ref)
.value(&int_value)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_constant_builder_multiple_constants() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field1 = CodedIndex::new(TableId::Field, 1, CodedIndexType::HasConstant);
let field2 = CodedIndex::new(TableId::Field, 2, CodedIndexType::HasConstant);
let param1 = CodedIndex::new(TableId::Param, 1, CodedIndexType::HasConstant);
let property1 = CodedIndex::new(TableId::Property, 1, CodedIndexType::HasConstant);
let const1_ref = ConstantBuilder::new()
.parent(field1)
.i4_value(42)
.build(&mut assembly)
.unwrap();
let const2_ref = ConstantBuilder::new()
.parent(field2)
.boolean_value(true)
.build(&mut assembly)
.unwrap();
let const3_ref = ConstantBuilder::new()
.parent(param1)
.string_value("default value")
.build(&mut assembly)
.unwrap();
let const4_ref = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::R8)
.parent(property1)
.value(&std::f64::consts::PI.to_le_bytes())
.build(&mut assembly)
.unwrap();
assert!(!std::sync::Arc::ptr_eq(&const1_ref, &const2_ref));
assert!(!std::sync::Arc::ptr_eq(&const1_ref, &const3_ref));
assert!(!std::sync::Arc::ptr_eq(&const1_ref, &const4_ref));
assert!(!std::sync::Arc::ptr_eq(&const2_ref, &const3_ref));
assert!(!std::sync::Arc::ptr_eq(&const2_ref, &const4_ref));
assert!(!std::sync::Arc::ptr_eq(&const3_ref, &const4_ref));
assert_eq!(
const1_ref.kind(),
ChangeRefKind::TableRow(TableId::Constant)
);
assert_eq!(
const2_ref.kind(),
ChangeRefKind::TableRow(TableId::Constant)
);
assert_eq!(
const3_ref.kind(),
ChangeRefKind::TableRow(TableId::Constant)
);
assert_eq!(
const4_ref.kind(),
ChangeRefKind::TableRow(TableId::Constant)
);
}
}
#[test]
fn test_constant_builder_all_primitive_types() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let field_refs: Vec<_> = (1..=12)
.map(|i| CodedIndex::new(TableId::Field, i, CodedIndexType::HasConstant))
.collect();
let _bool_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::BOOLEAN)
.parent(field_refs[0].clone())
.value(&[1u8])
.build(&mut assembly)
.unwrap();
let _char_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::CHAR)
.parent(field_refs[1].clone())
.value(&('A' as u16).to_le_bytes())
.build(&mut assembly)
.unwrap();
let _i1_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I1)
.parent(field_refs[2].clone())
.value(&(-42i8).to_le_bytes())
.build(&mut assembly)
.unwrap();
let _i2_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I2)
.parent(field_refs[3].clone())
.value(&(-1000i16).to_le_bytes())
.build(&mut assembly)
.unwrap();
let _i4_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I4)
.parent(field_refs[4].clone())
.value(&(-100000i32).to_le_bytes())
.build(&mut assembly)
.unwrap();
let _i8_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::I8)
.parent(field_refs[5].clone())
.value(&(-1000000000000i64).to_le_bytes())
.build(&mut assembly)
.unwrap();
let _u1_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::U1)
.parent(field_refs[6].clone())
.value(&255u8.to_le_bytes())
.build(&mut assembly)
.unwrap();
let _u2_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::U2)
.parent(field_refs[7].clone())
.value(&65535u16.to_le_bytes())
.build(&mut assembly)
.unwrap();
let _u4_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::U4)
.parent(field_refs[8].clone())
.value(&4294967295u32.to_le_bytes())
.build(&mut assembly)
.unwrap();
let _u8_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::U8)
.parent(field_refs[9].clone())
.value(&18446744073709551615u64.to_le_bytes())
.build(&mut assembly)
.unwrap();
let _r4_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::R4)
.parent(field_refs[10].clone())
.value(&std::f32::consts::PI.to_le_bytes())
.build(&mut assembly)
.unwrap();
let _r8_const = ConstantBuilder::new()
.element_type(ELEMENT_TYPE::R8)
.parent(field_refs[11].clone())
.value(&std::f64::consts::E.to_le_bytes())
.build(&mut assembly)
.unwrap();
}
}
}