use crate::{
cilassembly::{ChangeRefRc, CilAssembly},
metadata::{
method::{MethodAccessFlags, MethodImplCodeType, MethodModifiers},
signatures::{encode_method_signature, SignatureMethod, SignatureParameter, TypeSignature},
tables::{MethodDefBuilder, ParamAttributes, ParamBuilder, TableId},
token::Token,
},
Result,
};
use super::method_body::MethodBodyBuilder;
pub struct MethodBuilder {
name: String,
access_flags: MethodAccessFlags,
modifiers: MethodModifiers,
impl_flags: MethodImplCodeType,
return_type: TypeSignature,
parameters: Vec<(String, TypeSignature)>,
body_builder: Option<MethodBodyBuilder>,
has_this: bool,
explicit_this: bool,
default_calling_convention: bool,
vararg: bool,
cdecl: bool,
stdcall: bool,
thiscall: bool,
fastcall: bool,
}
impl MethodBuilder {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
access_flags: MethodAccessFlags::PRIVATE, modifiers: MethodModifiers::empty(),
impl_flags: MethodImplCodeType::IL,
return_type: TypeSignature::Void,
parameters: Vec::new(),
body_builder: None,
has_this: true, explicit_this: false,
default_calling_convention: true, vararg: false,
cdecl: false,
stdcall: false,
thiscall: false,
fastcall: false,
}
}
#[must_use]
pub fn constructor() -> Self {
Self::new(".ctor").public().special_name().rtspecial_name()
}
#[must_use]
pub fn static_constructor() -> Self {
Self::new(".cctor")
.private()
.static_method()
.special_name()
.rtspecial_name()
}
#[must_use]
pub fn property_getter(property_name: &str, return_type: TypeSignature) -> Self {
Self::new(&format!("get_{property_name}"))
.public()
.special_name()
.returns(return_type)
}
#[must_use]
pub fn property_setter(property_name: &str, value_type: TypeSignature) -> Self {
Self::new(&format!("set_{property_name}"))
.public()
.special_name()
.parameter("value", value_type)
}
#[must_use]
pub fn event_add(event_name: &str, delegate_type: TypeSignature) -> Self {
Self::new(&format!("add_{event_name}"))
.public()
.special_name()
.parameter("value", delegate_type)
}
#[must_use]
pub fn event_remove(event_name: &str, delegate_type: TypeSignature) -> Self {
Self::new(&format!("remove_{event_name}"))
.public()
.special_name()
.parameter("value", delegate_type)
}
#[must_use]
pub fn public(mut self) -> Self {
self.access_flags = MethodAccessFlags::PUBLIC;
self
}
#[must_use]
pub fn private(mut self) -> Self {
self.access_flags = MethodAccessFlags::PRIVATE;
self
}
#[must_use]
pub fn protected(mut self) -> Self {
self.access_flags = MethodAccessFlags::FAMILY;
self
}
#[must_use]
pub fn internal(mut self) -> Self {
self.access_flags = MethodAccessFlags::ASSEMBLY;
self
}
#[must_use]
pub fn static_method(mut self) -> Self {
self.modifiers |= MethodModifiers::STATIC;
self.has_this = false;
self
}
#[must_use]
pub fn virtual_method(mut self) -> Self {
self.modifiers |= MethodModifiers::VIRTUAL;
self
}
#[must_use]
pub fn abstract_method(mut self) -> Self {
self.modifiers |= MethodModifiers::ABSTRACT;
self
}
#[must_use]
pub fn sealed(mut self) -> Self {
self.modifiers |= MethodModifiers::FINAL;
self
}
#[must_use]
pub fn special_name(mut self) -> Self {
self.modifiers |= MethodModifiers::SPECIAL_NAME;
self
}
#[must_use]
pub fn rtspecial_name(mut self) -> Self {
self.modifiers |= MethodModifiers::RTSPECIAL_NAME;
self
}
#[must_use]
pub fn calling_convention_default(mut self) -> Self {
self.clear_calling_conventions();
self.default_calling_convention = true;
self
}
#[must_use]
pub fn calling_convention_vararg(mut self) -> Self {
self.clear_calling_conventions();
self.vararg = true;
self
}
#[must_use]
pub fn calling_convention_cdecl(mut self) -> Self {
self.clear_calling_conventions();
self.cdecl = true;
self
}
#[must_use]
pub fn calling_convention_stdcall(mut self) -> Self {
self.clear_calling_conventions();
self.stdcall = true;
self
}
#[must_use]
pub fn calling_convention_thiscall(mut self) -> Self {
self.clear_calling_conventions();
self.thiscall = true;
self
}
#[must_use]
pub fn calling_convention_fastcall(mut self) -> Self {
self.clear_calling_conventions();
self.fastcall = true;
self
}
#[must_use]
pub fn explicit_this(mut self) -> Self {
self.explicit_this = true;
self
}
#[must_use]
pub fn returns(mut self, return_type: TypeSignature) -> Self {
self.return_type = return_type;
self
}
#[must_use]
pub fn parameter(mut self, name: &str, param_type: TypeSignature) -> Self {
self.parameters.push((name.to_string(), param_type));
self
}
#[must_use]
pub fn implementation<F>(mut self, f: F) -> Self
where
F: FnOnce(MethodBodyBuilder) -> MethodBodyBuilder,
{
let body_builder = f(MethodBodyBuilder::new());
self.body_builder = Some(body_builder);
self
}
#[must_use]
pub fn extern_method(mut self) -> Self {
self.body_builder = None; self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<ChangeRefRc> {
let return_type = self.return_type.clone();
let parameters = self.parameters.clone();
let has_this = self.has_this;
let signature = SignatureMethod {
has_this,
explicit_this: self.explicit_this,
default: self.default_calling_convention,
vararg: self.vararg,
cdecl: self.cdecl,
stdcall: self.stdcall,
thiscall: self.thiscall,
fastcall: self.fastcall,
param_count_generic: 0,
param_count: u32::try_from(parameters.len())
.map_err(|_| malformed_error!("Method parameter count exceeds u32 range"))?,
return_type: SignatureParameter {
modifiers: Vec::new(),
by_ref: false,
base: return_type.clone(),
},
params: parameters
.iter()
.map(|(_, param_type)| SignatureParameter {
modifiers: Vec::new(),
by_ref: false,
base: param_type.clone(),
})
.collect(),
varargs: Vec::new(),
};
let signature_bytes = encode_method_signature(&signature)?;
let (rva, _local_sig_token) = if let Some(body_builder) = self.body_builder {
let (body_bytes, local_sig_token) = body_builder.build(assembly)?;
let placeholder_rva = assembly.store_method_body(body_bytes);
(placeholder_rva, local_sig_token)
} else {
(0u32, Token::new(0))
};
let combined_flags = self.access_flags.bits() | self.modifiers.bits();
let param_start_index = assembly.next_rid(TableId::Param)?;
for (sequence, (name, _param_type)) in parameters.iter().enumerate() {
let param_sequence = u32::try_from(sequence + 1)
.map_err(|_| malformed_error!("Parameter sequence exceeds u32 range"))?;
ParamBuilder::new()
.name(name)
.flags(ParamAttributes::IN) .sequence(param_sequence)
.build(assembly)?;
}
let method_token = MethodDefBuilder::new()
.name(&self.name)
.flags(combined_flags)
.impl_flags(self.impl_flags.bits())
.signature(&signature_bytes)
.rva(rva)
.param_list(param_start_index) .build(assembly)?;
Ok(method_token)
}
fn clear_calling_conventions(&mut self) {
self.default_calling_convention = false;
self.vararg = false;
self.cdecl = false;
self.stdcall = false;
self.thiscall = false;
self.fastcall = false;
}
}
impl Default for MethodBuilder {
fn default() -> Self {
Self::new("DefaultMethod")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::{ChangeRefKind, CilAssembly},
metadata::{cilassemblyview::CilAssemblyView, signatures::TypeSignature, tables::TableId},
};
use std::path::PathBuf;
fn get_test_assembly() -> Result<CilAssembly> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path)?;
Ok(CilAssembly::new(view))
}
#[test]
fn test_method_builder_basic() -> Result<()> {
let mut assembly = get_test_assembly()?;
let method_ref = MethodBuilder::new("TestMethod")
.public()
.static_method()
.returns(TypeSignature::Void)
.implementation(|body| {
body.implementation(|asm| {
asm.nop()?;
asm.ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
method_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_method_builder_with_parameters() -> Result<()> {
let mut assembly = get_test_assembly()?;
let method_ref = MethodBuilder::new("Add")
.public()
.static_method()
.parameter("a", TypeSignature::I4)
.parameter("b", TypeSignature::I4)
.returns(TypeSignature::I4)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()?.ldarg_1()?.add()?.ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
method_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_constructor_builder() -> Result<()> {
let mut assembly = get_test_assembly()?;
let ctor_ref = MethodBuilder::constructor()
.parameter("name", TypeSignature::String)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()? .call(Token::new(0x0A000001))? .ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(ctor_ref.kind(), ChangeRefKind::TableRow(TableId::MethodDef));
Ok(())
}
#[test]
fn test_property_getter() -> Result<()> {
let mut assembly = get_test_assembly()?;
let getter_ref = MethodBuilder::property_getter("Name", TypeSignature::String)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()?.ldfld(Token::new(0x04000001))?.ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
getter_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_property_setter() -> Result<()> {
let mut assembly = get_test_assembly()?;
let setter_ref = MethodBuilder::property_setter("Name", TypeSignature::String)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()? .ldarg_1()? .stfld(Token::new(0x04000001))?
.ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
setter_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_abstract_method() -> Result<()> {
let mut assembly = get_test_assembly()?;
let method_ref = MethodBuilder::new("AbstractMethod")
.public()
.abstract_method()
.virtual_method()
.returns(TypeSignature::I4)
.extern_method() .build(&mut assembly)?;
assert_eq!(
method_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_static_constructor() -> Result<()> {
let mut assembly = get_test_assembly()?;
let static_ctor_ref = MethodBuilder::static_constructor()
.implementation(|body| {
body.implementation(|asm| {
asm.ldc_i4_const(42)?
.stsfld(Token::new(0x04000001))?
.ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
static_ctor_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_method_with_locals() -> Result<()> {
let mut assembly = get_test_assembly()?;
let method_ref = MethodBuilder::new("ComplexMethod")
.public()
.static_method()
.parameter("input", TypeSignature::I4)
.returns(TypeSignature::I4)
.implementation(|body| {
body.local("temp", TypeSignature::I4)
.local("result", TypeSignature::I4)
.implementation(|asm| {
asm.ldarg_0()? .stloc_0()? .ldloc_0()? .ldc_i4_1()? .add()? .stloc_1()? .ldloc_1()? .ret()?; Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
method_ref.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_method_builder_calling_conventions() -> Result<()> {
let mut assembly = get_test_assembly()?;
let cdecl_method = MethodBuilder::new("CdeclMethod")
.public()
.static_method()
.calling_convention_cdecl()
.parameter("x", TypeSignature::I4)
.returns(TypeSignature::I4)
.extern_method() .build(&mut assembly)?;
assert_eq!(
cdecl_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
let stdcall_method = MethodBuilder::new("StdcallMethod")
.public()
.static_method()
.calling_convention_stdcall()
.parameter("x", TypeSignature::I4)
.returns(TypeSignature::I4)
.extern_method()
.build(&mut assembly)?;
assert_eq!(
stdcall_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
let default_method = MethodBuilder::new("DefaultMethod")
.public()
.static_method()
.calling_convention_default()
.parameter("x", TypeSignature::I4)
.returns(TypeSignature::I4)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()?.ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
default_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_method_builder_vararg_calling_convention() -> Result<()> {
let mut assembly = get_test_assembly()?;
let vararg_method = MethodBuilder::new("VarargMethod")
.public()
.static_method()
.calling_convention_vararg()
.parameter("format", TypeSignature::String)
.returns(TypeSignature::Void)
.extern_method() .build(&mut assembly)?;
assert_eq!(
vararg_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_method_builder_explicit_this() -> Result<()> {
let mut assembly = get_test_assembly()?;
let explicit_this_method = MethodBuilder::new("ExplicitThisMethod")
.public()
.explicit_this()
.parameter("value", TypeSignature::I4)
.returns(TypeSignature::Void)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()? .ldarg_1()? .stfld(Token::new(0x04000001))? .ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
explicit_this_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_method_builder_calling_convention_switching() -> Result<()> {
let mut assembly = get_test_assembly()?;
let method = MethodBuilder::new("SwitchingMethod")
.public()
.static_method()
.calling_convention_cdecl() .calling_convention_stdcall() .parameter("x", TypeSignature::I4)
.returns(TypeSignature::I4)
.extern_method()
.build(&mut assembly)?;
assert_eq!(method.kind(), ChangeRefKind::TableRow(TableId::MethodDef));
Ok(())
}
#[test]
fn test_event_add_method() -> Result<()> {
let mut assembly = get_test_assembly()?;
let add_method = MethodBuilder::event_add("OnClick", TypeSignature::Object)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()? .ldfld(Token::new(0x04000001))? .ldarg_1()? .call(Token::new(0x0A000001))? .stfld(Token::new(0x04000001))? .ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
add_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
#[test]
fn test_event_remove_method() -> Result<()> {
let mut assembly = get_test_assembly()?;
let remove_method = MethodBuilder::event_remove("OnClick", TypeSignature::Object)
.implementation(|body| {
body.implementation(|asm| {
asm.ldarg_0()? .ldfld(Token::new(0x04000001))? .ldarg_1()? .call(Token::new(0x0A000002))? .stfld(Token::new(0x04000001))? .ret()?;
Ok(())
})
})
.build(&mut assembly)?;
assert_eq!(
remove_method.kind(),
ChangeRefKind::TableRow(TableId::MethodDef)
);
Ok(())
}
}