use crate::{
cilassembly::{ChangeRefRc, CilAssembly},
metadata::{
tables::{CodedIndex, CodedIndexType, MethodSpecRaw, TableDataOwned, TableId},
token::Token,
},
Error, Result,
};
pub struct MethodSpecBuilder {
method: Option<CodedIndex>,
instantiation: Option<Vec<u8>>,
}
impl Default for MethodSpecBuilder {
fn default() -> Self {
Self::new()
}
}
impl MethodSpecBuilder {
#[must_use]
pub fn new() -> Self {
Self {
method: None,
instantiation: None,
}
}
#[must_use]
pub fn method(mut self, method: CodedIndex) -> Self {
self.method = Some(method);
self
}
#[must_use]
pub fn instantiation(mut self, instantiation: &[u8]) -> Self {
self.instantiation = Some(instantiation.to_vec());
self
}
#[must_use]
pub fn simple_instantiation(mut self, element_type: u8) -> Self {
let signature = vec![
0x01, element_type, ];
self.instantiation = Some(signature);
self
}
#[must_use]
pub fn multiple_primitives(mut self, element_types: &[u8]) -> Self {
let mut signature = vec![u8::try_from(element_types.len()).unwrap_or(255)]; signature.extend_from_slice(element_types);
self.instantiation = Some(signature);
self
}
#[must_use]
pub fn array_instantiation(mut self, element_type: u8) -> Self {
let signature = vec![
0x01, 0x1D, element_type, ];
self.instantiation = Some(signature);
self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<ChangeRefRc> {
let method = self
.method
.ok_or_else(|| Error::ModificationInvalid("Generic method is required".to_string()))?;
let instantiation = self.instantiation.ok_or_else(|| {
Error::ModificationInvalid("Instantiation signature is required".to_string())
})?;
if instantiation.is_empty() {
return Err(Error::ModificationInvalid(
"Instantiation signature cannot be empty".to_string(),
));
}
let valid_method_tables = CodedIndexType::MethodDefOrRef.tables();
if !valid_method_tables.contains(&method.tag) {
return Err(Error::ModificationInvalid(format!(
"Method must be a MethodDefOrRef coded index (MethodDef/MemberRef), got {:?}",
method.tag
)));
}
if instantiation.is_empty() {
return Err(Error::ModificationInvalid(
"Instantiation signature must contain at least the generic argument count"
.to_string(),
));
}
let arg_count = instantiation[0];
if arg_count == 0 {
return Err(Error::ModificationInvalid(
"Generic argument count cannot be zero".to_string(),
));
}
let instantiation_index = assembly.blob_add(&instantiation)?.placeholder();
let rid = assembly.next_rid(TableId::MethodSpec)?;
let token_value = ((TableId::MethodSpec as u32) << 24) | rid;
let token = Token::new(token_value);
let method_spec_raw = MethodSpecRaw {
rid,
token,
offset: 0, method,
instantiation: instantiation_index,
};
assembly.table_row_add(
TableId::MethodSpec,
TableDataOwned::MethodSpec(method_spec_raw),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::{ChangeRefKind, CilAssembly},
metadata::cilassemblyview::CilAssemblyView,
};
use std::path::PathBuf;
#[test]
fn test_method_spec_builder_basic() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let method_ref = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef); let instantiation_blob = vec![0x01, 0x08];
let ref_ = MethodSpecBuilder::new()
.method(method_ref)
.instantiation(&instantiation_blob)
.build(&mut assembly)
.unwrap();
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
}
}
#[test]
fn test_method_spec_builder_different_methods() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let instantiation_blob = vec![0x01, 0x08];
let methoddef = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef);
let ref1 = MethodSpecBuilder::new()
.method(methoddef)
.instantiation(&instantiation_blob)
.build(&mut assembly)
.unwrap();
let memberref = CodedIndex::new(TableId::MemberRef, 1, CodedIndexType::MethodDefOrRef);
let ref2 = MethodSpecBuilder::new()
.method(memberref)
.instantiation(&instantiation_blob)
.build(&mut assembly)
.unwrap();
assert_eq!(ref1.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert_eq!(ref2.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert!(!std::sync::Arc::ptr_eq(&ref1, &ref2));
}
}
#[test]
fn test_method_spec_builder_convenience_methods() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let method_ref = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef);
let ref1 = MethodSpecBuilder::new()
.method(method_ref.clone())
.simple_instantiation(0x08) .build(&mut assembly)
.unwrap();
let ref2 = MethodSpecBuilder::new()
.method(method_ref.clone())
.multiple_primitives(&[0x08, 0x0E]) .build(&mut assembly)
.unwrap();
let ref3 = MethodSpecBuilder::new()
.method(method_ref)
.array_instantiation(0x08) .build(&mut assembly)
.unwrap();
assert_eq!(ref1.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert_eq!(ref2.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert_eq!(ref3.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
}
}
#[test]
fn test_method_spec_builder_complex_instantiations() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let method_ref = CodedIndex::new(TableId::MemberRef, 1, CodedIndexType::MethodDefOrRef);
let complex_instantiation = vec![
0x03, 0x08, 0x0E, 0x1D, 0x08, ];
let ref_ = MethodSpecBuilder::new()
.method(method_ref)
.instantiation(&complex_instantiation)
.build(&mut assembly)
.unwrap();
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
}
}
#[test]
fn test_method_spec_builder_missing_method() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let instantiation_blob = vec![0x01, 0x08];
let result = MethodSpecBuilder::new()
.instantiation(&instantiation_blob)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_method_spec_builder_missing_instantiation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let method_ref = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef);
let result = MethodSpecBuilder::new()
.method(method_ref)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_method_spec_builder_empty_instantiation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let method_ref = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef);
let empty_blob = vec![];
let result = MethodSpecBuilder::new()
.method(method_ref)
.instantiation(&empty_blob)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_method_spec_builder_invalid_method_type() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let invalid_method = CodedIndex::new(TableId::Field, 1, CodedIndexType::MethodDefOrRef); let instantiation_blob = vec![0x01, 0x08];
let result = MethodSpecBuilder::new()
.method(invalid_method)
.instantiation(&instantiation_blob)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_method_spec_builder_zero_generic_args() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let method_ref = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef);
let zero_args_blob = vec![0x00];
let result = MethodSpecBuilder::new()
.method(method_ref)
.instantiation(&zero_args_blob)
.build(&mut assembly);
assert!(result.is_err());
}
}
#[test]
fn test_method_spec_builder_realistic_scenarios() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let list_add = CodedIndex::new(TableId::MethodDef, 1, CodedIndexType::MethodDefOrRef);
let ref1 = MethodSpecBuilder::new()
.method(list_add)
.simple_instantiation(0x08) .build(&mut assembly)
.unwrap();
let dict_tryget =
CodedIndex::new(TableId::MemberRef, 1, CodedIndexType::MethodDefOrRef);
let ref2 = MethodSpecBuilder::new()
.method(dict_tryget)
.multiple_primitives(&[0x0E, 0x08]) .build(&mut assembly)
.unwrap();
let array_method =
CodedIndex::new(TableId::MethodDef, 2, CodedIndexType::MethodDefOrRef);
let ref3 = MethodSpecBuilder::new()
.method(array_method)
.array_instantiation(0x0E) .build(&mut assembly)
.unwrap();
assert_eq!(ref1.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert_eq!(ref2.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert_eq!(ref3.kind(), ChangeRefKind::TableRow(TableId::MethodSpec));
assert!(!std::sync::Arc::ptr_eq(&ref1, &ref2));
assert!(!std::sync::Arc::ptr_eq(&ref1, &ref3));
assert!(!std::sync::Arc::ptr_eq(&ref2, &ref3));
}
}
}