use crate::killable::{Killable, KillableType};
use log::info;
use std::{
alloc::{alloc, dealloc, Layout},
collections::{HashMap, HashSet},
ffi::c_void,
io::{Error, Result},
ptr::addr_of,
slice,
};
use windows_sys::Win32::{
Foundation::{
CloseHandle, GetLastError, ERROR_INSUFFICIENT_BUFFER, FALSE, HANDLE, INVALID_HANDLE_VALUE,
NO_ERROR, WIN32_ERROR,
},
NetworkManagement::IpHelper::{
GetExtendedTcpTable, GetExtendedUdpTable, MIB_TCP6ROW_OWNER_MODULE,
MIB_TCP6TABLE_OWNER_MODULE, MIB_TCPROW_OWNER_MODULE, MIB_TCPTABLE_OWNER_MODULE,
MIB_UDP6ROW_OWNER_MODULE, MIB_UDP6TABLE_OWNER_MODULE, MIB_UDPROW_OWNER_MODULE,
MIB_UDPTABLE_OWNER_MODULE, TCP_TABLE_OWNER_MODULE_ALL, UDP_TABLE_OWNER_MODULE,
},
Networking::WinSock::{AF_INET, AF_INET6},
System::{
Diagnostics::ToolHelp::{
CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32,
TH32CS_SNAPPROCESS,
},
Threading::{OpenProcess, TerminateProcess, PROCESS_TERMINATE},
},
};
#[derive(Debug)]
pub struct WindowsProcess {
pid: u32,
name: String,
}
impl WindowsProcess {
pub fn new(pid: u32, name: String) -> Self {
Self { pid, name }
}
}
pub fn find_target_processes(port: u16) -> Result<Vec<WindowsProcess>> {
let lookup_table: ProcessLookupTable = ProcessLookupTable::create()?;
let mut pids: HashSet<u32> = HashSet::new();
let processes = unsafe {
use_extended_table::<MIB_TCPTABLE_OWNER_MODULE>(port, &mut pids)?;
use_extended_table::<MIB_TCP6TABLE_OWNER_MODULE>(port, &mut pids)?;
use_extended_table::<MIB_UDPTABLE_OWNER_MODULE>(port, &mut pids)?;
use_extended_table::<MIB_UDP6TABLE_OWNER_MODULE>(port, &mut pids)?;
let mut processes: Vec<WindowsProcess> = Vec::with_capacity(pids.len());
for pid in pids {
let process_name = lookup_table
.process_names
.get(&pid)
.cloned()
.unwrap_or_else(|| "Unknown".to_string());
processes.push(WindowsProcess::new(pid, process_name));
}
processes
};
Ok(processes)
}
impl Killable for WindowsProcess {
fn kill(&self, _signal: crate::signal::KillportSignal) -> Result<bool> {
unsafe {
kill_process(self)?;
}
Ok(true)
}
fn get_type(&self) -> KillableType {
KillableType::Process
}
fn get_name(&self) -> String {
self.name.to_string()
}
}
fn is_process_running(pid: u32) -> Result<bool> {
let mut snapshot = WindowsProcessesSnapshot::create()?;
let is_running = snapshot.any(|entry| entry.th32ProcessID == pid);
Ok(is_running)
}
pub struct ProcessLookupTable {
process_names: HashMap<u32, String>,
}
impl ProcessLookupTable {
pub fn create() -> Result<Self> {
let mut process_names: HashMap<u32, String> = HashMap::new();
WindowsProcessesSnapshot::create()?.for_each(|entry| {
process_names.insert(entry.th32ProcessID, get_process_entry_name(&entry));
});
Ok(Self { process_names })
}
}
fn get_process_entry_name(entry: &PROCESSENTRY32) -> String {
let name_chars: Vec<u8> = entry
.szExeFile
.iter()
.copied()
.take_while(|value| *value != 0)
.map(|c| c as u8)
.collect();
let name = String::from_utf8(name_chars);
name.unwrap_or_else(|_| "Unknown".to_string())
}
pub struct WindowsProcessesSnapshot {
handle: HANDLE,
entry: PROCESSENTRY32,
state: SnapshotState,
}
pub enum SnapshotState {
First,
Next,
End,
}
impl WindowsProcessesSnapshot {
pub fn create() -> Result<Self> {
let handle: HANDLE = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) };
if handle == INVALID_HANDLE_VALUE {
let error: WIN32_ERROR = unsafe { GetLastError() };
return Err(Error::other(format!(
"Failed to get handle to processes: {:#x}",
error
)));
}
let mut entry: PROCESSENTRY32 = unsafe { std::mem::zeroed() };
entry.dwSize = std::mem::size_of::<PROCESSENTRY32>() as u32;
Ok(Self {
handle,
entry,
state: SnapshotState::First,
})
}
}
impl Iterator for WindowsProcessesSnapshot {
type Item = PROCESSENTRY32;
fn next(&mut self) -> Option<Self::Item> {
match self.state {
SnapshotState::First => {
if unsafe { Process32First(self.handle, &mut self.entry) } == FALSE {
self.state = SnapshotState::End;
return None;
}
self.state = SnapshotState::Next;
Some(self.entry)
}
SnapshotState::Next => {
if unsafe { Process32Next(self.handle, &mut self.entry) } == FALSE {
self.state = SnapshotState::End;
return None;
}
Some(self.entry)
}
SnapshotState::End => None,
}
}
}
impl Drop for WindowsProcessesSnapshot {
fn drop(&mut self) {
unsafe {
CloseHandle(self.handle);
}
}
}
unsafe fn kill_process(process: &WindowsProcess) -> Result<()> {
info!("Killing process {}:{}", process.get_name(), process.pid);
let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, process.pid);
if handle.is_null() {
if !is_process_running(process.pid)? {
return Ok(());
}
let error: WIN32_ERROR = GetLastError();
return Err(Error::other(format!(
"Failed to obtain handle to process {}:{}: {:#x}",
process.get_name(),
process.pid,
error
)));
}
let result = TerminateProcess(handle, 0);
CloseHandle(handle);
if result == FALSE {
let error: WIN32_ERROR = GetLastError();
return Err(Error::other(format!(
"Failed to terminate process {}:{}: {:#x}",
process.get_name(),
process.pid,
error
)));
}
Ok(())
}
unsafe fn use_extended_table<T>(port: u16, pids: &mut HashSet<u32>) -> Result<()>
where
T: TableClass,
{
let mut layout: Layout = Layout::new::<T>();
let mut buffer: *mut u8 = alloc(layout);
let mut size: u32 = layout.size() as u32;
let mut result: WIN32_ERROR;
loop {
result = (T::TABLE_FN)(
buffer.cast(),
&mut size,
FALSE,
T::FAMILY,
T::TABLE_CLASS,
0,
);
if result == NO_ERROR {
break;
}
dealloc(buffer, layout);
if result == ERROR_INSUFFICIENT_BUFFER {
layout = Layout::from_size_align_unchecked(size as usize, layout.align());
buffer = alloc(layout);
continue;
}
return Err(Error::other(format!(
"Failed to get size estimate for extended table: {:#x}",
result
)));
}
let table: *const T = buffer.cast();
T::get_processes(table, port, pids);
dealloc(buffer, layout);
Ok(())
}
type GetExtendedTable =
unsafe extern "system" fn(*mut c_void, *mut u32, i32, AddressFamily, i32, u32) -> WIN32_ERROR;
type AddressFamily = u32;
const INET: AddressFamily = AF_INET as u32;
const INET6: AddressFamily = AF_INET6 as u32;
type TableClassType = i32;
const TCP_TYPE: TableClassType = TCP_TABLE_OWNER_MODULE_ALL;
const UDP_TYPE: TableClassType = UDP_TABLE_OWNER_MODULE;
trait TableClass {
const TABLE_FN: GetExtendedTable;
const FAMILY: AddressFamily;
const TABLE_CLASS: TableClassType;
unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet<u32>);
}
macro_rules! impl_get_processes {
($ty:ty) => {
unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet<u32>) {
let row_ptr: *const $ty = addr_of!((*table).table).cast();
let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize;
slice::from_raw_parts(row_ptr, length)
.iter()
.for_each(|element| {
let local_port: u16 = (element.dwLocalPort as u16).to_be();
if local_port == port {
pids.insert(element.dwOwningPid);
}
});
}
};
}
impl TableClass for MIB_TCPTABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedTcpTable;
const FAMILY: AddressFamily = INET;
const TABLE_CLASS: TableClassType = TCP_TYPE;
impl_get_processes!(MIB_TCPROW_OWNER_MODULE);
}
impl TableClass for MIB_TCP6TABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedTcpTable;
const FAMILY: AddressFamily = INET6;
const TABLE_CLASS: TableClassType = TCP_TYPE;
impl_get_processes!(MIB_TCP6ROW_OWNER_MODULE);
}
impl TableClass for MIB_UDPTABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedUdpTable;
const FAMILY: AddressFamily = INET;
const TABLE_CLASS: TableClassType = UDP_TYPE;
impl_get_processes!(MIB_UDPROW_OWNER_MODULE);
}
impl TableClass for MIB_UDP6TABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedUdpTable;
const FAMILY: AddressFamily = INET6;
const TABLE_CLASS: TableClassType = UDP_TYPE;
impl_get_processes!(MIB_UDP6ROW_OWNER_MODULE);
}