pub mod build;
pub mod module;
pub mod thread;
use crate::processes::nt::{NtIteratorError, NtProcessState};
use crate::WindowsString;
use crate::{processes, safe_handle::*};
use module::{Module, ModuleIterator};
use ntapi::ntpsapi::{NtResumeProcess, NtSuspendProcess};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::PathBuf;
use thiserror::Error;
use thread::{NtThread, NtThreadIterator, ThreadError, ThreadIterator};
use windows::core::s;
use windows::core::{Error, PCSTR};
use windows::Win32::System::Diagnostics::Debug::{ReadProcessMemory, WriteProcessMemory};
use windows::Win32::System::Memory::{
VirtualAllocEx, VirtualProtectEx, VirtualQueryEx, MEMORY_BASIC_INFORMATION, MEM_COMMIT,
MEM_RESERVE, PAGE_PROTECTION_FLAGS,
};
use windows::Win32::System::ProcessStatus::GetModuleFileNameExW;
use windows::Win32::System::Threading::{
CreateProcessA, CreateRemoteThread, GetCurrentProcessId, OpenProcess, TerminateProcess,
CREATE_SUSPENDED, PROCESS_ACCESS_RIGHTS, PROCESS_CREATION_FLAGS, PROCESS_INFORMATION,
STARTUPINFOA,
};
#[derive(Error, Debug)]
pub enum ProcessError {
#[error("Could not find the process searched for.")]
NoProcessFound,
#[error("Permission denied.")]
PermissionDenied(#[from] Error),
#[error(transparent)]
NtIteratorError(#[from] NtIteratorError),
#[error(transparent)]
HandleError(#[from] HandleError),
#[error("{0}")]
ProcessError(String),
#[error(transparent)]
ThreadError(#[from] ThreadError),
}
type Result<T> = std::result::Result<T, ProcessError>;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Snapshot;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct NT;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Created;
pub struct Process<Method> {
pub handle: SafeHandle,
pub name: String,
pub process_id: u32,
extensions: HashMap<TypeId, Box<dyn Any>>,
_marker: std::marker::PhantomData<Method>,
}
impl Process<Snapshot> {
pub fn from_name(name: &str, access: PROCESS_ACCESS_RIGHTS) -> Result<Self> {
let target = processes::get_from_snapshot()
.find(|p| p.get_name() == name)
.ok_or(ProcessError::NoProcessFound)?;
let handle = unsafe {
OpenProcess(access, false, target.process_id).map_err(ProcessError::PermissionDenied)?
};
let mut process = Self {
handle: SafeHandle::new(handle),
name: target.get_name(),
process_id: target.process_id,
extensions: HashMap::new(),
_marker: std::marker::PhantomData,
};
process.insert(target);
Ok(process)
}
pub fn from_pid(pid: u32, access: PROCESS_ACCESS_RIGHTS) -> Result<Self> {
let target = processes::get_from_snapshot()
.find(|p| p.process_id == pid)
.ok_or(ProcessError::NoProcessFound)?;
let handle = unsafe {
OpenProcess(access, false, target.process_id).map_err(ProcessError::PermissionDenied)?
};
let mut process = Self {
handle: SafeHandle::new(handle),
name: target.get_name(),
process_id: target.process_id,
extensions: HashMap::new(),
_marker: std::marker::PhantomData,
};
process.insert(target);
Ok(process)
}
pub fn get_threads(&self) -> ThreadIterator {
ThreadIterator::new(self.process_id)
}
}
impl Process<NT> {
pub fn from_name(name: &str, access: PROCESS_ACCESS_RIGHTS) -> Result<Self> {
let target = processes::get_from_nt()?
.find(|p| p.get_name() == name)
.ok_or(ProcessError::NoProcessFound)?;
let handle = unsafe {
OpenProcess(access, false, target.raw.UniqueProcessId as _)
.map_err(ProcessError::PermissionDenied)?
};
let mut process = Self {
handle: SafeHandle::new(handle),
name: target.get_name(),
process_id: target.raw.UniqueProcessId as _,
extensions: HashMap::new(),
_marker: std::marker::PhantomData,
};
process.insert(target);
Ok(process)
}
pub fn from_pid(pid: u32, access: PROCESS_ACCESS_RIGHTS) -> Result<Self> {
let target = processes::get_from_nt()?
.find(|p| p.raw.UniqueProcessId as u32 == pid)
.ok_or(ProcessError::NoProcessFound)?;
let handle = unsafe {
OpenProcess(access, false, target.raw.UniqueProcessId as _)
.map_err(ProcessError::PermissionDenied)?
};
let mut process = Self {
handle: SafeHandle::new(handle),
name: target.get_name(),
process_id: target.raw.UniqueProcessId as _,
extensions: HashMap::new(),
_marker: std::marker::PhantomData,
};
process.insert(target);
Ok(process)
}
pub fn get_threads(&self) -> NtThreadIterator {
let state = self
.get::<NtProcessState>()
.expect("couldn't find process state using NT type.");
NtThreadIterator::new(&state)
}
}
impl Process<Created> {
pub fn from_path(
path: PathBuf,
args: &str,
creation_flags: PROCESS_CREATION_FLAGS,
) -> Result<Self> {
let mut startup_info: STARTUPINFOA = STARTUPINFOA::default();
startup_info.cb = std::mem::size_of::<STARTUPINFOA>() as _;
let mut process_info: PROCESS_INFORMATION = PROCESS_INFORMATION::default();
unsafe {
let path_cstring = std::ffi::CString::new(path.to_string_lossy().to_string()).unwrap();
let args = std::ffi::CString::new(args).unwrap();
CreateProcessA(
PCSTR::from_raw(path_cstring.as_ptr() as _),
Some(windows::core::PSTR::from_raw(args.as_ptr() as _)),
None,
None,
false,
creation_flags,
None,
s!("C:\\"),
&startup_info,
&mut process_info,
)?;
}
Ok(Self {
handle: SafeHandle::new(process_info.hProcess),
name: path.file_name().unwrap().to_string_lossy().to_string(),
process_id: process_info.dwProcessId,
extensions: HashMap::new(),
_marker: std::marker::PhantomData,
})
}
}
impl<T> Process<T> {
fn insert<U: 'static>(&mut self, value: U) {
let type_id = TypeId::of::<U>();
self.extensions.insert(type_id, Box::new(value));
}
pub fn get<U: 'static>(&self) -> Option<&U> {
let type_id = TypeId::of::<U>();
self.extensions
.get(&type_id)
.and_then(|boxed| boxed.downcast_ref::<U>())
}
fn get_current_process_id() -> u32 {
unsafe { GetCurrentProcessId() }
}
pub fn get_full_path(&self) -> Result<PathBuf> {
unsafe {
let mut module_path_buf: [u16; 4096] = [0; 4096];
GetModuleFileNameExW(Some(self.handle.get()?), None, &mut module_path_buf);
let module_path = module_path_buf.to_string_null();
Ok(std::path::PathBuf::from(module_path))
}
}
pub fn kill(self) -> Result<()> {
unsafe {
TerminateProcess(self.handle.get()?, 0)?;
}
std::mem::forget(self);
Ok(())
}
pub fn suspend_process(&self) -> Result<()> {
unsafe {
NtSuspendProcess(self.handle.get()?.0 as _);
}
Ok(())
}
pub fn resume_process(&self) -> Result<()> {
unsafe {
NtResumeProcess(self.handle.get()?.0 as _);
}
Ok(())
}
pub fn virtual_alloc(
&self,
addr: Option<usize>,
size: usize,
protection: PAGE_PROTECTION_FLAGS,
) -> Result<u64> {
unsafe {
let address = VirtualAllocEx(
self.handle.get()?,
addr.map(|v| v as _),
size,
MEM_RESERVE | MEM_COMMIT,
protection,
) as usize;
if address as *mut c_void == std::ptr::null_mut() {
return Err(ProcessError::ProcessError(
"Failed to allocate memory.".to_string(),
));
}
Ok(address as u64)
}
}
pub fn set_protection(
&self,
address: u64,
size: usize,
protection: PAGE_PROTECTION_FLAGS,
) -> Result<PAGE_PROTECTION_FLAGS> {
unsafe {
let mut old_protection: PAGE_PROTECTION_FLAGS = PAGE_PROTECTION_FLAGS(0);
VirtualProtectEx(
self.handle.get()?,
address as _,
size as _,
protection,
&mut old_protection as _,
)?;
Ok(old_protection)
}
}
pub fn get_memory_regions(
&self,
mask: PAGE_PROTECTION_FLAGS,
) -> Result<Vec<MEMORY_BASIC_INFORMATION>> {
let mut info = MEMORY_BASIC_INFORMATION::default();
let mut regions: Vec<MEMORY_BASIC_INFORMATION> = Vec::new();
loop {
let base_addr = unsafe { info.BaseAddress.add(info.RegionSize) };
let bytes_written = unsafe {
VirtualQueryEx(
self.handle.get()?,
Some(base_addr),
&mut info,
std::mem::size_of::<MEMORY_BASIC_INFORMATION>(),
)
};
if bytes_written == 0 {
break;
}
if (info.Protect & mask).0 != 0 {
regions.push(info);
}
}
Ok(regions)
}
#[allow(unused)]
pub fn write<U>(&self, addr: u64, value: U) -> Result<()> {
unsafe {
WriteProcessMemory(
self.handle.get()?,
addr as _,
&value as *const U as *const c_void,
std::mem::size_of::<U>(),
None,
)?;
Ok(())
}
}
pub fn write_bytes(&self, addr: u64, value: &[u8]) -> Result<()> {
unsafe {
let buffer_pointer = value.as_ptr() as *const c_void;
WriteProcessMemory(
self.handle.get()?,
addr as _,
buffer_pointer,
std::mem::size_of_val(value),
None,
)?;
Ok(())
}
}
pub fn read_bytes(&self, addr: u64, len: u64) -> Result<Vec<u8>> {
unsafe {
let mut buffer = vec![0; len as usize];
ReadProcessMemory(
self.handle.get()?,
addr as _,
buffer.as_mut_ptr() as _,
len as usize,
None,
)?;
Ok(buffer)
}
}
pub fn read<U>(&self, addr: u64) -> Result<U> {
let bytes = self.read_bytes(addr, std::mem::size_of::<U>() as _)?;
Ok(unsafe { std::ptr::read(bytes.as_ptr() as *const _) })
}
pub fn create_remote_thread(
&self,
address: usize,
creation_flags: Option<u32>,
param: Option<*const c_void>,
) -> Result<SafeHandle> {
unsafe {
let handle = CreateRemoteThread(
self.handle.get()?,
None,
0,
Some(std::mem::transmute(address as usize)),
param,
creation_flags.unwrap_or(0),
None,
)?;
Ok(SafeHandle::new(handle))
}
}
pub fn get_modules(&self) -> ModuleIterator {
ModuleIterator::new(self.process_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use processes::snapshot::SnapshotState;
use thread::ThreadOperations;
use windows::Win32::System::{
Memory::{
PAGE_EXECUTE, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, PAGE_READONLY, PAGE_READWRITE,
},
Threading::PROCESS_ALL_ACCESS,
};
#[test]
fn process_handle_check() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
assert_eq!(process.handle.is_valid(), true);
Ok(())
}
#[test]
fn process_created_handle_check() -> Result<()> {
let process_id = Process::<Created>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
assert_eq!(process.handle.is_valid(), true);
Ok(())
}
#[test]
fn process_attach_bespoke_data() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
let bespoke = process.get::<SnapshotState>();
assert_eq!(bespoke.is_some(), true);
let threads = bespoke.unwrap().thread_count;
assert_eq!(threads > 0, true);
Ok(())
}
#[test]
fn process_get_memory_regions() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
let regions = process
.get_memory_regions(PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE | PAGE_EXECUTE)?;
assert_eq!(regions.len() > 0, true);
Ok(())
}
#[test]
fn process_get_process_path() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
let path = process.get_full_path()?;
assert_eq!(path.to_string_lossy().contains(".exe"), true);
Ok(())
}
#[test]
fn process_memory_methods_test() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
let addr = process.virtual_alloc(None, 1, PAGE_READONLY)?;
process.set_protection(addr, 1, PAGE_READWRITE)?;
process.write(addr, 1337 as usize)?;
assert_eq!(process.read::<usize>(addr)?, 1337 as usize);
Ok(())
}
#[test]
fn process_find_module() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
let module = process
.get_modules()
.find(|m| m.get_name().to_lowercase() == "kernel32.dll");
assert_eq!(module.is_some(), true);
Ok(())
}
#[test]
fn process_snapshot_threads() -> Result<()> {
let process_id = Process::<Snapshot>::get_current_process_id();
let process = Process::<Snapshot>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
assert_eq!(
process
.get_threads()
.all(|t| t.owner_process_id == process.process_id),
true
);
Ok(())
}
#[test]
fn process_nt_threads() -> Result<()> {
let process_id = Process::<NT>::get_current_process_id();
let process = Process::<NT>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
assert_eq!(
process
.get_threads()
.all(|t| t.owner_process_id == process.process_id),
true
);
Ok(())
}
fn process_nt_thread_context() -> Result<()> {
let process_id = Process::<NT>::get_current_process_id();
let process = Process::<NT>::from_pid(process_id, PROCESS_ALL_ACCESS)?;
let thread = process.get_threads().next().unwrap();
let context = thread.get_context()?;
assert_eq!(context.0.Rsp > 0, true);
Ok(())
}
}