use crate::engine::WfpEngine;
use crate::errors::{WfpError, WfpResult};
use std::ffi::{c_void, OsString};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::os::windows::ffi::OsStringExt;
use std::path::PathBuf;
use std::sync::mpsc;
use std::time::{Duration, SystemTime};
use windows::core::GUID;
use windows::Win32::Foundation::{ERROR_SUCCESS, FILETIME, HANDLE};
use windows::Win32::NetworkManagement::WindowsFilteringPlatform::{
FwpmNetEventSubscribe0, FwpmNetEventUnsubscribe0, FWPM_NET_EVENT1, FWPM_NET_EVENT_CALLBACK0,
FWPM_NET_EVENT_SUBSCRIPTION0,
};
#[derive(Debug, Clone)]
pub struct NetworkEvent {
pub timestamp: SystemTime,
pub event_type: NetworkEventType,
pub app_path: Option<PathBuf>,
pub protocol: u8,
pub local_addr: Option<IpAddr>,
pub remote_addr: Option<IpAddr>,
pub local_port: u16,
pub remote_port: u16,
pub filter_id: Option<u64>,
pub layer_id: Option<u16>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum NetworkEventType {
ClassifyDrop = 3,
ClassifyAllow = 6,
CapabilityDrop = 7,
Other(u32),
}
impl From<u32> for NetworkEventType {
fn from(value: u32) -> Self {
match value {
3 => NetworkEventType::ClassifyDrop,
6 => NetworkEventType::ClassifyAllow,
7 => NetworkEventType::CapabilityDrop,
other => NetworkEventType::Other(other),
}
}
}
impl std::fmt::Display for NetworkEventType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NetworkEventType::ClassifyDrop => write!(f, "ClassifyDrop"),
NetworkEventType::ClassifyAllow => write!(f, "ClassifyAllow"),
NetworkEventType::CapabilityDrop => write!(f, "CapabilityDrop"),
NetworkEventType::Other(n) => write!(f, "Other({})", n),
}
}
}
pub struct WfpEventSubscription {
engine: *const WfpEngine,
subscription_handle: HANDLE,
_callback: Box<FWPM_NET_EVENT_CALLBACK0>, receiver: mpsc::Receiver<NetworkEvent>,
sender_context: *mut c_void,
}
impl WfpEventSubscription {
pub fn new(engine: &WfpEngine) -> WfpResult<Self> {
let (sender, receiver) = mpsc::channel();
let sender_box = Box::new(sender);
let context = Box::into_raw(sender_box) as *mut c_void;
let callback: FWPM_NET_EVENT_CALLBACK0 = Some(event_callback);
let subscription = FWPM_NET_EVENT_SUBSCRIPTION0 {
enumTemplate: std::ptr::null_mut(), flags: 0,
sessionKey: GUID::zeroed(),
};
let mut subscription_handle = HANDLE::default();
unsafe {
let result = FwpmNetEventSubscribe0(
engine.handle(),
&subscription,
callback,
Some(context as *const c_void),
&mut subscription_handle,
);
if result != ERROR_SUCCESS.0 {
drop(Box::from_raw(context as *mut mpsc::Sender<NetworkEvent>));
return Err(WfpError::Other(format!(
"Failed to subscribe to WFP events: error code {}",
result
)));
}
}
Ok(Self {
engine: engine as *const WfpEngine,
subscription_handle,
_callback: Box::new(callback), receiver,
sender_context: context,
})
}
pub fn try_recv(&self) -> Result<NetworkEvent, mpsc::TryRecvError> {
self.receiver.try_recv()
}
pub fn recv(&self) -> Result<NetworkEvent, mpsc::RecvError> {
self.receiver.recv()
}
pub fn iter(&self) -> mpsc::Iter<'_, NetworkEvent> {
self.receiver.iter()
}
}
impl Drop for WfpEventSubscription {
fn drop(&mut self) {
if !self.subscription_handle.is_invalid() && !self.engine.is_null() {
unsafe {
let _ = FwpmNetEventUnsubscribe0((*self.engine).handle(), self.subscription_handle);
}
}
if !self.sender_context.is_null() {
unsafe {
drop(Box::from_raw(
self.sender_context as *mut mpsc::Sender<NetworkEvent>,
));
}
}
}
}
unsafe extern "system" fn event_callback(context: *mut c_void, event_ptr: *const FWPM_NET_EVENT1) {
if context.is_null() || event_ptr.is_null() {
return;
}
let sender = &*(context as *const mpsc::Sender<NetworkEvent>);
let event = &*event_ptr;
if let Some(network_event) = parse_network_event(event) {
let _ = sender.send(network_event);
}
}
unsafe fn parse_network_event(event: &FWPM_NET_EVENT1) -> Option<NetworkEvent> {
let header = &event.header;
let event_type = NetworkEventType::from(event.r#type.0 as u32);
let timestamp = filetime_to_systemtime(&header.timeStamp);
let app_path = if !header.appId.data.is_null() {
parse_wide_string(header.appId.data as *const u16).map(PathBuf::from)
} else {
None
};
let (local_addr, remote_addr) = if header.ipVersion.0 == 0 {
unsafe {
let local = parse_ipv4_union(&header.Anonymous1);
let remote = parse_ipv4_union_remote(&header.Anonymous2);
(local, remote)
}
} else if header.ipVersion.0 == 1 {
unsafe {
let local = parse_ipv6_union(&header.Anonymous1);
let remote = parse_ipv6_union_remote(&header.Anonymous2);
(local, remote)
}
} else {
(None, None)
};
let (filter_id, layer_id) = if event_type == NetworkEventType::ClassifyDrop {
unsafe {
if !event.Anonymous.classifyDrop.is_null() {
let drop_info = &*event.Anonymous.classifyDrop;
(Some(drop_info.filterId), Some(drop_info.layerId))
} else {
(None, None)
}
}
} else {
(None, None)
};
Some(NetworkEvent {
timestamp,
event_type,
app_path,
protocol: header.ipProtocol,
local_addr,
remote_addr,
local_port: header.localPort,
remote_port: header.remotePort,
filter_id,
layer_id,
})
}
fn filetime_to_systemtime(ft: &FILETIME) -> SystemTime {
const WINDOWS_TO_UNIX_EPOCH: u64 = 116444736000000000;
let intervals = ((ft.dwHighDateTime as u64) << 32) | (ft.dwLowDateTime as u64);
if intervals >= WINDOWS_TO_UNIX_EPOCH {
let unix_intervals = intervals - WINDOWS_TO_UNIX_EPOCH;
let secs = unix_intervals / 10_000_000;
let nanos = ((unix_intervals % 10_000_000) * 100) as u32;
SystemTime::UNIX_EPOCH + Duration::new(secs, nanos)
} else {
SystemTime::UNIX_EPOCH
}
}
unsafe fn parse_wide_string(ptr: *const u16) -> Option<OsString> {
if ptr.is_null() {
return None;
}
let mut len = 0;
while *ptr.add(len) != 0 {
len += 1;
}
if len == 0 {
return None;
}
let slice = std::slice::from_raw_parts(ptr, len);
Some(OsString::from_wide(slice))
}
unsafe fn parse_ipv4_union(
addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
) -> Option<IpAddr> {
let addr_u32 = addr_union.localAddrV4;
let bytes = addr_u32.to_ne_bytes();
Some(IpAddr::V4(Ipv4Addr::from(bytes)))
}
unsafe fn parse_ipv6_union(
addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_0,
) -> Option<IpAddr> {
let bytes = addr_union.localAddrV6.byteArray16;
Some(IpAddr::V6(Ipv6Addr::from(bytes)))
}
unsafe fn parse_ipv4_union_remote(
addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
) -> Option<IpAddr> {
let addr_u32 = addr_union.remoteAddrV4;
let bytes = addr_u32.to_ne_bytes();
Some(IpAddr::V4(Ipv4Addr::from(bytes)))
}
unsafe fn parse_ipv6_union_remote(
addr_union: &windows::Win32::NetworkManagement::WindowsFilteringPlatform::FWPM_NET_EVENT_HEADER1_1,
) -> Option<IpAddr> {
let bytes = addr_union.remoteAddrV6.byteArray16;
Some(IpAddr::V6(Ipv6Addr::from(bytes)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_type_conversion() {
assert_eq!(NetworkEventType::from(3), NetworkEventType::ClassifyDrop);
assert_eq!(NetworkEventType::from(6), NetworkEventType::ClassifyAllow);
assert_eq!(NetworkEventType::from(7), NetworkEventType::CapabilityDrop);
assert_eq!(NetworkEventType::from(99), NetworkEventType::Other(99));
}
#[test]
fn test_event_type_boundaries() {
assert_eq!(NetworkEventType::from(0), NetworkEventType::Other(0));
assert_eq!(NetworkEventType::from(2), NetworkEventType::Other(2));
assert_eq!(NetworkEventType::from(4), NetworkEventType::Other(4));
assert_eq!(NetworkEventType::from(5), NetworkEventType::Other(5));
assert_eq!(NetworkEventType::from(8), NetworkEventType::Other(8));
}
#[test]
fn test_filetime_to_systemtime_unix_epoch() {
let intervals: u64 = 116444736000000000;
let ft = FILETIME {
dwLowDateTime: intervals as u32,
dwHighDateTime: (intervals >> 32) as u32,
};
let result = filetime_to_systemtime(&ft);
assert_eq!(result, SystemTime::UNIX_EPOCH);
}
#[test]
fn test_filetime_to_systemtime_before_unix_epoch() {
let ft = FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
assert_eq!(filetime_to_systemtime(&ft), SystemTime::UNIX_EPOCH);
}
#[test]
fn test_filetime_to_systemtime_known_date() {
let intervals: u64 = 125911584000000000;
let ft = FILETIME {
dwLowDateTime: intervals as u32,
dwHighDateTime: (intervals >> 32) as u32,
};
let result = filetime_to_systemtime(&ft);
let duration = result.duration_since(SystemTime::UNIX_EPOCH).unwrap();
assert_eq!(duration.as_secs(), 946684800); }
#[test]
fn test_network_event_struct_creation() {
let event = NetworkEvent {
timestamp: SystemTime::UNIX_EPOCH,
event_type: NetworkEventType::ClassifyDrop,
app_path: Some(PathBuf::from(r"C:\test.exe")),
protocol: 6,
local_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
remote_addr: Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
local_port: 12345,
remote_port: 443,
filter_id: Some(42),
layer_id: Some(1),
};
assert_eq!(event.event_type, NetworkEventType::ClassifyDrop);
assert_eq!(event.protocol, 6);
assert_eq!(event.local_port, 12345);
assert_eq!(event.remote_port, 443);
assert!(event.app_path.is_some());
}
#[test]
fn test_network_event_clone() {
let event = NetworkEvent {
timestamp: SystemTime::UNIX_EPOCH,
event_type: NetworkEventType::ClassifyAllow,
app_path: None,
protocol: 17,
local_addr: None,
remote_addr: None,
local_port: 0,
remote_port: 0,
filter_id: None,
layer_id: None,
};
let cloned = event.clone();
assert_eq!(cloned.event_type, NetworkEventType::ClassifyAllow);
assert_eq!(cloned.protocol, 17);
assert!(cloned.app_path.is_none());
}
#[test]
fn test_network_event_type_display() {
assert_eq!(NetworkEventType::ClassifyDrop.to_string(), "ClassifyDrop");
assert_eq!(NetworkEventType::ClassifyAllow.to_string(), "ClassifyAllow");
assert_eq!(NetworkEventType::CapabilityDrop.to_string(), "CapabilityDrop");
assert_eq!(NetworkEventType::Other(42).to_string(), "Other(42)");
}
}