use std::collections::HashSet;
use std::convert::TryFrom;
use std::net::IpAddr;
use typed_builder::TypedBuilder;
use windows::core::BSTR;
use windows::Win32::Foundation::VARIANT_BOOL;
use windows::Win32::NetworkManagement::WindowsFirewall::{INetFwRule, NetFwRule};
use windows::Win32::System::Com::CoCreateInstance;
use crate::constants::DWCLSCONTEXT;
use crate::errors::WindowsFirewallError;
use crate::firewall_enums::{
ActionFirewallWindows, DirectionFirewallWindows, ProfileFirewallWindows,
ProtocolFirewallWindows,
};
use crate::utils::{
convert_bstr_to_hashset, convert_hashset_to_bstr, hashset_to_variant, is_not_icmp,
is_not_tcp_or_udp, to_string_hashset, variant_to_hashset, with_com_initialized,
};
use crate::windows_firewall::{add_or_update, remove_rule, rule_exists, update_rule};
use crate::{add_rule, add_rule_if_not_exists, disable_rule, enable_rule, InterfaceTypes};
#[derive(Debug, Clone, TypedBuilder)]
pub struct WindowsFirewallRule {
#[builder(setter(into))]
name: String,
#[builder(setter(into))]
direction: DirectionFirewallWindows,
#[builder(setter(into))]
enabled: bool,
#[builder(setter(into))]
action: ActionFirewallWindows,
#[builder(default, setter(strip_option, into))]
description: Option<String>,
#[builder(default, setter(strip_option, into))]
application_name: Option<String>,
#[builder(default, setter(strip_option, into))]
service_name: Option<String>,
#[builder(default, setter(strip_option, into))]
protocol: Option<ProtocolFirewallWindows>,
#[builder(default, setter(strip_option, into))]
local_ports: Option<HashSet<u16>>,
#[builder(default, setter(strip_option, into))]
remote_ports: Option<HashSet<u16>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
local_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
remote_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(strip_option, into))]
icmp_types_and_codes: Option<String>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = impl Into<String>>| Some(to_string_hashset(items))))]
interfaces: Option<HashSet<String>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = InterfaceTypes>| Some(items.into_iter().collect())))]
interface_types: Option<HashSet<InterfaceTypes>>,
#[builder(default, setter(strip_option, into))]
grouping: Option<String>,
#[builder(default, setter(strip_option, into))]
profiles: Option<ProfileFirewallWindows>,
#[builder(default, setter(strip_option, into))]
edge_traversal: Option<bool>,
}
#[allow(clippy::must_use_candidate)]
impl WindowsFirewallRule {
pub fn add(&self) -> Result<(), WindowsFirewallError> {
add_rule(self)
}
pub fn add_if_not_exists(&self) -> Result<bool, WindowsFirewallError> {
add_rule_if_not_exists(self)
}
pub fn add_or_update(&self) -> Result<bool, WindowsFirewallError> {
add_or_update(self)
}
pub fn remove(self) -> Result<(), WindowsFirewallError> {
remove_rule(&self.name)?;
Ok(())
}
pub fn update(
&mut self,
settings: &WindowsFirewallRuleSettings,
) -> Result<(), WindowsFirewallError> {
update_rule(&self.name, settings)?;
if let Some(name) = &settings.name {
self.name = name.to_string();
}
if let Some(direction) = &settings.direction {
self.direction = *direction;
}
if let Some(enabled) = settings.enabled {
self.enabled = enabled;
}
if let Some(action) = &settings.action {
self.action = *action;
}
if let Some(description) = &settings.description {
self.description = Some(description.to_string());
}
if let Some(application_name) = &settings.application_name {
self.application_name = Some(application_name.to_string());
}
if let Some(service_name) = &settings.service_name {
self.service_name = Some(service_name.to_string());
}
if let Some(protocol) = &settings.protocol {
if is_not_tcp_or_udp(protocol) {
self.local_ports = None;
self.remote_ports = None;
}
if is_not_icmp(protocol) {
self.icmp_types_and_codes = None;
}
self.protocol = Some(*protocol);
}
if let Some(local_ports) = &settings.local_ports {
self.local_ports = Some(local_ports.clone());
}
if let Some(remote_ports) = &settings.remote_ports {
self.remote_ports = Some(remote_ports.clone());
}
if let Some(local_addresses) = &settings.local_addresses {
self.local_addresses = Some(local_addresses.clone());
}
if let Some(remote_addresses) = &settings.remote_addresses {
self.remote_addresses = Some(remote_addresses.clone());
}
if let Some(icmp_types_and_codes) = &settings.icmp_types_and_codes {
self.icmp_types_and_codes = Some(icmp_types_and_codes.to_string());
}
if let Some(interfaces) = &settings.interfaces {
self.interfaces = Some(interfaces.clone());
}
if let Some(interface_types) = &settings.interface_types {
self.interface_types = Some(interface_types.clone());
}
if let Some(grouping) = &settings.grouping {
self.grouping = Some(grouping.to_string());
}
if let Some(profiles) = &settings.profiles {
self.profiles = Some(*profiles);
}
if let Some(edge_traversal) = &settings.edge_traversal {
self.edge_traversal = Some(*edge_traversal);
}
Ok(())
}
pub fn disable(&mut self, disable: bool) -> Result<(), WindowsFirewallError> {
let action = if disable { disable_rule } else { enable_rule };
action(&self.name)?;
self.enabled = !disable;
Ok(())
}
pub fn exists(&self) -> Result<bool, WindowsFirewallError> {
rule_exists(&self.name)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn direction(&self) -> &DirectionFirewallWindows {
&self.direction
}
pub fn enabled(&self) -> bool {
self.enabled
}
pub fn action(&self) -> &ActionFirewallWindows {
&self.action
}
pub fn description(&self) -> Option<&String> {
self.description.as_ref()
}
pub fn application_name(&self) -> Option<&String> {
self.application_name.as_ref()
}
pub fn service_name(&self) -> Option<&String> {
self.service_name.as_ref()
}
pub fn protocol(&self) -> Option<&ProtocolFirewallWindows> {
self.protocol.as_ref()
}
pub fn local_ports(&self) -> Option<&HashSet<u16>> {
self.local_ports.as_ref()
}
pub fn remote_ports(&self) -> Option<&HashSet<u16>> {
self.remote_ports.as_ref()
}
pub fn local_addresses(&self) -> Option<&HashSet<IpAddr>> {
self.local_addresses.as_ref()
}
pub fn remote_addresses(&self) -> Option<&HashSet<IpAddr>> {
self.remote_addresses.as_ref()
}
pub fn icmp_types_and_codes(&self) -> Option<&String> {
self.icmp_types_and_codes.as_ref()
}
pub fn interfaces(&self) -> Option<&HashSet<String>> {
self.interfaces.as_ref()
}
pub fn interface_types(&self) -> Option<&HashSet<InterfaceTypes>> {
self.interface_types.as_ref()
}
pub fn grouping(&self) -> Option<&String> {
self.grouping.as_ref()
}
pub fn profiles(&self) -> Option<&ProfileFirewallWindows> {
self.profiles.as_ref()
}
pub fn edge_traversal(&self) -> Option<bool> {
self.edge_traversal
}
}
impl TryFrom<INetFwRule> for WindowsFirewallRule {
type Error = WindowsFirewallError;
fn try_from(fw_rule: INetFwRule) -> Result<Self, WindowsFirewallError> {
unsafe {
Ok(Self {
name: fw_rule.Name().map(|bstr| bstr.to_string())?,
direction: fw_rule.Direction()?.try_into()?,
enabled: fw_rule.Enabled()?.into(),
action: fw_rule.Action()?.try_into()?,
description: fw_rule.Description().ok().and_then(|bstr| {
let string = bstr.to_string();
if string.is_empty() {
None
} else {
Some(string)
}
}),
application_name: fw_rule.ApplicationName().ok().and_then(|bstr| {
let string = bstr.to_string();
if string.is_empty() {
None
} else {
Some(string)
}
}),
service_name: fw_rule.ServiceName().ok().and_then(|bstr| {
let string = bstr.to_string();
if string.is_empty() {
None
} else {
Some(string)
}
}),
protocol: fw_rule.Protocol()?.try_into().ok(),
local_ports: convert_bstr_to_hashset(fw_rule.LocalPorts()),
remote_ports: convert_bstr_to_hashset(fw_rule.RemotePorts()),
local_addresses: convert_bstr_to_hashset(fw_rule.LocalAddresses()),
remote_addresses: convert_bstr_to_hashset(fw_rule.RemoteAddresses()),
icmp_types_and_codes: fw_rule.IcmpTypesAndCodes().ok().and_then(|bstr| {
let string = bstr.to_string();
if string.is_empty() {
None
} else {
Some(string)
}
}),
interfaces: Some(variant_to_hashset(&fw_rule.Interfaces()?)?),
interface_types: convert_bstr_to_hashset(fw_rule.InterfaceTypes()),
grouping: fw_rule.Grouping().ok().and_then(|bstr| {
let string = bstr.to_string();
if string.is_empty() {
None
} else {
Some(string)
}
}),
profiles: fw_rule.Profiles()?.try_into().ok(),
edge_traversal: fw_rule.EdgeTraversal().ok().map(VARIANT_BOOL::as_bool),
})
}
}
}
impl TryFrom<&WindowsFirewallRule> for INetFwRule {
type Error = WindowsFirewallError;
fn try_from(rule: &WindowsFirewallRule) -> Result<Self, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_rule: Self = CoCreateInstance(&NetFwRule, None, DWCLSCONTEXT)?;
fw_rule.SetName(&BSTR::from(&rule.name))?;
fw_rule.SetDirection(rule.direction.into())?;
fw_rule.SetEnabled(rule.enabled.into())?;
fw_rule.SetAction(rule.action.into())?;
if let Some(ref description) = rule.description {
fw_rule.SetDescription(&BSTR::from(description))?;
}
if let Some(ref app_name) = rule.application_name {
fw_rule.SetApplicationName(&BSTR::from(app_name))?;
}
if let Some(ref service_name) = rule.service_name {
fw_rule.SetServiceName(&BSTR::from(service_name))?;
}
if let Some(protocol) = rule.protocol {
fw_rule.SetProtocol(protocol.into())?;
}
if let Some(ref local_ports) = rule.local_ports {
fw_rule.SetLocalPorts(&convert_hashset_to_bstr(Some(local_ports)))?;
}
if let Some(ref remote_ports) = rule.remote_ports {
fw_rule.SetRemotePorts(&convert_hashset_to_bstr(Some(remote_ports)))?;
}
if let Some(ref local_addresses) = rule.local_addresses {
fw_rule.SetLocalAddresses(&convert_hashset_to_bstr(Some(local_addresses)))?;
}
if let Some(ref remote_addresses) = rule.remote_addresses {
fw_rule.SetRemoteAddresses(&convert_hashset_to_bstr(Some(remote_addresses)))?;
}
if let Some(ref icmp_types_and_codes) = rule.icmp_types_and_codes {
fw_rule.SetIcmpTypesAndCodes(&BSTR::from(icmp_types_and_codes))?;
}
if let Some(edge_traversal) = rule.edge_traversal {
fw_rule.SetEdgeTraversal(edge_traversal.into())?;
}
if let Some(ref grouping) = rule.grouping {
fw_rule.SetGrouping(&BSTR::from(grouping))?;
}
if let Some(ref interface) = rule.interfaces {
fw_rule.SetInterfaces(&hashset_to_variant(interface)?)?;
}
if let Some(ref interface_types) = rule.interface_types {
fw_rule.SetInterfaceTypes(&convert_hashset_to_bstr(Some(interface_types)))?;
}
if let Some(profiles) = rule.profiles {
fw_rule.SetProfiles(profiles.into())?;
}
Ok(fw_rule)
})
}
}
#[derive(Debug, Clone, TypedBuilder)]
pub struct WindowsFirewallRuleSettings {
#[builder(default, setter(strip_option, into))]
pub(crate) name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) direction: Option<DirectionFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) enabled: Option<bool>,
#[builder(default, setter(strip_option, into))]
pub(crate) action: Option<ActionFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) description: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) application_name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) service_name: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) protocol: Option<ProtocolFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) local_ports: Option<HashSet<u16>>,
#[builder(default, setter(strip_option, into))]
pub(crate) remote_ports: Option<HashSet<u16>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
pub(crate) local_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = IpAddr>| Some(items.into_iter().collect())))]
pub(crate) remote_addresses: Option<HashSet<IpAddr>>,
#[builder(default, setter(strip_option, into))]
pub(crate) icmp_types_and_codes: Option<String>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = impl Into<String>>| Some(to_string_hashset(items))))]
pub(crate) interfaces: Option<HashSet<String>>,
#[builder(default, setter(transform = |items: impl IntoIterator<Item = InterfaceTypes>| Some(items.into_iter().collect())))]
pub(crate) interface_types: Option<HashSet<InterfaceTypes>>,
#[builder(default, setter(strip_option, into))]
pub(crate) grouping: Option<String>,
#[builder(default, setter(strip_option, into))]
pub(crate) profiles: Option<ProfileFirewallWindows>,
#[builder(default, setter(strip_option, into))]
pub(crate) edge_traversal: Option<bool>,
}
impl From<WindowsFirewallRule> for WindowsFirewallRuleSettings {
fn from(rule: WindowsFirewallRule) -> Self {
Self {
name: Some(rule.name),
direction: Some(rule.direction),
enabled: Some(rule.enabled),
action: Some(rule.action),
description: rule.description,
application_name: rule.application_name,
service_name: rule.service_name,
protocol: rule.protocol,
local_ports: rule.local_ports,
remote_ports: rule.remote_ports,
local_addresses: rule.local_addresses,
remote_addresses: rule.remote_addresses,
icmp_types_and_codes: rule.icmp_types_and_codes,
interfaces: rule.interfaces,
interface_types: rule.interface_types,
grouping: rule.grouping,
profiles: rule.profiles,
edge_traversal: rule.edge_traversal,
}
}
}