use std::convert::TryFrom;
use windows::Win32::NetworkManagement::WindowsFirewall::{INetFwRule, INetFwRules};
use windows::core::BSTR;
use crate::errors::{SetRuleError, WindowsFirewallError};
use crate::firewall_rule::{Direction, FirewallRule, FirewallRuleUpdate};
use crate::utils::{hashset_to_bstr, hashset_to_variant, with_policy};
pub fn rule_exists(name: &str) -> Result<bool, WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(name);
let exist = unsafe { fw_rules.Item(&rule_name).is_ok() };
Ok(exist)
})
}
pub fn get_rule(name: &str) -> Result<FirewallRule, WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(name);
let rule = unsafe { fw_rules.Item(&rule_name) };
FirewallRule::try_from(rule?)
})
}
pub fn enable_rule(rule_name: &str, enabled: bool) -> Result<(), WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(rule_name);
let rule = unsafe { fw_rules.Item(&rule_name) }?;
unsafe {
rule.SetEnabled(enabled.into())
.map_err(SetRuleError::Enabled)
}?;
Ok(())
})
}
pub fn remove_rule(rule_name: &str) -> Result<(), WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(rule_name);
unsafe { fw_rules.Remove(&rule_name) }?;
Ok(())
})
}
pub fn add_rule(rule: &FirewallRule) -> Result<(), WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let new_rule: INetFwRule = rule.try_into()?;
unsafe { fw_rules.Add(&new_rule) }?;
Ok(())
})
}
pub fn add_rule_if_not_exists(rule: &FirewallRule) -> Result<bool, WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(rule.name());
let exist = unsafe { fw_rules.Item(&rule_name) }.is_ok();
if exist {
return Ok(false);
}
let new_rule: INetFwRule = rule.try_into()?;
unsafe { fw_rules.Add(&new_rule) }?;
Ok(true)
})
}
pub fn add_rule_or_update(rule: &FirewallRule) -> Result<bool, WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(rule.name());
let fw_rule_result = unsafe { fw_rules.Item(&rule_name) };
if let Ok(existing_rule) = fw_rule_result {
let settings = rule.clone().into();
update_inetfw_rule(&existing_rule, &settings)?;
return Ok(false);
}
let new_rule: INetFwRule = rule.try_into()?;
unsafe { fw_rules.Add(&new_rule) }?;
Ok(true)
})
}
pub fn update_rule(
rule_name: &str,
settings: &FirewallRuleUpdate,
) -> Result<(), WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rule_name = BSTR::from(rule_name);
let rule = unsafe { fw_rules.Item(&rule_name) }?;
update_inetfw_rule(&rule, settings)?;
Ok(())
})
}
fn update_inetfw_rule(
rule: &INetFwRule,
settings: &FirewallRuleUpdate,
) -> Result<(), WindowsFirewallError> {
if let Some(name) = &settings.name {
unsafe { rule.SetName(&BSTR::from(name)).map_err(SetRuleError::Name) }?;
}
if let Some(direction) = settings.direction {
if direction != Direction::In && (unsafe { rule.EdgeTraversal() }?.as_bool()) {
unsafe {
rule.SetEdgeTraversal(false.into())
.map_err(SetRuleError::EdgeTraversal)
}?;
}
unsafe {
rule.SetDirection(direction.into())
.map_err(SetRuleError::Direction)
}?;
}
if let Some(enabled) = settings.enabled {
unsafe {
rule.SetEnabled(enabled.into())
.map_err(SetRuleError::Enabled)
}?;
}
if let Some(action) = settings.action {
unsafe { rule.SetAction(action.into()).map_err(SetRuleError::Action) }?;
}
if let Some(description) = &settings.description {
unsafe {
rule.SetDescription(&BSTR::from(description))
.map_err(SetRuleError::Description)
}?;
}
if let Some(application_name) = &settings.application_name {
unsafe {
rule.SetApplicationName(&BSTR::from(application_name))
.map_err(SetRuleError::ApplicationName)
}?;
}
if let Some(service_name) = &settings.service_name {
unsafe {
rule.SetServiceName(&BSTR::from(service_name))
.map_err(SetRuleError::ServiceName)
}?;
}
if let Some(protocol) = settings.protocol {
if !protocol.is_tcp_or_udp() {
let _ = unsafe { rule.SetLocalPorts(&BSTR::from("")) };
let _ = unsafe { rule.SetRemotePorts(&BSTR::from("")) };
}
if !protocol.is_icmp() {
let _ = unsafe { rule.SetIcmpTypesAndCodes(&BSTR::from("")) };
}
unsafe {
rule.SetProtocol(protocol.into())
.map_err(SetRuleError::Protocol)
}?;
}
if let Some(local_ports) = &settings.local_ports {
unsafe {
rule.SetLocalPorts(&hashset_to_bstr(Some(local_ports)))
.map_err(SetRuleError::LocalPorts)
}?;
}
if let Some(remote_ports) = &settings.remote_ports {
unsafe {
rule.SetRemotePorts(&hashset_to_bstr(Some(remote_ports)))
.map_err(SetRuleError::RemotePorts)
}?;
}
if let Some(local_addresses) = &settings.local_addresses {
unsafe {
rule.SetLocalAddresses(&hashset_to_bstr(Some(local_addresses)))
.map_err(SetRuleError::LocalAddresses)
}?;
}
if let Some(remote_addresses) = &settings.remote_addresses {
unsafe {
rule.SetRemoteAddresses(&hashset_to_bstr(Some(remote_addresses)))
.map_err(SetRuleError::RemoteAddresses)
}?;
}
if let Some(icmp_types_and_codes) = &settings.icmp_types_and_codes {
unsafe {
rule.SetIcmpTypesAndCodes(&BSTR::from(icmp_types_and_codes))
.map_err(SetRuleError::IcmpTypesAndCodes)
}?;
}
if let Some(edge_traversal) = settings.edge_traversal {
unsafe {
rule.SetEdgeTraversal(edge_traversal.into())
.map_err(SetRuleError::EdgeTraversal)
}?;
}
if let Some(grouping) = &settings.grouping {
unsafe {
rule.SetGrouping(&BSTR::from(grouping))
.map_err(SetRuleError::Grouping)
}?;
}
if let Some(interfaces) = &settings.interfaces {
unsafe {
rule.SetInterfaces(&hashset_to_variant(interfaces)?)
.map_err(SetRuleError::Interfaces)
}?;
}
if let Some(interface_types) = &settings.interface_types {
unsafe {
rule.SetInterfaceTypes(&hashset_to_bstr(Some(interface_types)))
.map_err(SetRuleError::InterfaceType)
}?;
}
if let Some(profiles) = settings.profiles {
unsafe {
rule.SetProfiles(profiles.into())
.map_err(SetRuleError::Profiles)
}?;
}
Ok(())
}