use std::ffi::OsStr;
use std::io;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use windows_sys::Win32::Foundation::ERROR_SUCCESS;
use windows_sys::Win32::NetworkManagement::IpHelper::ConvertInterfaceAliasToLuid;
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use windows_sys::Win32::NetworkManagement::WindowsFilteringPlatform::{
FWP_BYTE_BLOB_TYPE, FWP_MATCH_EQUAL, FWP_MATCH_GREATER, FWP_MATCH_GREATER_OR_EQUAL,
FWP_MATCH_LESS, FWP_MATCH_LESS_OR_EQUAL, FWP_MATCH_RANGE, FWP_UINT8, FWP_UINT16, FWP_UINT32,
FWP_UINT64, FWP_UNICODE_STRING_TYPE, FWP_V4_ADDR_AND_MASK, FWP_V4_ADDR_MASK,
FWP_V6_ADDR_AND_MASK, FWP_V6_ADDR_MASK, FWPM_CONDITION_ALE_APP_ID,
FWPM_CONDITION_IP_LOCAL_ADDRESS, FWPM_CONDITION_IP_LOCAL_INTERFACE,
FWPM_CONDITION_IP_LOCAL_PORT, FWPM_CONDITION_IP_PROTOCOL, FWPM_CONDITION_IP_REMOTE_ADDRESS,
FWPM_CONDITION_IP_REMOTE_PORT, FWPM_FILTER_CONDITION0,
};
use windows_sys::core::GUID;
use crate::blob::{OwnedByteBlob, app_id_from_filename};
use crate::util::string_to_null_terminated_utf16;
const FWPM_CONDITION_ICMP_TYPE: GUID = FWPM_CONDITION_IP_LOCAL_PORT;
const FWPM_CONDITION_ICMP_CODE: GUID = FWPM_CONDITION_IP_REMOTE_PORT;
#[derive(Clone)]
pub struct PortConditionBuilder<Value> {
builder: ConditionBuilder,
_pd: std::marker::PhantomData<Value>,
}
#[doc(hidden)]
pub struct PortConditionBuilderMissingValue;
#[doc(hidden)]
pub struct PortConditionBuilderHasValue;
impl PortConditionBuilder<PortConditionBuilderMissingValue> {
pub fn remote() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::RemotePort),
_pd: std::marker::PhantomData,
}
}
pub fn local() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::LocalPort),
_pd: std::marker::PhantomData,
}
}
}
impl<Value> PortConditionBuilder<Value> {
pub fn equal(self, port: u16) -> PortConditionBuilder<PortConditionBuilderHasValue> {
PortConditionBuilder {
builder: self.builder.match_type(MatchType::Equal).value_u16(port),
_pd: std::marker::PhantomData,
}
}
}
impl PortConditionBuilder<PortConditionBuilderHasValue> {
pub fn build(self) -> Condition {
self.builder.build().expect("condition has value")
}
}
#[derive(Clone)]
pub struct ProtocolConditionBuilder {
builder: ConditionBuilder,
}
impl ProtocolConditionBuilder {
pub fn tcp() -> Self {
Self::new().equal(6)
}
pub fn udp() -> Self {
Self::new().equal(17)
}
pub fn icmp() -> Self {
Self::new().equal(1)
}
pub fn icmpv6() -> Self {
Self::new().equal(58)
}
fn new() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::Protocol),
}
}
fn equal(self, protocol: u8) -> Self {
Self {
builder: self.builder.match_type(MatchType::Equal).value_u8(protocol),
}
}
pub fn build(self) -> Condition {
self.builder.build().expect("all values are set")
}
}
impl Default for ProtocolConditionBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct IcmpConditionBuilder<Value> {
builder: ConditionBuilder,
_pd: std::marker::PhantomData<Value>,
}
#[doc(hidden)]
pub struct IcmpConditionBuilderMissingValue;
#[doc(hidden)]
pub struct IcmpConditionBuilderHasValue;
impl IcmpConditionBuilder<IcmpConditionBuilderMissingValue> {
pub fn r#type() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::IcmpType),
_pd: std::marker::PhantomData,
}
}
pub fn code() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::IcmpCode),
_pd: std::marker::PhantomData,
}
}
}
impl<Value> IcmpConditionBuilder<Value> {
pub fn equal(self, value: u8) -> IcmpConditionBuilder<IcmpConditionBuilderHasValue> {
IcmpConditionBuilder {
builder: self
.builder
.match_type(MatchType::Equal)
.value_u16(value.into()),
_pd: std::marker::PhantomData,
}
}
pub fn greater(self, value: u8) -> IcmpConditionBuilder<IcmpConditionBuilderHasValue> {
IcmpConditionBuilder {
builder: self
.builder
.match_type(MatchType::Greater)
.value_u16(value.into()),
_pd: std::marker::PhantomData,
}
}
pub fn less(self, value: u8) -> IcmpConditionBuilder<IcmpConditionBuilderHasValue> {
IcmpConditionBuilder {
builder: self
.builder
.match_type(MatchType::Less)
.value_u16(value.into()),
_pd: std::marker::PhantomData,
}
}
pub fn greater_or_equal(self, value: u8) -> IcmpConditionBuilder<IcmpConditionBuilderHasValue> {
IcmpConditionBuilder {
builder: self
.builder
.match_type(MatchType::GreaterOrEqual)
.value_u16(value.into()),
_pd: std::marker::PhantomData,
}
}
pub fn less_or_equal(self, value: u8) -> IcmpConditionBuilder<IcmpConditionBuilderHasValue> {
IcmpConditionBuilder {
builder: self
.builder
.match_type(MatchType::LessOrEqual)
.value_u16(value.into()),
_pd: std::marker::PhantomData,
}
}
}
impl IcmpConditionBuilder<IcmpConditionBuilderHasValue> {
pub fn build(self) -> Condition {
self.builder.build().expect("condition has value")
}
}
pub struct AppIdConditionBuilder<Value> {
builder: ConditionBuilder,
_pd: std::marker::PhantomData<Value>,
}
#[doc(hidden)]
pub struct AppIdConditionBuilderMissingValue;
#[doc(hidden)]
pub struct AppIdConditionBuilderHasValue;
impl AppIdConditionBuilder<AppIdConditionBuilderMissingValue> {
pub fn new() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::AppId),
_pd: std::marker::PhantomData,
}
}
}
impl<Value> AppIdConditionBuilder<Value> {
pub fn equal(
self,
app_path: impl AsRef<OsStr>,
) -> io::Result<AppIdConditionBuilder<AppIdConditionBuilderHasValue>> {
let byte_blob = app_id_from_filename(app_path)?;
Ok(AppIdConditionBuilder {
builder: self
.builder
.match_type(MatchType::Equal)
.value_byte_blob(byte_blob),
_pd: std::marker::PhantomData,
})
}
}
impl AppIdConditionBuilder<AppIdConditionBuilderHasValue> {
pub fn build(self) -> Condition {
self.builder.build().expect("condition has value")
}
}
impl Default for AppIdConditionBuilder<AppIdConditionBuilderMissingValue> {
fn default() -> Self {
Self::new()
}
}
pub struct InterfaceConditionBuilder<Value> {
builder: ConditionBuilder,
_pd: std::marker::PhantomData<Value>,
}
#[doc(hidden)]
pub struct InterfaceConditionBuilderMissingValue;
#[doc(hidden)]
pub struct InterfaceConditionBuilderHasValue;
impl InterfaceConditionBuilder<InterfaceConditionBuilderMissingValue> {
pub fn local() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::LocalInterface),
_pd: std::marker::PhantomData,
}
}
}
impl<Value> InterfaceConditionBuilder<Value> {
pub fn alias(
self,
alias: impl AsRef<OsStr>,
) -> io::Result<InterfaceConditionBuilder<InterfaceConditionBuilderHasValue>> {
let wide_alias: Vec<u16> = string_to_null_terminated_utf16(alias);
let mut luid = NET_LUID_LH::default();
let status = unsafe { ConvertInterfaceAliasToLuid(wide_alias.as_ptr(), &mut luid) };
if status != ERROR_SUCCESS {
return Err(io::Error::from_raw_os_error(status as i32));
}
let luid_value = unsafe { luid.Value };
Ok(self.luid(luid_value))
}
pub fn luid(self, luid: u64) -> InterfaceConditionBuilder<InterfaceConditionBuilderHasValue> {
InterfaceConditionBuilder {
builder: self.builder.match_type(MatchType::Equal).value_u64(luid),
_pd: std::marker::PhantomData,
}
}
}
impl InterfaceConditionBuilder<InterfaceConditionBuilderHasValue> {
pub fn build(self) -> Condition {
self.builder.build().expect("condition should be valid")
}
}
#[derive(Clone)]
pub struct IpAddressConditionBuilder<Value> {
builder: ConditionBuilder,
_pd: std::marker::PhantomData<Value>,
}
#[doc(hidden)]
pub struct IpAddressConditionBuilderMissingValue;
#[doc(hidden)]
pub struct IpAddressConditionBuilderHasValue;
impl IpAddressConditionBuilder<IpAddressConditionBuilderMissingValue> {
pub fn remote() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::RemoteAddress),
_pd: std::marker::PhantomData,
}
}
pub fn local() -> Self {
Self {
builder: ConditionBuilder::default().field(ConditionField::LocalAddress),
_pd: std::marker::PhantomData,
}
}
}
impl<V> IpAddressConditionBuilder<V> {
pub fn subnet_v4(
self,
addr: Ipv4Addr,
prefix_len: u8,
) -> IpAddressConditionBuilder<IpAddressConditionBuilderHasValue> {
let mask = v4_prefix_to_mask(prefix_len);
IpAddressConditionBuilder {
builder: self
.builder
.match_type(MatchType::Equal)
.value_v4_addr_mask(u32::from(addr), mask),
_pd: std::marker::PhantomData,
}
}
pub fn subnet_v6(
self,
addr: Ipv6Addr,
prefix_len: u8,
) -> IpAddressConditionBuilder<IpAddressConditionBuilderHasValue> {
assert!(prefix_len <= 128, "IPv6 prefix length must be <= 128");
IpAddressConditionBuilder {
builder: self
.builder
.match_type(MatchType::Equal)
.value_v6_addr_mask(addr.octets(), prefix_len),
_pd: std::marker::PhantomData,
}
}
}
impl IpAddressConditionBuilder<IpAddressConditionBuilderHasValue> {
pub fn build(self) -> Condition {
self.builder.build().expect("condition should be valid")
}
}
fn v4_prefix_to_mask(prefix_len: u8) -> u32 {
assert!(prefix_len <= 32, "IPv4 prefix length must be <= 32");
if prefix_len == 0 {
0
} else {
u32::MAX << (32 - prefix_len)
}
}
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MatchType {
Equal = FWP_MATCH_EQUAL,
Greater = FWP_MATCH_GREATER,
Less = FWP_MATCH_LESS,
GreaterOrEqual = FWP_MATCH_GREATER_OR_EQUAL,
LessOrEqual = FWP_MATCH_LESS_OR_EQUAL,
Range = FWP_MATCH_RANGE,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ConditionField {
RemoteAddress,
LocalAddress,
RemotePort,
LocalPort,
Protocol,
IcmpType,
IcmpCode,
AppId,
LocalInterface,
}
impl ConditionField {
pub fn guid(&self) -> &GUID {
match self {
Self::RemoteAddress => &FWPM_CONDITION_IP_REMOTE_ADDRESS,
Self::LocalAddress => &FWPM_CONDITION_IP_LOCAL_ADDRESS,
Self::RemotePort => &FWPM_CONDITION_IP_REMOTE_PORT,
Self::LocalPort => &FWPM_CONDITION_IP_LOCAL_PORT,
Self::Protocol => &FWPM_CONDITION_IP_PROTOCOL,
Self::IcmpType => &FWPM_CONDITION_ICMP_TYPE,
Self::IcmpCode => &FWPM_CONDITION_ICMP_CODE,
Self::AppId => &FWPM_CONDITION_ALE_APP_ID,
Self::LocalInterface => &FWPM_CONDITION_IP_LOCAL_INTERFACE,
}
}
}
#[derive(Default, Clone)]
struct ConditionBuilder {
field: Option<ConditionField>,
match_type: Option<MatchType>,
value: Option<Arc<ConditionValue>>,
}
enum ConditionValue {
UInt64(u64),
UInt32(u32),
UInt16(u16),
UInt8(u8),
String(Vec<u16>),
ByteBlob { blob: OwnedByteBlob },
V4AddrMask(FWP_V4_ADDR_AND_MASK),
V6AddrMask(FWP_V6_ADDR_AND_MASK),
}
impl ConditionBuilder {
pub fn field(mut self, field: ConditionField) -> Self {
self.field = Some(field);
self
}
pub fn match_type(mut self, match_type: MatchType) -> Self {
self.match_type = Some(match_type);
self
}
pub fn value_u64(mut self, value: u64) -> Self {
self.value = Some(ConditionValue::UInt64(value).into());
self
}
#[allow(dead_code)]
pub fn value_u32(mut self, value: u32) -> Self {
self.value = Some(ConditionValue::UInt32(value).into());
self
}
pub fn value_u16(mut self, value: u16) -> Self {
self.value = Some(ConditionValue::UInt16(value).into());
self
}
pub fn value_u8(mut self, value: u8) -> Self {
self.value = Some(ConditionValue::UInt8(value).into());
self
}
#[allow(dead_code)]
pub fn value_string(mut self, value: impl AsRef<OsStr>) -> Self {
let wide_string = string_to_null_terminated_utf16(value);
self.value = Some(ConditionValue::String(wide_string).into());
self
}
pub fn value_byte_blob(mut self, blob: impl Into<OwnedByteBlob>) -> Self {
self.value = Some(ConditionValue::ByteBlob { blob: blob.into() }.into());
self
}
pub fn value_v4_addr_mask(mut self, addr: u32, mask: u32) -> Self {
self.value = Some(ConditionValue::V4AddrMask(FWP_V4_ADDR_AND_MASK { addr, mask }).into());
self
}
pub fn value_v6_addr_mask(mut self, addr: [u8; 16], prefix_length: u8) -> Self {
self.value = Some(
ConditionValue::V6AddrMask(FWP_V6_ADDR_AND_MASK {
addr,
prefixLength: prefix_length,
})
.into(),
);
self
}
pub fn build(self) -> Option<Condition> {
let field = self.field?;
let match_type = self.match_type?;
let value = self.value?;
let mut raw_condition: FWPM_FILTER_CONDITION0 = unsafe { std::mem::zeroed() };
raw_condition.fieldKey = *field.guid();
raw_condition.matchType = match_type as i32;
match &*value {
ConditionValue::UInt64(val) => {
raw_condition.conditionValue.r#type = FWP_UINT64;
raw_condition.conditionValue.Anonymous.uint64 = val as *const u64 as *mut u64;
}
ConditionValue::UInt32(val) => {
raw_condition.conditionValue.r#type = FWP_UINT32;
raw_condition.conditionValue.Anonymous.uint32 = *val;
}
ConditionValue::UInt16(val) => {
raw_condition.conditionValue.r#type = FWP_UINT16;
raw_condition.conditionValue.Anonymous.uint16 = *val;
}
ConditionValue::UInt8(val) => {
raw_condition.conditionValue.r#type = FWP_UINT8;
raw_condition.conditionValue.Anonymous.uint8 = *val;
}
ConditionValue::String(wide_str) => {
raw_condition.conditionValue.r#type = FWP_UNICODE_STRING_TYPE;
raw_condition.conditionValue.Anonymous.unicodeString = wide_str.as_ptr() as *mut _;
}
ConditionValue::ByteBlob { blob } => {
raw_condition.conditionValue.r#type = FWP_BYTE_BLOB_TYPE;
raw_condition.conditionValue.Anonymous.byteBlob = blob.as_ptr() as _;
}
ConditionValue::V4AddrMask(addr_and_mask) => {
raw_condition.conditionValue.r#type = FWP_V4_ADDR_MASK;
raw_condition.conditionValue.Anonymous.v4AddrMask =
addr_and_mask as *const _ as *mut _;
}
ConditionValue::V6AddrMask(addr_and_mask) => {
raw_condition.conditionValue.r#type = FWP_V6_ADDR_MASK;
raw_condition.conditionValue.Anonymous.v6AddrMask =
addr_and_mask as *const _ as *mut _;
}
}
Some(Condition {
raw_condition,
_value: value,
})
}
}
#[derive(Clone)]
pub struct Condition {
raw_condition: FWPM_FILTER_CONDITION0,
_value: Arc<ConditionValue>,
}
impl Condition {
pub(crate) fn raw_condition(&self) -> &FWPM_FILTER_CONDITION0 {
&self.raw_condition
}
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use super::*;
fn assert_field_key_eq(actual: &GUID, expected: &GUID) {
assert_eq!(actual.data1, expected.data1);
assert_eq!(actual.data2, expected.data2);
assert_eq!(actual.data3, expected.data3);
assert_eq!(actual.data4, expected.data4);
}
#[test]
fn test_condition_local_interface_luid() {
let luid: u64 = 0xDEAD_BEEF_1234_5678;
let condition = InterfaceConditionBuilder::local().luid(luid).build();
assert_field_key_eq(
&condition.raw_condition.fieldKey,
&FWPM_CONDITION_IP_LOCAL_INTERFACE,
);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(condition.raw_condition.conditionValue.r#type, FWP_UINT64);
let ptr = unsafe { condition.raw_condition.conditionValue.Anonymous.uint64 };
assert!(!ptr.is_null());
assert_eq!(unsafe { *ptr }, luid);
}
#[test]
fn test_condition_local_interface_pointer_stable_after_clone() {
let luid: u64 = 0xCAFE_BABE_DEAD_F00D;
let original = InterfaceConditionBuilder::local().luid(luid).build();
let cloned = original.clone();
let original_value = unsafe { *original.raw_condition.conditionValue.Anonymous.uint64 };
let cloned_value = unsafe { *cloned.raw_condition.conditionValue.Anonymous.uint64 };
assert_eq!(original_value, luid);
assert_eq!(cloned_value, luid);
}
#[test]
fn test_condition_port_remote() {
let condition = PortConditionBuilder::remote().equal(80).build();
assert_field_key_eq(
&condition.raw_condition.fieldKey,
&FWPM_CONDITION_IP_REMOTE_PORT,
);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(
unsafe { condition.raw_condition.conditionValue.Anonymous.uint16 },
80
);
}
#[test]
fn test_icmp_type_condition_equal() {
let condition = IcmpConditionBuilder::r#type().equal(135).build();
assert_field_key_eq(&condition.raw_condition.fieldKey, &FWPM_CONDITION_ICMP_TYPE);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(condition.raw_condition.conditionValue.r#type, FWP_UINT16);
assert_eq!(
unsafe { condition.raw_condition.conditionValue.Anonymous.uint16 },
135
);
}
#[test]
fn test_icmp_code_condition_equal() {
let condition = IcmpConditionBuilder::code().equal(0).build();
assert_field_key_eq(&condition.raw_condition.fieldKey, &FWPM_CONDITION_ICMP_CODE);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(condition.raw_condition.conditionValue.r#type, FWP_UINT16);
assert_eq!(
unsafe { condition.raw_condition.conditionValue.Anonymous.uint16 },
0
);
}
#[test]
fn test_icmpv6_protocol_condition() {
let condition = ProtocolConditionBuilder::icmpv6().build();
assert_field_key_eq(
&condition.raw_condition.fieldKey,
&FWPM_CONDITION_IP_PROTOCOL,
);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(condition.raw_condition.conditionValue.r#type, FWP_UINT8);
assert_eq!(
unsafe { condition.raw_condition.conditionValue.Anonymous.uint8 },
58
);
}
#[test]
fn test_v4_prefix_to_mask() {
assert_eq!(v4_prefix_to_mask(0), 0x00000000);
assert_eq!(v4_prefix_to_mask(8), 0xFF000000);
assert_eq!(v4_prefix_to_mask(16), 0xFFFF0000);
assert_eq!(v4_prefix_to_mask(24), 0xFFFFFF00);
assert_eq!(v4_prefix_to_mask(32), 0xFFFFFFFF);
}
#[test]
#[should_panic(expected = "IPv4 prefix length must be <= 32")]
fn test_v4_prefix_to_mask_too_large() {
let _ = v4_prefix_to_mask(33);
}
#[test]
fn test_subnet_v4_remote() {
let condition = IpAddressConditionBuilder::remote()
.subnet_v4(Ipv4Addr::new(192, 168, 0, 0), 16)
.build();
assert_field_key_eq(
&condition.raw_condition.fieldKey,
&FWPM_CONDITION_IP_REMOTE_ADDRESS,
);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(
condition.raw_condition.conditionValue.r#type,
FWP_V4_ADDR_MASK
);
let v4 = unsafe { &*condition.raw_condition.conditionValue.Anonymous.v4AddrMask };
assert_eq!(v4.addr, 0xC0A80000);
assert_eq!(v4.mask, 0xFFFF0000);
}
#[test]
fn test_subnet_v4_local() {
let condition = IpAddressConditionBuilder::local()
.subnet_v4(Ipv4Addr::new(127, 0, 0, 0), 8)
.build();
assert_field_key_eq(
&condition.raw_condition.fieldKey,
&FWPM_CONDITION_IP_LOCAL_ADDRESS,
);
let v4 = unsafe { &*condition.raw_condition.conditionValue.Anonymous.v4AddrMask };
assert_eq!(v4.addr, 0x7F000000);
assert_eq!(v4.mask, 0xFF000000);
}
#[test]
fn test_subnet_v6_remote() {
let condition = IpAddressConditionBuilder::remote()
.subnet_v6(Ipv6Addr::from_str("fe80::").unwrap(), 10)
.build();
assert_field_key_eq(
&condition.raw_condition.fieldKey,
&FWPM_CONDITION_IP_REMOTE_ADDRESS,
);
assert_eq!(condition.raw_condition.matchType, FWP_MATCH_EQUAL);
assert_eq!(
condition.raw_condition.conditionValue.r#type,
FWP_V6_ADDR_MASK
);
let v6 = unsafe { &*condition.raw_condition.conditionValue.Anonymous.v6AddrMask };
assert_eq!(v6.addr[0], 0xfe);
assert_eq!(v6.addr[1], 0x80);
for byte in &v6.addr[2..] {
assert_eq!(*byte, 0);
}
assert_eq!(v6.prefixLength, 10);
}
#[test]
#[should_panic(expected = "IPv6 prefix length must be <= 128")]
fn test_subnet_v6_prefix_too_large() {
let _ = IpAddressConditionBuilder::remote()
.subnet_v6(Ipv6Addr::UNSPECIFIED, 129)
.build();
}
}