use getset::{Getters, Setters};
use std::collections::HashSet;
use std::convert::TryFrom;
use std::net::IpAddr;
use typed_builder::TypedBuilder;
use windows::Win32::Foundation::VARIANT_BOOL;
use windows::Win32::NetworkManagement::WindowsFirewall::{INetFwRule, NetFwRule};
use windows::Win32::System::Com::CoCreateInstance;
use windows::core::BSTR;
use crate::constants::DWCLSCONTEXT;
use crate::errors::{SetRuleError, WindowsFirewallError};
use crate::firewall_enums::{
ActionFirewallWindows, DirectionFirewallWindows, ProfileFirewallWindows,
ProtocolFirewallWindows,
};
use crate::utils::{
bstr_to_hashset, hashset_to_bstr, hashset_to_variant, is_not_icmp, is_not_tcp_or_udp,
to_string_hashset, variant_to_hashset, with_com_initialized,
};
use crate::windows_firewall::{add_rule_or_update, remove_rule, rule_exists, update_rule};
use crate::{InterfaceTypes, add_rule, add_rule_if_not_exists, enable_rule};
#[derive(Debug, Clone, Getters, Setters, TypedBuilder)]
pub struct WindowsFirewallRule {
#[builder(setter(into))]
#[getset(get = "pub", set = "pub")]
name: String,
#[builder(setter(into))]
#[getset(get = "pub", set = "pub")]
direction: DirectionFirewallWindows,
#[builder(setter(into))]
#[getset(get = "pub", set = "pub")]
enabled: bool,
#[builder(setter(into))]
#[getset(get = "pub", set = "pub")]
action: ActionFirewallWindows,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
description: Option<String>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
application_name: Option<String>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
service_name: Option<String>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
protocol: Option<ProtocolFirewallWindows>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
local_ports: Option<HashSet<u16>>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
remote_ports: Option<HashSet<u16>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
#[getset(get = "pub", set = "pub")]
local_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
#[getset(get = "pub", set = "pub")]
remote_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
icmp_types_and_codes: Option<String>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = impl Into<String>>| Some(to_string_hashset(items))))]
#[getset(get = "pub", set = "pub")]
interfaces: Option<HashSet<String>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = InterfaceTypes>| Some(items.into_iter().collect())))]
#[getset(get = "pub", set = "pub")]
interface_types: Option<HashSet<InterfaceTypes>>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
grouping: Option<String>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
profiles: Option<ProfileFirewallWindows>,
#[builder(default, setter(strip_option, into))]
#[getset(get = "pub", set = "pub")]
edge_traversal: Option<bool>,
}
impl WindowsFirewallRule {
pub fn add(&self) -> Result<(), WindowsFirewallError> {
add_rule(self)
}
pub fn add_if_not_exists(&self) -> Result<bool, WindowsFirewallError> {
add_rule_if_not_exists(self)
}
pub fn add_or_update(&self) -> Result<bool, WindowsFirewallError> {
add_rule_or_update(self)
}
pub fn remove(self) -> Result<(), WindowsFirewallError> {
remove_rule(&self.name)?;
Ok(())
}
pub fn update(
&mut self,
settings: &WindowsFirewallRuleSettings,
) -> Result<(), WindowsFirewallError> {
update_rule(&self.name, settings)?;
if let Some(name) = &settings.name {
self.name = name.clone();
}
if let Some(direction) = &settings.direction {
self.direction = *direction;
}
if let Some(enabled) = settings.enabled {
self.enabled = enabled;
}
if let Some(action) = &settings.action {
self.action = *action;
}
if let Some(description) = &settings.description {
self.description = Some(description.clone());
}
if let Some(application_name) = &settings.application_name {
self.application_name = Some(application_name.clone());
}
if let Some(service_name) = &settings.service_name {
self.service_name = Some(service_name.clone());
}
if let Some(protocol) = &settings.protocol {
if is_not_tcp_or_udp(*protocol) {
self.local_ports = None;
self.remote_ports = None;
}
if is_not_icmp(*protocol) {
self.icmp_types_and_codes = None;
}
self.protocol = Some(*protocol);
}
if let Some(local_ports) = &settings.local_ports {
self.local_ports = Some(local_ports.clone());
}
if let Some(remote_ports) = &settings.remote_ports {
self.remote_ports = Some(remote_ports.clone());
}
if let Some(local_addresses) = &settings.local_addresses {
self.local_addresses = Some(local_addresses.clone());
}
if let Some(remote_addresses) = &settings.remote_addresses {
self.remote_addresses = Some(remote_addresses.clone());
}
if let Some(icmp_types_and_codes) = &settings.icmp_types_and_codes {
self.icmp_types_and_codes = Some(icmp_types_and_codes.clone());
}
if let Some(interfaces) = &settings.interfaces {
self.interfaces = Some(interfaces.clone());
}
if let Some(interface_types) = &settings.interface_types {
self.interface_types = Some(interface_types.clone());
}
if let Some(grouping) = &settings.grouping {
self.grouping = Some(grouping.clone());
}
if let Some(profiles) = &settings.profiles {
self.profiles = Some(*profiles);
}
if let Some(edge_traversal) = &settings.edge_traversal {
self.edge_traversal = Some(*edge_traversal);
}
Ok(())
}
pub fn enable(&mut self, enable: bool) -> Result<(), WindowsFirewallError> {
enable_rule(&self.name, enable)?;
self.enabled = enable;
Ok(())
}
pub fn exists(&self) -> Result<bool, WindowsFirewallError> {
rule_exists(&self.name)
}
}
impl TryFrom<INetFwRule> for WindowsFirewallRule {
type Error = WindowsFirewallError;
fn try_from(fw_rule: INetFwRule) -> Result<Self, WindowsFirewallError> {
unsafe {
Ok(Self {
name: fw_rule.Name().map(|bstr| bstr.to_string())?,
direction: fw_rule.Direction()?.try_into()?,
enabled: fw_rule.Enabled()?.into(),
action: fw_rule.Action()?.try_into()?,
description: fw_rule
.Description()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
application_name: fw_rule
.ApplicationName()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
service_name: fw_rule
.ServiceName()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
protocol: fw_rule.Protocol()?.try_into().ok(),
local_ports: bstr_to_hashset(fw_rule.LocalPorts()),
remote_ports: bstr_to_hashset(fw_rule.RemotePorts()),
local_addresses: bstr_to_hashset(fw_rule.LocalAddresses()),
remote_addresses: bstr_to_hashset(fw_rule.RemoteAddresses()),
icmp_types_and_codes: fw_rule
.IcmpTypesAndCodes()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
interfaces: Some(variant_to_hashset(&fw_rule.Interfaces()?)?),
interface_types: bstr_to_hashset(fw_rule.InterfaceTypes()),
grouping: fw_rule
.Grouping()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
profiles: fw_rule.Profiles()?.try_into().ok(),
edge_traversal: fw_rule.EdgeTraversal().ok().map(VARIANT_BOOL::as_bool),
})
}
}
}
impl TryFrom<&WindowsFirewallRule> for INetFwRule {
type Error = WindowsFirewallError;
fn try_from(rule: &WindowsFirewallRule) -> Result<Self, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_rule: Self = CoCreateInstance(&NetFwRule, None, DWCLSCONTEXT)?;
fw_rule
.SetName(&BSTR::from(&rule.name))
.map_err(SetRuleError::Name)?;
fw_rule
.SetDirection(rule.direction.into())
.map_err(SetRuleError::Direction)?;
fw_rule
.SetEnabled(rule.enabled.into())
.map_err(SetRuleError::Enabled)?;
fw_rule
.SetAction(rule.action.into())
.map_err(SetRuleError::Action)?;
if let Some(ref description) = rule.description {
fw_rule
.SetDescription(&BSTR::from(description))
.map_err(SetRuleError::Description)?;
}
if let Some(ref app_name) = rule.application_name {
fw_rule
.SetApplicationName(&BSTR::from(app_name))
.map_err(SetRuleError::ApplicationName)?;
}
if let Some(ref service_name) = rule.service_name {
fw_rule
.SetServiceName(&BSTR::from(service_name))
.map_err(SetRuleError::ServiceName)?;
}
if let Some(protocol) = rule.protocol {
fw_rule
.SetProtocol(protocol.into())
.map_err(SetRuleError::Protocol)?;
}
if let Some(ref local_ports) = rule.local_ports {
fw_rule
.SetLocalPorts(&hashset_to_bstr(Some(local_ports)))
.map_err(SetRuleError::LocalPorts)?;
}
if let Some(ref remote_ports) = rule.remote_ports {
fw_rule
.SetRemotePorts(&hashset_to_bstr(Some(remote_ports)))
.map_err(SetRuleError::RemotePorts)?;
}
if let Some(ref local_addresses) = rule.local_addresses {
fw_rule
.SetLocalAddresses(&hashset_to_bstr(Some(local_addresses)))
.map_err(SetRuleError::LocalAddresses)?;
}
if let Some(ref remote_addresses) = rule.remote_addresses {
fw_rule
.SetRemoteAddresses(&hashset_to_bstr(Some(remote_addresses)))
.map_err(SetRuleError::RemoteAddresses)?;
}
if let Some(ref icmp_types_and_codes) = rule.icmp_types_and_codes {
fw_rule
.SetIcmpTypesAndCodes(&BSTR::from(icmp_types_and_codes))
.map_err(SetRuleError::IcmpTypesAndCodes)?;
}
if let Some(edge_traversal) = rule.edge_traversal {
fw_rule
.SetEdgeTraversal(edge_traversal.into())
.map_err(SetRuleError::EdgeTraversal)?;
}
if let Some(ref grouping) = rule.grouping {
fw_rule
.SetGrouping(&BSTR::from(grouping))
.map_err(SetRuleError::Grouping)?;
}
if let Some(ref interface) = rule.interfaces {
fw_rule
.SetInterfaces(&hashset_to_variant(interface)?)
.map_err(SetRuleError::Interfaces)?;
}
if let Some(ref interface_types) = rule.interface_types {
fw_rule
.SetInterfaceTypes(&hashset_to_bstr(Some(interface_types)))
.map_err(SetRuleError::InterfaceTypes)?;
}
if let Some(profiles) = rule.profiles {
fw_rule
.SetProfiles(profiles.into())
.map_err(SetRuleError::Profiles)?;
}
Ok(fw_rule)
})
}
}
#[derive(Debug, Clone, TypedBuilder)]
pub struct WindowsFirewallRuleSettings {
#[builder(default, setter(strip_option, into))]
pub(crate) name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) direction: Option<DirectionFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) enabled: Option<bool>,
#[builder(default, setter(strip_option, into))]
pub(crate) action: Option<ActionFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) description: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) application_name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) service_name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) protocol: Option<ProtocolFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) local_ports: Option<HashSet<u16>>,
#[builder(default, setter(strip_option, into))]
pub(crate) remote_ports: Option<HashSet<u16>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
pub(crate) local_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
pub(crate) remote_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(strip_option, into))]
pub(crate) icmp_types_and_codes: Option<String>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = impl Into<String>>| Some(to_string_hashset(items))))]
pub(crate) interfaces: Option<HashSet<String>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = InterfaceTypes>| Some(items.into_iter().collect())))]
pub(crate) interface_types: Option<HashSet<InterfaceTypes>>,
#[builder(default, setter(strip_option, into))]
pub(crate) grouping: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) profiles: Option<ProfileFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) edge_traversal: Option<bool>,
}
impl From<WindowsFirewallRule> for WindowsFirewallRuleSettings {
fn from(rule: WindowsFirewallRule) -> Self {
Self {
name: Some(rule.name),
direction: Some(rule.direction),
enabled: Some(rule.enabled),
action: Some(rule.action),
description: rule.description,
application_name: rule.application_name,
service_name: rule.service_name,
protocol: rule.protocol,
local_ports: rule.local_ports,
remote_ports: rule.remote_ports,
local_addresses: rule.local_addresses,
remote_addresses: rule.remote_addresses,
icmp_types_and_codes: rule.icmp_types_and_codes,
interfaces: rule.interfaces,
interface_types: rule.interface_types,
grouping: rule.grouping,
profiles: rule.profiles,
edge_traversal: rule.edge_traversal,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::net::IpAddr;
use std::str::FromStr;
#[test]
fn test_windows_firewall_rule_setters() {
let mut rule = WindowsFirewallRule::builder()
.name("test")
.action(ActionFirewallWindows::Block)
.direction(DirectionFirewallWindows::Out)
.enabled(false)
.build();
rule.set_name("new_name".to_string());
assert_eq!(rule.name(), "new_name");
rule.set_direction(DirectionFirewallWindows::In);
assert_eq!(rule.direction(), &DirectionFirewallWindows::In);
rule.set_enabled(true);
assert!(rule.enabled());
rule.set_action(ActionFirewallWindows::Allow);
assert_eq!(rule.action(), &ActionFirewallWindows::Allow);
let desc: Option<String> = Some("desc".to_string());
rule.set_description(desc);
assert_eq!(*rule.description(), Some("desc".to_string()));
rule.set_description(None::<String>);
assert_eq!(*rule.description(), None);
let app: Option<String> = Some("app.exe".to_string());
rule.set_application_name(app);
assert_eq!(*rule.application_name(), Some("app.exe".to_string()));
rule.set_application_name(None::<String>);
assert_eq!(*rule.application_name(), None);
let svc: Option<String> = Some("svc".to_string());
rule.set_service_name(svc);
assert_eq!(*rule.service_name(), Some("svc".to_string()));
rule.set_service_name(None::<String>);
assert_eq!(*rule.service_name(), None);
rule.set_protocol(Some(ProtocolFirewallWindows::Tcp));
assert_eq!(*rule.protocol(), Some(ProtocolFirewallWindows::Tcp));
rule.set_protocol(None);
assert_eq!(*rule.protocol(), None);
let mut ports = HashSet::new();
ports.insert(80);
rule.set_local_ports(Some(ports.clone()));
assert_eq!(*rule.local_ports(), Some(ports));
rule.set_local_ports(None);
assert_eq!(*rule.local_ports(), None);
let mut rports = HashSet::new();
rports.insert(443);
rule.set_remote_ports(Some(rports.clone()));
assert_eq!(*rule.remote_ports(), Some(rports));
rule.set_remote_ports(None);
assert_eq!(*rule.remote_ports(), None);
let mut addrs = HashSet::new();
addrs.insert(IpAddr::from_str("127.0.0.1").unwrap());
rule.set_local_addresses(Some(addrs.clone()));
assert_eq!(*rule.local_addresses(), Some(addrs));
rule.set_local_addresses(None);
assert_eq!(*rule.local_addresses(), None);
let mut raddrs = HashSet::new();
raddrs.insert(IpAddr::from_str("8.8.8.8").unwrap());
rule.set_remote_addresses(Some(raddrs.clone()));
assert_eq!(*rule.remote_addresses(), Some(raddrs));
rule.set_remote_addresses(None);
assert_eq!(*rule.remote_addresses(), None);
let icmp: Option<String> = Some("8:0".to_string());
rule.set_icmp_types_and_codes(icmp);
assert_eq!(*rule.icmp_types_and_codes(), Some("8:0".to_string()));
rule.set_icmp_types_and_codes(None::<String>);
assert_eq!(*rule.icmp_types_and_codes(), None);
let mut interfaces = HashSet::new();
interfaces.insert("Wi-Fi".to_string());
rule.set_interfaces(Some(interfaces.clone()));
assert_eq!(*rule.interfaces(), Some(interfaces));
rule.set_interfaces(None);
assert_eq!(*rule.interfaces(), None);
let mut iftypes = HashSet::new();
iftypes.insert(InterfaceTypes::Lan);
rule.set_interface_types(Some(iftypes.clone()));
assert_eq!(*rule.interface_types(), Some(iftypes));
rule.set_interface_types(None);
assert_eq!(*rule.interface_types(), None);
let group: Option<String> = Some("group".to_string());
rule.set_grouping(group);
assert_eq!(*rule.grouping(), Some("group".to_string()));
rule.set_grouping(None::<String>);
assert_eq!(*rule.grouping(), None);
rule.set_profiles(Some(ProfileFirewallWindows::Private));
assert_eq!(*rule.profiles(), Some(ProfileFirewallWindows::Private));
rule.set_profiles(None);
assert_eq!(*rule.profiles(), None);
rule.set_edge_traversal(Some(true));
assert_eq!(*rule.edge_traversal(), Some(true));
rule.set_edge_traversal(None);
assert_eq!(*rule.edge_traversal(), None);
}
}