use scopeguard::guard;
use std::convert::TryFrom;
use std::mem::ManuallyDrop;
use tracing::error;
use windows::core::{Interface, BSTR};
use windows::Win32::NetworkManagement::WindowsFirewall::{
INetFwPolicy2, INetFwRule, INetFwRules, NetFwPolicy2, NET_FW_PROFILE_TYPE2,
};
use windows::Win32::System::Com::CoCreateInstance;
use windows::Win32::System::Ole::IEnumVARIANT;
use windows::Win32::System::Variant::VARIANT;
use crate::constants::DWCLSCONTEXT;
use crate::errors::WindowsFirewallError;
use crate::firewall_enums::ProfileFirewallWindows;
use crate::firewall_rule::{WindowsFirewallRule, WindowsFirewallRuleSettings};
use crate::utils::{
convert_hashset_to_bstr, hashset_to_variant, is_not_icmp, is_not_tcp_or_udp,
with_com_initialized,
};
use crate::DirectionFirewallWindows;
pub fn rule_exists(name: &str) -> Result<bool, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(name);
let exist = fw_rules.Item(&rule_name).is_ok();
Ok(exist)
})
}
pub fn get_rule(name: &str) -> Result<WindowsFirewallRule, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(name);
let rule = fw_rules.Item(&rule_name);
WindowsFirewallRule::try_from(rule?)
})
}
pub fn add_rule(rule: &WindowsFirewallRule) -> Result<(), WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let new_rule: INetFwRule = rule.try_into()?;
fw_rules.Add(&new_rule)?;
Ok(())
})
}
pub fn add_rule_if_not_exists(rule: &WindowsFirewallRule) -> Result<bool, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(rule.name());
let exist = fw_rules.Item(&rule_name).is_ok();
if exist {
return Ok(false);
}
let new_rule: INetFwRule = rule.try_into()?;
fw_rules.Add(&new_rule)?;
Ok(true)
})
}
pub fn add_or_update(rule: &WindowsFirewallRule) -> Result<bool, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(rule.name());
let fw_rule_result = fw_rules.Item(&rule_name);
if let Ok(existing_rule) = fw_rule_result {
let settings = rule.clone().into();
update_inetfw_rule(&existing_rule, &settings)?;
return Ok(false);
}
let new_rule: INetFwRule = rule.try_into()?;
fw_rules.Add(&new_rule)?;
Ok(true)
})
}
pub fn update_rule(
rule_name: &str,
settings: &WindowsFirewallRuleSettings,
) -> Result<(), WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(rule_name);
let rule = fw_rules.Item(&rule_name)?;
update_inetfw_rule(&rule, settings)?;
Ok(())
})
}
unsafe fn update_inetfw_rule(
rule: &INetFwRule,
settings: &WindowsFirewallRuleSettings,
) -> Result<(), windows::core::Error> {
if let Some(name) = &settings.name {
rule.SetName(&BSTR::from(name))?;
}
if let Some(direction) = settings.direction {
rule.SetDirection(direction.into())?;
}
if let Some(enabled) = settings.enabled {
rule.SetEnabled(enabled.into())?;
}
if let Some(action) = settings.action {
rule.SetAction(action.into())?;
}
if let Some(description) = &settings.description {
rule.SetDescription(&BSTR::from(description))?;
}
if let Some(application_name) = &settings.application_name {
rule.SetApplicationName(&BSTR::from(application_name))?;
}
if let Some(service_name) = &settings.service_name {
rule.SetServiceName(&BSTR::from(service_name))?;
}
if let Some(protocol) = settings.protocol {
if is_not_tcp_or_udp(&protocol) {
let _ = rule.SetLocalPorts(&BSTR::from(""));
let _ = rule.SetRemotePorts(&BSTR::from(""));
}
if is_not_icmp(&protocol) {
let _ = rule.SetIcmpTypesAndCodes(&BSTR::from(""));
}
rule.SetProtocol(protocol.into())?;
}
if let Some(local_ports) = &settings.local_ports {
rule.SetLocalPorts(&convert_hashset_to_bstr(Some(local_ports)))?;
}
if let Some(remote_ports) = &settings.remote_ports {
rule.SetRemotePorts(&convert_hashset_to_bstr(Some(remote_ports)))?;
}
if let Some(local_addresses) = &settings.local_addresses {
rule.SetLocalAddresses(&convert_hashset_to_bstr(Some(local_addresses)))?;
}
if let Some(remote_addresses) = &settings.remote_addresses {
rule.SetRemoteAddresses(&convert_hashset_to_bstr(Some(remote_addresses)))?;
}
if let Some(icmp_types_and_codes) = &settings.icmp_types_and_codes {
rule.SetIcmpTypesAndCodes(&BSTR::from(icmp_types_and_codes))?;
}
if let Some(edge_traversal) = settings.edge_traversal {
rule.SetEdgeTraversal(edge_traversal.into())?;
}
if let Some(grouping) = &settings.grouping {
rule.SetGrouping(&BSTR::from(grouping))?;
}
if let Some(interfaces) = &settings.interfaces {
rule.SetInterfaces(&hashset_to_variant(interfaces)?)?;
}
if let Some(interface_types) = &settings.interface_types {
rule.SetInterfaceTypes(&convert_hashset_to_bstr(Some(interface_types)))?;
}
if let Some(profiles) = settings.profiles {
rule.SetProfiles(profiles.into())?;
}
Ok(())
}
pub fn disable_rule(rule_name: &str) -> Result<(), WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(rule_name);
let rule = fw_rules.Item(&rule_name)?;
rule.SetEnabled(false.into())?;
Ok(())
})
}
pub fn enable_rule(rule_name: &str) -> Result<(), WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(rule_name);
let rule = fw_rules.Item(&rule_name)?;
rule.SetEnabled(true.into())?;
Ok(())
})
}
pub fn remove_rule(rule_name: &str) -> Result<(), WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rule_name = BSTR::from(rule_name);
fw_rules.Remove(&rule_name)?;
Ok(())
})
}
pub fn list_rules() -> Result<Vec<WindowsFirewallRule>, WindowsFirewallError> {
let mut rules_list = Vec::new();
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let fw_rules: INetFwRules = fw_policy.Rules()?;
let rules_count = fw_rules.Count()?;
let enumerator = fw_rules._NewEnum()?.cast::<IEnumVARIANT>()?;
let mut variants: [VARIANT; 1] = Default::default();
let mut pceltfetch: u32 = 0;
for _ in 0..rules_count {
let fetched = enumerator.Next(&mut variants, &mut pceltfetch);
if fetched.is_err() {
error!("Error while fetching rules");
continue;
};
if let Some(variant) = variants.first() {
let dispatch = variant.Anonymous.Anonymous.Anonymous.pdispVal.clone();
let _dispatch_cleanup = guard(dispatch.clone(), |mut d| {
ManuallyDrop::drop(&mut d);
});
if let Some(dispatch) = dispatch.as_ref() {
let fw_rule = dispatch.cast::<INetFwRule>()?;
rules_list.push(fw_rule.try_into()?);
}
}
}
Ok(rules_list)
})
}
pub fn list_incoming_rules() -> Result<Vec<WindowsFirewallRule>, WindowsFirewallError> {
let all_rules = list_rules()?;
let incoming_rules: Vec<WindowsFirewallRule> = all_rules
.into_iter()
.filter(|rule| *rule.direction() == DirectionFirewallWindows::In)
.collect();
Ok(incoming_rules)
}
pub fn list_outgoing_rules() -> Result<Vec<WindowsFirewallRule>, WindowsFirewallError> {
let all_rules = list_rules()?;
let outgoing_rules: Vec<WindowsFirewallRule> = all_rules
.into_iter()
.filter(|rule| *rule.direction() == DirectionFirewallWindows::Out)
.collect();
Ok(outgoing_rules)
}
pub fn get_active_profile() -> Result<ProfileFirewallWindows, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let active_profile = ProfileFirewallWindows::try_from(fw_policy.CurrentProfileTypes()?)?;
Ok(active_profile)
})
}
pub fn get_firewall_state(profile: ProfileFirewallWindows) -> Result<bool, WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
let enabled = fw_policy
.get_FirewallEnabled(NET_FW_PROFILE_TYPE2(profile.into()))?
.as_bool();
Ok(enabled)
})
}
pub fn set_firewall_state(
profile: ProfileFirewallWindows,
state: bool,
) -> Result<(), WindowsFirewallError> {
with_com_initialized(|| unsafe {
let fw_policy: INetFwPolicy2 = CoCreateInstance(&NetFwPolicy2, None, DWCLSCONTEXT)?;
fw_policy.put_FirewallEnabled(NET_FW_PROFILE_TYPE2(profile.into()), state.into())?;
Ok(())
})
}