use crate::{
assembly::InstructionAssembler,
cilassembly::CilAssembly,
metadata::{
method::{encode_exception_handlers, ExceptionHandler, ExceptionHandlerFlags},
signatures::{
encode_local_var_signature, SignatureLocalVariable, SignatureLocalVariables,
TypeSignature,
},
tables::StandAloneSigBuilder,
token::Token,
typesystem::CilTypeRc,
},
Error, Result,
};
#[derive(Clone)]
struct LabeledExceptionHandler {
flags: ExceptionHandlerFlags,
try_start_label: String,
try_end_label: String,
handler_start_label: String,
handler_end_label: String,
handler_type: Option<CilTypeRc>,
filter_start_label: Option<String>,
}
type ImplementationFn = Box<dyn FnOnce(&mut InstructionAssembler) -> Result<()>>;
use crate::metadata::method::encode_method_body_header;
fn resolve_labeled_exception_handler(
assembler: &InstructionAssembler,
labeled_handler: &LabeledExceptionHandler,
) -> Result<ExceptionHandler> {
let try_start_offset = assembler
.get_label_position(&labeled_handler.try_start_label)
.ok_or_else(|| Error::UndefinedLabel(labeled_handler.try_start_label.clone()))?;
let try_end_offset = assembler
.get_label_position(&labeled_handler.try_end_label)
.ok_or_else(|| Error::UndefinedLabel(labeled_handler.try_end_label.clone()))?;
let handler_start_offset = assembler
.get_label_position(&labeled_handler.handler_start_label)
.ok_or_else(|| Error::UndefinedLabel(labeled_handler.handler_start_label.clone()))?;
let handler_end_offset = assembler
.get_label_position(&labeled_handler.handler_end_label)
.ok_or_else(|| Error::UndefinedLabel(labeled_handler.handler_end_label.clone()))?;
if try_end_offset < try_start_offset {
return Err(Error::ModificationInvalid(format!(
"Try end label '{}' (at {}) is before try start label '{}' (at {})",
labeled_handler.try_end_label,
try_end_offset,
labeled_handler.try_start_label,
try_start_offset
)));
}
if handler_end_offset < handler_start_offset {
return Err(Error::ModificationInvalid(format!(
"Handler end label '{}' (at {}) is before handler start label '{}' (at {})",
labeled_handler.handler_end_label,
handler_end_offset,
labeled_handler.handler_start_label,
handler_start_offset
)));
}
let try_length = try_end_offset - try_start_offset;
let handler_length = handler_end_offset - handler_start_offset;
let filter_offset = if let Some(filter_label) = &labeled_handler.filter_start_label {
assembler
.get_label_position(filter_label)
.ok_or_else(|| Error::UndefinedLabel(filter_label.clone()))?
} else {
0
};
Ok(ExceptionHandler {
flags: labeled_handler.flags,
try_offset: try_start_offset,
try_length,
handler_offset: handler_start_offset,
handler_length,
handler: labeled_handler.handler_type.clone(),
filter_offset,
})
}
fn validate_exception_handler_ranges(handlers: &[ExceptionHandler], code_size: u32) -> Result<()> {
for (index, handler) in handlers.iter().enumerate() {
let try_end = handler
.try_offset
.checked_add(handler.try_length)
.ok_or_else(|| {
Error::ModificationInvalid(format!(
"Exception handler {}: try block range overflow (offset {} + length {})",
index, handler.try_offset, handler.try_length
))
})?;
if try_end > code_size {
return Err(Error::ModificationInvalid(format!(
"Exception handler {}: try block exceeds code size (offset {} + length {} = {}, code size = {})",
index, handler.try_offset, handler.try_length, try_end, code_size
)));
}
let handler_end = handler
.handler_offset
.checked_add(handler.handler_length)
.ok_or_else(|| {
Error::ModificationInvalid(format!(
"Exception handler {}: handler block range overflow (offset {} + length {})",
index, handler.handler_offset, handler.handler_length
))
})?;
if handler_end > code_size {
return Err(Error::ModificationInvalid(format!(
"Exception handler {}: handler block exceeds code size (offset {} + length {} = {}, code size = {})",
index, handler.handler_offset, handler.handler_length, handler_end, code_size
)));
}
if handler.flags.contains(ExceptionHandlerFlags::FILTER)
&& handler.filter_offset >= code_size
{
return Err(Error::ModificationInvalid(format!(
"Exception handler {}: filter offset {} exceeds code size {}",
index, handler.filter_offset, code_size
)));
}
}
Ok(())
}
pub struct MethodBodyBuilder {
max_stack: Option<u16>,
init_locals: bool,
locals: Vec<(String, TypeSignature)>,
implementation: Option<ImplementationFn>,
exception_handlers: Vec<ExceptionHandler>,
labeled_exception_handlers: Vec<LabeledExceptionHandler>,
prebuilt_bytecode: Option<Vec<u8>>,
prebuilt_local_sig: Option<SignatureLocalVariables>,
}
impl MethodBodyBuilder {
#[must_use]
pub fn new() -> Self {
Self {
max_stack: None,
init_locals: true,
locals: Vec::new(),
implementation: None,
exception_handlers: Vec::new(),
labeled_exception_handlers: Vec::new(),
prebuilt_bytecode: None,
prebuilt_local_sig: None,
}
}
#[must_use]
pub fn from_compilation(
bytecode: Vec<u8>,
max_stack: u16,
locals: Vec<SignatureLocalVariable>,
exception_handlers: Vec<ExceptionHandler>,
) -> Self {
let prebuilt_local_sig = if locals.is_empty() {
None
} else {
Some(SignatureLocalVariables { locals })
};
Self {
max_stack: Some(max_stack),
init_locals: true,
locals: Vec::new(),
implementation: None,
exception_handlers,
labeled_exception_handlers: Vec::new(),
prebuilt_bytecode: Some(bytecode),
prebuilt_local_sig,
}
}
#[must_use]
pub fn max_stack(mut self, stack_size: u16) -> Self {
self.max_stack = Some(stack_size);
self
}
#[must_use]
pub fn local(mut self, name: &str, local_type: TypeSignature) -> Self {
self.locals.push((name.to_string(), local_type));
self
}
#[must_use]
pub fn init_locals(mut self, init: bool) -> Self {
self.init_locals = init;
self
}
#[must_use]
pub fn exception_handler(mut self, handler: ExceptionHandler) -> Self {
self.exception_handlers.push(handler);
self
}
#[must_use]
pub fn catch_handler(
mut self,
try_offset: u32,
try_length: u32,
handler_offset: u32,
handler_length: u32,
exception_type: Option<CilTypeRc>,
) -> Self {
let handler = ExceptionHandler {
flags: if exception_type.is_some() {
ExceptionHandlerFlags::EXCEPTION
} else {
ExceptionHandlerFlags::FAULT
},
try_offset,
try_length,
handler_offset,
handler_length,
handler: exception_type,
filter_offset: 0,
};
self.exception_handlers.push(handler);
self
}
#[must_use]
pub fn finally_handler(
mut self,
try_offset: u32,
try_length: u32,
handler_offset: u32,
handler_length: u32,
) -> Self {
let handler = ExceptionHandler {
flags: ExceptionHandlerFlags::FINALLY,
try_offset,
try_length,
handler_offset,
handler_length,
handler: None,
filter_offset: 0,
};
self.exception_handlers.push(handler);
self
}
#[must_use]
pub fn finally_handler_with_labels(
mut self,
try_start_label: &str,
try_end_label: &str,
handler_start_label: &str,
handler_end_label: &str,
) -> Self {
let handler = LabeledExceptionHandler {
flags: ExceptionHandlerFlags::FINALLY,
try_start_label: try_start_label.to_string(),
try_end_label: try_end_label.to_string(),
handler_start_label: handler_start_label.to_string(),
handler_end_label: handler_end_label.to_string(),
handler_type: None,
filter_start_label: None,
};
self.labeled_exception_handlers.push(handler);
self
}
#[must_use]
pub fn catch_handler_with_labels(
mut self,
try_start_label: &str,
try_end_label: &str,
handler_start_label: &str,
handler_end_label: &str,
exception_type: Option<CilTypeRc>,
) -> Self {
let handler = LabeledExceptionHandler {
flags: if exception_type.is_some() {
ExceptionHandlerFlags::EXCEPTION
} else {
ExceptionHandlerFlags::FAULT
},
try_start_label: try_start_label.to_string(),
try_end_label: try_end_label.to_string(),
handler_start_label: handler_start_label.to_string(),
handler_end_label: handler_end_label.to_string(),
handler_type: exception_type,
filter_start_label: None,
};
self.labeled_exception_handlers.push(handler);
self
}
#[must_use]
pub fn filter_handler_with_labels(
mut self,
try_start_label: &str,
try_end_label: &str,
filter_start_label: &str,
handler_start_label: &str,
handler_end_label: &str,
) -> Self {
let handler = LabeledExceptionHandler {
flags: ExceptionHandlerFlags::FILTER,
try_start_label: try_start_label.to_string(),
try_end_label: try_end_label.to_string(),
handler_start_label: handler_start_label.to_string(),
handler_end_label: handler_end_label.to_string(),
handler_type: None,
filter_start_label: Some(filter_start_label.to_string()),
};
self.labeled_exception_handlers.push(handler);
self
}
#[must_use]
pub fn implementation<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut InstructionAssembler) -> Result<()> + 'static,
{
self.implementation = Some(Box::new(f));
self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<(Vec<u8>, Token)> {
let MethodBodyBuilder {
max_stack,
init_locals,
locals,
implementation,
exception_handlers,
labeled_exception_handlers,
prebuilt_bytecode,
prebuilt_local_sig,
} = self;
let (code_bytes, all_exception_handlers, max_stack) =
if let Some(bytecode) = prebuilt_bytecode {
let max_stack = max_stack.unwrap_or(8);
(bytecode, exception_handlers, max_stack)
} else {
let implementation = implementation.ok_or_else(|| {
Error::ModificationInvalid("Method body implementation is required".to_string())
})?;
let mut assembler = InstructionAssembler::new();
implementation(&mut assembler)?;
let mut all_exception_handlers = exception_handlers;
for labeled_handler in labeled_exception_handlers {
let resolved_handler =
resolve_labeled_exception_handler(&assembler, &labeled_handler)?;
all_exception_handlers.push(resolved_handler);
}
let (code_bytes, calculated_max_stack, _) = assembler.finish()?;
let max_stack = max_stack.unwrap_or(calculated_max_stack);
(code_bytes, all_exception_handlers, max_stack)
};
let local_var_sig_token_value = if let Some(local_sig) = prebuilt_local_sig {
let sig_bytes = encode_local_var_signature(&local_sig)?;
let local_sig_ref = StandAloneSigBuilder::new()
.signature(&sig_bytes)
.build(assembly)?;
local_sig_ref.placeholder_token().map_or(0, |t| t.value())
} else if locals.is_empty() {
0u32
} else {
let signature_locals: Vec<SignatureLocalVariable> = locals
.iter()
.map(|(_, sig)| SignatureLocalVariable {
modifiers: Vec::new(),
is_byref: false,
is_pinned: false,
base: sig.clone(),
})
.collect();
let local_sig = SignatureLocalVariables {
locals: signature_locals,
};
let sig_bytes = encode_local_var_signature(&local_sig)?;
let local_sig_ref = StandAloneSigBuilder::new()
.signature(&sig_bytes)
.build(assembly)?;
local_sig_ref.placeholder_token().map_or(0, |t| t.value())
};
let has_exceptions = !all_exception_handlers.is_empty();
let code_size = u32::try_from(code_bytes.len())
.map_err(|_| malformed_error!("Method body size exceeds u32 range"))?;
let header = encode_method_body_header(
code_size,
max_stack,
local_var_sig_token_value,
has_exceptions,
init_locals,
)?;
let mut body = header;
body.extend_from_slice(&code_bytes);
if has_exceptions {
validate_exception_handler_ranges(&all_exception_handlers, code_size)?;
while body.len() % 4 != 0 {
body.push(0x00);
}
let eh_section = encode_exception_handlers(&all_exception_handlers)?;
body.extend_from_slice(&eh_section);
}
Ok((body, Token::new(local_var_sig_token_value)))
}
}
impl Default for MethodBodyBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cilassembly::CilAssembly;
use crate::metadata::cilassemblyview::CilAssemblyView;
use std::path::PathBuf;
fn get_test_assembly() -> Result<CilAssembly> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path)?;
Ok(CilAssembly::new(view))
}
#[test]
fn test_method_body_builder_basic() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, _local_sig_token) = MethodBodyBuilder::new()
.implementation(|asm| {
asm.ldc_i4_1()?.ret()?;
Ok(())
})
.build(&mut assembly)?;
assert!(body_bytes.len() >= 3);
assert_eq!(body_bytes[0], 0x0A);
assert_eq!(body_bytes[1], 0x17); assert_eq!(body_bytes[2], 0x2A);
Ok(())
}
#[test]
fn test_method_body_builder_with_max_stack() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, _local_sig_token) = MethodBodyBuilder::new()
.max_stack(10)
.implementation(|asm| {
asm.nop()?.ret()?;
Ok(())
})
.build(&mut assembly)?;
assert!(body_bytes.len() >= 14);
let flags = u16::from_le_bytes([body_bytes[0], body_bytes[1]]);
assert_eq!(flags & 0x0003, 0x0003);
Ok(())
}
#[test]
fn test_method_body_builder_with_locals() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, local_sig_token) = MethodBodyBuilder::new()
.local("temp", TypeSignature::I4)
.local("result", TypeSignature::String)
.implementation(|asm| {
asm.ldarg_0()?.stloc_0()?.ldloc_0()?.ret()?;
Ok(())
})
.build(&mut assembly)?;
assert_ne!(local_sig_token.value(), 0);
assert!(!body_bytes.is_empty());
Ok(())
}
#[test]
fn test_method_body_builder_complex_method() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, _local_sig_token) = MethodBodyBuilder::new()
.local("counter", TypeSignature::I4)
.implementation(|asm| {
asm.ldc_i4_0()? .stloc_0()? .label("loop")? .ldloc_0()? .ldc_i4_const(10)? .blt_s("continue")? .ldloc_0()? .ret()? .label("continue")?
.ldloc_0()? .ldc_i4_1()? .add()? .stloc_0()? .br_s("loop")?; Ok(())
})
.build(&mut assembly)?;
assert!(body_bytes.len() > 10);
Ok(())
}
#[test]
fn test_method_body_builder_no_implementation_fails() {
let mut assembly = get_test_assembly().unwrap();
let result = MethodBodyBuilder::new().build(&mut assembly);
assert!(result.is_err());
}
#[test]
fn test_method_body_with_exception_handlers() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, _local_sig_token) = MethodBodyBuilder::new()
.finally_handler(0, 5, 5, 3) .implementation(|asm| {
asm.nop()?; asm.nop()?; asm.nop()?; asm.nop()?; asm.nop()?; asm.nop()?; asm.nop()?; asm.endfinally()?; asm.ret()?; Ok(())
})
.build(&mut assembly)?;
assert!(!body_bytes.is_empty());
assert!(body_bytes.len() >= 12);
Ok(())
}
#[test]
fn test_filter_handler_with_labels() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, _local_sig_token) = MethodBodyBuilder::new()
.filter_handler_with_labels(
"try_start",
"try_end",
"filter_start",
"handler_start",
"handler_end",
)
.implementation(|asm| {
asm.label("try_start")?;
asm.nop()?;
asm.nop()?;
asm.leave_s("try_end")?;
asm.label("filter_start")?;
asm.ldc_i4_1()?; asm.endfilter()?;
asm.label("handler_start")?;
asm.nop()?; asm.leave_s("handler_end")?;
asm.label("handler_end")?;
asm.label("try_end")?;
asm.ret()?;
Ok(())
})
.build(&mut assembly)?;
assert!(!body_bytes.is_empty());
assert!(body_bytes.len() >= 12);
Ok(())
}
#[test]
fn test_accurate_stack_tracking() -> Result<()> {
let mut assembly = get_test_assembly()?;
let (body_bytes, _local_sig_token) = MethodBodyBuilder::new()
.implementation(|asm| {
asm.ldc_i4_1()?.ldc_i4_2()?.add()?.dup()?.ret()?;
Ok(())
})
.build(&mut assembly)?;
assert!(!body_bytes.is_empty());
assert_eq!(body_bytes[0], (5 << 2) | 0x02);
Ok(())
}
}