use crate::{
cilassembly::{ChangeRefRc, CilAssembly},
metadata::{
tables::{CodedIndex, CodedIndexType, MethodImplRaw, TableDataOwned, TableId},
token::Token,
},
Error, Result,
};
#[derive(Debug, Clone, Copy)]
enum MethodRefTarget {
MethodDef(u32),
MemberRef(u32),
}
pub struct MethodImplBuilder {
class: Option<u32>,
method_body: Option<MethodRefTarget>,
method_declaration: Option<MethodRefTarget>,
}
impl Default for MethodImplBuilder {
fn default() -> Self {
Self::new()
}
}
impl MethodImplBuilder {
#[must_use]
pub fn new() -> Self {
Self {
class: None,
method_body: None,
method_declaration: None,
}
}
#[must_use]
pub fn class(mut self, row: u32) -> Self {
self.class = Some(row);
self
}
#[must_use]
pub fn method_body_from_method_def(mut self, row: u32) -> Self {
self.method_body = Some(MethodRefTarget::MethodDef(row));
self
}
#[must_use]
pub fn method_body_from_member_ref(mut self, row: u32) -> Self {
self.method_body = Some(MethodRefTarget::MemberRef(row));
self
}
#[must_use]
pub fn method_declaration_from_method_def(mut self, row: u32) -> Self {
self.method_declaration = Some(MethodRefTarget::MethodDef(row));
self
}
#[must_use]
pub fn method_declaration_from_member_ref(mut self, row: u32) -> Self {
self.method_declaration = Some(MethodRefTarget::MemberRef(row));
self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<ChangeRefRc> {
let class_rid = self.class.ok_or_else(|| {
Error::ModificationInvalid("MethodImplBuilder requires a class row index".to_string())
})?;
let method_body_target = self.method_body.ok_or_else(|| {
Error::ModificationInvalid("MethodImplBuilder requires a method body".to_string())
})?;
let method_declaration_target = self.method_declaration.ok_or_else(|| {
Error::ModificationInvalid(
"MethodImplBuilder requires a method declaration".to_string(),
)
})?;
let method_body = match method_body_target {
MethodRefTarget::MethodDef(row) => {
CodedIndex::new(TableId::MethodDef, row, CodedIndexType::MethodDefOrRef)
}
MethodRefTarget::MemberRef(row) => {
CodedIndex::new(TableId::MemberRef, row, CodedIndexType::MethodDefOrRef)
}
};
let method_declaration = match method_declaration_target {
MethodRefTarget::MethodDef(row) => {
CodedIndex::new(TableId::MethodDef, row, CodedIndexType::MethodDefOrRef)
}
MethodRefTarget::MemberRef(row) => {
CodedIndex::new(TableId::MemberRef, row, CodedIndexType::MethodDefOrRef)
}
};
let next_rid = assembly.next_rid(TableId::MethodImpl)?;
let token = Token::new(((TableId::MethodImpl as u32) << 24) | next_rid);
let method_impl_raw = MethodImplRaw {
rid: next_rid,
token,
offset: 0, class: class_rid,
method_body,
method_declaration,
};
assembly.table_row_add(
TableId::MethodImpl,
TableDataOwned::MethodImpl(method_impl_raw),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::{ChangeRefKind, CilAssembly},
metadata::cilassemblyview::CilAssemblyView,
};
use std::path::PathBuf;
#[test]
fn test_methodimpl_builder_creation() {
let builder = MethodImplBuilder::new();
assert!(builder.class.is_none());
assert!(builder.method_body.is_none());
assert!(builder.method_declaration.is_none());
}
#[test]
fn test_methodimpl_builder_default() {
let builder = MethodImplBuilder::default();
assert!(builder.class.is_none());
assert!(builder.method_body.is_none());
assert!(builder.method_declaration.is_none());
}
#[test]
fn test_interface_implementation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let _expected_rid = assembly.next_rid(TableId::MethodImpl).unwrap();
let implementing_class = 1; let implementation_method = 1; let interface_method = 1;
let ref_ = MethodImplBuilder::new()
.class(implementing_class)
.method_body_from_method_def(implementation_method)
.method_declaration_from_member_ref(interface_method)
.build(&mut assembly)
.expect("Should build MethodImpl");
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
}
}
#[test]
fn test_virtual_method_override() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let _expected_rid = assembly.next_rid(TableId::MethodImpl).unwrap();
let derived_class = 2; let override_method = 2; let base_method = 3;
let ref_ = MethodImplBuilder::new()
.class(derived_class)
.method_body_from_method_def(override_method)
.method_declaration_from_method_def(base_method)
.build(&mut assembly)
.expect("Should build virtual override MethodImpl");
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
}
}
#[test]
fn test_explicit_interface_implementation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let _expected_rid = assembly.next_rid(TableId::MethodImpl).unwrap();
let explicit_class = 3; let explicit_method = 4; let interface_decl = 2;
let ref_ = MethodImplBuilder::new()
.class(explicit_class)
.method_body_from_method_def(explicit_method)
.method_declaration_from_member_ref(interface_decl)
.build(&mut assembly)
.expect("Should build explicit interface MethodImpl");
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
}
}
#[test]
fn test_external_method_body() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let _expected_rid = assembly.next_rid(TableId::MethodImpl).unwrap();
let implementing_class = 1;
let external_method = 3; let interface_method = 4;
let ref_ = MethodImplBuilder::new()
.class(implementing_class)
.method_body_from_member_ref(external_method)
.method_declaration_from_member_ref(interface_method)
.build(&mut assembly)
.expect("Should build external method MethodImpl");
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
}
}
#[test]
fn test_mixed_method_refs() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let _expected_rid = assembly.next_rid(TableId::MethodImpl).unwrap();
let implementing_class = 1;
let ref_ = MethodImplBuilder::new()
.class(implementing_class)
.method_body_from_method_def(1)
.method_declaration_from_member_ref(1)
.build(&mut assembly)
.expect("Should build mixed method ref MethodImpl");
assert_eq!(ref_.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
}
}
#[test]
fn test_build_without_class_fails() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let result = MethodImplBuilder::new()
.method_body_from_method_def(1)
.method_declaration_from_member_ref(1)
.build(&mut assembly);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("requires a class row index"));
}
}
#[test]
fn test_build_without_method_body_fails() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let result = MethodImplBuilder::new()
.class(1)
.method_declaration_from_member_ref(1)
.build(&mut assembly);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("requires a method body"));
}
}
#[test]
fn test_build_without_method_declaration_fails() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let result = MethodImplBuilder::new()
.class(1)
.method_body_from_method_def(1)
.build(&mut assembly);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("requires a method declaration"));
}
}
#[test]
fn test_multiple_method_impls() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut assembly = CilAssembly::new(view);
let ref1 = MethodImplBuilder::new()
.class(1)
.method_body_from_method_def(1)
.method_declaration_from_member_ref(1)
.build(&mut assembly)
.expect("Should build first MethodImpl");
let ref2 = MethodImplBuilder::new()
.class(1)
.method_body_from_method_def(2)
.method_declaration_from_member_ref(2)
.build(&mut assembly)
.expect("Should build second MethodImpl");
assert_eq!(ref1.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
assert_eq!(ref2.kind(), ChangeRefKind::TableRow(TableId::MethodImpl));
assert!(!std::sync::Arc::ptr_eq(&ref1, &ref2));
}
}
}