use crate::{cilassembly::CilAssembly, Result};
#[derive(Debug, Clone)]
pub struct NativeImportsBuilder {
dlls: Vec<String>,
functions: Vec<(String, String)>,
ordinal_functions: Vec<(String, u16)>,
}
impl NativeImportsBuilder {
#[must_use]
pub fn new() -> Self {
Self {
dlls: Vec::new(),
functions: Vec::new(),
ordinal_functions: Vec::new(),
}
}
fn validate_dll_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(malformed_error!("DLL name cannot be empty"));
}
if name.contains('\0') {
return Err(malformed_error!("DLL name contains null character"));
}
if name.contains('/') || name.contains('\\') {
return Err(malformed_error!(
"DLL name contains path separators - use filename only"
));
}
Ok(())
}
fn validate_function_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(malformed_error!("Function name cannot be empty"));
}
if name.contains('\0') {
return Err(malformed_error!("Function name contains null character"));
}
Ok(())
}
pub fn add_dll(mut self, dll_name: impl Into<String>) -> Result<Self> {
let dll_name = dll_name.into();
Self::validate_dll_name(&dll_name)?;
if !self.dlls.contains(&dll_name) {
self.dlls.push(dll_name);
}
Ok(self)
}
pub fn add_function(
mut self,
dll_name: impl Into<String>,
function_name: impl Into<String>,
) -> Result<Self> {
let dll_name = dll_name.into();
let function_name = function_name.into();
Self::validate_dll_name(&dll_name)?;
Self::validate_function_name(&function_name)?;
if !self.dlls.contains(&dll_name) {
self.dlls.push(dll_name.clone());
}
self.functions.push((dll_name, function_name));
Ok(self)
}
pub fn add_function_by_ordinal(
mut self,
dll_name: impl Into<String>,
ordinal: u16,
) -> Result<Self> {
let dll_name = dll_name.into();
Self::validate_dll_name(&dll_name)?;
if ordinal == 0 {
return Err(malformed_error!("Ordinal cannot be 0"));
}
if !self.dlls.contains(&dll_name) {
self.dlls.push(dll_name.clone());
}
self.ordinal_functions.push((dll_name, ordinal));
Ok(self)
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<()> {
for dll_name in &self.dlls {
assembly.add_native_import_dll(dll_name)?;
}
for (dll_name, function_name) in &self.functions {
assembly.add_native_import_function(dll_name, function_name)?;
}
for (dll_name, ordinal) in &self.ordinal_functions {
assembly.add_native_import_function_by_ordinal(dll_name, *ordinal)?;
}
Ok(())
}
}
impl Default for NativeImportsBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cilassembly::CilAssembly;
use std::path::PathBuf;
#[test]
fn test_native_imports_builder_basic() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeImportsBuilder::new()
.add_dll("kernel32.dll")
.and_then(|b| b.add_function("kernel32.dll", "GetCurrentProcessId"))
.and_then(|b| b.add_function("kernel32.dll", "ExitProcess"))
.and_then(|b| b.build(&mut assembly));
assert!(result.is_ok());
}
}
#[test]
fn test_native_imports_builder_with_ordinals() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeImportsBuilder::new()
.add_dll("user32.dll")
.and_then(|b| b.add_function_by_ordinal("user32.dll", 120))
.and_then(|b| b.add_function("user32.dll", "GetWindowTextW"))
.and_then(|b| b.build(&mut assembly));
assert!(result.is_ok());
}
}
#[test]
fn test_native_imports_builder_auto_dll_addition() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeImportsBuilder::new()
.add_function("kernel32.dll", "GetCurrentProcessId")
.and_then(|b| b.add_function_by_ordinal("user32.dll", 120))
.and_then(|b| b.build(&mut assembly));
assert!(result.is_ok());
}
}
#[test]
fn test_native_imports_builder_empty() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeImportsBuilder::new().build(&mut assembly);
assert!(result.is_ok());
}
}
#[test]
fn test_native_imports_builder_duplicate_dlls() {
let builder = NativeImportsBuilder::new()
.add_dll("kernel32.dll")
.and_then(|b| b.add_dll("kernel32.dll"))
.and_then(|b| b.add_dll("user32.dll"))
.expect("Should not fail for valid DLL names");
assert_eq!(builder.dlls.len(), 2);
assert!(builder.dlls.contains(&"kernel32.dll".to_string()));
assert!(builder.dlls.contains(&"user32.dll".to_string()));
}
#[test]
fn test_native_imports_builder_fluent_api() {
let builder = NativeImportsBuilder::new()
.add_dll("kernel32.dll")
.and_then(|b| b.add_function("kernel32.dll", "GetCurrentProcessId"))
.and_then(|b| b.add_function("kernel32.dll", "ExitProcess"))
.and_then(|b| b.add_dll("user32.dll"))
.and_then(|b| b.add_function_by_ordinal("user32.dll", 120))
.expect("Should not fail for valid inputs");
assert_eq!(builder.dlls.len(), 2);
assert_eq!(builder.functions.len(), 2);
assert_eq!(builder.ordinal_functions.len(), 1);
assert!(builder.dlls.contains(&"kernel32.dll".to_string()));
assert!(builder.dlls.contains(&"user32.dll".to_string()));
assert!(builder.functions.contains(&(
"kernel32.dll".to_string(),
"GetCurrentProcessId".to_string()
)));
assert!(builder
.functions
.contains(&("kernel32.dll".to_string(), "ExitProcess".to_string())));
assert!(builder
.ordinal_functions
.contains(&("user32.dll".to_string(), 120)));
}
#[test]
fn test_native_imports_builder_validation_empty_dll() {
let result = NativeImportsBuilder::new().add_dll("");
assert!(result.is_err());
}
#[test]
fn test_native_imports_builder_validation_empty_function() {
let result = NativeImportsBuilder::new()
.add_dll("kernel32.dll")
.and_then(|b| b.add_function("kernel32.dll", ""));
assert!(result.is_err());
}
#[test]
fn test_native_imports_builder_validation_ordinal_zero() {
let result = NativeImportsBuilder::new()
.add_dll("user32.dll")
.and_then(|b| b.add_function_by_ordinal("user32.dll", 0));
assert!(result.is_err());
}
#[test]
fn test_native_imports_builder_validation_dll_with_path() {
let result = NativeImportsBuilder::new().add_dll("C:\\Windows\\kernel32.dll");
assert!(result.is_err());
}
}