use std::collections::HashSet;
use std::convert::TryFrom;
use std::net::IpAddr;
use typed_builder::TypedBuilder;
use windows::core::BSTR;
use windows::Win32::Foundation::VARIANT_BOOL;
use windows::Win32::NetworkManagement::WindowsFirewall::{INetFwRule, NetFwRule};
use windows::Win32::System::Com::CoCreateInstance;
use crate::constants::DWCLSCONTEXT;
use crate::errors::{SetRuleError, WindowsFirewallError};
use crate::firewall_enums::{
ActionFirewallWindows, DirectionFirewallWindows, ProfileFirewallWindows,
ProtocolFirewallWindows,
};
use crate::utils::{
bstr_to_hashset, hashset_to_bstr, hashset_to_variant, is_not_icmp, is_not_tcp_or_udp,
to_string_hashset, variant_to_hashset, with_com_initialized,
};
use crate::windows_firewall::{add_or_update, remove_rule, rule_exists, update_rule};
use crate::{add_rule, add_rule_if_not_exists, enable_rule, InterfaceTypes};
#[derive(Debug, Clone, TypedBuilder)]
pub struct WindowsFirewallRule {
#[builder(setter(into))]
name: String,
#[builder(setter(into))]
direction: DirectionFirewallWindows,
#[builder(setter(into))]
enabled: bool,
#[builder(setter(into))]
action: ActionFirewallWindows,
#[builder(default, setter(strip_option, into))]
description: Option<String>,
#[builder(default, setter(strip_option, into))]
application_name: Option<String>,
#[builder(default, setter(strip_option, into))]
service_name: Option<String>,
#[builder(default, setter(strip_option, into))]
protocol: Option<ProtocolFirewallWindows>,
#[builder(default, setter(strip_option, into))]
local_ports: Option<HashSet<u16>>,
#[builder(default, setter(strip_option, into))]
remote_ports: Option<HashSet<u16>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
local_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
remote_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(strip_option, into))]
icmp_types_and_codes: Option<String>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = impl Into<String>>| Some(to_string_hashset(items))))]
interfaces: Option<HashSet<String>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = InterfaceTypes>| Some(items.into_iter().collect())))]
interface_types: Option<HashSet<InterfaceTypes>>,
#[builder(default, setter(strip_option, into))]
grouping: Option<String>,
#[builder(default, setter(strip_option, into))]
profiles: Option<ProfileFirewallWindows>,
#[builder(default, setter(strip_option, into))]
edge_traversal: Option<bool>,
}
#[allow(clippy::must_use_candidate)]
impl WindowsFirewallRule {
pub fn add(&self) -> Result<(), WindowsFirewallError> {
add_rule(self)
}
pub fn add_if_not_exists(&self) -> Result<bool, WindowsFirewallError> {
add_rule_if_not_exists(self)
}
pub fn add_or_update(&self) -> Result<bool, WindowsFirewallError> {
add_or_update(self)
}
pub fn remove(self) -> Result<(), WindowsFirewallError> {
remove_rule(&self.name)?;
Ok(())
}
pub fn update(
&mut self,
settings: &WindowsFirewallRuleSettings,
) -> Result<(), WindowsFirewallError> {
update_rule(&self.name, settings)?;
if let Some(name) = &settings.name {
self.name = name.to_string();
}
if let Some(direction) = &settings.direction {
self.direction = *direction;
}
if let Some(enabled) = settings.enabled {
self.enabled = enabled;
}
if let Some(action) = &settings.action {
self.action = *action;
}
if let Some(description) = &settings.description {
self.description = Some(description.to_string());
}
if let Some(application_name) = &settings.application_name {
self.application_name = Some(application_name.to_string());
}
if let Some(service_name) = &settings.service_name {
self.service_name = Some(service_name.to_string());
}
if let Some(protocol) = &settings.protocol {
if is_not_tcp_or_udp(protocol) {
self.local_ports = None;
self.remote_ports = None;
}
if is_not_icmp(protocol) {
self.icmp_types_and_codes = None;
}
self.protocol = Some(*protocol);
}
if let Some(local_ports) = &settings.local_ports {
self.local_ports = Some(local_ports.clone());
}
if let Some(remote_ports) = &settings.remote_ports {
self.remote_ports = Some(remote_ports.clone());
}
if let Some(local_addresses) = &settings.local_addresses {
self.local_addresses = Some(local_addresses.clone());
}
if let Some(remote_addresses) = &settings.remote_addresses {
self.remote_addresses = Some(remote_addresses.clone());
}
if let Some(icmp_types_and_codes) = &settings.icmp_types_and_codes {
self.icmp_types_and_codes = Some(icmp_types_and_codes.to_string());
}
if let Some(interfaces) = &settings.interfaces {
self.interfaces = Some(interfaces.clone());
}
if let Some(interface_types) = &settings.interface_types {
self.interface_types = Some(interface_types.clone());
}
if let Some(grouping) = &settings.grouping {
self.grouping = Some(grouping.to_string());
}
if let Some(profiles) = &settings.profiles {
self.profiles = Some(*profiles);
}
if let Some(edge_traversal) = &settings.edge_traversal {
self.edge_traversal = Some(*edge_traversal);
}
Ok(())
}
pub fn enable(&mut self, enable: bool) -> Result<(), WindowsFirewallError> {
enable_rule(&self.name, enable)?;
self.enabled = enable;
Ok(())
}
pub fn exists(&self) -> Result<bool, WindowsFirewallError> {
rule_exists(&self.name)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn direction(&self) -> &DirectionFirewallWindows {
&self.direction
}
pub fn enabled(&self) -> bool {
self.enabled
}
pub fn action(&self) -> &ActionFirewallWindows {
&self.action
}
pub fn description(&self) -> Option<&String> {
self.description.as_ref()
}
pub fn application_name(&self) -> Option<&String> {
self.application_name.as_ref()
}
pub fn service_name(&self) -> Option<&String> {
self.service_name.as_ref()
}
pub fn protocol(&self) -> Option<&ProtocolFirewallWindows> {
self.protocol.as_ref()
}
pub fn local_ports(&self) -> Option<&HashSet<u16>> {
self.local_ports.as_ref()
}
pub fn remote_ports(&self) -> Option<&HashSet<u16>> {
self.remote_ports.as_ref()
}
pub fn local_addresses(&self) -> Option<&HashSet<IpAddr>> {
self.local_addresses.as_ref()
}
pub fn remote_addresses(&self) -> Option<&HashSet<IpAddr>> {
self.remote_addresses.as_ref()
}
pub fn icmp_types_and_codes(&self) -> Option<&String> {
self.icmp_types_and_codes.as_ref()
}
pub fn interfaces(&self) -> Option<&HashSet<String>> {
self.interfaces.as_ref()
}
pub fn interface_types(&self) -> Option<&HashSet<InterfaceTypes>> {
self.interface_types.as_ref()
}
pub fn grouping(&self) -> Option<&String> {
self.grouping.as_ref()
}
pub fn profiles(&self) -> Option<&ProfileFirewallWindows> {
self.profiles.as_ref()
}
pub fn edge_traversal(&self) -> Option<bool> {
self.edge_traversal
}
pub fn set_name(&mut self, name: impl Into<String>) -> &mut Self {
self.name = name.into();
self
}
pub fn set_direction(&mut self, direction: DirectionFirewallWindows) -> &mut Self {
self.direction = direction;
self
}
pub fn set_enabled(&mut self, enabled: bool) -> &mut Self {
self.enabled = enabled;
self
}
pub fn set_action(&mut self, action: ActionFirewallWindows) -> &mut Self {
self.action = action;
self
}
pub fn set_description(&mut self, description: Option<impl Into<String>>) -> &mut Self {
self.description = description.map(|d| d.into());
self
}
pub fn set_application_name(
&mut self,
application_name: Option<impl Into<String>>,
) -> &mut Self {
self.application_name = application_name.map(|a| a.into());
self
}
pub fn set_service_name(&mut self, service_name: Option<impl Into<String>>) -> &mut Self {
self.service_name = service_name.map(|s| s.into());
self
}
pub fn set_protocol(&mut self, protocol: Option<ProtocolFirewallWindows>) -> &mut Self {
self.protocol = protocol;
self
}
pub fn set_local_ports(&mut self, local_ports: Option<HashSet<u16>>) -> &mut Self {
self.local_ports = local_ports;
self
}
pub fn set_remote_ports(&mut self, remote_ports: Option<HashSet<u16>>) -> &mut Self {
self.remote_ports = remote_ports;
self
}
pub fn set_local_addresses(&mut self, local_addresses: Option<HashSet<IpAddr>>) -> &mut Self {
self.local_addresses = local_addresses;
self
}
pub fn set_remote_addresses(&mut self, remote_addresses: Option<HashSet<IpAddr>>) -> &mut Self {
self.remote_addresses = remote_addresses;
self
}
pub fn set_icmp_types_and_codes(
&mut self,
icmp_types_and_codes: Option<impl Into<String>>,
) -> &mut Self {
self.icmp_types_and_codes = icmp_types_and_codes.map(|i| i.into());
self
}
pub fn set_interfaces(&mut self, interfaces: Option<HashSet<String>>) -> &mut Self {
self.interfaces = interfaces;
self
}
pub fn set_interface_types(
&mut self,
interface_types: Option<HashSet<InterfaceTypes>>,
) -> &mut Self {
self.interface_types = interface_types;
self
}
pub fn set_grouping(&mut self, grouping: Option<impl Into<String>>) -> &mut Self {
self.grouping = grouping.map(|g| g.into());
self
}
pub fn set_profiles(&mut self, profiles: Option<ProfileFirewallWindows>) -> &mut Self {
self.profiles = profiles;
self
}
pub fn set_edge_traversal(&mut self, edge_traversal: Option<bool>) -> &mut Self {
self.edge_traversal = edge_traversal;
self
}
}
impl TryFrom<INetFwRule> for WindowsFirewallRule {
type Error = WindowsFirewallError;
fn try_from(fw_rule: INetFwRule) -> Result<Self, WindowsFirewallError> {
unsafe {
Ok(Self {
name: fw_rule.Name().map(|bstr| bstr.to_string())?,
direction: fw_rule.Direction()?.try_into()?,
enabled: fw_rule.Enabled()?.into(),
action: fw_rule.Action()?.try_into()?,
description: fw_rule
.Description()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
application_name: fw_rule
.ApplicationName()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
service_name: fw_rule
.ServiceName()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
protocol: fw_rule.Protocol()?.try_into().ok(),
local_ports: bstr_to_hashset(fw_rule.LocalPorts()),
remote_ports: bstr_to_hashset(fw_rule.RemotePorts()),
local_addresses: bstr_to_hashset(fw_rule.LocalAddresses()),
remote_addresses: bstr_to_hashset(fw_rule.RemoteAddresses()),
icmp_types_and_codes: fw_rule
.IcmpTypesAndCodes()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
interfaces: Some(variant_to_hashset(&fw_rule.Interfaces()?)?),
interface_types: bstr_to_hashset(fw_rule.InterfaceTypes()),
grouping: fw_rule
.Grouping()
.ok()
.map(|bstr| bstr.to_string())
.filter(|s| !s.is_empty()),
profiles: fw_rule.Profiles()?.try_into().ok(),
edge_traversal: fw_rule.EdgeTraversal().ok().map(VARIANT_BOOL::as_bool),
})
}
}
}
impl TryFrom<&WindowsFirewallRule> for INetFwRule {
type Error = WindowsFirewallError;
fn try_from(rule: &WindowsFirewallRule) -> Result<Self, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_rule: Self = CoCreateInstance(&NetFwRule, None, DWCLSCONTEXT)?;
fw_rule
.SetName(&BSTR::from(&rule.name))
.map_err(SetRuleError::Name)?;
fw_rule
.SetDirection(rule.direction.into())
.map_err(SetRuleError::Direction)?;
fw_rule
.SetEnabled(rule.enabled.into())
.map_err(SetRuleError::Enabled)?;
fw_rule
.SetAction(rule.action.into())
.map_err(SetRuleError::Action)?;
if let Some(ref description) = rule.description {
fw_rule
.SetDescription(&BSTR::from(description))
.map_err(SetRuleError::Description)?;
}
if let Some(ref app_name) = rule.application_name {
fw_rule
.SetApplicationName(&BSTR::from(app_name))
.map_err(SetRuleError::ApplicationName)?;
}
if let Some(ref service_name) = rule.service_name {
fw_rule
.SetServiceName(&BSTR::from(service_name))
.map_err(SetRuleError::ServiceName)?;
}
if let Some(protocol) = rule.protocol {
fw_rule
.SetProtocol(protocol.into())
.map_err(SetRuleError::Protocol)?;
}
if let Some(ref local_ports) = rule.local_ports {
fw_rule
.SetLocalPorts(&hashset_to_bstr(Some(local_ports)))
.map_err(SetRuleError::LocalPorts)?;
}
if let Some(ref remote_ports) = rule.remote_ports {
fw_rule
.SetRemotePorts(&hashset_to_bstr(Some(remote_ports)))
.map_err(SetRuleError::RemotePorts)?;
}
if let Some(ref local_addresses) = rule.local_addresses {
fw_rule
.SetLocalAddresses(&hashset_to_bstr(Some(local_addresses)))
.map_err(SetRuleError::LocalAddresses)?;
}
if let Some(ref remote_addresses) = rule.remote_addresses {
fw_rule
.SetRemoteAddresses(&hashset_to_bstr(Some(remote_addresses)))
.map_err(SetRuleError::RemoteAddresses)?;
}
if let Some(ref icmp_types_and_codes) = rule.icmp_types_and_codes {
fw_rule
.SetIcmpTypesAndCodes(&BSTR::from(icmp_types_and_codes))
.map_err(SetRuleError::IcmpTypesAndCodes)?;
}
if let Some(edge_traversal) = rule.edge_traversal {
fw_rule
.SetEdgeTraversal(edge_traversal.into())
.map_err(SetRuleError::EdgeTraversal)?;
}
if let Some(ref grouping) = rule.grouping {
fw_rule
.SetGrouping(&BSTR::from(grouping))
.map_err(SetRuleError::Grouping)?;
}
if let Some(ref interface) = rule.interfaces {
fw_rule
.SetInterfaces(&hashset_to_variant(interface)?)
.map_err(SetRuleError::Interfaces)?;
}
if let Some(ref interface_types) = rule.interface_types {
fw_rule
.SetInterfaceTypes(&hashset_to_bstr(Some(interface_types)))
.map_err(SetRuleError::InterfaceTypes)?;
}
if let Some(profiles) = rule.profiles {
fw_rule
.SetProfiles(profiles.into())
.map_err(SetRuleError::Profiles)?;
}
Ok(fw_rule)
})
}
}
#[derive(Debug, Clone, TypedBuilder)]
pub struct WindowsFirewallRuleSettings {
#[builder(default, setter(strip_option, into))]
pub(crate) name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) direction: Option<DirectionFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) enabled: Option<bool>,
#[builder(default, setter(strip_option, into))]
pub(crate) action: Option<ActionFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) description: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) application_name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) service_name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) protocol: Option<ProtocolFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) local_ports: Option<HashSet<u16>>,
#[builder(default, setter(strip_option, into))]
pub(crate) remote_ports: Option<HashSet<u16>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
pub(crate) local_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
pub(crate) remote_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(strip_option, into))]
pub(crate) icmp_types_and_codes: Option<String>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = impl Into<String>>| Some(to_string_hashset(items))))]
pub(crate) interfaces: Option<HashSet<String>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = InterfaceTypes>| Some(items.into_iter().collect())))]
pub(crate) interface_types: Option<HashSet<InterfaceTypes>>,
#[builder(default, setter(strip_option, into))]
pub(crate) grouping: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) profiles: Option<ProfileFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) edge_traversal: Option<bool>,
}
impl From<WindowsFirewallRule> for WindowsFirewallRuleSettings {
fn from(rule: WindowsFirewallRule) -> Self {
Self {
name: Some(rule.name),
direction: Some(rule.direction),
enabled: Some(rule.enabled),
action: Some(rule.action),
description: rule.description,
application_name: rule.application_name,
service_name: rule.service_name,
protocol: rule.protocol,
local_ports: rule.local_ports,
remote_ports: rule.remote_ports,
local_addresses: rule.local_addresses,
remote_addresses: rule.remote_addresses,
icmp_types_and_codes: rule.icmp_types_and_codes,
interfaces: rule.interfaces,
interface_types: rule.interface_types,
grouping: rule.grouping,
profiles: rule.profiles,
edge_traversal: rule.edge_traversal,
}
}
}